Unverified Commit 2196bac1 authored by BadrBasowid's avatar BadrBasowid Committed by GitHub
Browse files

[Compilation] Refactor SiluMul activation+quant Fusion Pass (#39684)


Signed-off-by: default avatarBadrBasowid <badr.basowid@gmail.com>
parent 4b7869d6
...@@ -189,7 +189,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn): ...@@ -189,7 +189,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
# TODO: Remove log counting in unit tests # TODO: Remove log counting in unit tests
# once all matchers implement VllmFusionPatternMatcherPass # once all matchers implement VllmFusionPatternMatcherPass
n_expected = tp_size * num_ranges_activated n_expected = tp_size * num_ranges_activated
if match_name != "attn_quant_fusion": if match_name not in ("attn_quant_fusion", "act_quant_fusion"):
assert len(log_matches) == n_expected, ( assert len(log_matches) == n_expected, (
f"Could not find {n_expected} {match_name} " f"Could not find {n_expected} {match_name} "
f"(found {len(log_matches)}) in:\n {log_holder.text}" f"(found {len(log_matches)}) in:\n {log_holder.text}"
...@@ -250,6 +250,12 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn): ...@@ -250,6 +250,12 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
f"entries (SP took precedence), found: {log_matches}" f"entries (SP took precedence), found: {log_matches}"
) )
elif match_name == "act_quant_fusion":
actual_match = match_table.get("activation_quant_fusion_pass", 0)
assert actual_match == expected_matches * n_expected, (
f"Could not find {expected_matches * n_expected} "
f"{match_name} (found {actual_match})."
)
elif match_name == "attn_quant_fusion": elif match_name == "attn_quant_fusion":
actual_match = match_table.get( actual_match = match_table.get(
"attn_quant_fusion", 0 "attn_quant_fusion", 0
......
...@@ -168,7 +168,7 @@ class TestSiluMulGroupFp8QuantModel(torch.nn.Module): ...@@ -168,7 +168,7 @@ class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
def forward(self, x): def forward(self, x):
y = self.silu_and_mul(x) y = self.silu_and_mul(x)
x2 = self.w8a8_block_fp8_linear(y, self.w, self.wscale) x2 = self.w8a8_block_fp8_linear(y)
return x2 return x2
def ops_in_model_before(self): def ops_in_model_before(self):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod import itertools
from typing import Any from typing import Any
import torch import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import (
PatternMatcherPass,
fwd_only,
register_replacement,
)
from torch._ops import OpOverload from torch._ops import OpOverload
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -24,8 +19,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -24,8 +19,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..inductor_pass import enable_fake_mode from ..vllm_inductor_pass import VllmFusionPatternMatcherPass, VllmPatternReplacement
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherQuantFP8, MatcherSiluAndMul from .matcher_utils import MatcherQuantFP8, MatcherSiluAndMul
from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
...@@ -50,9 +44,9 @@ if current_platform.is_cuda_alike(): ...@@ -50,9 +44,9 @@ if current_platform.is_cuda_alike():
FUSED_OPS[kFp8Dynamic64Sym] = torch.ops._C.silu_and_mul_per_block_quant.default FUSED_OPS[kFp8Dynamic64Sym] = torch.ops._C.silu_and_mul_per_block_quant.default
class ActivationQuantPattern(ABC): class ActivationQuantPattern(VllmPatternReplacement):
""" """
The base class for Activation+Quant fusions. Base class for Activation+Quant fusions.
Should not be used directly. Should not be used directly.
""" """
...@@ -79,10 +73,6 @@ class ActivationQuantPattern(ABC): ...@@ -79,10 +73,6 @@ class ActivationQuantPattern(ABC):
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs} kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
return torch.empty(*args, **kwargs) return torch.empty(*args, **kwargs)
@abstractmethod
def register(self, pm_pass: PatternMatcherPass) -> None:
raise NotImplementedError
class SiluMulFp8StaticQuantPattern(ActivationQuantPattern): class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
""" """
...@@ -100,8 +90,9 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern): ...@@ -100,8 +90,9 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
scale, scale,
] ]
def register(self, pm_pass: PatternMatcherPass) -> None: @property
def pattern( def pattern(self):
def _pattern(
input: torch.Tensor, input: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -109,7 +100,11 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern): ...@@ -109,7 +100,11 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
result_quant = self.quant_matcher(result_silu_mul, scale) result_quant = self.quant_matcher(result_silu_mul, scale)
return result_quant[0] return result_quant[0]
def replacement( return _pattern
@property
def replacement(self):
def _replacement(
input: torch.Tensor, input: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -123,10 +118,7 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern): ...@@ -123,10 +118,7 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
) )
return at[1] return at[1]
inps = self.get_inputs() return _replacement
pattern(*inps)
register_replacement(pattern, replacement, inps, fwd_only, pm_pass)
class SiluMulNvfp4QuantPattern(ActivationQuantPattern): class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
...@@ -144,8 +136,9 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern): ...@@ -144,8 +136,9 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
scale = empty_fp32(1, 1) scale = empty_fp32(1, 1)
return [result, output_scale, input_, scale] return [result, output_scale, input_, scale]
def register(self, pm_pass: PatternMatcherPass) -> None: @property
def pattern( def pattern(self):
def _pattern(
result: torch.Tensor, result: torch.Tensor,
output_scale: torch.Tensor, output_scale: torch.Tensor,
input: torch.Tensor, input: torch.Tensor,
...@@ -162,7 +155,11 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern): ...@@ -162,7 +155,11 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
) )
return at[1], at[2] return at[1], at[2]
def replacement( return _pattern
@property
def replacement(self):
def _replacement(
result: torch.Tensor, result: torch.Tensor,
output_scale: torch.Tensor, output_scale: torch.Tensor,
input: torch.Tensor, input: torch.Tensor,
...@@ -177,7 +174,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern): ...@@ -177,7 +174,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
) )
return at[1], at[2] return at[1], at[2]
register_replacement(pattern, replacement, self.get_inputs(), fwd_only, pm_pass) return _replacement
class SiluMulBlockQuantPattern(ActivationQuantPattern): class SiluMulBlockQuantPattern(ActivationQuantPattern):
...@@ -210,10 +207,9 @@ class SiluMulBlockQuantPattern(ActivationQuantPattern): ...@@ -210,10 +207,9 @@ class SiluMulBlockQuantPattern(ActivationQuantPattern):
scale = self.quant_matcher.empty_f32(1, 1) scale = self.quant_matcher.empty_f32(1, 1)
return self.silu_and_mul_matcher.inputs() + [scale] return self.silu_and_mul_matcher.inputs() + [scale]
def register(self, pm_pass: PatternMatcherPass) -> None: @property
is_scale_transposed = self.is_scale_transposed def pattern(self):
def _pattern(
def pattern(
input: torch.Tensor, input: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
...@@ -235,12 +231,16 @@ class SiluMulBlockQuantPattern(ActivationQuantPattern): ...@@ -235,12 +231,16 @@ class SiluMulBlockQuantPattern(ActivationQuantPattern):
fp8_min=finfo.min, fp8_min=finfo.min,
fp8_max=finfo.max, fp8_max=finfo.max,
scale_ue8m0=self.is_e8m0, scale_ue8m0=self.is_e8m0,
dummy_is_scale_transposed=is_scale_transposed, dummy_is_scale_transposed=self.is_scale_transposed,
dummy_is_tma_aligned=self.is_tma_aligned, dummy_is_tma_aligned=self.is_tma_aligned,
) )
return result, scale return result, scale
def replacement( return _pattern
@property
def replacement(self):
def _replacement(
input: torch.Tensor, input: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
...@@ -249,7 +249,7 @@ class SiluMulBlockQuantPattern(ActivationQuantPattern): ...@@ -249,7 +249,7 @@ class SiluMulBlockQuantPattern(ActivationQuantPattern):
result = torch.empty( result = torch.empty(
output_shape, device=input.device, dtype=self.quant_dtype output_shape, device=input.device, dtype=self.quant_dtype
) )
if is_scale_transposed: if self.is_scale_transposed:
scale = torch.empty( scale = torch.empty(
(d // self.group_size, input.shape[0]), (d // self.group_size, input.shape[0]),
device=input.device, device=input.device,
...@@ -268,15 +268,14 @@ class SiluMulBlockQuantPattern(ActivationQuantPattern): ...@@ -268,15 +268,14 @@ class SiluMulBlockQuantPattern(ActivationQuantPattern):
scales=scale, scales=scale,
group_size=self.group_size, group_size=self.group_size,
scale_ub=None, scale_ub=None,
is_scale_transposed=is_scale_transposed, is_scale_transposed=self.is_scale_transposed,
) )
return at[1], at[2] return at[1], at[2]
inps = self.get_inputs() return _replacement
register_replacement(pattern, replacement, inps, fwd_only, pm_pass)
class ActivationQuantFusionPass(VllmPatternMatcherPass): class ActivationQuantFusionPass(VllmFusionPatternMatcherPass):
""" """
This pass fuses a pre-defined set of custom ops into fused ops. This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them. It uses the torch pattern matcher to find the patterns and replace them.
...@@ -286,45 +285,33 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass): ...@@ -286,45 +285,33 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass):
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
""" """
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None: def __init__(self, config: VllmConfig) -> None:
super().__init__(config) super().__init__(config, "activation_quant_fusion_pass")
self.patterns: PatternMatcherPass = PatternMatcherPass( self.register(SiluMulFp8StaticQuantPattern())
pass_name="activation_quant_fusion_pass"
)
pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern()
pattern_silu_mul_fp8.register(self.patterns)
if silu_and_mul_nvfp4_quant_supported: if silu_and_mul_nvfp4_quant_supported:
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern() self.register(SiluMulNvfp4QuantPattern())
pattern_silu_mul_nvfp4.register(self.patterns)
if current_platform.is_cuda(): if current_platform.is_cuda():
for quant_key in [kFp8Dynamic128Sym, kFp8Dynamic64Sym]: for (
for is_scale_transposed in [False, True]: quant_key,
for is_e8m0 in [True, False]: is_scale_transposed,
for is_tma_aligned in [False, True]: is_e8m0,
is_tma_aligned,
) in itertools.product(
[kFp8Dynamic128Sym, kFp8Dynamic64Sym],
[False, True],
[True, False],
[False, True],
):
self.register(
SiluMulBlockQuantPattern( SiluMulBlockQuantPattern(
quant_key, quant_key,
is_scale_transposed=is_scale_transposed, is_scale_transposed=is_scale_transposed,
is_e8m0=is_e8m0, is_e8m0=is_e8m0,
is_tma_aligned=is_tma_aligned, is_tma_aligned=is_tma_aligned,
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> str:
return VllmInductorPass.hash_source(
self,
ActivationQuantPattern,
SiluMulFp8StaticQuantPattern,
SiluMulNvfp4QuantPattern,
SiluMulBlockQuantPattern,
) )
)
self.dump_patterns(config, self.pm_pass)
...@@ -28,7 +28,6 @@ from ..vllm_inductor_pass import ( ...@@ -28,7 +28,6 @@ from ..vllm_inductor_pass import (
VllmPatternMatcherPass, VllmPatternMatcherPass,
VllmPatternReplacement, VllmPatternReplacement,
) )
from .act_quant_fusion import ActivationQuantPattern
from .matcher_utils import ( from .matcher_utils import (
MatcherFusedAddRMSNorm, MatcherFusedAddRMSNorm,
MatcherQuantFP8, MatcherQuantFP8,
...@@ -345,7 +344,7 @@ class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass): ...@@ -345,7 +344,7 @@ class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass):
return self.hash_source(self, *fusion_patterns) return self.hash_source(self, *fusion_patterns)
class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern): class AiterSiluMulFp8GroupQuantPattern(VllmPatternReplacement):
""" """
This pattern fuses aiter silu_and_mul & group fp8 quant custom This pattern fuses aiter silu_and_mul & group fp8 quant custom
ops into an aiter silu_and_mul_group_fp8_quant op. ops into an aiter silu_and_mul_group_fp8_quant op.
...@@ -364,26 +363,29 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern): ...@@ -364,26 +363,29 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
self.silu_and_mul_matcher.inputs()[0], self.silu_and_mul_matcher.inputs()[0],
] ]
def register(self, pm_pass: PatternMatcherPass) -> None: @property
def pattern( def pattern(self):
def _pattern(
input: torch.Tensor, input: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
at1 = self.silu_and_mul_matcher(input) at1 = self.silu_and_mul_matcher(input)
at2 = self.quant_matcher(at1) at2 = self.quant_matcher(at1)
return at2[0], at2[1] return at2[0], at2[1]
def replacement( return _pattern
@property
def replacement(self):
def _replacement(
input: torch.Tensor, input: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
at = self.FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128) at = self.FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
return at[0], at[1] return at[0], at[1]
pm.register_replacement( return _replacement
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass): class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmFusionPatternMatcherPass):
""" """
This pass fuses a pre-defined set of custom ops into fused ops. This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them. It uses the torch pattern matcher to find the patterns and replace them.
...@@ -393,29 +395,12 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass): ...@@ -393,29 +395,12 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
""" """
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None: def __init__(self, config: VllmConfig) -> None:
super().__init__(config) super().__init__(config, "rocm_aiter_silu_mul_fp8_group_quant_fusion_pass")
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
)
AiterSiluMulFp8GroupQuantPattern().register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log self.register(AiterSiluMulFp8GroupQuantPattern())
def __call__(self, graph: torch.fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> str: self.dump_patterns(config, self.pm_pass)
fusion_patterns = [
ActivationQuantPattern,
AiterSiluMulFp8GroupQuantPattern,
]
return VllmInductorPass.hash_source(self, *fusion_patterns)
class AddAiterRMSNormPadPattern: class AddAiterRMSNormPadPattern:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment