Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3d9a1d2d
Unverified
Commit
3d9a1d2d
authored
Sep 20, 2025
by
Cyrus Leung
Committed by
GitHub
Sep 20, 2025
Browse files
[V1] Support `LLM.apply_model` (#18465)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
be874c02
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
194 additions
and
169 deletions
+194
-169
tests/conftest.py
tests/conftest.py
+1
-11
tests/kernels/moe/test_mxfp4_moe.py
tests/kernels/moe/test_mxfp4_moe.py
+22
-15
tests/models/multimodal/generation/test_qwen2_vl.py
tests/models/multimodal/generation/test_qwen2_vl.py
+23
-23
tests/models/quantization/test_awq.py
tests/models/quantization/test_awq.py
+1
-1
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+8
-10
tests/quantization/test_fp8.py
tests/quantization/test_fp8.py
+4
-4
tests/quantization/test_gptq_dynamic.py
tests/quantization/test_gptq_dynamic.py
+38
-33
tests/quantization/test_lm_head.py
tests/quantization/test_lm_head.py
+2
-2
tests/quantization/test_modelopt.py
tests/quantization/test_modelopt.py
+3
-7
tests/quantization/test_ptpc_fp8.py
tests/quantization/test_ptpc_fp8.py
+28
-19
tests/quantization/test_quark.py
tests/quantization/test_quark.py
+11
-15
tests/quantization/test_register_quantization_config.py
tests/quantization/test_register_quantization_config.py
+10
-7
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+6
-1
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+7
-2
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+16
-17
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+6
-1
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+8
-1
No files found.
tests/conftest.py
View file @
3d9a1d2d
...
...
@@ -987,17 +987,7 @@ class VllmRunner:
return
[
req_output
.
outputs
.
score
for
req_output
in
req_outputs
]
def
apply_model
(
self
,
func
:
Callable
[[
nn
.
Module
],
_R
])
->
list
[
_R
]:
if
hasattr
(
self
.
llm
.
llm_engine
,
"model_executor"
):
# This works either in V0 or in V1 with
# VLLM_ENABLE_V1_MULTIPROCESSING=0
executor
=
self
.
llm
.
llm_engine
.
model_executor
return
executor
.
apply_model
(
func
)
# This works in V1 with VLLM_ALLOW_INSECURE_SERIALIZATION=1
def
_apply_model
(
self
):
return
func
(
self
.
get_model
())
return
self
.
llm
.
llm_engine
.
collective_rpc
(
_apply_model
)
return
self
.
llm
.
apply_model
(
func
)
def
get_llm
(
self
)
->
LLM
:
return
self
.
llm
...
...
tests/kernels/moe/test_mxfp4_moe.py
View file @
3d9a1d2d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
importlib
import
importlib.metadata
from
dataclasses
import
dataclass
from
importlib.util
import
find_spec
from
typing
import
Optional
import
pytest
import
torch
from
packaging
import
version
from
vllm.model_executor.layers.quantization.quark.quark
import
(
# noqa: E501
QuarkLinearMethod
,
QuarkW4A4MXFP4
)
from
vllm.model_executor.layers.quantization.quark.quark_moe
import
(
# noqa: E501
QuarkW4A4MXFp4MoEMethod
)
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
has_flashinfer
QUARK_MXFP4_AVAILABLE
=
importlib
.
util
.
find_spec
(
"quark"
)
is
not
None
and
version
.
parse
(
importlib
.
metadata
.
version
(
"amd-quark"
))
>=
version
.
parse
(
'0.8.99'
)
QUARK_MXFP4_AVAILABLE
=
find_spec
(
"quark"
)
is
not
None
and
version
.
parse
(
importlib
.
metadata
.
version
(
"amd-quark"
))
>=
version
.
parse
(
'0.8.99'
)
TRTLLM_GEN_MXFP4_AVAILABLE
=
current_platform
.
is_cuda
(
)
and
current_platform
.
is_device_capability
(
100
)
...
...
@@ -39,6 +42,12 @@ class ModelCase:
tp
:
int
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
enable_pickle
(
monkeypatch
):
"""`LLM.apply_model` requires pickling a function."""
monkeypatch
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
@
pytest
.
mark
.
parametrize
(
'model_case'
,
[
ModelCase
(
"fxmarty/qwen_1.5-moe-a2.7b-mxfp4"
,
tp
=
1
),
ModelCase
(
"fxmarty/deepseek_r1_3_layers_mxfp4"
,
tp
=
8
),
...
...
@@ -55,21 +64,19 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
tensor_parallel_size
=
model_case
.
tp
,
load_format
=
"dummy"
)
as
llm
:
# TODO: llm.apply_model(check_model) currently relies on V0 internals.
# Re-enable this later.
# def check_model(model):
# layer = model.model.layers[0]
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
#
qkv_proj = layer.self_attn.qkv_proj
qkv_proj
=
layer
.
self_attn
.
qkv_proj
#
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
#
assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4)
assert
isinstance
(
qkv_proj
.
quant_method
,
QuarkLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
QuarkW4A4MXFP4
)
#
assert isinstance(layer.mlp.experts.quant_method,
#
QuarkW4A4MXFp4MoEMethod)
assert
isinstance
(
layer
.
mlp
.
experts
.
quant_method
,
QuarkW4A4MXFp4MoEMethod
)
#
if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
#
llm.apply_model(check_model)
if
model_case
.
model_id
==
"fxmarty/qwen_1.5-moe-a2.7b-mxfp4"
:
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Today I am in the French Alps and"
,
max_tokens
=
20
)
...
...
tests/models/multimodal/generation/test_qwen2_vl.py
View file @
3d9a1d2d
...
...
@@ -10,6 +10,7 @@ from PIL import Image
from
vllm.multimodal.image
import
rescale_image_size
from
vllm.multimodal.video
import
rescale_video_size
,
sample_frames_from_video
from
vllm.utils
import
set_default_torch_num_threads
from
....conftest
import
(
IMAGE_ASSETS
,
VIDEO_ASSETS
,
PromptImageInput
,
PromptVideoInput
,
VllmRunner
)
...
...
@@ -17,11 +18,9 @@ from ...utils import check_logprobs_close
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
"""
V1 Test: batch_make_xxxxx_embeddings calls a V0 internal
"""
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
def
enable_pickle
(
monkeypatch
):
"""`LLM.apply_model` requires pickling a function."""
monkeypatch
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
models
=
[
"Qwen/Qwen2-VL-2B-Instruct"
]
...
...
@@ -126,9 +125,8 @@ def batch_make_image_embeddings(
image_grid_thw_on_device
=
image_grid_thw
.
to
(
visual
.
device
,
dtype
=
torch
.
int64
)
return
visual
(
pixel_values_on_device
,
grid_thw
=
image_grid_thw_on_device
)
grid_thw
=
image_grid_thw_on_device
)
.
cpu
()
# V1 Test: this calls a V0 internal.
image_embeds
=
torch
.
concat
(
llm
.
apply_model
(
get_image_embeds
))
# split into original batches
...
...
@@ -210,7 +208,7 @@ def batch_make_video_embeddings(
video_grid_thw_on_device
=
video_grid_thw
.
to
(
visual
.
device
,
dtype
=
torch
.
int64
)
return
visual
(
pixel_values_on_device
,
grid_thw
=
video_grid_thw_on_device
)
grid_thw
=
video_grid_thw_on_device
)
.
cpu
()
# V1 Test: this calls a V0 internal.
video_embeds
=
torch
.
concat
(
llm
.
apply_model
(
get_image_embeds
))
...
...
@@ -266,19 +264,22 @@ def run_embedding_input_test(
processor
=
AutoProcessor
.
from_pretrained
(
model
)
# max_model_len should be greater than image_feature_size
with
vllm_runner
(
model
,
runner
=
"generate"
,
max_model_len
=
4000
,
max_num_seqs
=
3
,
dtype
=
dtype
,
limit_mm_per_prompt
=
{
"image"
:
mm_limit
,
"video"
:
mm_limit
},
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
)
as
vllm_model
:
with
set_default_torch_num_threads
(
1
):
vllm_model
=
vllm_runner
(
model
,
runner
=
"generate"
,
max_model_len
=
4000
,
max_num_seqs
=
3
,
dtype
=
dtype
,
limit_mm_per_prompt
=
{
"image"
:
mm_limit
,
"video"
:
mm_limit
},
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
)
with
vllm_model
:
outputs_per_case_for_original_input
=
[
vllm_model
.
generate_greedy_logprobs
(
prompts
,
max_tokens
,
...
...
@@ -329,9 +330,8 @@ def run_embedding_input_test(
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
10
])
def
test_qwen2_vl_image_embeddings_input
(
vllm_runner
,
image_assets
,
model
,
size_factors
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
)
->
None
:
size_factors
,
dtype
,
max_tokens
,
num_logprobs
,
monkeypatch
)
->
None
:
images
=
[
asset
.
pil_image
for
asset
in
image_assets
]
inputs_per_case
:
list
[
tuple
[
...
...
tests/models/quantization/test_awq.py
View file @
3d9a1d2d
...
...
@@ -112,7 +112,7 @@ def test_awq_models(vllm_runner, image_assets, source_model, quant_model,
monkeypatch
)
->
None
:
# Test V1: this test hangs during setup on single-scale input.
# TODO: fi
x
ure out why and re-enable this on V1.
# TODO: fi
g
ure out why and re-enable this on V1.
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
run_awq_test
(
vllm_runner
,
...
...
tests/quantization/test_compressed_tensors.py
View file @
3d9a1d2d
...
...
@@ -43,12 +43,9 @@ ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL = [
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
"""
This module relies on V0 internals, so set VLLM_USE_V1=0.
"""
if
not
current_platform
.
is_cpu
():
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
def
enable_pickle
(
monkeypatch
):
"""`LLM.apply_model` requires pickling a function."""
monkeypatch
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -176,10 +173,11 @@ def test_compressed_tensors_w8a8_logprobs(
dtype
=
"bfloat16"
# skip language translation prompt for the static per tensor asym model
if
(
model_path
==
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
):
# noqa: E501
# skip language translation prompt for the static per tensor models
if
model_path
in
(
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym"
,
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
,
):
example_prompts
=
example_prompts
[
0
:
-
1
]
with
hf_runner
(
model_path
,
dtype
=
dtype
)
as
hf_model
:
...
...
tests/quantization/test_fp8.py
View file @
3d9a1d2d
...
...
@@ -60,8 +60,8 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str,
if
use_rocm_aiter
:
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
#
vllm_runner
.apply_model
()
re
lies on V0 internals
.
monkeypatch
.
setenv
(
"VLLM_
USE_V1
"
,
"
0
"
)
#
`LLM
.apply_model
`
re
quires pickling a function
.
monkeypatch
.
setenv
(
"VLLM_
ALLOW_INSECURE_SERIALIZATION
"
,
"
1
"
)
with
vllm_runner
(
model_id
,
kv_cache_dtype
=
"fp8"
)
as
llm
:
def
check_model
(
model
):
...
...
@@ -104,8 +104,8 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
if
use_rocm_aiter
:
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
#
vllm_runner
.apply_model
()
re
lies on V0 internals
.
monkeypatch
.
setenv
(
"VLLM_
USE_V1
"
,
"
0
"
)
#
`LLM
.apply_model
`
re
quires pickling a function
.
monkeypatch
.
setenv
(
"VLLM_
ALLOW_INSECURE_SERIALIZATION
"
,
"
1
"
)
if
force_marlin
:
monkeypatch
.
setenv
(
"VLLM_TEST_FORCE_FP8_MARLIN"
,
"1"
)
...
...
tests/quantization/test_gptq_dynamic.py
View file @
3d9a1d2d
...
...
@@ -31,41 +31,46 @@ MODEL_QUANT = [
@
pytest
.
mark
.
parametrize
(
"model_id, use_marlin_kernel"
,
MODEL_QUANT
)
def
test_gptq_with_dynamic
(
vllm_runner
,
model_id
:
str
,
use_marlin_kernel
:
bool
,
monkeypatch
):
# vllm_runner.apply_model() relies on V0 internals.
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
vllm_model
=
vllm_runner
(
model_id
,
dtype
=
torch
.
float16
,
max_model_len
=
2048
)
# `LLM.apply_model` requires pickling a function.
monkeypatch
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
linear_method_cls
=
GPTQMarlinLinearMethod
if
use_marlin_kernel
else
(
GPTQLinearMethod
)
for
name
,
submodule
in
(
vllm_model
.
llm
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
.
named_modules
()):
if
name
==
"lm_head"
:
assert
isinstance
(
submodule
.
quant_method
,
linear_method_cls
)
elif
name
==
'model.layers.0.self_attn.qkv_proj'
:
# The first layer is quantized using bits=4, group_size=128
# desc_act=True
assert
isinstance
(
submodule
.
quant_method
,
linear_method_cls
)
config
=
submodule
.
quant_method
.
quant_config
assert
config
.
weight_bits
==
4
assert
config
.
group_size
==
128
assert
config
.
desc_act
elif
name
==
'model.layers.1.self_attn.qkv_proj'
:
# The second layer is quantized using bits=8, group_size=32
# desc_act=False
assert
isinstance
(
submodule
.
quant_method
,
linear_method_cls
)
config
=
submodule
.
quant_method
.
quant_config
assert
get_dynamic_override
(
config
,
layer_name
=
name
,
key
=
"bits"
)
==
8
assert
get_dynamic_override
(
config
,
layer_name
=
name
,
key
=
"group_size"
)
==
32
assert
not
get_dynamic_override
(
config
,
layer_name
=
name
,
key
=
"desc_act"
)
elif
(
name
==
'model.layers.2.self_attn.qkv_proj'
or
name
==
'model.layers.2.mlp.gate_up_proj'
):
# All other layers (layer index >= 2) are not quantized
assert
isinstance
(
submodule
.
quant_method
,
UnquantizedLinearMethod
)
with
vllm_runner
(
model_id
,
dtype
=
torch
.
float16
,
max_model_len
=
2048
)
as
llm
:
def
check_model
(
model
):
for
name
,
submodule
in
model
.
named_modules
():
if
name
==
"lm_head"
:
assert
isinstance
(
submodule
.
quant_method
,
linear_method_cls
)
elif
name
==
'model.layers.0.self_attn.qkv_proj'
:
# The first layer is quantized using bits=4, group_size=128
# desc_act=True
assert
isinstance
(
submodule
.
quant_method
,
linear_method_cls
)
config
=
submodule
.
quant_method
.
quant_config
assert
config
.
weight_bits
==
4
assert
config
.
group_size
==
128
assert
config
.
desc_act
elif
name
==
'model.layers.1.self_attn.qkv_proj'
:
# The second layer is quantized using bits=8, group_size=32
# desc_act=False
assert
isinstance
(
submodule
.
quant_method
,
linear_method_cls
)
config
=
submodule
.
quant_method
.
quant_config
assert
get_dynamic_override
(
config
,
layer_name
=
name
,
key
=
"bits"
)
==
8
assert
get_dynamic_override
(
config
,
layer_name
=
name
,
key
=
"group_size"
)
==
32
assert
not
get_dynamic_override
(
config
,
layer_name
=
name
,
key
=
"desc_act"
)
elif
(
name
==
'model.layers.2.self_attn.qkv_proj'
or
name
==
'model.layers.2.mlp.gate_up_proj'
):
# All other layers (layer index >= 2) are not quantized
assert
isinstance
(
submodule
.
quant_method
,
UnquantizedLinearMethod
)
del
vllm_model
llm
.
apply_model
(
check_model
)
tests/quantization/test_lm_head.py
View file @
3d9a1d2d
...
...
@@ -29,8 +29,8 @@ def test_lm_head(
lm_head_quantized
:
bool
,
monkeypatch
,
)
->
None
:
#
vllm_runner
.apply_model
()
re
lies on V0 internals
.
monkeypatch
.
setenv
(
"VLLM_
USE_V1
"
,
"
0
"
)
#
`LLM
.apply_model
`
re
quires pickling a function
.
monkeypatch
.
setenv
(
"VLLM_
ALLOW_INSECURE_SERIALIZATION
"
,
"
1
"
)
with
vllm_runner
(
model_id
,
dtype
=
torch
.
float16
,
max_model_len
=
2048
)
as
vllm_model
:
...
...
tests/quantization/test_modelopt.py
View file @
3d9a1d2d
...
...
@@ -11,16 +11,12 @@ import pytest
import
torch
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm.platforms
import
current_platform
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
"""
This module relies on V0 internals, so set VLLM_USE_V1=0.
"""
if
not
current_platform
.
is_cpu
():
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
def
enable_pickle
(
monkeypatch
):
"""`LLM.apply_model` requires pickling a function."""
monkeypatch
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"modelopt"
),
...
...
tests/quantization/test_ptpc_fp8.py
View file @
3d9a1d2d
...
...
@@ -13,6 +13,16 @@ from vllm.model_executor.layers.quantization.ptpc_fp8 import (
PTPCFp8LinearMethod
)
from
vllm.platforms
import
current_platform
UNSUPPORTED_STR
=
(
"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only "
"support output dtype of bfloat16. torch.float16 is specified."
)
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
enable_pickle
(
monkeypatch
):
"""`LLM.apply_model` requires pickling a function."""
monkeypatch
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"ptpc_fp8"
),
reason
=
"PTPC FP8 is not supported on this GPU type."
)
...
...
@@ -21,14 +31,22 @@ from vllm.platforms import current_platform
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"auto"
,
"bfloat16"
,
"float16"
])
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
[
"auto"
,
"fp8"
,
"fp8_e4m3"
])
def
test_ptpc_fp8_rocm
(
vllm_runner
,
dtype
:
str
,
kv_cache_dtype
:
str
)
->
None
:
try
:
with
vllm_runner
(
"facebook/opt-125m"
,
dtype
=
dtype
,
quantization
=
"ptpc_fp8"
,
kv_cache_dtype
=
kv_cache_dtype
)
as
llm
:
llm
=
vllm_runner
(
"facebook/opt-125m"
,
dtype
=
dtype
,
quantization
=
"ptpc_fp8"
,
kv_cache_dtype
=
kv_cache_dtype
)
except
AssertionError
as
e
:
if
str
(
e
)
==
UNSUPPORTED_STR
:
# If the error message matches, the test passes
return
else
:
# If the error message does not match, re-raise the exception
raise
with
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
def
check_model
(
model
):
fc1
=
model
.
model
.
decoder
.
layers
[
0
].
fc1
assert
isinstance
(
fc1
.
quant_method
,
PTPCFp8LinearMethod
)
if
kv_cache_dtype
==
"ptpc_fp8"
:
...
...
@@ -40,17 +58,8 @@ def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
if
current_platform
.
has_device_capability
(
94
):
# For GPUs with hardware support, we keep weights in fp8
assert
fc1
.
weight
.
dtype
==
torch
.
float8_e4m3fnuz
else
:
pytest
.
skip
()
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
assert
output
except
AssertionError
as
e
:
if
str
(
e
)
==
"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. torch.float16 is specified."
:
# noqa: E501
# If the error message matches, the test passes
pass
else
:
# If the error message does not match, re-raise the exception
raise
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
assert
output
tests/quantization/test_quark.py
View file @
3d9a1d2d
...
...
@@ -7,10 +7,10 @@ Run `pytest tests/quantization/test_quark.py`.
See also `tests/kernels/moe/test_mxfp4_moe.py`.
"""
import
importlib
import
importlib.metadata
import
os
from
dataclasses
import
dataclass
from
importlib.util
import
find_spec
import
huggingface_hub
import
lm_eval
...
...
@@ -24,9 +24,8 @@ from vllm.platforms import current_platform
from
.reference_mxfp4
import
dq_mxfp4_torch
,
qdq_mxfp4_torch
QUARK_MXFP4_AVAILABLE
=
importlib
.
util
.
find_spec
(
"quark"
)
is
not
None
and
version
.
parse
(
importlib
.
metadata
.
version
(
"amd-quark"
))
>=
version
.
parse
(
'0.8.99'
)
QUARK_MXFP4_AVAILABLE
=
find_spec
(
"quark"
)
is
not
None
and
version
.
parse
(
importlib
.
metadata
.
version
(
"amd-quark"
))
>=
version
.
parse
(
'0.8.99'
)
if
QUARK_MXFP4_AVAILABLE
:
from
quark.torch.export.nn.modules.realquantizer
import
(
...
...
@@ -43,11 +42,9 @@ except huggingface_hub.errors.RepositoryNotFoundError:
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
"""
This module relies on V0 internals, so set VLLM_USE_V1=0.
"""
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
def
enable_pickle
(
monkeypatch
):
"""`LLM.apply_model` requires pickling a function."""
monkeypatch
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
@
pytest
.
mark
.
parametrize
(
'kv_cache_dtype'
,
[
'auto'
,
'fp8'
])
...
...
@@ -132,13 +129,12 @@ def test_quark_fp8_parity(vllm_runner):
}
with
(
vllm_runner
(
quark_model_id
,
**
llm_kwargs
)
as
quark_handle
,
vllm_runner
(
fp8_model_id
,
**
llm_kwargs
)
as
fp8_handle
):
quark_model
=
(
quark_handle
.
llm
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
)
quark_state_dict
=
quark_model
.
state_dict
()
fp8_model
=
(
fp8_handle
.
llm
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
)
fp8_state_dict
=
fp8_model
.
state_dict
()
def
get_state_dict
(
model
):
return
{
k
:
v
.
cpu
()
for
k
,
v
in
model
.
state_dict
().
items
()}
quark_state_dict
,
=
quark_handle
.
apply_model
(
get_state_dict
)
fp8_state_dict
,
=
fp8_handle
.
apply_model
(
get_state_dict
)
assert
fp8_state_dict
.
keys
()
==
quark_state_dict
.
keys
()
...
...
tests/quantization/test_register_quantization_config.py
View file @
3d9a1d2d
...
...
@@ -105,18 +105,21 @@ def test_register_quantization_config():
])
def
test_custom_quant
(
vllm_runner
,
model
,
monkeypatch
):
"""Test infer with the custom quantization method."""
# vllm_runner.apply_model() relies on V0 internals.
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
# `LLM.apply_model` requires pickling a function.
monkeypatch
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
with
vllm_runner
(
model_name
=
model
,
quantization
=
"custom_quant"
,
enforce_eager
=
True
)
as
llm
:
model
=
llm
.
llm
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
# Check the quantization method is FakeQuantLinearMethod
assert
isinstance
(
qkv_proj
.
quant_method
,
FakeQuantLinearMethod
)
# Check the quantization method is FakeQuantLinearMethod
assert
isinstance
(
qkv_proj
.
quant_method
,
FakeQuantLinearMethod
)
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
assert
output
vllm/engine/llm_engine.py
View file @
3d9a1d2d
...
...
@@ -13,6 +13,7 @@ from typing import Sequence as GenericSequence
from
typing
import
Set
,
Type
,
Union
,
cast
import
torch
import
torch.nn
as
nn
from
typing_extensions
import
TypeVar
import
vllm.envs
as
envs
...
...
@@ -55,6 +56,7 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
from
vllm.utils
import
Counter
,
Device
,
resolve_obj_by_qualname
,
weak_bind
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.worker.model_runner_base
import
InputProcessingError
from
vllm.worker.worker_base
import
WorkerBase
logger
=
init_logger
(
__name__
)
_LOCAL_LOGGING_INTERVAL_SEC
=
5
...
...
@@ -1817,13 +1819,16 @@ class LLMEngine:
return
sampling_params
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[
...
,
_R
]],
method
:
Union
[
str
,
Callable
[
[
WorkerBase
]
,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
self
.
model_executor
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
def
apply_model
(
self
,
func
:
Callable
[[
nn
.
Module
],
_R
])
->
list
[
_R
]:
return
self
.
collective_rpc
(
"apply_model"
,
args
=
(
func
,
))
if
envs
.
is_set
(
"VLLM_USE_V1"
)
and
envs
.
VLLM_USE_V1
:
from
vllm.v1.engine.llm_engine
import
LLMEngine
as
V1LLMEngine
...
...
vllm/entrypoints/llm.py
View file @
3d9a1d2d
...
...
@@ -522,9 +522,14 @@ class LLM:
"""
Run a function directly on the model inside each worker,
returning the result for each of them.
!!! warning
To reduce the overhead of data transfer, avoid returning large
arrays or tensors from this method. If you must return them,
make sure you move them to CPU first to avoid taking up additional
VRAM!
"""
executor
=
self
.
llm_engine
.
model_executor
return
executor
.
apply_model
(
func
)
return
self
.
llm_engine
.
apply_model
(
func
)
def
_get_beam_search_lora_requests
(
self
,
...
...
vllm/executor/executor_base.py
View file @
3d9a1d2d
...
...
@@ -5,11 +5,10 @@ import asyncio
import
time
from
abc
import
ABC
,
abstractmethod
from
functools
import
cached_property
from
typing
import
(
Any
,
Awaitable
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
)
from
typing
import
Any
,
Awaitable
,
Callable
,
List
,
Optional
,
Set
,
Union
import
torch.nn
as
nn
from
typing_extensions
import
TypeVar
from
typing_extensions
import
TypeVar
,
deprecated
import
vllm.platforms
from
vllm.config
import
VllmConfig
...
...
@@ -63,10 +62,10 @@ class ExecutorBase(ABC):
@
abstractmethod
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[
...
,
_R
]],
method
:
Union
[
str
,
Callable
[
[
WorkerBase
]
,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
T
uple
=
(),
kwargs
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
)
->
L
ist
[
_R
]:
args
:
t
uple
=
(),
kwargs
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
)
->
l
ist
[
_R
]:
"""
Execute an RPC call on all workers.
...
...
@@ -91,7 +90,7 @@ class ExecutorBase(ABC):
"""
raise
NotImplementedError
def
determine_num_available_blocks
(
self
)
->
T
uple
[
int
,
int
]:
def
determine_num_available_blocks
(
self
)
->
t
uple
[
int
,
int
]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
...
...
@@ -99,9 +98,10 @@ class ExecutorBase(ABC):
ExecutorBase may require modification of the result, e.g. to ensure the
selected cache sizes are compatible with all workers.
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
Returns a tuple `(num_gpu_blocks, num_cpu_blocks)`, where
`num_gpu_blocks` are blocks that are "active" on the device and can be
appended to.
`num_cpu_blocks` refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
results
=
self
.
collective_rpc
(
"determine_num_available_blocks"
)
...
...
@@ -127,16 +127,15 @@ class ExecutorBase(ABC):
self
.
collective_rpc
(
"initialize_cache"
,
args
=
(
num_gpu_blocks
,
num_cpu_blocks
))
@
deprecated
(
"`llm_engine.model_executor.apply_model` will no longer work "
"in V1 Engine. Please replace with `llm_engine.apply_model` "
"and set `VLLM_ALLOW_INSECURE_SERIALIZATION=1`."
)
def
apply_model
(
self
,
func
:
Callable
[[
nn
.
Module
],
_R
])
->
list
[
_R
]:
"""
Run a function directly on the model inside each worker,
returning the result for each of them.
"""
def
rpc_func
(
worker
:
WorkerBase
)
->
_R
:
return
func
(
worker
.
get_model
())
return
self
.
collective_rpc
(
rpc_func
)
return
self
.
collective_rpc
(
"apply_model"
,
args
=
(
func
,
))
@
cached_property
# Avoid unnecessary RPC calls
def
supported_tasks
(
self
)
->
tuple
[
SupportedTask
,
...]:
...
...
@@ -308,8 +307,8 @@ class DistributedExecutorBase(ExecutorBase):
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
],
timeout
:
Optional
[
float
]
=
None
,
args
:
T
uple
=
(),
kwargs
:
Optional
[
D
ict
]
=
None
)
->
L
ist
[
Any
]:
args
:
t
uple
=
(),
kwargs
:
Optional
[
d
ict
[
str
,
Any
]
]
=
None
)
->
l
ist
[
Any
]:
return
self
.
_run_workers
(
method
,
*
args
,
**
(
kwargs
or
{}))
@
abstractmethod
...
...
vllm/v1/engine/llm_engine.py
View file @
3d9a1d2d
...
...
@@ -5,6 +5,7 @@ from collections.abc import Mapping
from
copy
import
copy
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
torch.nn
as
nn
from
typing_extensions
import
TypeVar
import
vllm.envs
as
envs
...
...
@@ -33,6 +34,7 @@ from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase,
StatLoggerFactory
)
from
vllm.v1.metrics.reader
import
Metric
,
get_metrics_snapshot
from
vllm.v1.metrics.stats
import
IterationStats
from
vllm.v1.worker.worker_base
import
WorkerBase
logger
=
init_logger
(
__name__
)
...
...
@@ -319,12 +321,15 @@ class LLMEngine:
return
self
.
engine_core
.
pin_lora
(
lora_id
)
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[
...
,
_R
]],
method
:
Union
[
str
,
Callable
[
[
WorkerBase
]
,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
self
.
engine_core
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
def
apply_model
(
self
,
func
:
Callable
[[
nn
.
Module
],
_R
])
->
list
[
_R
]:
return
self
.
collective_rpc
(
"apply_model"
,
args
=
(
func
,
))
def
__del__
(
self
):
if
dp_group
:
=
getattr
(
self
,
"dp_group"
,
None
):
stateless_destroy_torch_distributed_process_group
(
dp_group
)
vllm/worker/worker_base.py
View file @
3d9a1d2d
...
...
@@ -5,7 +5,8 @@ import dataclasses
import
os
import
time
from
abc
import
abstractmethod
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
from
typing
import
(
Any
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
TypeVar
,
Union
)
import
cloudpickle
import
torch
...
...
@@ -28,6 +29,8 @@ from vllm.worker.model_runner_base import (BroadcastableModelInput,
logger
=
init_logger
(
__name__
)
_R
=
TypeVar
(
"_R"
)
@
warn_for_unimplemented_methods
class
WorkerBase
:
...
...
@@ -70,6 +73,10 @@ class WorkerBase:
def
get_model
(
self
)
->
nn
.
Module
:
raise
NotImplementedError
def
apply_model
(
self
,
fn
:
Callable
[[
nn
.
Module
],
_R
])
->
_R
:
"""Apply a function on the model inside this worker."""
return
fn
(
self
.
get_model
())
def
load_model
(
self
)
->
None
:
"""Load model onto target device."""
raise
NotImplementedError
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment