Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3d9a1d2d
Unverified
Commit
3d9a1d2d
authored
Sep 20, 2025
by
Cyrus Leung
Committed by
GitHub
Sep 20, 2025
Browse files
[V1] Support `LLM.apply_model` (#18465)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
be874c02
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
194 additions
and
169 deletions
+194
-169
tests/conftest.py
tests/conftest.py
+1
-11
tests/kernels/moe/test_mxfp4_moe.py
tests/kernels/moe/test_mxfp4_moe.py
+22
-15
tests/models/multimodal/generation/test_qwen2_vl.py
tests/models/multimodal/generation/test_qwen2_vl.py
+23
-23
tests/models/quantization/test_awq.py
tests/models/quantization/test_awq.py
+1
-1
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+8
-10
tests/quantization/test_fp8.py
tests/quantization/test_fp8.py
+4
-4
tests/quantization/test_gptq_dynamic.py
tests/quantization/test_gptq_dynamic.py
+38
-33
tests/quantization/test_lm_head.py
tests/quantization/test_lm_head.py
+2
-2
tests/quantization/test_modelopt.py
tests/quantization/test_modelopt.py
+3
-7
tests/quantization/test_ptpc_fp8.py
tests/quantization/test_ptpc_fp8.py
+28
-19
tests/quantization/test_quark.py
tests/quantization/test_quark.py
+11
-15
tests/quantization/test_register_quantization_config.py
tests/quantization/test_register_quantization_config.py
+10
-7
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+6
-1
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+7
-2
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+16
-17
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+6
-1
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+8
-1
No files found.
tests/conftest.py
View file @
3d9a1d2d
...
@@ -987,17 +987,7 @@ class VllmRunner:
...
@@ -987,17 +987,7 @@ class VllmRunner:
return
[
req_output
.
outputs
.
score
for
req_output
in
req_outputs
]
return
[
req_output
.
outputs
.
score
for
req_output
in
req_outputs
]
def
apply_model
(
self
,
func
:
Callable
[[
nn
.
Module
],
_R
])
->
list
[
_R
]:
def
apply_model
(
self
,
func
:
Callable
[[
nn
.
Module
],
_R
])
->
list
[
_R
]:
if
hasattr
(
self
.
llm
.
llm_engine
,
"model_executor"
):
return
self
.
llm
.
apply_model
(
func
)
# This works either in V0 or in V1 with
# VLLM_ENABLE_V1_MULTIPROCESSING=0
executor
=
self
.
llm
.
llm_engine
.
model_executor
return
executor
.
apply_model
(
func
)
# This works in V1 with VLLM_ALLOW_INSECURE_SERIALIZATION=1
def
_apply_model
(
self
):
return
func
(
self
.
get_model
())
return
self
.
llm
.
llm_engine
.
collective_rpc
(
_apply_model
)
def
get_llm
(
self
)
->
LLM
:
def
get_llm
(
self
)
->
LLM
:
return
self
.
llm
return
self
.
llm
...
...
tests/kernels/moe/test_mxfp4_moe.py
View file @
3d9a1d2d
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
importlib
import
importlib.metadata
import
importlib.metadata
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
importlib.util
import
find_spec
from
typing
import
Optional
from
typing
import
Optional
import
pytest
import
pytest
import
torch
import
torch
from
packaging
import
version
from
packaging
import
version
from
vllm.model_executor.layers.quantization.quark.quark
import
(
# noqa: E501
QuarkLinearMethod
,
QuarkW4A4MXFP4
)
from
vllm.model_executor.layers.quantization.quark.quark_moe
import
(
# noqa: E501
QuarkW4A4MXFp4MoEMethod
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
has_flashinfer
from
vllm.utils.flashinfer
import
has_flashinfer
QUARK_MXFP4_AVAILABLE
=
importlib
.
util
.
find_spec
(
QUARK_MXFP4_AVAILABLE
=
find_spec
(
"quark"
)
is
not
None
and
version
.
parse
(
"quark"
)
is
not
None
and
version
.
parse
(
importlib
.
metadata
.
version
(
"amd-quark"
))
>=
version
.
parse
(
'0.8.99'
)
importlib
.
metadata
.
version
(
"amd-quark"
))
>=
version
.
parse
(
'0.8.99'
)
TRTLLM_GEN_MXFP4_AVAILABLE
=
current_platform
.
is_cuda
(
TRTLLM_GEN_MXFP4_AVAILABLE
=
current_platform
.
is_cuda
(
)
and
current_platform
.
is_device_capability
(
100
)
)
and
current_platform
.
is_device_capability
(
100
)
...
@@ -39,6 +42,12 @@ class ModelCase:
...
@@ -39,6 +42,12 @@ class ModelCase:
tp
:
int
tp
:
int
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
enable_pickle
(
monkeypatch
):
"""`LLM.apply_model` requires pickling a function."""
monkeypatch
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
@
pytest
.
mark
.
parametrize
(
'model_case'
,
[
@
pytest
.
mark
.
parametrize
(
'model_case'
,
[
ModelCase
(
"fxmarty/qwen_1.5-moe-a2.7b-mxfp4"
,
tp
=
1
),
ModelCase
(
"fxmarty/qwen_1.5-moe-a2.7b-mxfp4"
,
tp
=
1
),
ModelCase
(
"fxmarty/deepseek_r1_3_layers_mxfp4"
,
tp
=
8
),
ModelCase
(
"fxmarty/deepseek_r1_3_layers_mxfp4"
,
tp
=
8
),
...
@@ -55,21 +64,19 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
...
@@ -55,21 +64,19 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
tensor_parallel_size
=
model_case
.
tp
,
tensor_parallel_size
=
model_case
.
tp
,
load_format
=
"dummy"
)
as
llm
:
load_format
=
"dummy"
)
as
llm
:
# TODO: llm.apply_model(check_model) currently relies on V0 internals.
def
check_model
(
model
):
# Re-enable this later.
layer
=
model
.
model
.
layers
[
0
]
# def check_model(model):
# layer = model.model.layers[0]
#
qkv_proj = layer.self_attn.qkv_proj
qkv_proj
=
layer
.
self_attn
.
qkv_proj
#
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
assert
isinstance
(
qkv_proj
.
quant_method
,
QuarkLinearMethod
)
#
assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4)
assert
isinstance
(
qkv_proj
.
scheme
,
QuarkW4A4MXFP4
)
#
assert isinstance(layer.mlp.experts.quant_method,
assert
isinstance
(
layer
.
mlp
.
experts
.
quant_method
,
#
QuarkW4A4MXFp4MoEMethod)
QuarkW4A4MXFp4MoEMethod
)
#
if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
if
model_case
.
model_id
==
"fxmarty/qwen_1.5-moe-a2.7b-mxfp4"
:
#
llm.apply_model(check_model)
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Today I am in the French Alps and"
,
output
=
llm
.
generate_greedy
(
"Today I am in the French Alps and"
,
max_tokens
=
20
)
max_tokens
=
20
)
...
...
tests/models/multimodal/generation/test_qwen2_vl.py
View file @
3d9a1d2d
...
@@ -10,6 +10,7 @@ from PIL import Image
...
@@ -10,6 +10,7 @@ from PIL import Image
from
vllm.multimodal.image
import
rescale_image_size
from
vllm.multimodal.image
import
rescale_image_size
from
vllm.multimodal.video
import
rescale_video_size
,
sample_frames_from_video
from
vllm.multimodal.video
import
rescale_video_size
,
sample_frames_from_video
from
vllm.utils
import
set_default_torch_num_threads
from
....conftest
import
(
IMAGE_ASSETS
,
VIDEO_ASSETS
,
PromptImageInput
,
from
....conftest
import
(
IMAGE_ASSETS
,
VIDEO_ASSETS
,
PromptImageInput
,
PromptVideoInput
,
VllmRunner
)
PromptVideoInput
,
VllmRunner
)
...
@@ -17,11 +18,9 @@ from ...utils import check_logprobs_close
...
@@ -17,11 +18,9 @@ from ...utils import check_logprobs_close
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
def
enable_pickle
(
monkeypatch
):
"""
"""`LLM.apply_model` requires pickling a function."""
V1 Test: batch_make_xxxxx_embeddings calls a V0 internal
monkeypatch
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
"""
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
models
=
[
"Qwen/Qwen2-VL-2B-Instruct"
]
models
=
[
"Qwen/Qwen2-VL-2B-Instruct"
]
...
@@ -126,9 +125,8 @@ def batch_make_image_embeddings(
...
@@ -126,9 +125,8 @@ def batch_make_image_embeddings(
image_grid_thw_on_device
=
image_grid_thw
.
to
(
visual
.
device
,
image_grid_thw_on_device
=
image_grid_thw
.
to
(
visual
.
device
,
dtype
=
torch
.
int64
)
dtype
=
torch
.
int64
)
return
visual
(
pixel_values_on_device
,
return
visual
(
pixel_values_on_device
,
grid_thw
=
image_grid_thw_on_device
)
grid_thw
=
image_grid_thw_on_device
)
.
cpu
()
# V1 Test: this calls a V0 internal.
image_embeds
=
torch
.
concat
(
llm
.
apply_model
(
get_image_embeds
))
image_embeds
=
torch
.
concat
(
llm
.
apply_model
(
get_image_embeds
))
# split into original batches
# split into original batches
...
@@ -210,7 +208,7 @@ def batch_make_video_embeddings(
...
@@ -210,7 +208,7 @@ def batch_make_video_embeddings(
video_grid_thw_on_device
=
video_grid_thw
.
to
(
visual
.
device
,
video_grid_thw_on_device
=
video_grid_thw
.
to
(
visual
.
device
,
dtype
=
torch
.
int64
)
dtype
=
torch
.
int64
)
return
visual
(
pixel_values_on_device
,
return
visual
(
pixel_values_on_device
,
grid_thw
=
video_grid_thw_on_device
)
grid_thw
=
video_grid_thw_on_device
)
.
cpu
()
# V1 Test: this calls a V0 internal.
# V1 Test: this calls a V0 internal.
video_embeds
=
torch
.
concat
(
llm
.
apply_model
(
get_image_embeds
))
video_embeds
=
torch
.
concat
(
llm
.
apply_model
(
get_image_embeds
))
...
@@ -266,19 +264,22 @@ def run_embedding_input_test(
...
@@ -266,19 +264,22 @@ def run_embedding_input_test(
processor
=
AutoProcessor
.
from_pretrained
(
model
)
processor
=
AutoProcessor
.
from_pretrained
(
model
)
# max_model_len should be greater than image_feature_size
# max_model_len should be greater than image_feature_size
with
vllm_runner
(
model
,
with
set_default_torch_num_threads
(
1
):
runner
=
"generate"
,
vllm_model
=
vllm_runner
(
max_model_len
=
4000
,
model
,
max_num_seqs
=
3
,
runner
=
"generate"
,
dtype
=
dtype
,
max_model_len
=
4000
,
limit_mm_per_prompt
=
{
max_num_seqs
=
3
,
"image"
:
mm_limit
,
dtype
=
dtype
,
"video"
:
mm_limit
limit_mm_per_prompt
=
{
},
"image"
:
mm_limit
,
tensor_parallel_size
=
tensor_parallel_size
,
"video"
:
mm_limit
distributed_executor_backend
=
distributed_executor_backend
},
)
as
vllm_model
:
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
)
with
vllm_model
:
outputs_per_case_for_original_input
=
[
outputs_per_case_for_original_input
=
[
vllm_model
.
generate_greedy_logprobs
(
prompts
,
vllm_model
.
generate_greedy_logprobs
(
prompts
,
max_tokens
,
max_tokens
,
...
@@ -329,9 +330,8 @@ def run_embedding_input_test(
...
@@ -329,9 +330,8 @@ def run_embedding_input_test(
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
10
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
10
])
def
test_qwen2_vl_image_embeddings_input
(
vllm_runner
,
image_assets
,
model
,
def
test_qwen2_vl_image_embeddings_input
(
vllm_runner
,
image_assets
,
model
,
size_factors
,
dtype
:
str
,
size_factors
,
dtype
,
max_tokens
,
max_tokens
:
int
,
num_logprobs
,
monkeypatch
)
->
None
:
num_logprobs
:
int
)
->
None
:
images
=
[
asset
.
pil_image
for
asset
in
image_assets
]
images
=
[
asset
.
pil_image
for
asset
in
image_assets
]
inputs_per_case
:
list
[
tuple
[
inputs_per_case
:
list
[
tuple
[
...
...
tests/models/quantization/test_awq.py
View file @
3d9a1d2d
...
@@ -112,7 +112,7 @@ def test_awq_models(vllm_runner, image_assets, source_model, quant_model,
...
@@ -112,7 +112,7 @@ def test_awq_models(vllm_runner, image_assets, source_model, quant_model,
monkeypatch
)
->
None
:
monkeypatch
)
->
None
:
# Test V1: this test hangs during setup on single-scale input.
# Test V1: this test hangs during setup on single-scale input.
# TODO: fi
x
ure out why and re-enable this on V1.
# TODO: fi
g
ure out why and re-enable this on V1.
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
run_awq_test
(
run_awq_test
(
vllm_runner
,
vllm_runner
,
...
...
tests/quantization/test_compressed_tensors.py
View file @
3d9a1d2d
...
@@ -43,12 +43,9 @@ ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL = [
...
@@ -43,12 +43,9 @@ ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL = [
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
def
enable_pickle
(
monkeypatch
):
"""
"""`LLM.apply_model` requires pickling a function."""
This module relies on V0 internals, so set VLLM_USE_V1=0.
monkeypatch
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
"""
if
not
current_platform
.
is_cpu
():
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -176,10 +173,11 @@ def test_compressed_tensors_w8a8_logprobs(
...
@@ -176,10 +173,11 @@ def test_compressed_tensors_w8a8_logprobs(
dtype
=
"bfloat16"
dtype
=
"bfloat16"
# skip language translation prompt for the static per tensor asym model
# skip language translation prompt for the static per tensor models
if
(
model_path
==
if
model_path
in
(
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym"
,
):
# noqa: E501
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
,
):
example_prompts
=
example_prompts
[
0
:
-
1
]
example_prompts
=
example_prompts
[
0
:
-
1
]
with
hf_runner
(
model_path
,
dtype
=
dtype
)
as
hf_model
:
with
hf_runner
(
model_path
,
dtype
=
dtype
)
as
hf_model
:
...
...
tests/quantization/test_fp8.py
View file @
3d9a1d2d
...
@@ -60,8 +60,8 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str,
...
@@ -60,8 +60,8 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str,
if
use_rocm_aiter
:
if
use_rocm_aiter
:
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
#
vllm_runner
.apply_model
()
re
lies on V0 internals
.
#
`LLM
.apply_model
`
re
quires pickling a function
.
monkeypatch
.
setenv
(
"VLLM_
USE_V1
"
,
"
0
"
)
monkeypatch
.
setenv
(
"VLLM_
ALLOW_INSECURE_SERIALIZATION
"
,
"
1
"
)
with
vllm_runner
(
model_id
,
kv_cache_dtype
=
"fp8"
)
as
llm
:
with
vllm_runner
(
model_id
,
kv_cache_dtype
=
"fp8"
)
as
llm
:
def
check_model
(
model
):
def
check_model
(
model
):
...
@@ -104,8 +104,8 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
...
@@ -104,8 +104,8 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
if
use_rocm_aiter
:
if
use_rocm_aiter
:
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
#
vllm_runner
.apply_model
()
re
lies on V0 internals
.
#
`LLM
.apply_model
`
re
quires pickling a function
.
monkeypatch
.
setenv
(
"VLLM_
USE_V1
"
,
"
0
"
)
monkeypatch
.
setenv
(
"VLLM_
ALLOW_INSECURE_SERIALIZATION
"
,
"
1
"
)
if
force_marlin
:
if
force_marlin
:
monkeypatch
.
setenv
(
"VLLM_TEST_FORCE_FP8_MARLIN"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_TEST_FORCE_FP8_MARLIN"
,
"1"
)
...
...
tests/quantization/test_gptq_dynamic.py
View file @
3d9a1d2d
...
@@ -31,41 +31,46 @@ MODEL_QUANT = [
...
@@ -31,41 +31,46 @@ MODEL_QUANT = [
@
pytest
.
mark
.
parametrize
(
"model_id, use_marlin_kernel"
,
MODEL_QUANT
)
@
pytest
.
mark
.
parametrize
(
"model_id, use_marlin_kernel"
,
MODEL_QUANT
)
def
test_gptq_with_dynamic
(
vllm_runner
,
model_id
:
str
,
use_marlin_kernel
:
bool
,
def
test_gptq_with_dynamic
(
vllm_runner
,
model_id
:
str
,
use_marlin_kernel
:
bool
,
monkeypatch
):
monkeypatch
):
# vllm_runner.apply_model() relies on V0 internals.
# `LLM.apply_model` requires pickling a function.
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
monkeypatch
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
vllm_model
=
vllm_runner
(
model_id
,
dtype
=
torch
.
float16
,
max_model_len
=
2048
)
linear_method_cls
=
GPTQMarlinLinearMethod
if
use_marlin_kernel
else
(
linear_method_cls
=
GPTQMarlinLinearMethod
if
use_marlin_kernel
else
(
GPTQLinearMethod
)
GPTQLinearMethod
)
for
name
,
submodule
in
(
vllm_model
.
llm
.
llm_engine
.
model_executor
.
with
vllm_runner
(
model_id
,
dtype
=
torch
.
float16
,
max_model_len
=
2048
)
as
llm
:
driver_worker
.
model_runner
.
model
.
named_modules
()):
if
name
==
"lm_head"
:
def
check_model
(
model
):
assert
isinstance
(
submodule
.
quant_method
,
linear_method_cls
)
for
name
,
submodule
in
model
.
named_modules
():
elif
name
==
'model.layers.0.self_attn.qkv_proj'
:
if
name
==
"lm_head"
:
# The first layer is quantized using bits=4, group_size=128
assert
isinstance
(
submodule
.
quant_method
,
# desc_act=True
linear_method_cls
)
assert
isinstance
(
submodule
.
quant_method
,
linear_method_cls
)
elif
name
==
'model.layers.0.self_attn.qkv_proj'
:
config
=
submodule
.
quant_method
.
quant_config
# The first layer is quantized using bits=4, group_size=128
assert
config
.
weight_bits
==
4
# desc_act=True
assert
config
.
group_size
==
128
assert
isinstance
(
submodule
.
quant_method
,
assert
config
.
desc_act
linear_method_cls
)
elif
name
==
'model.layers.1.self_attn.qkv_proj'
:
config
=
submodule
.
quant_method
.
quant_config
# The second layer is quantized using bits=8, group_size=32
assert
config
.
weight_bits
==
4
# desc_act=False
assert
config
.
group_size
==
128
assert
isinstance
(
submodule
.
quant_method
,
linear_method_cls
)
assert
config
.
desc_act
config
=
submodule
.
quant_method
.
quant_config
elif
name
==
'model.layers.1.self_attn.qkv_proj'
:
assert
get_dynamic_override
(
config
,
layer_name
=
name
,
# The second layer is quantized using bits=8, group_size=32
key
=
"bits"
)
==
8
# desc_act=False
assert
get_dynamic_override
(
config
,
assert
isinstance
(
submodule
.
quant_method
,
layer_name
=
name
,
linear_method_cls
)
key
=
"group_size"
)
==
32
config
=
submodule
.
quant_method
.
quant_config
assert
not
get_dynamic_override
(
assert
get_dynamic_override
(
config
,
config
,
layer_name
=
name
,
key
=
"desc_act"
)
layer_name
=
name
,
elif
(
name
==
'model.layers.2.self_attn.qkv_proj'
key
=
"bits"
)
==
8
or
name
==
'model.layers.2.mlp.gate_up_proj'
):
assert
get_dynamic_override
(
config
,
# All other layers (layer index >= 2) are not quantized
layer_name
=
name
,
assert
isinstance
(
submodule
.
quant_method
,
UnquantizedLinearMethod
)
key
=
"group_size"
)
==
32
assert
not
get_dynamic_override
(
config
,
layer_name
=
name
,
key
=
"desc_act"
)
elif
(
name
==
'model.layers.2.self_attn.qkv_proj'
or
name
==
'model.layers.2.mlp.gate_up_proj'
):
# All other layers (layer index >= 2) are not quantized
assert
isinstance
(
submodule
.
quant_method
,
UnquantizedLinearMethod
)
del
vllm_model
llm
.
apply_model
(
check_model
)
tests/quantization/test_lm_head.py
View file @
3d9a1d2d
...
@@ -29,8 +29,8 @@ def test_lm_head(
...
@@ -29,8 +29,8 @@ def test_lm_head(
lm_head_quantized
:
bool
,
lm_head_quantized
:
bool
,
monkeypatch
,
monkeypatch
,
)
->
None
:
)
->
None
:
#
vllm_runner
.apply_model
()
re
lies on V0 internals
.
#
`LLM
.apply_model
`
re
quires pickling a function
.
monkeypatch
.
setenv
(
"VLLM_
USE_V1
"
,
"
0
"
)
monkeypatch
.
setenv
(
"VLLM_
ALLOW_INSECURE_SERIALIZATION
"
,
"
1
"
)
with
vllm_runner
(
model_id
,
dtype
=
torch
.
float16
,
with
vllm_runner
(
model_id
,
dtype
=
torch
.
float16
,
max_model_len
=
2048
)
as
vllm_model
:
max_model_len
=
2048
)
as
vllm_model
:
...
...
tests/quantization/test_modelopt.py
View file @
3d9a1d2d
...
@@ -11,16 +11,12 @@ import pytest
...
@@ -11,16 +11,12 @@ import pytest
import
torch
import
torch
from
tests.quantization.utils
import
is_quant_method_supported
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm.platforms
import
current_platform
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
def
enable_pickle
(
monkeypatch
):
"""
"""`LLM.apply_model` requires pickling a function."""
This module relies on V0 internals, so set VLLM_USE_V1=0.
monkeypatch
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
"""
if
not
current_platform
.
is_cpu
():
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"modelopt"
),
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"modelopt"
),
...
...
tests/quantization/test_ptpc_fp8.py
View file @
3d9a1d2d
...
@@ -13,6 +13,16 @@ from vllm.model_executor.layers.quantization.ptpc_fp8 import (
...
@@ -13,6 +13,16 @@ from vllm.model_executor.layers.quantization.ptpc_fp8 import (
PTPCFp8LinearMethod
)
PTPCFp8LinearMethod
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
UNSUPPORTED_STR
=
(
"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only "
"support output dtype of bfloat16. torch.float16 is specified."
)
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
enable_pickle
(
monkeypatch
):
"""`LLM.apply_model` requires pickling a function."""
monkeypatch
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"ptpc_fp8"
),
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"ptpc_fp8"
),
reason
=
"PTPC FP8 is not supported on this GPU type."
)
reason
=
"PTPC FP8 is not supported on this GPU type."
)
...
@@ -21,14 +31,22 @@ from vllm.platforms import current_platform
...
@@ -21,14 +31,22 @@ from vllm.platforms import current_platform
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"auto"
,
"bfloat16"
,
"float16"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"auto"
,
"bfloat16"
,
"float16"
])
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
[
"auto"
,
"fp8"
,
"fp8_e4m3"
])
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
[
"auto"
,
"fp8"
,
"fp8_e4m3"
])
def
test_ptpc_fp8_rocm
(
vllm_runner
,
dtype
:
str
,
kv_cache_dtype
:
str
)
->
None
:
def
test_ptpc_fp8_rocm
(
vllm_runner
,
dtype
:
str
,
kv_cache_dtype
:
str
)
->
None
:
try
:
try
:
with
vllm_runner
(
"facebook/opt-125m"
,
llm
=
vllm_runner
(
"facebook/opt-125m"
,
dtype
=
dtype
,
dtype
=
dtype
,
quantization
=
"ptpc_fp8"
,
quantization
=
"ptpc_fp8"
,
kv_cache_dtype
=
kv_cache_dtype
)
as
llm
:
kv_cache_dtype
=
kv_cache_dtype
)
except
AssertionError
as
e
:
if
str
(
e
)
==
UNSUPPORTED_STR
:
# If the error message matches, the test passes
return
else
:
# If the error message does not match, re-raise the exception
raise
with
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
def
check_model
(
model
):
fc1
=
model
.
model
.
decoder
.
layers
[
0
].
fc1
fc1
=
model
.
model
.
decoder
.
layers
[
0
].
fc1
assert
isinstance
(
fc1
.
quant_method
,
PTPCFp8LinearMethod
)
assert
isinstance
(
fc1
.
quant_method
,
PTPCFp8LinearMethod
)
if
kv_cache_dtype
==
"ptpc_fp8"
:
if
kv_cache_dtype
==
"ptpc_fp8"
:
...
@@ -40,17 +58,8 @@ def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
...
@@ -40,17 +58,8 @@ def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
if
current_platform
.
has_device_capability
(
94
):
if
current_platform
.
has_device_capability
(
94
):
# For GPUs with hardware support, we keep weights in fp8
# For GPUs with hardware support, we keep weights in fp8
assert
fc1
.
weight
.
dtype
==
torch
.
float8_e4m3fnuz
assert
fc1
.
weight
.
dtype
==
torch
.
float8_e4m3fnuz
else
:
pytest
.
skip
()
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
llm
.
apply_model
(
check_model
)
assert
output
except
AssertionError
as
e
:
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
if
str
(
assert
output
e
)
==
"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. torch.float16 is specified."
:
# noqa: E501
# If the error message matches, the test passes
pass
else
:
# If the error message does not match, re-raise the exception
raise
tests/quantization/test_quark.py
View file @
3d9a1d2d
...
@@ -7,10 +7,10 @@ Run `pytest tests/quantization/test_quark.py`.
...
@@ -7,10 +7,10 @@ Run `pytest tests/quantization/test_quark.py`.
See also `tests/kernels/moe/test_mxfp4_moe.py`.
See also `tests/kernels/moe/test_mxfp4_moe.py`.
"""
"""
import
importlib
import
importlib.metadata
import
importlib.metadata
import
os
import
os
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
importlib.util
import
find_spec
import
huggingface_hub
import
huggingface_hub
import
lm_eval
import
lm_eval
...
@@ -24,9 +24,8 @@ from vllm.platforms import current_platform
...
@@ -24,9 +24,8 @@ from vllm.platforms import current_platform
from
.reference_mxfp4
import
dq_mxfp4_torch
,
qdq_mxfp4_torch
from
.reference_mxfp4
import
dq_mxfp4_torch
,
qdq_mxfp4_torch
QUARK_MXFP4_AVAILABLE
=
importlib
.
util
.
find_spec
(
QUARK_MXFP4_AVAILABLE
=
find_spec
(
"quark"
)
is
not
None
and
version
.
parse
(
"quark"
)
is
not
None
and
version
.
parse
(
importlib
.
metadata
.
version
(
"amd-quark"
))
>=
version
.
parse
(
'0.8.99'
)
importlib
.
metadata
.
version
(
"amd-quark"
))
>=
version
.
parse
(
'0.8.99'
)
if
QUARK_MXFP4_AVAILABLE
:
if
QUARK_MXFP4_AVAILABLE
:
from
quark.torch.export.nn.modules.realquantizer
import
(
from
quark.torch.export.nn.modules.realquantizer
import
(
...
@@ -43,11 +42,9 @@ except huggingface_hub.errors.RepositoryNotFoundError:
...
@@ -43,11 +42,9 @@ except huggingface_hub.errors.RepositoryNotFoundError:
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
def
enable_pickle
(
monkeypatch
):
"""
"""`LLM.apply_model` requires pickling a function."""
This module relies on V0 internals, so set VLLM_USE_V1=0.
monkeypatch
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
"""
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
@
pytest
.
mark
.
parametrize
(
'kv_cache_dtype'
,
[
'auto'
,
'fp8'
])
@
pytest
.
mark
.
parametrize
(
'kv_cache_dtype'
,
[
'auto'
,
'fp8'
])
...
@@ -132,13 +129,12 @@ def test_quark_fp8_parity(vllm_runner):
...
@@ -132,13 +129,12 @@ def test_quark_fp8_parity(vllm_runner):
}
}
with
(
vllm_runner
(
quark_model_id
,
**
llm_kwargs
)
as
with
(
vllm_runner
(
quark_model_id
,
**
llm_kwargs
)
as
quark_handle
,
vllm_runner
(
fp8_model_id
,
**
llm_kwargs
)
as
fp8_handle
):
quark_handle
,
vllm_runner
(
fp8_model_id
,
**
llm_kwargs
)
as
fp8_handle
):
quark_model
=
(
quark_handle
.
llm
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
)
quark_state_dict
=
quark_model
.
state_dict
()
fp8_model
=
(
fp8_handle
.
llm
.
llm_engine
.
model_executor
.
driver_worker
.
def
get_state_dict
(
model
):
model_runner
.
model
)
return
{
k
:
v
.
cpu
()
for
k
,
v
in
model
.
state_dict
().
items
()}
fp8_state_dict
=
fp8_model
.
state_dict
()
quark_state_dict
,
=
quark_handle
.
apply_model
(
get_state_dict
)
fp8_state_dict
,
=
fp8_handle
.
apply_model
(
get_state_dict
)
assert
fp8_state_dict
.
keys
()
==
quark_state_dict
.
keys
()
assert
fp8_state_dict
.
keys
()
==
quark_state_dict
.
keys
()
...
...
tests/quantization/test_register_quantization_config.py
View file @
3d9a1d2d
...
@@ -105,18 +105,21 @@ def test_register_quantization_config():
...
@@ -105,18 +105,21 @@ def test_register_quantization_config():
])
])
def
test_custom_quant
(
vllm_runner
,
model
,
monkeypatch
):
def
test_custom_quant
(
vllm_runner
,
model
,
monkeypatch
):
"""Test infer with the custom quantization method."""
"""Test infer with the custom quantization method."""
# vllm_runner.apply_model() relies on V0 internals.
# `LLM.apply_model` requires pickling a function.
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
monkeypatch
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
with
vllm_runner
(
model_name
=
model
,
with
vllm_runner
(
model_name
=
model
,
quantization
=
"custom_quant"
,
quantization
=
"custom_quant"
,
enforce_eager
=
True
)
as
llm
:
enforce_eager
=
True
)
as
llm
:
model
=
llm
.
llm
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
qkv_proj
=
layer
.
self_attn
.
qkv_proj
# Check the quantization method is FakeQuantLinearMethod
assert
isinstance
(
qkv_proj
.
quant_method
,
FakeQuantLinearMethod
)
# Check the quantization method is FakeQuantLinearMethod
llm
.
apply_model
(
check_model
)
assert
isinstance
(
qkv_proj
.
quant_method
,
FakeQuantLinearMethod
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
assert
output
assert
output
vllm/engine/llm_engine.py
View file @
3d9a1d2d
...
@@ -13,6 +13,7 @@ from typing import Sequence as GenericSequence
...
@@ -13,6 +13,7 @@ from typing import Sequence as GenericSequence
from
typing
import
Set
,
Type
,
Union
,
cast
from
typing
import
Set
,
Type
,
Union
,
cast
import
torch
import
torch
import
torch.nn
as
nn
from
typing_extensions
import
TypeVar
from
typing_extensions
import
TypeVar
import
vllm.envs
as
envs
import
vllm.envs
as
envs
...
@@ -55,6 +56,7 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
...
@@ -55,6 +56,7 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
from
vllm.utils
import
Counter
,
Device
,
resolve_obj_by_qualname
,
weak_bind
from
vllm.utils
import
Counter
,
Device
,
resolve_obj_by_qualname
,
weak_bind
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.worker.model_runner_base
import
InputProcessingError
from
vllm.worker.model_runner_base
import
InputProcessingError
from
vllm.worker.worker_base
import
WorkerBase
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_LOCAL_LOGGING_INTERVAL_SEC
=
5
_LOCAL_LOGGING_INTERVAL_SEC
=
5
...
@@ -1817,13 +1819,16 @@ class LLMEngine:
...
@@ -1817,13 +1819,16 @@ class LLMEngine:
return
sampling_params
return
sampling_params
def
collective_rpc
(
self
,
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[
...
,
_R
]],
method
:
Union
[
str
,
Callable
[
[
WorkerBase
]
,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
self
.
model_executor
.
collective_rpc
(
method
,
timeout
,
args
,
return
self
.
model_executor
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
kwargs
)
def
apply_model
(
self
,
func
:
Callable
[[
nn
.
Module
],
_R
])
->
list
[
_R
]:
return
self
.
collective_rpc
(
"apply_model"
,
args
=
(
func
,
))
if
envs
.
is_set
(
"VLLM_USE_V1"
)
and
envs
.
VLLM_USE_V1
:
if
envs
.
is_set
(
"VLLM_USE_V1"
)
and
envs
.
VLLM_USE_V1
:
from
vllm.v1.engine.llm_engine
import
LLMEngine
as
V1LLMEngine
from
vllm.v1.engine.llm_engine
import
LLMEngine
as
V1LLMEngine
...
...
vllm/entrypoints/llm.py
View file @
3d9a1d2d
...
@@ -522,9 +522,14 @@ class LLM:
...
@@ -522,9 +522,14 @@ class LLM:
"""
"""
Run a function directly on the model inside each worker,
Run a function directly on the model inside each worker,
returning the result for each of them.
returning the result for each of them.
!!! warning
To reduce the overhead of data transfer, avoid returning large
arrays or tensors from this method. If you must return them,
make sure you move them to CPU first to avoid taking up additional
VRAM!
"""
"""
executor
=
self
.
llm_engine
.
model_executor
return
self
.
llm_engine
.
apply_model
(
func
)
return
executor
.
apply_model
(
func
)
def
_get_beam_search_lora_requests
(
def
_get_beam_search_lora_requests
(
self
,
self
,
...
...
vllm/executor/executor_base.py
View file @
3d9a1d2d
...
@@ -5,11 +5,10 @@ import asyncio
...
@@ -5,11 +5,10 @@ import asyncio
import
time
import
time
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
(
Any
,
Awaitable
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
from
typing
import
Any
,
Awaitable
,
Callable
,
List
,
Optional
,
Set
,
Union
Union
)
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing_extensions
import
TypeVar
from
typing_extensions
import
TypeVar
,
deprecated
import
vllm.platforms
import
vllm.platforms
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
...
@@ -63,10 +62,10 @@ class ExecutorBase(ABC):
...
@@ -63,10 +62,10 @@ class ExecutorBase(ABC):
@
abstractmethod
@
abstractmethod
def
collective_rpc
(
self
,
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[
...
,
_R
]],
method
:
Union
[
str
,
Callable
[
[
WorkerBase
]
,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
timeout
:
Optional
[
float
]
=
None
,
args
:
T
uple
=
(),
args
:
t
uple
=
(),
kwargs
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
)
->
L
ist
[
_R
]:
kwargs
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
)
->
l
ist
[
_R
]:
"""
"""
Execute an RPC call on all workers.
Execute an RPC call on all workers.
...
@@ -91,7 +90,7 @@ class ExecutorBase(ABC):
...
@@ -91,7 +90,7 @@ class ExecutorBase(ABC):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
determine_num_available_blocks
(
self
)
->
T
uple
[
int
,
int
]:
def
determine_num_available_blocks
(
self
)
->
t
uple
[
int
,
int
]:
"""Determine the number of available blocks for the GPU KV cache and
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
swappable CPU KV cache.
...
@@ -99,9 +98,10 @@ class ExecutorBase(ABC):
...
@@ -99,9 +98,10 @@ class ExecutorBase(ABC):
ExecutorBase may require modification of the result, e.g. to ensure the
ExecutorBase may require modification of the result, e.g. to ensure the
selected cache sizes are compatible with all workers.
selected cache sizes are compatible with all workers.
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
Returns a tuple `(num_gpu_blocks, num_cpu_blocks)`, where
are blocks that are "active" on the device and can be appended to.
`num_gpu_blocks` are blocks that are "active" on the device and can be
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
`num_cpu_blocks` refers to "swapped" blocks in CPU memory and cannot be
appended to.
appended to.
"""
"""
results
=
self
.
collective_rpc
(
"determine_num_available_blocks"
)
results
=
self
.
collective_rpc
(
"determine_num_available_blocks"
)
...
@@ -127,16 +127,15 @@ class ExecutorBase(ABC):
...
@@ -127,16 +127,15 @@ class ExecutorBase(ABC):
self
.
collective_rpc
(
"initialize_cache"
,
self
.
collective_rpc
(
"initialize_cache"
,
args
=
(
num_gpu_blocks
,
num_cpu_blocks
))
args
=
(
num_gpu_blocks
,
num_cpu_blocks
))
@
deprecated
(
"`llm_engine.model_executor.apply_model` will no longer work "
"in V1 Engine. Please replace with `llm_engine.apply_model` "
"and set `VLLM_ALLOW_INSECURE_SERIALIZATION=1`."
)
def
apply_model
(
self
,
func
:
Callable
[[
nn
.
Module
],
_R
])
->
list
[
_R
]:
def
apply_model
(
self
,
func
:
Callable
[[
nn
.
Module
],
_R
])
->
list
[
_R
]:
"""
"""
Run a function directly on the model inside each worker,
Run a function directly on the model inside each worker,
returning the result for each of them.
returning the result for each of them.
"""
"""
return
self
.
collective_rpc
(
"apply_model"
,
args
=
(
func
,
))
def
rpc_func
(
worker
:
WorkerBase
)
->
_R
:
return
func
(
worker
.
get_model
())
return
self
.
collective_rpc
(
rpc_func
)
@
cached_property
# Avoid unnecessary RPC calls
@
cached_property
# Avoid unnecessary RPC calls
def
supported_tasks
(
self
)
->
tuple
[
SupportedTask
,
...]:
def
supported_tasks
(
self
)
->
tuple
[
SupportedTask
,
...]:
...
@@ -308,8 +307,8 @@ class DistributedExecutorBase(ExecutorBase):
...
@@ -308,8 +307,8 @@ class DistributedExecutorBase(ExecutorBase):
def
collective_rpc
(
self
,
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
],
method
:
Union
[
str
,
Callable
],
timeout
:
Optional
[
float
]
=
None
,
timeout
:
Optional
[
float
]
=
None
,
args
:
T
uple
=
(),
args
:
t
uple
=
(),
kwargs
:
Optional
[
D
ict
]
=
None
)
->
L
ist
[
Any
]:
kwargs
:
Optional
[
d
ict
[
str
,
Any
]
]
=
None
)
->
l
ist
[
Any
]:
return
self
.
_run_workers
(
method
,
*
args
,
**
(
kwargs
or
{}))
return
self
.
_run_workers
(
method
,
*
args
,
**
(
kwargs
or
{}))
@
abstractmethod
@
abstractmethod
...
...
vllm/v1/engine/llm_engine.py
View file @
3d9a1d2d
...
@@ -5,6 +5,7 @@ from collections.abc import Mapping
...
@@ -5,6 +5,7 @@ from collections.abc import Mapping
from
copy
import
copy
from
copy
import
copy
from
typing
import
Any
,
Callable
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
torch.nn
as
nn
from
typing_extensions
import
TypeVar
from
typing_extensions
import
TypeVar
import
vllm.envs
as
envs
import
vllm.envs
as
envs
...
@@ -33,6 +34,7 @@ from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase,
...
@@ -33,6 +34,7 @@ from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase,
StatLoggerFactory
)
StatLoggerFactory
)
from
vllm.v1.metrics.reader
import
Metric
,
get_metrics_snapshot
from
vllm.v1.metrics.reader
import
Metric
,
get_metrics_snapshot
from
vllm.v1.metrics.stats
import
IterationStats
from
vllm.v1.metrics.stats
import
IterationStats
from
vllm.v1.worker.worker_base
import
WorkerBase
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -319,12 +321,15 @@ class LLMEngine:
...
@@ -319,12 +321,15 @@ class LLMEngine:
return
self
.
engine_core
.
pin_lora
(
lora_id
)
return
self
.
engine_core
.
pin_lora
(
lora_id
)
def
collective_rpc
(
self
,
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[
...
,
_R
]],
method
:
Union
[
str
,
Callable
[
[
WorkerBase
]
,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
self
.
engine_core
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
return
self
.
engine_core
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
def
apply_model
(
self
,
func
:
Callable
[[
nn
.
Module
],
_R
])
->
list
[
_R
]:
return
self
.
collective_rpc
(
"apply_model"
,
args
=
(
func
,
))
def
__del__
(
self
):
def
__del__
(
self
):
if
dp_group
:
=
getattr
(
self
,
"dp_group"
,
None
):
if
dp_group
:
=
getattr
(
self
,
"dp_group"
,
None
):
stateless_destroy_torch_distributed_process_group
(
dp_group
)
stateless_destroy_torch_distributed_process_group
(
dp_group
)
vllm/worker/worker_base.py
View file @
3d9a1d2d
...
@@ -5,7 +5,8 @@ import dataclasses
...
@@ -5,7 +5,8 @@ import dataclasses
import
os
import
os
import
time
import
time
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
from
typing
import
(
Any
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
TypeVar
,
Union
)
import
cloudpickle
import
cloudpickle
import
torch
import
torch
...
@@ -28,6 +29,8 @@ from vllm.worker.model_runner_base import (BroadcastableModelInput,
...
@@ -28,6 +29,8 @@ from vllm.worker.model_runner_base import (BroadcastableModelInput,
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_R
=
TypeVar
(
"_R"
)
@
warn_for_unimplemented_methods
@
warn_for_unimplemented_methods
class
WorkerBase
:
class
WorkerBase
:
...
@@ -70,6 +73,10 @@ class WorkerBase:
...
@@ -70,6 +73,10 @@ class WorkerBase:
def
get_model
(
self
)
->
nn
.
Module
:
def
get_model
(
self
)
->
nn
.
Module
:
raise
NotImplementedError
raise
NotImplementedError
def
apply_model
(
self
,
fn
:
Callable
[[
nn
.
Module
],
_R
])
->
_R
:
"""Apply a function on the model inside this worker."""
return
fn
(
self
.
get_model
())
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
"""Load model onto target device."""
"""Load model onto target device."""
raise
NotImplementedError
raise
NotImplementedError
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment