Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ab3e8004
Unverified
Commit
ab3e8004
authored
Oct 22, 2025
by
Jiangyun Zhu
Committed by
GitHub
Oct 22, 2025
Browse files
[torch.compile] Enable silu_mul_fp8_quant fusion without custom ops enabled (#27146)
Signed-off-by:
zjy0516
<
riverclouds.zhu@qq.com
>
parent
ceacedc1
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
128 additions
and
70 deletions
+128
-70
tests/compile/test_silu_mul_quant_fusion.py
tests/compile/test_silu_mul_quant_fusion.py
+75
-43
vllm/compilation/activation_quant_fusion.py
vllm/compilation/activation_quant_fusion.py
+21
-26
vllm/compilation/matcher_utils.py
vllm/compilation/matcher_utils.py
+30
-0
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+2
-1
No files found.
tests/compile/test_silu_mul_quant_fusion.py
View file @
ab3e8004
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
cast
import
itertools
import
pytest
import
torch
...
...
@@ -16,7 +16,13 @@ from vllm.compilation.activation_quant_fusion import (
from
vllm.compilation.fusion
import
QUANT_OPS
from
vllm.compilation.noop_elimination
import
NoOpEliminationPass
from
vllm.compilation.post_cleanup
import
PostCleanupPass
from
vllm.config
import
CompilationConfig
,
PassConfig
,
VllmConfig
from
vllm.config
import
(
CompilationConfig
,
CompilationMode
,
PassConfig
,
VllmConfig
,
set_current_vllm_config
,
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
...
...
@@ -25,7 +31,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
,
cutlass_fp8_supported
,
maybe_create_device_identity
,
)
from
vllm.platforms
import
current_platform
...
...
@@ -54,6 +60,8 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
act_quant_static
=
True
,
act_quant_group_shape
=
GroupShape
.
PER_TENSOR
,
)
self
.
enable_silu_mul_custom_op
=
self
.
silu_and_mul
.
enabled
()
self
.
enable_quant_fp8_custom_op
=
self
.
fp8_linear
.
quant_fp8
.
enabled
()
def
forward
(
self
,
x
):
y
=
self
.
silu_and_mul
(
x
)
...
...
@@ -61,7 +69,14 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
return
x2
def
ops_in_model_before
(
self
):
return
[
SILU_MUL_OP
,
QUANT_OPS
[
kFp8StaticTensorSym
]]
return
[
SILU_MUL_OP
if
self
.
enable_silu_mul_custom_op
else
torch
.
ops
.
aten
.
mul
,
(
QUANT_OPS
[
kFp8StaticTensorSym
]
if
self
.
enable_quant_fp8_custom_op
else
torch
.
ops
.
aten
.
reciprocal
),
]
def
ops_in_model_after
(
self
):
return
[
FUSED_OPS
[
kFp8StaticTensorSym
]]
...
...
@@ -77,6 +92,7 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
assert
silu_and_mul_nvfp4_quant_supported
self
.
silu_and_mul
=
SiluAndMul
()
self
.
enable_silu_mul_custom_op
=
self
.
silu_and_mul
.
enabled
()
# create nvfp4 weight
w
=
torch
.
rand
((
hidden_size
,
hidden_size
))
...
...
@@ -101,7 +117,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
return
out
def
ops_in_model_before
(
self
):
return
[
SILU_MUL_OP
,
QUANT_OPS
[
kNvfp4Quant
]]
return
[
SILU_MUL_OP
if
self
.
enable_silu_mul_custom_op
else
torch
.
ops
.
aten
.
mul
,
QUANT_OPS
[
kNvfp4Quant
],
]
def
ops_in_model_after
(
self
):
return
[
FUSED_OPS
[
kNvfp4Quant
]]
...
...
@@ -110,44 +129,57 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
32
,
64
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"enable_silu_mul_custom_op"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"model_class"
,
cast
(
list
[
type
],
[
TestSiluMulFp8QuantModel
,
TestSiluMulNvfp4QuantModel
]
if
is_nvfp4_supported
()
else
[
TestSiluMulFp8QuantModel
],
),
"model_class, enable_quant_fp8_custom_op, cuda_force_torch"
,
list
(
itertools
.
product
([
TestSiluMulFp8QuantModel
],
[
True
,
False
],
[
True
,
False
]))
+
[(
TestSiluMulNvfp4QuantModel
,
False
,
False
)],
)
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@
pytest
.
mark
.
parametrize
(
"cuda_force_torch"
,
[
True
,
False
]
if
cutlass_fp8_supported
()
else
[
True
]
)
@
pytest
.
mark
.
skipif
(
envs
.
VLLM_TARGET_DEVICE
not
in
[
"cuda"
,
"rocm"
],
reason
=
"Only test on CUDA and ROCm"
)
def
test_fusion_silu_and_mul_quant
(
num_tokens
,
hidden_size
,
dtype
,
model_class
,
cuda_force_torch
num_tokens
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
model_class
:
type
[
TestSiluMulFp8QuantModel
|
TestSiluMulNvfp4QuantModel
],
enable_silu_mul_custom_op
:
bool
,
enable_quant_fp8_custom_op
:
bool
,
cuda_force_torch
:
bool
,
):
if
model_class
==
TestSiluMulNvfp4QuantModel
and
cuda_force_torch
:
pytest
.
skip
(
"
Duplicate tests for NVFP4
"
)
if
model_class
is
TestSiluMulNvfp4QuantModel
and
not
is_nvfp4_supported
()
:
pytest
.
skip
(
"
NVFP4 is not supported on this GPU.
"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_dtype
(
dtype
)
maybe_create_device_identity
()
x
=
torch
.
rand
(
num_tokens
,
hidden_size
*
2
)
# Reshape pass is needed for the fusion pass to work
config
=
VllmConfig
()
config
.
compilation_config
=
CompilationConfig
(
pass_config
=
PassConfig
(
enable_fusion
=
True
,
enable_noop
=
True
)
custom_ops
=
[]
if
enable_silu_mul_custom_op
:
custom_ops
.
append
(
"+silu_and_mul"
)
if
enable_quant_fp8_custom_op
:
custom_ops
.
append
(
"+quant_fp8"
)
config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
custom_ops
=
custom_ops
,
pass_config
=
PassConfig
(
enable_fusion
=
True
,
enable_noop
=
True
),
),
)
with
set_current_vllm_config
(
config
):
fusion_pass
=
ActivationQuantFusionPass
(
config
)
passes
=
[
NoOpEliminationPass
(
config
),
fusion_pass
,
PostCleanupPass
(
config
)]
backend
=
TestBackend
(
*
passes
)
model
=
model_class
(
hidden_size
=
hidden_size
,
cuda_force_torch
=
cuda_force_torch
,
x
=
x
)
model
=
model_class
(
hidden_size
=
hidden_size
,
cuda_force_torch
=
cuda_force_torch
,
x
=
x
)
# First dimension dynamic
torch
.
_dynamo
.
mark_dynamic
(
x
,
0
)
...
...
vllm/compilation/activation_quant_fusion.py
View file @
ab3e8004
...
...
@@ -18,12 +18,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey
,
kFp8StaticTensorSym
,
kNvfp4Quant
,
kStaticTensorScale
,
)
from
vllm.platforms
import
current_platform
from
.fusion
import
QUANT_OPS
,
empty_bf16
,
empty_fp32
,
empty_i32
from
.inductor_pass
import
enable_fake_mode
from
.matcher_utils
import
MatcherQuantFP8
,
MatcherSiluAndMul
from
.vllm_inductor_pass
import
VllmInductorPass
,
VllmPatternMatcherPass
logger
=
init_logger
(
__name__
)
...
...
@@ -66,6 +66,8 @@ class ActivationQuantPattern(ABC):
)
self
.
FUSED_OP
=
FUSED_OPS
[
self
.
quant_key
]
self
.
silu_and_mul_matcher
=
MatcherSiluAndMul
()
def
empty_quant
(
self
,
*
args
,
**
kwargs
):
kwargs
=
{
"dtype"
:
self
.
quant_dtype
,
"device"
:
"cuda"
,
**
kwargs
}
return
torch
.
empty
(
*
args
,
**
kwargs
)
...
...
@@ -80,42 +82,38 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
Fusion for SiluMul+Fp8StaticQuant Pattern
"""
def
__init__
(
self
,
symmetric
:
bool
=
True
):
quant_key
=
QuantKey
(
dtype
=
FP8_DTYPE
,
scale
=
kStaticTensorScale
,
symmetric
=
symmetric
)
super
().
__init__
(
quant_key
)
def
__init__
(
self
):
super
().
__init__
(
kFp8StaticTensorSym
)
self
.
quant_matcher
=
MatcherQuantFP8
(
kFp8StaticTensorSym
)
def
register
(
self
,
pm_pass
:
PatternMatcherPass
):
def
pattern
(
result
:
torch
.
Tensor
,
result_silu_mul
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
at1
=
auto_functionalized
(
SILU_MUL_OP
,
result
=
result_silu_mul
,
input
=
input
)
at2
=
auto_functionalized
(
self
.
QUANT_OP
,
result
=
result
,
input
=
at1
[
1
],
scale
=
scale
)
return
at2
[
1
]
result_silu_mul
=
self
.
silu_and_mul_matcher
(
input
)
result_quant
=
self
.
quant_matcher
(
result_silu_mul
,
scale
)
return
result_quant
[
0
]
def
replacement
(
result
:
torch
.
Tensor
,
result_silu_mul
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
d
=
input
.
shape
[
-
1
]
//
2
output_shape
=
input
.
shape
[:
-
1
]
+
(
d
,)
result
=
torch
.
empty
(
output_shape
,
device
=
input
.
device
,
dtype
=
self
.
quant_dtype
)
at
=
auto_functionalized
(
self
.
FUSED_OP
,
result
=
result
,
input
=
input
,
scale
=
scale
)
return
at
[
1
]
inputs
=
[
self
.
empty_quant
(
5
,
4
),
# result
empty_bf16
(
5
,
4
),
# result_silu_mul
empty_bf16
(
5
,
4
),
# input
empty_fp32
(
1
,
1
),
# scale
*
self
.
silu_and_mul_matcher
.
inputs
(),
# input
self
.
quant_matcher
.
inputs
()[
1
],
# scale
]
pattern
(
*
inputs
)
register_replacement
(
pattern
,
replacement
,
inputs
,
fwd_only
,
pm_pass
)
...
...
@@ -132,24 +130,22 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
def
pattern
(
result
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
result_silu_mul
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
at1
=
auto_functionalized
(
SILU_MUL_OP
,
result
=
result_silu_mul
,
input
=
input
)
at
2
=
auto_functionalized
(
result_silu_mul
=
self
.
silu_and_mul_matcher
(
input
)
at
=
auto_functionalized
(
self
.
QUANT_OP
,
output
=
result
,
input
=
at1
[
1
]
,
input
=
result_silu_mul
,
output_scale
=
output_scale
,
input_scale
=
scale
,
)
return
at
2
[
1
],
at
2
[
2
]
return
at
[
1
],
at
[
2
]
def
replacement
(
result
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
result_silu_mul
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
):
...
...
@@ -165,7 +161,6 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
inputs
=
[
self
.
empty_quant
(
5
,
32
),
# result
empty_i32
(
128
,
4
),
# output_scale
empty_bf16
(
5
,
64
),
# result_silu_mul
empty_bf16
(
5
,
64
),
# input
empty_fp32
(
1
,
1
),
# scale
]
...
...
vllm/compilation/matcher_utils.py
View file @
ab3e8004
...
...
@@ -7,6 +7,7 @@ from torch._higher_order_ops import auto_functionalized
from
torch._ops
import
OpOverload
from
vllm.config
import
get_current_vllm_config
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
...
...
@@ -31,6 +32,8 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
if
current_platform
.
is_cuda
()
and
hasattr
(
torch
.
ops
.
_C
,
"scaled_fp4_quant"
):
QUANT_OPS
[
kNvfp4Quant
]
=
torch
.
ops
.
_C
.
scaled_fp4_quant
.
default
# noqa: E501
SILU_MUL_OP
=
torch
.
ops
.
_C
.
silu_and_mul
.
default
class
MatcherCustomOp
(
ABC
):
def
__init__
(
self
,
enabled
:
bool
):
...
...
@@ -206,3 +209,30 @@ class MatcherQuantFP8(MatcherCustomOp):
return
[
input
,
self
.
empty_f32
(
1
,
1
)]
return
[
input
]
class
MatcherSiluAndMul
(
MatcherCustomOp
):
def
__init__
(
self
,
enabled
:
bool
|
None
=
None
):
if
enabled
is
None
:
enabled
=
SiluAndMul
.
enabled
()
super
().
__init__
(
enabled
)
def
inputs
(
self
)
->
list
[
torch
.
Tensor
]:
input
=
self
.
empty
(
5
,
4
)
return
[
input
]
def
forward_custom
(
self
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
output_shape
=
x
.
shape
[:
-
1
]
+
(
d
,)
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
result
=
auto_functionalized
(
SILU_MUL_OP
,
result
=
out
,
input
=
x
)
return
result
[
1
]
def
forward_native
(
self
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
return
SiluAndMul
.
forward_native
(
x
)
vllm/model_executor/layers/activation.py
View file @
ab3e8004
...
...
@@ -80,7 +80,8 @@ class SiluAndMul(CustomOp):
elif
current_platform
.
is_cpu
():
self
.
_forward_method
=
self
.
forward_native
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
@
staticmethod
def
forward_native
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward()."""
d
=
x
.
shape
[
-
1
]
//
2
return
F
.
silu
(
x
[...,
:
d
])
*
x
[...,
d
:]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment