Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2196bac1
Unverified
Commit
2196bac1
authored
Apr 23, 2026
by
BadrBasowid
Committed by
GitHub
Apr 23, 2026
Browse files
[Compilation] Refactor SiluMul activation+quant Fusion Pass (#39684)
Signed-off-by:
BadrBasowid
<
badr.basowid@gmail.com
>
parent
4b7869d6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
81 additions
and
103 deletions
+81
-103
tests/compile/fusions_e2e/conftest.py
tests/compile/fusions_e2e/conftest.py
+7
-1
tests/compile/passes/test_silu_mul_quant_fusion.py
tests/compile/passes/test_silu_mul_quant_fusion.py
+1
-1
vllm/compilation/passes/fusion/act_quant_fusion.py
vllm/compilation/passes/fusion/act_quant_fusion.py
+59
-72
vllm/compilation/passes/fusion/rocm_aiter_fusion.py
vllm/compilation/passes/fusion/rocm_aiter_fusion.py
+14
-29
No files found.
tests/compile/fusions_e2e/conftest.py
View file @
2196bac1
...
@@ -189,7 +189,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
...
@@ -189,7 +189,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
# TODO: Remove log counting in unit tests
# TODO: Remove log counting in unit tests
# once all matchers implement VllmFusionPatternMatcherPass
# once all matchers implement VllmFusionPatternMatcherPass
n_expected
=
tp_size
*
num_ranges_activated
n_expected
=
tp_size
*
num_ranges_activated
if
match_name
!=
"attn_quant_fusion"
:
if
match_name
not
in
(
"attn_quant_fusion"
,
"act_quant_fusion"
)
:
assert
len
(
log_matches
)
==
n_expected
,
(
assert
len
(
log_matches
)
==
n_expected
,
(
f
"Could not find
{
n_expected
}
{
match_name
}
"
f
"Could not find
{
n_expected
}
{
match_name
}
"
f
"(found
{
len
(
log_matches
)
}
) in:
\n
{
log_holder
.
text
}
"
f
"(found
{
len
(
log_matches
)
}
) in:
\n
{
log_holder
.
text
}
"
...
@@ -250,6 +250,12 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
...
@@ -250,6 +250,12 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
f
"entries (SP took precedence), found:
{
log_matches
}
"
f
"entries (SP took precedence), found:
{
log_matches
}
"
)
)
elif
match_name
==
"act_quant_fusion"
:
actual_match
=
match_table
.
get
(
"activation_quant_fusion_pass"
,
0
)
assert
actual_match
==
expected_matches
*
n_expected
,
(
f
"Could not find
{
expected_matches
*
n_expected
}
"
f
"
{
match_name
}
(found
{
actual_match
}
)."
)
elif
match_name
==
"attn_quant_fusion"
:
elif
match_name
==
"attn_quant_fusion"
:
actual_match
=
match_table
.
get
(
actual_match
=
match_table
.
get
(
"attn_quant_fusion"
,
0
"attn_quant_fusion"
,
0
...
...
tests/compile/passes/test_silu_mul_quant_fusion.py
View file @
2196bac1
...
@@ -168,7 +168,7 @@ class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
...
@@ -168,7 +168,7 @@ class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
y
=
self
.
silu_and_mul
(
x
)
y
=
self
.
silu_and_mul
(
x
)
x2
=
self
.
w8a8_block_fp8_linear
(
y
,
self
.
w
,
self
.
wscale
)
x2
=
self
.
w8a8_block_fp8_linear
(
y
)
return
x2
return
x2
def
ops_in_model_before
(
self
):
def
ops_in_model_before
(
self
):
...
...
vllm/compilation/passes/fusion/act_quant_fusion.py
View file @
2196bac1
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
import
itertools
from
typing
import
Any
from
typing
import
Any
import
torch
import
torch
from
torch._higher_order_ops.auto_functionalize
import
auto_functionalized
from
torch._higher_order_ops.auto_functionalize
import
auto_functionalized
from
torch._inductor.pattern_matcher
import
(
PatternMatcherPass
,
fwd_only
,
register_replacement
,
)
from
torch._ops
import
OpOverload
from
torch._ops
import
OpOverload
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
...
@@ -24,8 +19,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
...
@@ -24,8 +19,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
..inductor_pass
import
enable_fake_mode
from
..vllm_inductor_pass
import
VllmFusionPatternMatcherPass
,
VllmPatternReplacement
from
..vllm_inductor_pass
import
VllmInductorPass
,
VllmPatternMatcherPass
from
.matcher_utils
import
MatcherQuantFP8
,
MatcherSiluAndMul
from
.matcher_utils
import
MatcherQuantFP8
,
MatcherSiluAndMul
from
.rms_quant_fusion
import
QUANT_OPS
,
empty_bf16
,
empty_fp32
,
empty_i32
from
.rms_quant_fusion
import
QUANT_OPS
,
empty_bf16
,
empty_fp32
,
empty_i32
...
@@ -50,9 +44,9 @@ if current_platform.is_cuda_alike():
...
@@ -50,9 +44,9 @@ if current_platform.is_cuda_alike():
FUSED_OPS
[
kFp8Dynamic64Sym
]
=
torch
.
ops
.
_C
.
silu_and_mul_per_block_quant
.
default
FUSED_OPS
[
kFp8Dynamic64Sym
]
=
torch
.
ops
.
_C
.
silu_and_mul_per_block_quant
.
default
class
ActivationQuantPattern
(
ABC
):
class
ActivationQuantPattern
(
VllmPatternReplacement
):
"""
"""
The b
ase class for Activation+Quant fusions.
B
ase class for Activation+Quant fusions.
Should not be used directly.
Should not be used directly.
"""
"""
...
@@ -79,10 +73,6 @@ class ActivationQuantPattern(ABC):
...
@@ -79,10 +73,6 @@ class ActivationQuantPattern(ABC):
kwargs
=
{
"dtype"
:
self
.
quant_dtype
,
"device"
:
"cuda"
,
**
kwargs
}
kwargs
=
{
"dtype"
:
self
.
quant_dtype
,
"device"
:
"cuda"
,
**
kwargs
}
return
torch
.
empty
(
*
args
,
**
kwargs
)
return
torch
.
empty
(
*
args
,
**
kwargs
)
@
abstractmethod
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
raise
NotImplementedError
class
SiluMulFp8StaticQuantPattern
(
ActivationQuantPattern
):
class
SiluMulFp8StaticQuantPattern
(
ActivationQuantPattern
):
"""
"""
...
@@ -100,8 +90,9 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
...
@@ -100,8 +90,9 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
scale
,
scale
,
]
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
@
property
def
pattern
(
def
pattern
(
self
):
def
_pattern
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -109,7 +100,11 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
...
@@ -109,7 +100,11 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
result_quant
=
self
.
quant_matcher
(
result_silu_mul
,
scale
)
result_quant
=
self
.
quant_matcher
(
result_silu_mul
,
scale
)
return
result_quant
[
0
]
return
result_quant
[
0
]
def
replacement
(
return
_pattern
@
property
def
replacement
(
self
):
def
_replacement
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -123,10 +118,7 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
...
@@ -123,10 +118,7 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
)
)
return
at
[
1
]
return
at
[
1
]
inps
=
self
.
get_inputs
()
return
_replacement
pattern
(
*
inps
)
register_replacement
(
pattern
,
replacement
,
inps
,
fwd_only
,
pm_pass
)
class
SiluMulNvfp4QuantPattern
(
ActivationQuantPattern
):
class
SiluMulNvfp4QuantPattern
(
ActivationQuantPattern
):
...
@@ -144,8 +136,9 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
...
@@ -144,8 +136,9 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
scale
=
empty_fp32
(
1
,
1
)
scale
=
empty_fp32
(
1
,
1
)
return
[
result
,
output_scale
,
input_
,
scale
]
return
[
result
,
output_scale
,
input_
,
scale
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
@
property
def
pattern
(
def
pattern
(
self
):
def
_pattern
(
result
:
torch
.
Tensor
,
result
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
...
@@ -162,7 +155,11 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
...
@@ -162,7 +155,11 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
)
)
return
at
[
1
],
at
[
2
]
return
at
[
1
],
at
[
2
]
def
replacement
(
return
_pattern
@
property
def
replacement
(
self
):
def
_replacement
(
result
:
torch
.
Tensor
,
result
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
...
@@ -177,7 +174,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
...
@@ -177,7 +174,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
)
)
return
at
[
1
],
at
[
2
]
return
at
[
1
],
at
[
2
]
re
gister_replacement
(
pattern
,
replacement
,
self
.
get_inputs
(),
fwd_only
,
pm_pass
)
re
turn
_replacement
class
SiluMulBlockQuantPattern
(
ActivationQuantPattern
):
class
SiluMulBlockQuantPattern
(
ActivationQuantPattern
):
...
@@ -210,10 +207,9 @@ class SiluMulBlockQuantPattern(ActivationQuantPattern):
...
@@ -210,10 +207,9 @@ class SiluMulBlockQuantPattern(ActivationQuantPattern):
scale
=
self
.
quant_matcher
.
empty_f32
(
1
,
1
)
scale
=
self
.
quant_matcher
.
empty_f32
(
1
,
1
)
return
self
.
silu_and_mul_matcher
.
inputs
()
+
[
scale
]
return
self
.
silu_and_mul_matcher
.
inputs
()
+
[
scale
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
@
property
is_scale_transposed
=
self
.
is_scale_transposed
def
pattern
(
self
):
def
_pattern
(
def
pattern
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
@@ -235,12 +231,16 @@ class SiluMulBlockQuantPattern(ActivationQuantPattern):
...
@@ -235,12 +231,16 @@ class SiluMulBlockQuantPattern(ActivationQuantPattern):
fp8_min
=
finfo
.
min
,
fp8_min
=
finfo
.
min
,
fp8_max
=
finfo
.
max
,
fp8_max
=
finfo
.
max
,
scale_ue8m0
=
self
.
is_e8m0
,
scale_ue8m0
=
self
.
is_e8m0
,
dummy_is_scale_transposed
=
is_scale_transposed
,
dummy_is_scale_transposed
=
self
.
is_scale_transposed
,
dummy_is_tma_aligned
=
self
.
is_tma_aligned
,
dummy_is_tma_aligned
=
self
.
is_tma_aligned
,
)
)
return
result
,
scale
return
result
,
scale
def
replacement
(
return
_pattern
@
property
def
replacement
(
self
):
def
_replacement
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
@@ -249,7 +249,7 @@ class SiluMulBlockQuantPattern(ActivationQuantPattern):
...
@@ -249,7 +249,7 @@ class SiluMulBlockQuantPattern(ActivationQuantPattern):
result
=
torch
.
empty
(
result
=
torch
.
empty
(
output_shape
,
device
=
input
.
device
,
dtype
=
self
.
quant_dtype
output_shape
,
device
=
input
.
device
,
dtype
=
self
.
quant_dtype
)
)
if
is_scale_transposed
:
if
self
.
is_scale_transposed
:
scale
=
torch
.
empty
(
scale
=
torch
.
empty
(
(
d
//
self
.
group_size
,
input
.
shape
[
0
]),
(
d
//
self
.
group_size
,
input
.
shape
[
0
]),
device
=
input
.
device
,
device
=
input
.
device
,
...
@@ -268,15 +268,14 @@ class SiluMulBlockQuantPattern(ActivationQuantPattern):
...
@@ -268,15 +268,14 @@ class SiluMulBlockQuantPattern(ActivationQuantPattern):
scales
=
scale
,
scales
=
scale
,
group_size
=
self
.
group_size
,
group_size
=
self
.
group_size
,
scale_ub
=
None
,
scale_ub
=
None
,
is_scale_transposed
=
is_scale_transposed
,
is_scale_transposed
=
self
.
is_scale_transposed
,
)
)
return
at
[
1
],
at
[
2
]
return
at
[
1
],
at
[
2
]
inps
=
self
.
get_inputs
()
return
_replacement
register_replacement
(
pattern
,
replacement
,
inps
,
fwd_only
,
pm_pass
)
class
ActivationQuantFusionPass
(
VllmPatternMatcherPass
):
class
ActivationQuantFusionPass
(
Vllm
Fusion
PatternMatcherPass
):
"""
"""
This pass fuses a pre-defined set of custom ops into fused ops.
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
It uses the torch pattern matcher to find the patterns and replace them.
...
@@ -286,45 +285,33 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass):
...
@@ -286,45 +285,33 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass):
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
"""
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
)
super
().
__init__
(
config
,
"activation_quant_fusion_pass"
)
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
pass_name
=
"activation_quant_fusion_pass"
)
pattern_silu_mul_fp8
=
SiluMulFp8StaticQuantPattern
()
self
.
register
(
SiluMulFp8StaticQuantPattern
())
pattern_silu_mul_fp8
.
register
(
self
.
patterns
)
if
silu_and_mul_nvfp4_quant_supported
:
if
silu_and_mul_nvfp4_quant_supported
:
pattern_silu_mul_nvfp4
=
SiluMulNvfp4QuantPattern
()
self
.
register
(
SiluMulNvfp4QuantPattern
())
pattern_silu_mul_nvfp4
.
register
(
self
.
patterns
)
if
current_platform
.
is_cuda
():
if
current_platform
.
is_cuda
():
for
quant_key
in
[
kFp8Dynamic128Sym
,
kFp8Dynamic64Sym
]:
for
(
for
is_scale_transposed
in
[
False
,
True
]:
quant_key
,
for
is_e8m0
in
[
True
,
False
]:
is_scale_transposed
,
for
is_tma_aligned
in
[
False
,
True
]:
is_e8m0
,
SiluMulBlockQuantPattern
(
is_tma_aligned
,
quant_key
,
)
in
itertools
.
product
(
is_scale_transposed
=
is_scale_transposed
,
[
kFp8Dynamic128Sym
,
kFp8Dynamic64Sym
],
is_e8m0
=
is_e8m0
,
[
False
,
True
],
is_tma_aligned
=
is_tma_aligned
,
[
True
,
False
],
).
register
(
self
.
patterns
)
[
False
,
True
],
):
self
.
dump_patterns
(
config
,
self
.
patterns
)
self
.
register
(
SiluMulBlockQuantPattern
(
@
VllmInductorPass
.
time_and_log
quant_key
,
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
)
->
None
:
is_scale_transposed
=
is_scale_transposed
,
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
is_e8m0
=
is_e8m0
,
logger
.
debug
(
"Replaced %s patterns"
,
self
.
matched_count
)
is_tma_aligned
=
is_tma_aligned
,
)
def
uuid
(
self
)
->
str
:
)
return
VllmInductorPass
.
hash_source
(
self
,
self
.
dump_patterns
(
config
,
self
.
pm_pass
)
ActivationQuantPattern
,
SiluMulFp8StaticQuantPattern
,
SiluMulNvfp4QuantPattern
,
SiluMulBlockQuantPattern
,
)
vllm/compilation/passes/fusion/rocm_aiter_fusion.py
View file @
2196bac1
...
@@ -28,7 +28,6 @@ from ..vllm_inductor_pass import (
...
@@ -28,7 +28,6 @@ from ..vllm_inductor_pass import (
VllmPatternMatcherPass
,
VllmPatternMatcherPass
,
VllmPatternReplacement
,
VllmPatternReplacement
,
)
)
from
.act_quant_fusion
import
ActivationQuantPattern
from
.matcher_utils
import
(
from
.matcher_utils
import
(
MatcherFusedAddRMSNorm
,
MatcherFusedAddRMSNorm
,
MatcherQuantFP8
,
MatcherQuantFP8
,
...
@@ -345,7 +344,7 @@ class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass):
...
@@ -345,7 +344,7 @@ class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass):
return
self
.
hash_source
(
self
,
*
fusion_patterns
)
return
self
.
hash_source
(
self
,
*
fusion_patterns
)
class
AiterSiluMulFp8GroupQuantPattern
(
ActivationQuantPattern
):
class
AiterSiluMulFp8GroupQuantPattern
(
VllmPatternReplacement
):
"""
"""
This pattern fuses aiter silu_and_mul & group fp8 quant custom
This pattern fuses aiter silu_and_mul & group fp8 quant custom
ops into an aiter silu_and_mul_group_fp8_quant op.
ops into an aiter silu_and_mul_group_fp8_quant op.
...
@@ -364,26 +363,29 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
...
@@ -364,26 +363,29 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
self
.
silu_and_mul_matcher
.
inputs
()[
0
],
self
.
silu_and_mul_matcher
.
inputs
()[
0
],
]
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
@
property
def
pattern
(
def
pattern
(
self
):
def
_pattern
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
at1
=
self
.
silu_and_mul_matcher
(
input
)
at1
=
self
.
silu_and_mul_matcher
(
input
)
at2
=
self
.
quant_matcher
(
at1
)
at2
=
self
.
quant_matcher
(
at1
)
return
at2
[
0
],
at2
[
1
]
return
at2
[
0
],
at2
[
1
]
def
replacement
(
return
_pattern
@
property
def
replacement
(
self
):
def
_replacement
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
at
=
self
.
FUSED_SILU_MUL_QUANT_OP
(
x
=
input
,
group_size
=
128
)
at
=
self
.
FUSED_SILU_MUL_QUANT_OP
(
x
=
input
,
group_size
=
128
)
return
at
[
0
],
at
[
1
]
return
at
[
0
],
at
[
1
]
pm
.
register_replacement
(
return
_replacement
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
)
class
RocmAiterSiluMulFp8GroupQuantFusionPass
(
VllmPatternMatcherPass
):
class
RocmAiterSiluMulFp8GroupQuantFusionPass
(
Vllm
Fusion
PatternMatcherPass
):
"""
"""
This pass fuses a pre-defined set of custom ops into fused ops.
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
It uses the torch pattern matcher to find the patterns and replace them.
...
@@ -393,29 +395,12 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
...
@@ -393,29 +395,12 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
"""
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
)
super
().
__init__
(
config
,
"rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
)
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
pass_name
=
"rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
)
AiterSiluMulFp8GroupQuantPattern
().
register
(
self
.
patterns
)
self
.
dump_patterns
(
config
,
self
.
patterns
)
@
VllmInductorPass
.
time_and_log
self
.
register
(
AiterSiluMulFp8GroupQuantPattern
())
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
)
->
None
:
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Replaced %s patterns"
,
self
.
matched_count
)
def
uuid
(
self
)
->
str
:
self
.
dump_patterns
(
config
,
self
.
pm_pass
)
fusion_patterns
=
[
ActivationQuantPattern
,
AiterSiluMulFp8GroupQuantPattern
,
]
return
VllmInductorPass
.
hash_source
(
self
,
*
fusion_patterns
)
class
AddAiterRMSNormPadPattern
:
class
AddAiterRMSNormPadPattern
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment