Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
994acec0
Unverified
Commit
994acec0
authored
Dec 14, 2025
by
ElizaWszola
Committed by
GitHub
Dec 14, 2025
Browse files
[Bugfix] Fix fusion for VL models (#30244)
Signed-off-by:
ElizaWszola
<
ewszola@redhat.com
>
parent
48b8456f
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
143 additions
and
72 deletions
+143
-72
tests/compile/distributed/test_fusions_e2e.py
tests/compile/distributed/test_fusions_e2e.py
+78
-0
vllm/compilation/fusion.py
vllm/compilation/fusion.py
+52
-48
vllm/compilation/matcher_utils.py
vllm/compilation/matcher_utils.py
+13
-7
vllm/utils/deep_gemm.py
vllm/utils/deep_gemm.py
+0
-17
No files found.
tests/compile/distributed/test_fusions_e2e.py
View file @
994acec0
...
...
@@ -27,6 +27,7 @@ is_blackwell = lambda: current_platform.is_device_capability_family(100)
class
Matches
(
NamedTuple
):
attention_fusion
:
int
=
0
allreduce_fusion
:
int
=
0
rms_quant_norm_fusion
:
int
=
0
sequence_parallel
:
int
=
0
async_tp
:
int
=
0
...
...
@@ -40,6 +41,7 @@ class ModelBackendTestCase(NamedTuple):
MODELS_FP8
:
list
[
ModelBackendTestCase
]
=
[]
MODELS_FP4
:
list
[
ModelBackendTestCase
]
=
[]
MODELS_GROUP_FP8
:
list
[
ModelBackendTestCase
]
=
[]
MODELS
:
list
[
ModelBackendTestCase
]
=
[]
# tp-only
if
current_platform
.
is_cuda
():
...
...
@@ -498,3 +500,79 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg
compilation_config
.
compile_ranges_split_points
=
(
llm
.
llm_engine
.
vllm_config
.
compilation_config
.
compile_ranges_split_points
)
if
current_platform
.
is_cuda
():
MODELS_GROUP_FP8
=
[
ModelBackendTestCase
(
model_name
=
"Qwen/Qwen3-30B-A3B-FP8"
,
model_kwargs
=
dict
(
max_model_len
=
1024
,
kv_cache_dtype
=
"fp8"
),
backend
=
AttentionBackendEnum
.
TRITON_ATTN
,
matches
=
Matches
(
rms_quant_norm_fusion
=
48
,
),
),
]
CUSTOM_OPS_QUANT_RMS_NORM
=
[
"+quant_fp8,+rms_norm"
]
@
pytest
.
mark
.
parametrize
(
"model_name, model_kwargs, backend, matches, custom_ops"
,
# Test rms norm+group quant_fp8 fusion
list
[
tuple
[
Any
,
...]](
flat_product
(
MODELS_GROUP_FP8
,
CUSTOM_OPS_QUANT_RMS_NORM
)),
)
@
pytest
.
mark
.
parametrize
(
"inductor_graph_partition"
,
[
True
,
False
])
def
test_rms_group_quant
(
model_name
:
str
,
model_kwargs
:
dict
[
str
,
Any
],
backend
:
AttentionBackendEnum
,
matches
:
Matches
,
custom_ops
:
str
,
inductor_graph_partition
:
bool
,
caplog_mp_spawn
,
monkeypatch
,
):
if
inductor_graph_partition
and
not
is_torch_equal_or_newer
(
"2.9.0.dev"
):
pytest
.
skip
(
"Inductor graph partition requires torch>=2.9"
)
custom_ops_list
=
custom_ops
.
split
(
","
)
if
custom_ops
else
[]
if
inductor_graph_partition
:
mode
=
CUDAGraphMode
.
FULL_AND_PIECEWISE
splitting_ops
:
list
[
str
]
|
None
=
None
else
:
mode
=
CUDAGraphMode
.
FULL_DECODE_ONLY
splitting_ops
=
[]
# Disable, compile cache to make sure custom passes run.
# Otherwise, we can't verify fusion happened through the logs.
monkeypatch
.
setenv
(
"VLLM_DISABLE_COMPILE_CACHE"
,
"1"
)
# To capture subprocess logs, we need to know whether spawn or fork is used.
# Force spawn as it is more general.
monkeypatch
.
setenv
(
"VLLM_WORKER_MULTIPROC_METHOD"
,
"spawn"
)
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
.
name
)
compilation_config
=
CompilationConfig
(
# Testing properties
custom_ops
=
custom_ops_list
,
use_inductor_graph_partition
=
inductor_graph_partition
,
cudagraph_mode
=
mode
,
splitting_ops
=
splitting_ops
,
# Common
mode
=
CompilationMode
.
VLLM_COMPILE
,
pass_config
=
PassConfig
(
eliminate_noops
=
True
,
enable_fusion
=
True
),
# Inductor caches custom passes by default as well via uuid
inductor_compile_config
=
{
"force_disable_caches"
:
True
},
)
with
caplog_mp_spawn
(
logging
.
DEBUG
)
as
log_holder
:
run_model
(
compilation_config
,
model_name
,
**
model_kwargs
)
log_matches
=
re
.
findall
(
r
"\[fusion.py:\d+] Replaced (\d+) patterns"
,
log_holder
.
text
,
)
assert
len
(
log_matches
)
==
1
,
log_holder
.
text
assert
int
(
log_matches
[
0
])
==
matches
.
rms_quant_norm_fusion
vllm/compilation/fusion.py
View file @
994acec0
...
...
@@ -23,17 +23,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kNvfp4Quant
,
kStaticTensorScale
,
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
cutlass_block_fp8_supported
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.deep_gemm
import
(
is_deep_gemm_e8m0_used
,
should_use_deepgemm_for_fp8_linear_for_nk
,
)
from
.inductor_pass
import
enable_fake_mode
from
.matcher_utils
import
MatcherFusedAddRMSNorm
,
MatcherQuantFP8
,
MatcherRMSNorm
from
.matcher_utils
import
(
MatcherFusedAddRMSNorm
,
MatcherQuantFP8
,
MatcherRMSNorm
,
)
from
.vllm_inductor_pass
import
VllmInductorPass
,
VllmPatternMatcherPass
logger
=
init_logger
(
__name__
)
...
...
@@ -118,21 +115,18 @@ FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
class
RMSNormQuantPattern
:
def
__init__
(
self
,
epsilon
:
float
,
key
:
FusedRMSQuantKey
):
def
__init__
(
self
,
epsilon
:
float
,
key
:
FusedRMSQuantKey
,
has_col_major_scales
:
bool
=
False
,
is_e8m0
:
bool
=
False
,
):
self
.
epsilon
=
epsilon
self
.
quant_dtype
=
key
.
quant
.
dtype
config
=
get_current_vllm_config
()
self
.
model_dtype
=
config
.
model_config
.
dtype
if
config
.
model_config
else
None
# groupwise FP8 linear uses col major scales if deepgemm and cutlass
using_deepgemm
=
should_use_deepgemm_for_fp8_linear_for_nk
(
self
.
model_dtype
,
config
.
model_config
.
hf_config
.
intermediate_size
,
config
.
model_config
.
hf_config
.
hidden_size
,
)
use_col_major_scales
=
using_deepgemm
or
cutlass_block_fp8_supported
()
use_e8m0
=
is_deep_gemm_e8m0_used
()
if
using_deepgemm
else
False
assert
key
in
FUSED_OPS
,
f
"unsupported fused rmsnorm+quant op for
{
key
}
"
self
.
FUSED_OP
=
FUSED_OPS
[
key
]
...
...
@@ -142,7 +136,7 @@ class RMSNormQuantPattern:
else
MatcherFusedAddRMSNorm
(
epsilon
)
)
self
.
quant_matcher
=
MatcherQuantFP8
(
key
.
quant
,
use
_col_major_scales
=
use
_col_major_scales
,
use
_e8m0
=
use
_e8m0
key
.
quant
,
has
_col_major_scales
=
has
_col_major_scales
,
is
_e8m0
=
is
_e8m0
)
...
...
@@ -260,6 +254,8 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
quant_dtype
:
torch
.
dtype
,
group_shape
:
GroupShape
,
symmetric
=
True
,
has_col_major_scales
:
bool
=
False
,
is_e8m0
:
bool
=
False
,
):
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
key
=
FusedRMSQuantKey
(
...
...
@@ -267,7 +263,11 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
quant
=
QuantKey
(
dtype
=
quant_dtype
,
scale
=
scale
,
symmetric
=
symmetric
),
)
self
.
group_shape
=
group_shape
super
().
__init__
(
epsilon
,
key
)
self
.
has_col_major_scales
=
has_col_major_scales
self
.
is_e8m0
=
is_e8m0
super
().
__init__
(
epsilon
,
key
,
has_col_major_scales
=
has_col_major_scales
,
is_e8m0
=
is_e8m0
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
):
...
...
@@ -283,9 +283,7 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
input
=
input
.
to
(
dtype
=
self
.
model_dtype
)
result
=
torch
.
empty_like
(
input
,
dtype
=
self
.
quant_dtype
)
scale
=
self
.
quant_matcher
.
make_scale
(
input
,
transposed
=
self
.
quant_matcher
.
use_col_major_scales
)
scale
=
self
.
quant_matcher
.
make_scale
(
input
,
self
.
has_col_major_scales
)
at
=
auto_functionalized
(
self
.
FUSED_OP
,
result
=
result
,
...
...
@@ -296,7 +294,7 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
scale_ub
=
None
,
residual
=
residual
,
group_size
=
self
.
group_shape
[
1
],
is_scale_transposed
=
self
.
quant_matcher
.
use
_col_major_scales
,
is_scale_transposed
=
self
.
has
_col_major_scales
,
)
# result, residual, scale
...
...
@@ -318,6 +316,8 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
quant_dtype
:
torch
.
dtype
,
group_shape
:
GroupShape
,
symmetric
=
True
,
has_col_major_scales
:
bool
=
False
,
is_e8m0
:
bool
=
False
,
):
scale
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
key
=
FusedRMSQuantKey
(
...
...
@@ -325,7 +325,9 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
quant
=
QuantKey
(
dtype
=
quant_dtype
,
scale
=
scale
,
symmetric
=
symmetric
),
)
self
.
group_shape
=
group_shape
super
().
__init__
(
epsilon
,
key
)
super
().
__init__
(
epsilon
,
key
,
has_col_major_scales
=
has_col_major_scales
,
is_e8m0
=
is_e8m0
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
):
...
...
@@ -340,7 +342,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
result
=
torch
.
empty_like
(
input
,
dtype
=
self
.
quant_dtype
)
scale
=
self
.
quant_matcher
.
make_scale
(
input
,
transposed
=
self
.
quant_matcher
.
use
_col_major_scales
input
,
transposed
=
self
.
quant_matcher
.
has
_col_major_scales
)
at
=
auto_functionalized
(
self
.
FUSED_OP
,
...
...
@@ -352,7 +354,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
scale_ub
=
None
,
residual
=
None
,
group_size
=
self
.
group_shape
[
1
],
is_scale_transposed
=
self
.
quant_matcher
.
use
_col_major_scales
,
is_scale_transposed
=
self
.
quant_matcher
.
has
_col_major_scales
,
)
# result, scale
...
...
@@ -489,27 +491,6 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
# Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops
for
epsilon
in
[
1e-5
,
1e-6
]:
# Fuse fused_add_rms_norm + fp8 group quant
# Only register group quant patterns on CUDA where the C++ op exists
if
current_platform
.
is_cuda
():
FusedAddRMSNormGroupQuantPattern
(
epsilon
,
FP8_DTYPE
,
group_shape
=
GroupShape
(
1
,
128
)
).
register
(
self
.
patterns
)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern
(
epsilon
,
FP8_DTYPE
,
group_shape
=
GroupShape
(
1
,
128
)
).
register
(
self
.
patterns
)
FusedAddRMSNormGroupQuantPattern
(
epsilon
,
FP8_DTYPE
,
group_shape
=
GroupShape
(
1
,
64
)
).
register
(
self
.
patterns
)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern
(
epsilon
,
FP8_DTYPE
,
group_shape
=
GroupShape
(
1
,
64
)
).
register
(
self
.
patterns
)
# Fuse fused_add_rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern
(
epsilon
,
FP8_DTYPE
).
register
(
self
.
patterns
...
...
@@ -526,6 +507,29 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern
(
epsilon
,
FP8_DTYPE
).
register
(
self
.
patterns
)
# Only register group quant patterns on CUDA where the C++ op exists
if
current_platform
.
is_cuda
():
for
group_shape
in
[
GroupShape
(
1
,
128
),
GroupShape
(
1
,
64
)]:
for
has_col_major_scales
in
[
True
,
False
]:
for
is_e8m0
in
[
True
,
False
]:
# Fuse fused_add_rms_norm + fp8 group quant
FusedAddRMSNormGroupQuantPattern
(
epsilon
,
FP8_DTYPE
,
group_shape
=
group_shape
,
has_col_major_scales
=
has_col_major_scales
,
is_e8m0
=
is_e8m0
,
).
register
(
self
.
patterns
)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern
(
epsilon
,
FP8_DTYPE
,
group_shape
=
group_shape
,
has_col_major_scales
=
has_col_major_scales
,
is_e8m0
=
is_e8m0
,
).
register
(
self
.
patterns
)
self
.
dump_patterns
(
config
,
self
.
patterns
)
@
VllmInductorPass
.
time_and_log
...
...
vllm/compilation/matcher_utils.py
View file @
994acec0
...
...
@@ -234,24 +234,30 @@ class MatcherQuantFP8(MatcherCustomOp):
self
,
quant_key
:
QuantKey
,
enabled
:
bool
|
None
=
None
,
use
_col_major_scales
:
bool
=
False
,
use
_e8m0
:
bool
=
False
,
has
_col_major_scales
:
bool
=
False
,
is
_e8m0
:
bool
=
False
,
):
if
enabled
is
None
:
enabled
=
QuantFP8
.
enabled
()
super
().
__init__
(
enabled
)
self
.
quant_key
=
quant_key
self
.
use_col_major_scales
=
use_col_major_scales
self
.
use_e8m0
=
use_e8m0
assert
quant_key
in
QUANT_OPS
,
f
"unsupported quantization scheme
{
quant_key
}
"
self
.
QUANT_OP
=
QUANT_OPS
[
quant_key
]
self
.
has_col_major_scales
=
has_col_major_scales
self
.
is_e8m0
=
is_e8m0
assert
quant_key
.
dtype
==
current_platform
.
fp8_dtype
(),
(
"Only QuantFP8 supported by"
)
assert
quant_key
.
scale2
is
None
self
.
quant_fp8
=
QuantFP8
(
quant_key
.
scale
.
static
,
quant_key
.
scale
.
group_shape
)
self
.
quant_fp8
=
QuantFP8
(
quant_key
.
scale
.
static
,
quant_key
.
scale
.
group_shape
,
column_major_scales
=
has_col_major_scales
,
use_ue8m0
=
is_e8m0
,
)
def
forward_custom
(
self
,
...
...
@@ -264,7 +270,7 @@ class MatcherQuantFP8(MatcherCustomOp):
if
self
.
quant_key
.
scale
.
group_shape
.
is_per_group
():
assert
scale
is
None
scale
=
self
.
make_scale
(
input
,
transposed
=
self
.
use
_col_major_scales
)
scale
=
self
.
make_scale
(
input
,
transposed
=
self
.
has
_col_major_scales
)
finfo
=
torch
.
finfo
(
self
.
quant_key
.
dtype
)
fp8_min
=
finfo
.
min
...
...
@@ -279,7 +285,7 @@ class MatcherQuantFP8(MatcherCustomOp):
eps
=
1e-10
,
fp8_min
=
fp8_min
,
fp8_max
=
fp8_max
,
scale_ue8m0
=
self
.
use
_e8m0
,
scale_ue8m0
=
self
.
is
_e8m0
,
)
return
result
,
scale
...
...
vllm/utils/deep_gemm.py
View file @
994acec0
...
...
@@ -381,22 +381,6 @@ def should_use_deepgemm_for_fp8_linear(
)
def
should_use_deepgemm_for_fp8_linear_for_nk
(
output_dtype
:
torch
.
dtype
,
shape0
:
int
,
shape1
:
int
,
supports_deep_gemm
:
bool
|
None
=
None
,
):
if
supports_deep_gemm
is
None
:
supports_deep_gemm
=
is_deep_gemm_supported
()
return
(
supports_deep_gemm
and
output_dtype
==
torch
.
bfloat16
and
shape0
%
128
==
0
and
shape1
%
128
==
0
)
__all__
=
[
"calc_diff"
,
"DeepGemmQuantScaleFMT"
,
...
...
@@ -411,7 +395,6 @@ __all__ = [
"is_deep_gemm_supported"
,
"get_num_sms"
,
"should_use_deepgemm_for_fp8_linear"
,
"should_use_deepgemm_for_fp8_linear_for_nk"
,
"get_col_major_tma_aligned_tensor"
,
"get_mk_alignment_for_contiguous_layout"
,
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment