Unverified Commit 2196bac1 authored by BadrBasowid's avatar BadrBasowid Committed by GitHub
Browse files

[Compilation] Refactor SiluMul activation+quant Fusion Pass (#39684)


Signed-off-by: default avatarBadrBasowid <badr.basowid@gmail.com>
parent 4b7869d6
......@@ -189,7 +189,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
# TODO: Remove log counting in unit tests
# once all matchers implement VllmFusionPatternMatcherPass
n_expected = tp_size * num_ranges_activated
if match_name != "attn_quant_fusion":
if match_name not in ("attn_quant_fusion", "act_quant_fusion"):
assert len(log_matches) == n_expected, (
f"Could not find {n_expected} {match_name} "
f"(found {len(log_matches)}) in:\n {log_holder.text}"
......@@ -250,6 +250,12 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
f"entries (SP took precedence), found: {log_matches}"
)
elif match_name == "act_quant_fusion":
actual_match = match_table.get("activation_quant_fusion_pass", 0)
assert actual_match == expected_matches * n_expected, (
f"Could not find {expected_matches * n_expected} "
f"{match_name} (found {actual_match})."
)
elif match_name == "attn_quant_fusion":
actual_match = match_table.get(
"attn_quant_fusion", 0
......
......@@ -168,7 +168,7 @@ class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
def forward(self, x):
y = self.silu_and_mul(x)
x2 = self.w8a8_block_fp8_linear(y, self.w, self.wscale)
x2 = self.w8a8_block_fp8_linear(y)
return x2
def ops_in_model_before(self):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import itertools
from typing import Any
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import (
PatternMatcherPass,
fwd_only,
register_replacement,
)
from torch._ops import OpOverload
from vllm.config import VllmConfig
......@@ -24,8 +19,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.platforms import current_platform
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from ..vllm_inductor_pass import VllmFusionPatternMatcherPass, VllmPatternReplacement
from .matcher_utils import MatcherQuantFP8, MatcherSiluAndMul
from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
......@@ -50,9 +44,9 @@ if current_platform.is_cuda_alike():
FUSED_OPS[kFp8Dynamic64Sym] = torch.ops._C.silu_and_mul_per_block_quant.default
class ActivationQuantPattern(ABC):
class ActivationQuantPattern(VllmPatternReplacement):
"""
The base class for Activation+Quant fusions.
Base class for Activation+Quant fusions.
Should not be used directly.
"""
......@@ -79,10 +73,6 @@ class ActivationQuantPattern(ABC):
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
return torch.empty(*args, **kwargs)
@abstractmethod
def register(self, pm_pass: PatternMatcherPass) -> None:
raise NotImplementedError
class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
"""
......@@ -100,8 +90,9 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
scale,
]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
@property
def pattern(self):
def _pattern(
input: torch.Tensor,
scale: torch.Tensor,
) -> torch.Tensor:
......@@ -109,7 +100,11 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
result_quant = self.quant_matcher(result_silu_mul, scale)
return result_quant[0]
def replacement(
return _pattern
@property
def replacement(self):
def _replacement(
input: torch.Tensor,
scale: torch.Tensor,
) -> torch.Tensor:
......@@ -123,10 +118,7 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
)
return at[1]
inps = self.get_inputs()
pattern(*inps)
register_replacement(pattern, replacement, inps, fwd_only, pm_pass)
return _replacement
class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
......@@ -144,8 +136,9 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
scale = empty_fp32(1, 1)
return [result, output_scale, input_, scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
@property
def pattern(self):
def _pattern(
result: torch.Tensor,
output_scale: torch.Tensor,
input: torch.Tensor,
......@@ -162,7 +155,11 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
)
return at[1], at[2]
def replacement(
return _pattern
@property
def replacement(self):
def _replacement(
result: torch.Tensor,
output_scale: torch.Tensor,
input: torch.Tensor,
......@@ -177,7 +174,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
)
return at[1], at[2]
register_replacement(pattern, replacement, self.get_inputs(), fwd_only, pm_pass)
return _replacement
class SiluMulBlockQuantPattern(ActivationQuantPattern):
......@@ -210,10 +207,9 @@ class SiluMulBlockQuantPattern(ActivationQuantPattern):
scale = self.quant_matcher.empty_f32(1, 1)
return self.silu_and_mul_matcher.inputs() + [scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
is_scale_transposed = self.is_scale_transposed
def pattern(
@property
def pattern(self):
def _pattern(
input: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
......@@ -235,12 +231,16 @@ class SiluMulBlockQuantPattern(ActivationQuantPattern):
fp8_min=finfo.min,
fp8_max=finfo.max,
scale_ue8m0=self.is_e8m0,
dummy_is_scale_transposed=is_scale_transposed,
dummy_is_scale_transposed=self.is_scale_transposed,
dummy_is_tma_aligned=self.is_tma_aligned,
)
return result, scale
def replacement(
return _pattern
@property
def replacement(self):
def _replacement(
input: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
......@@ -249,7 +249,7 @@ class SiluMulBlockQuantPattern(ActivationQuantPattern):
result = torch.empty(
output_shape, device=input.device, dtype=self.quant_dtype
)
if is_scale_transposed:
if self.is_scale_transposed:
scale = torch.empty(
(d // self.group_size, input.shape[0]),
device=input.device,
......@@ -268,15 +268,14 @@ class SiluMulBlockQuantPattern(ActivationQuantPattern):
scales=scale,
group_size=self.group_size,
scale_ub=None,
is_scale_transposed=is_scale_transposed,
is_scale_transposed=self.is_scale_transposed,
)
return at[1], at[2]
inps = self.get_inputs()
register_replacement(pattern, replacement, inps, fwd_only, pm_pass)
return _replacement
class ActivationQuantFusionPass(VllmPatternMatcherPass):
class ActivationQuantFusionPass(VllmFusionPatternMatcherPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
......@@ -286,45 +285,33 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass):
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
super().__init__(config, "activation_quant_fusion_pass")
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="activation_quant_fusion_pass"
)
pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern()
pattern_silu_mul_fp8.register(self.patterns)
self.register(SiluMulFp8StaticQuantPattern())
if silu_and_mul_nvfp4_quant_supported:
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
pattern_silu_mul_nvfp4.register(self.patterns)
self.register(SiluMulNvfp4QuantPattern())
if current_platform.is_cuda():
for quant_key in [kFp8Dynamic128Sym, kFp8Dynamic64Sym]:
for is_scale_transposed in [False, True]:
for is_e8m0 in [True, False]:
for is_tma_aligned in [False, True]:
for (
quant_key,
is_scale_transposed,
is_e8m0,
is_tma_aligned,
) in itertools.product(
[kFp8Dynamic128Sym, kFp8Dynamic64Sym],
[False, True],
[True, False],
[False, True],
):
self.register(
SiluMulBlockQuantPattern(
quant_key,
is_scale_transposed=is_scale_transposed,
is_e8m0=is_e8m0,
is_tma_aligned=is_tma_aligned,
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> str:
return VllmInductorPass.hash_source(
self,
ActivationQuantPattern,
SiluMulFp8StaticQuantPattern,
SiluMulNvfp4QuantPattern,
SiluMulBlockQuantPattern,
)
)
self.dump_patterns(config, self.pm_pass)
......@@ -28,7 +28,6 @@ from ..vllm_inductor_pass import (
VllmPatternMatcherPass,
VllmPatternReplacement,
)
from .act_quant_fusion import ActivationQuantPattern
from .matcher_utils import (
MatcherFusedAddRMSNorm,
MatcherQuantFP8,
......@@ -345,7 +344,7 @@ class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass):
return self.hash_source(self, *fusion_patterns)
class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
class AiterSiluMulFp8GroupQuantPattern(VllmPatternReplacement):
"""
This pattern fuses aiter silu_and_mul & group fp8 quant custom
ops into an aiter silu_and_mul_group_fp8_quant op.
......@@ -364,26 +363,29 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
self.silu_and_mul_matcher.inputs()[0],
]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
@property
def pattern(self):
def _pattern(
input: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
at1 = self.silu_and_mul_matcher(input)
at2 = self.quant_matcher(at1)
return at2[0], at2[1]
def replacement(
return _pattern
@property
def replacement(self):
def _replacement(
input: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
at = self.FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
return at[0], at[1]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
return _replacement
class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmFusionPatternMatcherPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
......@@ -393,29 +395,12 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
)
AiterSiluMulFp8GroupQuantPattern().register(self.patterns)
self.dump_patterns(config, self.patterns)
super().__init__(config, "rocm_aiter_silu_mul_fp8_group_quant_fusion_pass")
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
self.register(AiterSiluMulFp8GroupQuantPattern())
def uuid(self) -> str:
fusion_patterns = [
ActivationQuantPattern,
AiterSiluMulFp8GroupQuantPattern,
]
return VllmInductorPass.hash_source(self, *fusion_patterns)
self.dump_patterns(config, self.pm_pass)
class AddAiterRMSNormPadPattern:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment