Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a246d08c
Commit
a246d08c
authored
Jul 24, 2025
by
zhuwenwen
Browse files
skip fp8 and ActivationQuantFusionPass
parent
d560429c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
22 deletions
+22
-22
vllm/compilation/fusion.py
vllm/compilation/fusion.py
+8
-8
vllm/compilation/pass_manager.py
vllm/compilation/pass_manager.py
+4
-4
vllm/compilation/sequence_parallelism.py
vllm/compilation/sequence_parallelism.py
+10
-10
No files found.
vllm/compilation/fusion.py
View file @
a246d08c
...
@@ -586,23 +586,23 @@ class FusionPass(VllmInductorPass):
...
@@ -586,23 +586,23 @@ class FusionPass(VllmInductorPass):
for
epsilon
in
[
1e-5
,
1e-6
]:
for
epsilon
in
[
1e-5
,
1e-6
]:
# Fuse rms_norm + static fp8 quant
# Fuse rms_norm + static fp8 quant
RMSNormStaticQuantPattern
(
epsilon
,
#
RMSNormStaticQuantPattern(epsilon,
FP8_DTYPE
).
register
(
self
.
patterns
)
#
FP8_DTYPE).register(self.patterns)
# Matches for patterns below have 2 or more outputs,
# Matches for patterns below have 2 or more outputs,
# so we need to process them manually (see process_matches)
# so we need to process them manually (see process_matches)
# Fuse rms_norm + static fp8 quant
# Fuse rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern
(
epsilon
,
FP8_DTYPE
).
register
(
#
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
self
.
patterns
,
self
.
record_match
)
#
self.patterns, self.record_match)
# Fuse rms_norm + dynamic per-token fp8 quant
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern
(
epsilon
,
FP8_DTYPE
).
register
(
#
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self
.
patterns
,
self
.
record_match
)
#
self.patterns, self.record_match)
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
FusedAddRMSNormDynamicQuantPattern
(
epsilon
,
FP8_DTYPE
).
register
(
#
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self
.
patterns
,
self
.
record_match
)
#
self.patterns, self.record_match)
# WARNING: This is a hack to clear the pattern matcher cache
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
# and allow multiple values of epsilon.
...
...
vllm/compilation/pass_manager.py
View file @
a246d08c
...
@@ -6,7 +6,7 @@ from torch import fx as fx
...
@@ -6,7 +6,7 @@ from torch import fx as fx
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
.activation_quant_fusion
import
ActivationQuantFusionPass
#
from .activation_quant_fusion import ActivationQuantFusionPass
from
.collective_fusion
import
AsyncTPPass
from
.collective_fusion
import
AsyncTPPass
from
.fix_functionalization
import
FixFunctionalizationPass
from
.fix_functionalization
import
FixFunctionalizationPass
from
.fusion
import
FusionPass
from
.fusion
import
FusionPass
...
@@ -56,9 +56,9 @@ class PostGradPassManager(CustomGraphPass):
...
@@ -56,9 +56,9 @@ class PostGradPassManager(CustomGraphPass):
if
self
.
pass_config
.
enable_async_tp
:
if
self
.
pass_config
.
enable_async_tp
:
self
.
passes
+=
[
AsyncTPPass
(
config
)]
self
.
passes
+=
[
AsyncTPPass
(
config
)]
if
self
.
pass_config
.
enable_fusion
:
#
if self.pass_config.enable_fusion:
self
.
passes
+=
[
FusionPass
.
instance
(
config
)]
#
self.passes += [FusionPass.instance(config)]
self
.
passes
+=
[
ActivationQuantFusionPass
(
config
)]
#
self.passes += [ActivationQuantFusionPass(config)]
if
self
.
pass_config
.
enable_attn_fusion
:
if
self
.
pass_config
.
enable_attn_fusion
:
self
.
passes
+=
[
AttnFusionPass
(
config
)]
self
.
passes
+=
[
AttnFusionPass
(
config
)]
...
...
vllm/compilation/sequence_parallelism.py
View file @
a246d08c
...
@@ -444,16 +444,16 @@ class SequenceParallelismPass(VllmInductorPass):
...
@@ -444,16 +444,16 @@ class SequenceParallelismPass(VllmInductorPass):
for
epsilon
in
[
1e-5
,
1e-6
]:
for
epsilon
in
[
1e-5
,
1e-6
]:
# RMSNorm + Static FP8 quantization patterns
# RMSNorm + Static FP8 quantization patterns
fp8_quant_op
=
torch
.
ops
.
_C
.
static_scaled_fp8_quant
.
default
#
fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default
FirstAllReduceRMSNormStaticFP8Pattern
(
#
FirstAllReduceRMSNormStaticFP8Pattern(
epsilon
,
self
.
model_dtype
,
self
.
device
,
#
epsilon, self.model_dtype, self.device,
fp8_quant_op
).
register
(
self
.
patterns
)
#
fp8_quant_op).register(self.patterns)
MiddleAllReduceRMSNormStaticFP8Pattern
(
#
MiddleAllReduceRMSNormStaticFP8Pattern(
epsilon
,
self
.
model_dtype
,
self
.
device
,
#
epsilon, self.model_dtype, self.device,
fp8_quant_op
).
register
(
self
.
patterns
)
#
fp8_quant_op).register(self.patterns)
LastAllReduceRMSNormStaticFP8Pattern
(
#
LastAllReduceRMSNormStaticFP8Pattern(
epsilon
,
self
.
model_dtype
,
self
.
device
,
#
epsilon, self.model_dtype, self.device,
fp8_quant_op
).
register
(
self
.
patterns
)
#
fp8_quant_op).register(self.patterns)
# Normal RMSNorm patterns
# Normal RMSNorm patterns
FirstAllReduceRMSNormPattern
(
epsilon
,
self
.
model_dtype
,
FirstAllReduceRMSNormPattern
(
epsilon
,
self
.
model_dtype
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment