Commit a246d08c authored by zhuwenwen's avatar zhuwenwen
Browse files

skip fp8 and ActivationQuantFusionPass

parent d560429c
......@@ -586,23 +586,23 @@ class FusionPass(VllmInductorPass):
for epsilon in [1e-5, 1e-6]:
# Fuse rms_norm + static fp8 quant
RMSNormStaticQuantPattern(epsilon,
FP8_DTYPE).register(self.patterns)
# RMSNormStaticQuantPattern(epsilon,
# FP8_DTYPE).register(self.patterns)
# Matches for patterns below have 2 or more outputs,
# so we need to process them manually (see process_matches)
# Fuse rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns, self.record_match)
# FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
# self.patterns, self.record_match)
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns, self.record_match)
# RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
# self.patterns, self.record_match)
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns, self.record_match)
# FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
# self.patterns, self.record_match)
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
......
......@@ -6,7 +6,7 @@ from torch import fx as fx
from vllm.config import VllmConfig
from vllm.logger import init_logger
from .activation_quant_fusion import ActivationQuantFusionPass
# from .activation_quant_fusion import ActivationQuantFusionPass
from .collective_fusion import AsyncTPPass
from .fix_functionalization import FixFunctionalizationPass
from .fusion import FusionPass
......@@ -56,9 +56,9 @@ class PostGradPassManager(CustomGraphPass):
if self.pass_config.enable_async_tp:
self.passes += [AsyncTPPass(config)]
if self.pass_config.enable_fusion:
self.passes += [FusionPass.instance(config)]
self.passes += [ActivationQuantFusionPass(config)]
# if self.pass_config.enable_fusion:
# self.passes += [FusionPass.instance(config)]
# self.passes += [ActivationQuantFusionPass(config)]
if self.pass_config.enable_attn_fusion:
self.passes += [AttnFusionPass(config)]
......
......@@ -444,16 +444,16 @@ class SequenceParallelismPass(VllmInductorPass):
for epsilon in [1e-5, 1e-6]:
# RMSNorm + Static FP8 quantization patterns
fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default
FirstAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device,
fp8_quant_op).register(self.patterns)
MiddleAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device,
fp8_quant_op).register(self.patterns)
LastAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device,
fp8_quant_op).register(self.patterns)
# fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default
# FirstAllReduceRMSNormStaticFP8Pattern(
# epsilon, self.model_dtype, self.device,
# fp8_quant_op).register(self.patterns)
# MiddleAllReduceRMSNormStaticFP8Pattern(
# epsilon, self.model_dtype, self.device,
# fp8_quant_op).register(self.patterns)
# LastAllReduceRMSNormStaticFP8Pattern(
# epsilon, self.model_dtype, self.device,
# fp8_quant_op).register(self.patterns)
# Normal RMSNorm patterns
FirstAllReduceRMSNormPattern(epsilon, self.model_dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment