Unverified Commit 077a9a8e authored by BadrBasowid's avatar BadrBasowid Committed by GitHub
Browse files

[torch.compile] Refactor Attention Quant Fusion Pass and Remove Boilerplate (#37373)


Signed-off-by: default avatarBadrBasowid <badr.basowid@gmail.com>
Co-authored-by: default avatarvllmellm <vllm.ellm@embeddedllm.com>
parent 07edd551
...@@ -225,7 +225,7 @@ outputs = model.generate( ...@@ -225,7 +225,7 @@ outputs = model.generate(
### Piecewise compilation and full graph custom passes (attention fusion, sequence parallelism) ### Piecewise compilation and full graph custom passes (attention fusion, sequence parallelism)
Unfortunately, some custom compile passes have to see the whole graph to be effective and hence aren't compatible with piecewise compilation. This includes `AttnFusionPass` and `SequenceParallelismPass`. As a short-term solution, we automatically disable piecewise compilation (by setting `splitting_ops=[]`) when attention fusion is enabled. We use CUDA Graph modes `FULL` or `FULL_DECODE_ONLY` (depending on backend support). However, this leads to another optimization incompatibility and confusing performance tradeoffs. Unfortunately, some custom compile passes have to see the whole graph to be effective and hence aren't compatible with piecewise compilation. This includes `AttnQuantFusionPass` and `SequenceParallelismPass`. As a short-term solution, we automatically disable piecewise compilation (by setting `splitting_ops=[]`) when attention fusion is enabled. We use CUDA Graph modes `FULL` or `FULL_DECODE_ONLY` (depending on backend support). However, this leads to another optimization incompatibility and confusing performance tradeoffs.
Long term, we've added the ability to partition the graph in Inductor instead of right after Dynamo. It can be enabled with `CompilationConfig.use_inductor_graph_partition=True` but is currently experimental and only available with `torch>=2.9`. This also increases compilation time as it has to compile the whole graph and cannot reuse piecewise compilation artifacts. Once vLLM supports 2.9, we plan to make this the default approach as it will also speed up piecewise cudagraph capture. Long term, we've added the ability to partition the graph in Inductor instead of right after Dynamo. It can be enabled with `CompilationConfig.use_inductor_graph_partition=True` but is currently experimental and only available with `torch>=2.9`. This also increases compilation time as it has to compile the whole graph and cannot reuse piecewise compilation artifacts. Once vLLM supports 2.9, we plan to make this the default approach as it will also speed up piecewise cudagraph capture.
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging import logging
from collections import defaultdict
import pytest import pytest
import regex as re import regex as re
...@@ -52,6 +53,16 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg ...@@ -52,6 +53,16 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg
llm.llm_engine.vllm_config.compilation_config.compile_ranges_endpoints llm.llm_engine.vllm_config.compilation_config.compile_ranges_endpoints
) )
# Fetch match table from each worker via RPC and sum across workers.
worker_tables = llm.llm_engine.engine_core.collective_rpc(
"get_compilation_match_table"
)
combined: defaultdict[str, int] = defaultdict(int)
for table in worker_tables:
for k, v in table.items():
combined[k] += v
return dict(combined)
@pytest.fixture @pytest.fixture
def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn): def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
...@@ -113,7 +124,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn): ...@@ -113,7 +124,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
) )
with caplog_mp_spawn(logging.DEBUG) as log_holder: with caplog_mp_spawn(logging.DEBUG) as log_holder:
run_model(full_compilation_config, model_name, **model_kwargs) match_table = run_model(full_compilation_config, model_name, **model_kwargs)
num_compile_ranges = len(full_compilation_config.get_compile_ranges()) num_compile_ranges = len(full_compilation_config.get_compile_ranges())
assert num_compile_ranges in [1, 2, 3] assert num_compile_ranges in [1, 2, 3]
...@@ -155,7 +166,10 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn): ...@@ -155,7 +166,10 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
else: else:
num_ranges_activated = num_compile_ranges num_ranges_activated = num_compile_ranges
# TODO: Remove log counting in unit tests
# once all matchers implement VllmFusionPatternMatcherPass
n_expected = tp_size * num_ranges_activated n_expected = tp_size * num_ranges_activated
if match_name != "attn_quant_fusion":
assert len(log_matches) == n_expected, ( assert len(log_matches) == n_expected, (
f"Could not find {n_expected} {match_name} " f"Could not find {n_expected} {match_name} "
f"(found {len(log_matches)}) in:\n {log_holder.text}" f"(found {len(log_matches)}) in:\n {log_holder.text}"
...@@ -215,6 +229,13 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn): ...@@ -215,6 +229,13 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
f"{tp_size * (num_ranges_activated - 1)} large-range " f"{tp_size * (num_ranges_activated - 1)} large-range "
f"entries (SP took precedence), found: {log_matches}" f"entries (SP took precedence), found: {log_matches}"
) )
elif match_name == "attn_quant_fusion":
actual_match = match_table.get(match_name, 0)
assert actual_match == expected_matches * n_expected, (
f"Could not find {expected_matches * n_expected} "
f"{match_name} (found {actual_match})."
)
else: else:
expected_matches_list = [expected_matches] * n_expected expected_matches_list = [expected_matches] * n_expected
assert sorted(log_matches) == expected_matches_list, ( assert sorted(log_matches) == expected_matches_list, (
......
...@@ -9,7 +9,10 @@ from tests.compile.backend import LazyInitPass, TestBackend ...@@ -9,7 +9,10 @@ from tests.compile.backend import LazyInitPass, TestBackend
from tests.utils import TestFP8Layer, flat_product from tests.utils import TestFP8Layer, flat_product
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.compilation.passes.fusion.attn_quant_fusion import ATTN_OP, AttnFusionPass from vllm.compilation.passes.fusion.attn_quant_fusion import (
ATTN_OP,
AttnQuantFusionPass,
)
from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS
from vllm.compilation.passes.fx_utils import find_op_nodes from vllm.compilation.passes.fx_utils import find_op_nodes
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
...@@ -384,7 +387,7 @@ def test_attention_quant_pattern( ...@@ -384,7 +387,7 @@ def test_attention_quant_pattern(
# Create test backend with fusion passes enabled # Create test backend with fusion passes enabled
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
attn_pass = LazyInitPass(AttnFusionPass, vllm_config) attn_pass = LazyInitPass(AttnQuantFusionPass, vllm_config)
cleanup_pass = PostCleanupPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config)
test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass) test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
...@@ -434,7 +437,7 @@ def test_attention_quant_pattern( ...@@ -434,7 +437,7 @@ def test_attention_quant_pattern(
# Only output quant ops are fused into attention. # Only output quant ops are fused into attention.
test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Dynamic) test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Dynamic)
# access the underlying `AttnFusionPass` on the `LazyInitPass` # access the underlying `AttnQuantFusionPass` on the `LazyInitPass`
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported) assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
# Check attention ops in the graph before and after fusion # Check attention ops in the graph before and after fusion
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from typing import Any, ParamSpec
import torch import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -22,14 +18,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -22,14 +18,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up from vllm.utils.math_utils import round_up
from ..fx_utils import is_func from ..vllm_inductor_pass import VllmFusionPatternMatcherPass, VllmPatternReplacement
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherQuantFP8 from .matcher_utils import MatcherQuantFP8
from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 from .rms_quant_fusion import QUANT_OPS
logger = init_logger(__name__) logger = init_logger(__name__)
P = ParamSpec("P")
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8 FP4_DTYPE = torch.uint8
...@@ -37,83 +31,10 @@ ATTN_OP = torch.ops.vllm.unified_attention_with_output.default ...@@ -37,83 +31,10 @@ ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
RESHAPE_OP = torch.ops.aten.reshape.default RESHAPE_OP = torch.ops.aten.reshape.default
class AttentionQuantPattern(ABC): _FP8_QUANT_KEY = QuantKey(dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=True)
"""
The base class for Attn+Quant fusions.
Should not be used directly.
"""
def __init__(
self,
layer: Attention,
quant_key: QuantKey,
dtype: torch.dtype,
) -> None:
self.layer = layer
self.layer_name = layer.layer_name
self.num_heads = layer.num_heads
self.head_size = layer.head_size
self.quant_key = quant_key
self.quant_dtype = quant_key.dtype
self.dtype = dtype
assert self.quant_key in QUANT_OPS, (
f"unsupported quantization scheme {self.quant_key}"
)
self.QUANT_OP = QUANT_OPS[self.quant_key]
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
kwargs = {"dtype": self.dtype, "device": "cuda", **kwargs}
return torch.empty(*args, **kwargs)
def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor:
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
return torch.empty(*args, **kwargs)
@staticmethod
def wrap_trace_fn(
trace_fn: Callable[P, fx.GraphModule],
*process_fx_fns: Callable[[fx.GraphModule], None],
) -> Callable[P, fx.GraphModule]:
def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule:
gm = trace_fn(*args, **kwargs)
for process_fx in process_fx_fns:
process_fx(gm)
return gm
return wrapped
@staticmethod
def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None:
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
@staticmethod
def remove_noop_permutes(gm: torch.fx.GraphModule) -> None:
for node in gm.graph.nodes:
if not is_func(node, torch.ops.aten.permute.default):
continue
dims = node.args[1]
if any(dim != i for i, dim in enumerate(dims)):
continue
# this is now an identity op, remove
node.replace_all_uses_with(node.args[0])
gm.graph.erase_node(node)
def register_if_supported(self, pm_pass: PatternMatcherPass) -> None:
if self.layer.impl.fused_output_quant_supported(self.quant_key):
self._register(pm_pass)
@abstractmethod class AttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
def _register(self, pm_pass: PatternMatcherPass) -> None:
raise NotImplementedError
class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
""" """
Fusion for Attention+Fp8StaticQuant. Fusion for Attention+Fp8StaticQuant.
...@@ -123,20 +44,16 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): ...@@ -123,20 +44,16 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
will be passed into Attention op as the `output_scale` argument. will be passed into Attention op as the `output_scale` argument.
""" """
def __init__( def __init__(self, layer: Attention, dtype: torch.dtype):
self, self._layer_name = layer.layer_name
layer: Attention, self._num_heads = layer.num_heads
dtype: torch.dtype, self._head_size = layer.head_size
symmetric: bool = True, self._dtype = dtype
) -> None: self._quant_matcher = MatcherQuantFP8(_FP8_QUANT_KEY)
quant_key = QuantKey(
dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
)
super().__init__(layer, quant_key, dtype)
self.quant_matcher = MatcherQuantFP8(quant_key)
def _register(self, pm_pass: PatternMatcherPass) -> None: @property
def pattern( def pattern(self) -> Callable[..., torch.Tensor]:
def _pattern(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
...@@ -150,18 +67,21 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): ...@@ -150,18 +67,21 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
key=k, key=k,
value=v, value=v,
output=output_attn, output=output_attn,
layer_name=self.layer_name, layer_name=self._layer_name,
output_scale=None, output_scale=None,
output_block_scale=None, output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep, kv_cache_dummy_dep=kv_cache_dummy_dep,
) )
attn_out_view = RESHAPE_OP( attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self.num_heads * self.head_size] at1[1], [q.shape[0], self._num_heads * self._head_size]
) )
return self._quant_matcher(attn_out_view, scale)[0]
return self.quant_matcher(attn_out_view, scale)[0] return _pattern
def replacement( @property
def replacement(self) -> Callable[..., torch.Tensor]:
def _replacement(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
...@@ -169,10 +89,9 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): ...@@ -169,10 +89,9 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
scale: torch.Tensor, scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor, kv_cache_dummy_dep: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
# attn output in quant_dtype
output_attn = torch.empty( output_attn = torch.empty(
[q.shape[0], self.num_heads, self.head_size], [q.shape[0], self._num_heads, self._head_size],
dtype=self.quant_dtype, dtype=FP8_DTYPE,
device=q.device, device=q.device,
) )
at1 = auto_functionalized( at1 = auto_functionalized(
...@@ -181,36 +100,32 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): ...@@ -181,36 +100,32 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
key=k, key=k,
value=v, value=v,
output=output_attn, output=output_attn,
layer_name=self.layer_name, layer_name=self._layer_name,
output_scale=scale, output_scale=scale,
output_block_scale=None, output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep, kv_cache_dummy_dep=kv_cache_dummy_dep,
) )
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) return RESHAPE_OP(at1[1], [-1, self._num_heads * self._head_size])
inputs = [ return _replacement
self.empty(5, self.num_heads, self.head_size), # q
self.empty(5, self.num_heads, self.head_size), # k def get_inputs(self):
self.empty(5, self.num_heads, self.head_size), # v dtype = self._dtype
self.empty(5, self.num_heads, self.head_size), # attn_output num_heads = self._num_heads
empty_fp32(1, 1), # scale head_size = self._head_size
self.empty(0), # kv_cache_dummy_dep return [
self.empty(5, num_heads, head_size, dtype=dtype), # q
self.empty(5, num_heads, head_size, dtype=dtype), # k
self.empty(5, num_heads, head_size, dtype=dtype), # v
self.empty(5, num_heads, head_size, dtype=dtype), # attn_output
self.empty_fp32(1, 1), # scale
self.empty(0, dtype=dtype), # kv_cache_dummy_dep
] ]
pm.register_replacement(
pattern,
replacement,
inputs,
AttentionQuantPattern.wrap_trace_fn(
pm.fwd_only,
AttentionQuantPattern.fx_view_to_reshape,
AttentionQuantPattern.remove_noop_permutes,
),
pm_pass,
)
class AttentionNvfp4QuantPattern(AttentionQuantPattern): class AttnNvfp4QuantPattern(
VllmPatternReplacement[..., tuple[torch.Tensor, torch.Tensor]]
):
""" """
Fusion for Attention+Nvfp4Quant. Fusion for Attention+Nvfp4Quant.
...@@ -220,11 +135,16 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): ...@@ -220,11 +135,16 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
will be passed into Attention op as the `output_scale` argument. will be passed into Attention op as the `output_scale` argument.
""" """
def __init__(self, layer: Attention, dtype: torch.dtype) -> None: def __init__(self, layer: Attention, dtype: torch.dtype):
super().__init__(layer, kNvfp4Dynamic, dtype) self._layer_name = layer.layer_name
self._num_heads = layer.num_heads
self._head_size = layer.head_size
self._dtype = dtype
self._QUANT_OP = QUANT_OPS[kNvfp4Dynamic]
def _register(self, pm_pass: PatternMatcherPass) -> None: @property
def pattern( def pattern(self) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
def _pattern(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
...@@ -240,16 +160,16 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): ...@@ -240,16 +160,16 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
key=k, key=k,
value=v, value=v,
output=output_attn, output=output_attn,
layer_name=self.layer_name, layer_name=self._layer_name,
output_scale=None, output_scale=None,
output_block_scale=None, output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep, kv_cache_dummy_dep=kv_cache_dummy_dep,
) )
attn_out_view = RESHAPE_OP( attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self.num_heads * self.head_size] at1[1], [q.shape[0], self._num_heads * self._head_size]
) )
at2 = auto_functionalized( at2 = auto_functionalized(
self.QUANT_OP, self._QUANT_OP,
input=attn_out_view, input=attn_out_view,
input_scale=input_scale, input_scale=input_scale,
is_sf_swizzled_layout=True, is_sf_swizzled_layout=True,
...@@ -259,23 +179,25 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): ...@@ -259,23 +179,25 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE) output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
return at2[1], output_scale_view return at2[1], output_scale_view
def replacement( return _pattern
@property
def replacement(self) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
def _replacement(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_attn: torch.Tensor,
output_quant: torch.Tensor, _output_quant: torch.Tensor,
output_scale: torch.Tensor, output_scale: torch.Tensor,
input_scale: torch.Tensor, input_scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor, kv_cache_dummy_dep: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# attention output in quant_dtype
output_attn = torch.empty( output_attn = torch.empty(
[q.shape[0], self.num_heads, self.head_size // 2], [q.shape[0], self._num_heads, self._head_size // 2],
dtype=self.quant_dtype, dtype=FP4_DTYPE,
device=q.device, device=q.device,
) )
# attention output block scale
output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE) output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
at2 = auto_functionalized( at2 = auto_functionalized(
ATTN_OP, ATTN_OP,
...@@ -283,41 +205,35 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): ...@@ -283,41 +205,35 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
key=k, key=k,
value=v, value=v,
output=output_attn, output=output_attn,
layer_name=self.layer_name, layer_name=self._layer_name,
output_scale=input_scale, output_scale=input_scale,
output_block_scale=output_scale_view, output_block_scale=output_scale_view,
kv_cache_dummy_dep=kv_cache_dummy_dep, kv_cache_dummy_dep=kv_cache_dummy_dep,
) )
output = RESHAPE_OP(at2[1], [-1, self.num_heads * self.head_size // 2]) output = RESHAPE_OP(at2[1], [-1, self._num_heads * self._head_size // 2])
return output, at2[2] return output, at2[2]
inputs = [ return _replacement
empty_bf16(5, self.num_heads, self.head_size), # q
empty_bf16(5, self.num_heads, self.head_size), # k def get_inputs(self):
empty_bf16(5, self.num_heads, self.head_size), # v dtype = self._dtype
empty_bf16(5, self.num_heads, self.head_size), # output_attn num_heads = self._num_heads
self.empty_quant(5, self.num_heads * self.head_size // 2), # output_quant head_size = self._head_size
empty_i32( return [
128, round_up(self.num_heads * self.head_size // 16, 4) self.empty_bf16(5, num_heads, head_size), # q
self.empty_bf16(5, num_heads, head_size), # k
self.empty_bf16(5, num_heads, head_size), # v
self.empty_bf16(5, num_heads, head_size), # output_attn
self.empty(5, num_heads * head_size // 2, dtype=FP4_DTYPE), # output_quant
self.empty_i32(
128, round_up(num_heads * head_size // 16, 4)
), # output_scale ), # output_scale
empty_fp32(1, 1), # input_scale self.empty_fp32(1, 1), # input_scale
self.empty(0), # kv_cache_dummy_dep self.empty(0, dtype=dtype), # kv_cache_dummy_dep
] ]
pm.register_replacement(
pattern,
replacement,
inputs,
AttentionQuantPattern.wrap_trace_fn(
pm.fwd_only,
AttentionQuantPattern.fx_view_to_reshape,
AttentionQuantPattern.remove_noop_permutes,
),
pm_pass,
)
class AttnFusionPass(VllmPatternMatcherPass): class AttnQuantFusionPass(VllmFusionPatternMatcherPass):
""" """
This pass fuses post-attention quantization onto attention if supported. This pass fuses post-attention quantization onto attention if supported.
...@@ -330,43 +246,26 @@ class AttnFusionPass(VllmPatternMatcherPass): ...@@ -330,43 +246,26 @@ class AttnFusionPass(VllmPatternMatcherPass):
support are attention kernels, which need to support fusing output quant. support are attention kernels, which need to support fusing output quant.
""" """
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None: def __init__(self, config: VllmConfig) -> None:
super().__init__(config) super().__init__(config, "attn_quant_fusion")
self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")
attn_layers = get_layers_from_vllm_config(config, Attention)
for layer_name, layer in attn_layers.items():
pattern_fp8 = AttentionFp8StaticQuantPattern(
layer, config.model_config.dtype
)
pattern_fp8.register_if_supported(self.patterns)
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): dtype = config.model_config.dtype
pattern_nvfp4 = AttentionNvfp4QuantPattern( layers = list(get_layers_from_vllm_config(config, Attention).values())
layer, config.model_config.dtype
)
pattern_nvfp4.register_if_supported(self.patterns)
if len(attn_layers) == 0: if len(layers) == 0:
logger.warning( logger.warning(
"Attention + quant fusion is enabled, but no attention layers " "Attention + quant fusion is enabled, but no attention layers "
"were found in CompilationConfig.static_forward_context " "were found in CompilationConfig.static_forward_context "
"so no fusion patterns were registered." "so no fusion patterns were registered."
) )
self.dump_patterns(config, self.patterns) for layer in layers:
if layer.impl.fused_output_quant_supported(_FP8_QUANT_KEY):
self.register(AttnFp8StaticQuantPattern(layer, dtype))
@VllmInductorPass.time_and_log if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
def __call__(self, graph: torch.fx.graph.Graph) -> None: for layer in layers:
self.matched_count = self.patterns.apply(graph) if layer.impl.fused_output_quant_supported(kNvfp4Dynamic):
logger.debug("Fused quant onto %s attention nodes", self.matched_count) self.register(AttnNvfp4QuantPattern(layer, dtype))
def uuid(self) -> str: self.dump_patterns(config, self.pm_pass)
return VllmInductorPass.hash_source(
self,
AttentionQuantPattern,
AttentionFp8StaticQuantPattern,
AttentionNvfp4QuantPattern,
)
...@@ -14,7 +14,7 @@ from vllm.logger import init_logger ...@@ -14,7 +14,7 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.system_utils import set_env_var from vllm.utils.system_utils import set_env_var
from .vllm_inductor_pass import VllmInductorPass from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
if rocm_aiter_ops.is_enabled(): if rocm_aiter_ops.is_enabled():
from .fusion.rocm_aiter_fusion import ( from .fusion.rocm_aiter_fusion import (
...@@ -25,7 +25,7 @@ if rocm_aiter_ops.is_enabled(): ...@@ -25,7 +25,7 @@ if rocm_aiter_ops.is_enabled():
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from .fusion.act_quant_fusion import ActivationQuantFusionPass from .fusion.act_quant_fusion import ActivationQuantFusionPass
from .fusion.attn_quant_fusion import AttnFusionPass from .fusion.attn_quant_fusion import AttnQuantFusionPass
from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass
from .fusion.rms_quant_fusion import RMSNormQuantFusionPass from .fusion.rms_quant_fusion import RMSNormQuantFusionPass
from .fusion.rope_kvcache_fusion import RopeKVCacheFusionPass from .fusion.rope_kvcache_fusion import RopeKVCacheFusionPass
...@@ -108,6 +108,8 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc] ...@@ -108,6 +108,8 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
self.fix_functionalization(graph) self.fix_functionalization(graph)
VllmInductorPass.dump_prefix = None # Cleanup index VllmInductorPass.dump_prefix = None # Cleanup index
VllmPatternMatcherPass.log_match_summary()
def configure(self, config: VllmConfig) -> None: def configure(self, config: VllmConfig) -> None:
self.pass_config = config.compilation_config.pass_config self.pass_config = config.compilation_config.pass_config
...@@ -144,7 +146,7 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc] ...@@ -144,7 +146,7 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
self.passes += [RopeKVCacheFusionPass(config)] self.passes += [RopeKVCacheFusionPass(config)]
if self.pass_config.fuse_attn_quant: if self.pass_config.fuse_attn_quant:
self.passes += [AttnFusionPass(config)] self.passes += [AttnQuantFusionPass(config)]
if self.pass_config.enable_qk_norm_rope_fusion: if self.pass_config.enable_qk_norm_rope_fusion:
self.passes += [SplitCoalescingPass(config)] self.passes += [SplitCoalescingPass(config)]
......
...@@ -3,19 +3,24 @@ ...@@ -3,19 +3,24 @@
import functools import functools
import operator import operator
import time import time
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar from typing import Any, ClassVar, Generic, ParamSpec, TypeVar
import regex as re import regex as re
import torch import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._dynamo.utils import lazy_format_graph_code from torch._dynamo.utils import lazy_format_graph_code
from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from .inductor_pass import InductorPass from .fx_utils import is_func
from .inductor_pass import InductorPass, enable_fake_mode
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -79,18 +84,23 @@ class VllmInductorPass(InductorPass): ...@@ -79,18 +84,23 @@ class VllmInductorPass(InductorPass):
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms) logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
def get_match_table() -> dict[str, int]:
"""Return a snapshot of the match table."""
return dict(VllmPatternMatcherPass.match_table)
class VllmPatternMatcherPass(VllmInductorPass): class VllmPatternMatcherPass(VllmInductorPass):
""" """
A VllmInductorPass that uses the Inductor pattern matcher. A VllmInductorPass that uses the Inductor pattern matcher.
Its main use is providing the dump_patterns utility that dumps the Provides pattern registration with match counting, debug dumping, and logging.
Inductor pattern matcher patterns into a file, which greatly aids debugging.
TODO(luka) move more utilities to this pass.
""" """
matched_count: int = 0 matched_count: int = 0
"""The number of matched patterns in the pass.""" """The number of matched patterns in the pass."""
match_table: ClassVar[defaultdict[str, int]] = defaultdict(int)
"""Global table mapping pass name to its total match count."""
_OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile( _OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile(
r"<OpOverload\(op='([^']*)', overload='([^']*)'\)>" r"<OpOverload\(op='([^']*)', overload='([^']*)'\)>"
) )
...@@ -104,6 +114,11 @@ class VllmPatternMatcherPass(VllmInductorPass): ...@@ -104,6 +114,11 @@ class VllmPatternMatcherPass(VllmInductorPass):
) )
) )
@classmethod
def log_match_summary(cls) -> None:
if cls.match_table:
logger.debug("fusion pass matches: %s", dict(cls.match_table))
def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass) -> None: def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass) -> None:
""" """
If debug dumping is enabled, dump the Inductor pattern-matcher patterns If debug dumping is enabled, dump the Inductor pattern-matcher patterns
...@@ -171,6 +186,124 @@ class VllmPatternMatcherPass(VllmInductorPass): ...@@ -171,6 +186,124 @@ class VllmPatternMatcherPass(VllmInductorPass):
print(f"{pattern_repr}\n", file=f) print(f"{pattern_repr}\n", file=f)
P = ParamSpec("P")
R = TypeVar("R")
class VllmPatternReplacement(ABC, Generic[P, R]):
"""
A pattern/replacement pair for FX graph fusion.
Implement the three abstract members below, then pass
instances to VllmFusionPatternMatcherPass.register(). The pass will
find every occurrence of `pattern` in the graph and substitute it
with `replacement`.
"""
# TODO(Badr): bound methods work for pattern registration since
# PyTorch 2.10. Once vLLM requires torch>=2.11, replace these properties
# with plain methods and drop the closure indirection.
@property
@abstractmethod
def pattern(self) -> Callable[P, R]:
"""Returns a closure defining the FX subgraph to search for."""
...
@property
@abstractmethod
def replacement(self) -> Callable[P, R]:
"""
Returns a closure defining the FX subgraph to
substitute in place of each match.
"""
...
@abstractmethod
def get_inputs(self) -> list[torch.Tensor]:
"""Example tensors used to trace pattern and replacement."""
...
# Helpers for get_inputs: uninitialized tensors of common dtypes.
@staticmethod
def empty(*args, **kwargs) -> torch.Tensor:
return torch.empty(*args, device="cuda", **kwargs)
@staticmethod
def empty_bf16(*args, **kwargs) -> torch.Tensor:
return torch.empty(*args, dtype=torch.bfloat16, device="cuda", **kwargs)
@staticmethod
def empty_fp16(*args, **kwargs) -> torch.Tensor:
return torch.empty(*args, dtype=torch.float16, device="cuda", **kwargs)
@staticmethod
def empty_fp32(*args, **kwargs) -> torch.Tensor:
return torch.empty(*args, dtype=torch.float32, device="cuda", **kwargs)
@staticmethod
def empty_i32(*args, **kwargs) -> torch.Tensor:
return torch.empty(*args, dtype=torch.int32, device="cuda", **kwargs)
def _fx_view_to_reshape(gm: fx.GraphModule) -> None:
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
def _remove_noop_permutes(gm: fx.GraphModule) -> None:
for node in gm.graph.nodes:
if not is_func(node, torch.ops.aten.permute.default):
continue
dims = node.args[1]
if any(dim != i for i, dim in enumerate(dims)):
continue
node.replace_all_uses_with(node.args[0])
gm.graph.erase_node(node)
class VllmFusionPatternMatcherPass(VllmPatternMatcherPass):
"""
A VllmPatternMatcherPass for passes that use VllmPatternReplacement objects.
Subclasses register patterns via self.register() in their own __init__.
"""
def __init__(self, config: VllmConfig, pass_name: str) -> None:
super().__init__(config)
self.pass_name = pass_name
self.pm_pass = PatternMatcherPass(pass_name=pass_name)
self._pattern_replacements: list[VllmPatternReplacement] = []
@enable_fake_mode
def register(self, pr: VllmPatternReplacement) -> None:
pm.register_replacement(
pr.pattern,
pr.replacement,
pr.get_inputs(),
self._trace_fn,
self.pm_pass,
)
self._pattern_replacements.append(pr)
def uuid(self) -> str:
return VllmInductorPass.hash_source(
type(self),
*[type(pr) for pr in self._pattern_replacements],
)
@staticmethod
def _trace_fn(*args: Any, **kwargs: Any) -> fx.GraphModule:
gm = pm.fwd_only(*args, **kwargs)
_fx_view_to_reshape(gm)
_remove_noop_permutes(gm)
return gm
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph) -> None:
self.matched_count = self.pm_pass.apply(graph)
VllmPatternMatcherPass.match_table[self.pass_name] += self.matched_count
class PrinterInductorPass(VllmInductorPass): class PrinterInductorPass(VllmInductorPass):
def __init__(self, name: str, config: VllmConfig) -> None: def __init__(self, name: str, config: VllmConfig) -> None:
super().__init__(config) super().__init__(config)
......
...@@ -703,6 +703,11 @@ class Worker(WorkerBase): ...@@ -703,6 +703,11 @@ class Worker(WorkerBase):
def get_supported_tasks(self) -> tuple[SupportedTask, ...]: def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_runner.get_supported_tasks() return self.model_runner.get_supported_tasks()
def get_compilation_match_table(self) -> dict[str, int]:
from vllm.compilation.passes.vllm_inductor_pass import get_match_table
return get_match_table()
def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]: def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]:
"""Get encoder timing stats from model runner.""" """Get encoder timing stats from model runner."""
return self.model_runner.get_encoder_timing_stats() return self.model_runner.get_encoder_timing_stats()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment