Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2196bac1
Unverified
Commit
2196bac1
authored
Apr 23, 2026
by
BadrBasowid
Committed by
GitHub
Apr 23, 2026
Browse files
[Compilation] Refactor SiluMul activation+quant Fusion Pass (#39684)
Signed-off-by:
BadrBasowid
<
badr.basowid@gmail.com
>
parent
4b7869d6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
81 additions
and
103 deletions
+81
-103
tests/compile/fusions_e2e/conftest.py
tests/compile/fusions_e2e/conftest.py
+7
-1
tests/compile/passes/test_silu_mul_quant_fusion.py
tests/compile/passes/test_silu_mul_quant_fusion.py
+1
-1
vllm/compilation/passes/fusion/act_quant_fusion.py
vllm/compilation/passes/fusion/act_quant_fusion.py
+59
-72
vllm/compilation/passes/fusion/rocm_aiter_fusion.py
vllm/compilation/passes/fusion/rocm_aiter_fusion.py
+14
-29
No files found.
tests/compile/fusions_e2e/conftest.py
View file @
2196bac1
...
...
@@ -189,7 +189,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
# TODO: Remove log counting in unit tests
# once all matchers implement VllmFusionPatternMatcherPass
n_expected
=
tp_size
*
num_ranges_activated
if
match_name
!=
"attn_quant_fusion"
:
if
match_name
not
in
(
"attn_quant_fusion"
,
"act_quant_fusion"
)
:
assert
len
(
log_matches
)
==
n_expected
,
(
f
"Could not find
{
n_expected
}
{
match_name
}
"
f
"(found
{
len
(
log_matches
)
}
) in:
\n
{
log_holder
.
text
}
"
...
...
@@ -250,6 +250,12 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
f
"entries (SP took precedence), found:
{
log_matches
}
"
)
elif
match_name
==
"act_quant_fusion"
:
actual_match
=
match_table
.
get
(
"activation_quant_fusion_pass"
,
0
)
assert
actual_match
==
expected_matches
*
n_expected
,
(
f
"Could not find
{
expected_matches
*
n_expected
}
"
f
"
{
match_name
}
(found
{
actual_match
}
)."
)
elif
match_name
==
"attn_quant_fusion"
:
actual_match
=
match_table
.
get
(
"attn_quant_fusion"
,
0
...
...
tests/compile/passes/test_silu_mul_quant_fusion.py
View file @
2196bac1
...
...
@@ -168,7 +168,7 @@ class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
def
forward
(
self
,
x
):
y
=
self
.
silu_and_mul
(
x
)
x2
=
self
.
w8a8_block_fp8_linear
(
y
,
self
.
w
,
self
.
wscale
)
x2
=
self
.
w8a8_block_fp8_linear
(
y
)
return
x2
def
ops_in_model_before
(
self
):
...
...
vllm/compilation/passes/fusion/act_quant_fusion.py
View file @
2196bac1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
import
itertools
from
typing
import
Any
import
torch
from
torch._higher_order_ops.auto_functionalize
import
auto_functionalized
from
torch._inductor.pattern_matcher
import
(
PatternMatcherPass
,
fwd_only
,
register_replacement
,
)
from
torch._ops
import
OpOverload
from
vllm.config
import
VllmConfig
...
...
@@ -24,8 +19,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from
vllm.platforms
import
current_platform
from
..inductor_pass
import
enable_fake_mode
from
..vllm_inductor_pass
import
VllmInductorPass
,
VllmPatternMatcherPass
from
..vllm_inductor_pass
import
VllmFusionPatternMatcherPass
,
VllmPatternReplacement
from
.matcher_utils
import
MatcherQuantFP8
,
MatcherSiluAndMul
from
.rms_quant_fusion
import
QUANT_OPS
,
empty_bf16
,
empty_fp32
,
empty_i32
...
...
@@ -50,9 +44,9 @@ if current_platform.is_cuda_alike():
FUSED_OPS
[
kFp8Dynamic64Sym
]
=
torch
.
ops
.
_C
.
silu_and_mul_per_block_quant
.
default
class
ActivationQuantPattern
(
ABC
):
class
ActivationQuantPattern
(
VllmPatternReplacement
):
"""
The b
ase class for Activation+Quant fusions.
B
ase class for Activation+Quant fusions.
Should not be used directly.
"""
...
...
@@ -79,10 +73,6 @@ class ActivationQuantPattern(ABC):
kwargs
=
{
"dtype"
:
self
.
quant_dtype
,
"device"
:
"cuda"
,
**
kwargs
}
return
torch
.
empty
(
*
args
,
**
kwargs
)
@
abstractmethod
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
raise
NotImplementedError
class
SiluMulFp8StaticQuantPattern
(
ActivationQuantPattern
):
"""
...
...
@@ -100,8 +90,9 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
scale
,
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
@
property
def
pattern
(
self
):
def
_pattern
(
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
...
...
@@ -109,7 +100,11 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
result_quant
=
self
.
quant_matcher
(
result_silu_mul
,
scale
)
return
result_quant
[
0
]
def
replacement
(
return
_pattern
@
property
def
replacement
(
self
):
def
_replacement
(
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
...
...
@@ -123,10 +118,7 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
)
return
at
[
1
]
inps
=
self
.
get_inputs
()
pattern
(
*
inps
)
register_replacement
(
pattern
,
replacement
,
inps
,
fwd_only
,
pm_pass
)
return
_replacement
class
SiluMulNvfp4QuantPattern
(
ActivationQuantPattern
):
...
...
@@ -144,8 +136,9 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
scale
=
empty_fp32
(
1
,
1
)
return
[
result
,
output_scale
,
input_
,
scale
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
@
property
def
pattern
(
self
):
def
_pattern
(
result
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
...
...
@@ -162,7 +155,11 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
)
return
at
[
1
],
at
[
2
]
def
replacement
(
return
_pattern
@
property
def
replacement
(
self
):
def
_replacement
(
result
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
...
...
@@ -177,7 +174,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
)
return
at
[
1
],
at
[
2
]
re
gister_replacement
(
pattern
,
replacement
,
self
.
get_inputs
(),
fwd_only
,
pm_pass
)
re
turn
_replacement
class
SiluMulBlockQuantPattern
(
ActivationQuantPattern
):
...
...
@@ -210,10 +207,9 @@ class SiluMulBlockQuantPattern(ActivationQuantPattern):
scale
=
self
.
quant_matcher
.
empty_f32
(
1
,
1
)
return
self
.
silu_and_mul_matcher
.
inputs
()
+
[
scale
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
is_scale_transposed
=
self
.
is_scale_transposed
def
pattern
(
@
property
def
pattern
(
self
):
def
_pattern
(
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
@@ -235,12 +231,16 @@ class SiluMulBlockQuantPattern(ActivationQuantPattern):
fp8_min
=
finfo
.
min
,
fp8_max
=
finfo
.
max
,
scale_ue8m0
=
self
.
is_e8m0
,
dummy_is_scale_transposed
=
is_scale_transposed
,
dummy_is_scale_transposed
=
self
.
is_scale_transposed
,
dummy_is_tma_aligned
=
self
.
is_tma_aligned
,
)
return
result
,
scale
def
replacement
(
return
_pattern
@
property
def
replacement
(
self
):
def
_replacement
(
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
@@ -249,7 +249,7 @@ class SiluMulBlockQuantPattern(ActivationQuantPattern):
result
=
torch
.
empty
(
output_shape
,
device
=
input
.
device
,
dtype
=
self
.
quant_dtype
)
if
is_scale_transposed
:
if
self
.
is_scale_transposed
:
scale
=
torch
.
empty
(
(
d
//
self
.
group_size
,
input
.
shape
[
0
]),
device
=
input
.
device
,
...
...
@@ -268,15 +268,14 @@ class SiluMulBlockQuantPattern(ActivationQuantPattern):
scales
=
scale
,
group_size
=
self
.
group_size
,
scale_ub
=
None
,
is_scale_transposed
=
is_scale_transposed
,
is_scale_transposed
=
self
.
is_scale_transposed
,
)
return
at
[
1
],
at
[
2
]
inps
=
self
.
get_inputs
()
register_replacement
(
pattern
,
replacement
,
inps
,
fwd_only
,
pm_pass
)
return
_replacement
class
ActivationQuantFusionPass
(
VllmPatternMatcherPass
):
class
ActivationQuantFusionPass
(
Vllm
Fusion
PatternMatcherPass
):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
...
...
@@ -286,45 +285,33 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass):
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
)
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
pass_name
=
"activation_quant_fusion_pass"
)
super
().
__init__
(
config
,
"activation_quant_fusion_pass"
)
pattern_silu_mul_fp8
=
SiluMulFp8StaticQuantPattern
()
pattern_silu_mul_fp8
.
register
(
self
.
patterns
)
self
.
register
(
SiluMulFp8StaticQuantPattern
())
if
silu_and_mul_nvfp4_quant_supported
:
pattern_silu_mul_nvfp4
=
SiluMulNvfp4QuantPattern
()
pattern_silu_mul_nvfp4
.
register
(
self
.
patterns
)
self
.
register
(
SiluMulNvfp4QuantPattern
())
if
current_platform
.
is_cuda
():
for
quant_key
in
[
kFp8Dynamic128Sym
,
kFp8Dynamic64Sym
]:
for
is_scale_transposed
in
[
False
,
True
]:
for
is_e8m0
in
[
True
,
False
]:
for
is_tma_aligned
in
[
False
,
True
]:
SiluMulBlockQuantPattern
(
quant_key
,
is_scale_transposed
=
is_scale_transposed
,
is_e8m0
=
is_e8m0
,
is_tma_aligned
=
is_tma_aligned
,
).
register
(
self
.
patterns
)
self
.
dump_patterns
(
config
,
self
.
patterns
)
@
VllmInductorPass
.
time_and_log
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
)
->
None
:
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Replaced %s patterns"
,
self
.
matched_count
)
def
uuid
(
self
)
->
str
:
return
VllmInductorPass
.
hash_source
(
self
,
ActivationQuantPattern
,
SiluMulFp8StaticQuantPattern
,
SiluMulNvfp4QuantPattern
,
SiluMulBlockQuantPattern
,
)
for
(
quant_key
,
is_scale_transposed
,
is_e8m0
,
is_tma_aligned
,
)
in
itertools
.
product
(
[
kFp8Dynamic128Sym
,
kFp8Dynamic64Sym
],
[
False
,
True
],
[
True
,
False
],
[
False
,
True
],
):
self
.
register
(
SiluMulBlockQuantPattern
(
quant_key
,
is_scale_transposed
=
is_scale_transposed
,
is_e8m0
=
is_e8m0
,
is_tma_aligned
=
is_tma_aligned
,
)
)
self
.
dump_patterns
(
config
,
self
.
pm_pass
)
vllm/compilation/passes/fusion/rocm_aiter_fusion.py
View file @
2196bac1
...
...
@@ -28,7 +28,6 @@ from ..vllm_inductor_pass import (
VllmPatternMatcherPass
,
VllmPatternReplacement
,
)
from
.act_quant_fusion
import
ActivationQuantPattern
from
.matcher_utils
import
(
MatcherFusedAddRMSNorm
,
MatcherQuantFP8
,
...
...
@@ -345,7 +344,7 @@ class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass):
return
self
.
hash_source
(
self
,
*
fusion_patterns
)
class
AiterSiluMulFp8GroupQuantPattern
(
ActivationQuantPattern
):
class
AiterSiluMulFp8GroupQuantPattern
(
VllmPatternReplacement
):
"""
This pattern fuses aiter silu_and_mul & group fp8 quant custom
ops into an aiter silu_and_mul_group_fp8_quant op.
...
...
@@ -364,26 +363,29 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
self
.
silu_and_mul_matcher
.
inputs
()[
0
],
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
@
property
def
pattern
(
self
):
def
_pattern
(
input
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
at1
=
self
.
silu_and_mul_matcher
(
input
)
at2
=
self
.
quant_matcher
(
at1
)
return
at2
[
0
],
at2
[
1
]
def
replacement
(
return
_pattern
@
property
def
replacement
(
self
):
def
_replacement
(
input
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
at
=
self
.
FUSED_SILU_MUL_QUANT_OP
(
x
=
input
,
group_size
=
128
)
return
at
[
0
],
at
[
1
]
pm
.
register_replacement
(
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
)
return
_replacement
class
RocmAiterSiluMulFp8GroupQuantFusionPass
(
VllmPatternMatcherPass
):
class
RocmAiterSiluMulFp8GroupQuantFusionPass
(
Vllm
Fusion
PatternMatcherPass
):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
...
...
@@ -393,29 +395,12 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
)
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
pass_name
=
"rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
)
AiterSiluMulFp8GroupQuantPattern
().
register
(
self
.
patterns
)
self
.
dump_patterns
(
config
,
self
.
patterns
)
super
().
__init__
(
config
,
"rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
)
@
VllmInductorPass
.
time_and_log
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
)
->
None
:
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Replaced %s patterns"
,
self
.
matched_count
)
self
.
register
(
AiterSiluMulFp8GroupQuantPattern
())
def
uuid
(
self
)
->
str
:
fusion_patterns
=
[
ActivationQuantPattern
,
AiterSiluMulFp8GroupQuantPattern
,
]
return
VllmInductorPass
.
hash_source
(
self
,
*
fusion_patterns
)
self
.
dump_patterns
(
config
,
self
.
pm_pass
)
class
AddAiterRMSNormPadPattern
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment