Commit a246d08c authored by zhuwenwen's avatar zhuwenwen
Browse files

skip fp8 and ActivationQuantFusionPass

parent d560429c
...@@ -586,23 +586,23 @@ class FusionPass(VllmInductorPass): ...@@ -586,23 +586,23 @@ class FusionPass(VllmInductorPass):
for epsilon in [1e-5, 1e-6]: for epsilon in [1e-5, 1e-6]:
# Fuse rms_norm + static fp8 quant # Fuse rms_norm + static fp8 quant
RMSNormStaticQuantPattern(epsilon, # RMSNormStaticQuantPattern(epsilon,
FP8_DTYPE).register(self.patterns) # FP8_DTYPE).register(self.patterns)
# Matches for patterns below have 2 or more outputs, # Matches for patterns below have 2 or more outputs,
# so we need to process them manually (see process_matches) # so we need to process them manually (see process_matches)
# Fuse rms_norm + static fp8 quant # Fuse rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( # FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns, self.record_match) # self.patterns, self.record_match)
# Fuse rms_norm + dynamic per-token fp8 quant # Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( # RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns, self.record_match) # self.patterns, self.record_match)
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( # FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns, self.record_match) # self.patterns, self.record_match)
# WARNING: This is a hack to clear the pattern matcher cache # WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon. # and allow multiple values of epsilon.
......
...@@ -6,7 +6,7 @@ from torch import fx as fx ...@@ -6,7 +6,7 @@ from torch import fx as fx
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from .activation_quant_fusion import ActivationQuantFusionPass # from .activation_quant_fusion import ActivationQuantFusionPass
from .collective_fusion import AsyncTPPass from .collective_fusion import AsyncTPPass
from .fix_functionalization import FixFunctionalizationPass from .fix_functionalization import FixFunctionalizationPass
from .fusion import FusionPass from .fusion import FusionPass
...@@ -56,9 +56,9 @@ class PostGradPassManager(CustomGraphPass): ...@@ -56,9 +56,9 @@ class PostGradPassManager(CustomGraphPass):
if self.pass_config.enable_async_tp: if self.pass_config.enable_async_tp:
self.passes += [AsyncTPPass(config)] self.passes += [AsyncTPPass(config)]
if self.pass_config.enable_fusion: # if self.pass_config.enable_fusion:
self.passes += [FusionPass.instance(config)] # self.passes += [FusionPass.instance(config)]
self.passes += [ActivationQuantFusionPass(config)] # self.passes += [ActivationQuantFusionPass(config)]
if self.pass_config.enable_attn_fusion: if self.pass_config.enable_attn_fusion:
self.passes += [AttnFusionPass(config)] self.passes += [AttnFusionPass(config)]
......
...@@ -444,16 +444,16 @@ class SequenceParallelismPass(VllmInductorPass): ...@@ -444,16 +444,16 @@ class SequenceParallelismPass(VllmInductorPass):
for epsilon in [1e-5, 1e-6]: for epsilon in [1e-5, 1e-6]:
# RMSNorm + Static FP8 quantization patterns # RMSNorm + Static FP8 quantization patterns
fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default # fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default
FirstAllReduceRMSNormStaticFP8Pattern( # FirstAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device, # epsilon, self.model_dtype, self.device,
fp8_quant_op).register(self.patterns) # fp8_quant_op).register(self.patterns)
MiddleAllReduceRMSNormStaticFP8Pattern( # MiddleAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device, # epsilon, self.model_dtype, self.device,
fp8_quant_op).register(self.patterns) # fp8_quant_op).register(self.patterns)
LastAllReduceRMSNormStaticFP8Pattern( # LastAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device, # epsilon, self.model_dtype, self.device,
fp8_quant_op).register(self.patterns) # fp8_quant_op).register(self.patterns)
# Normal RMSNorm patterns # Normal RMSNorm patterns
FirstAllReduceRMSNormPattern(epsilon, self.model_dtype, FirstAllReduceRMSNormPattern(epsilon, self.model_dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment