Unverified Commit 4506319a authored by Carl Y's avatar Carl Y Committed by GitHub
Browse files

[compile] mla + group fp8 fusion (#38877)


Signed-off-by: default avatarCarl You <4531192+carlyou@users.noreply.github.com>
Co-authored-by: default avatarmergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
parent 9b60e2ff
...@@ -44,7 +44,7 @@ The table below lists the quantization schemes supported by each fusion on each ...@@ -44,7 +44,7 @@ The table below lists the quantization schemes supported by each fusion on each
| `fuse_allreduce_rms` | FP16/BF16, FP8 static, NVFP4 | FP16/BF16, FP8 static | — | — | — | | `fuse_allreduce_rms` | FP16/BF16, FP8 static, NVFP4 | FP16/BF16, FP8 static | — | — | — |
| `fuse_minimax_qk_norm`\* | FP16/BF16 | FP16/BF16 | FP16/BF16 | FP16/BF16 | — | | `fuse_minimax_qk_norm`\* | FP16/BF16 | FP16/BF16 | FP16/BF16 | FP16/BF16 | — |
| `fuse_attn_quant`\* | FP8 static\*, NVFP4\* | FP8 static\* | FP8 static\* | — | FP8 static\* | | `fuse_attn_quant`\* | FP8 static\*, NVFP4\* | FP8 static\* | FP8 static\* | — | FP8 static\* |
| `fuse_attn_quant` (MLA)\* | FP8 static\*, NVFP4\* | FP8 static\* | FP8 static\* | — | FP8 static(untested)\* | | `fuse_attn_quant` (MLA)\* | FP8 static\*, FP8 per-group\*, NVFP4\* | FP8 static\*, FP8 per-group\* | FP8 static\*, FP8 per-group\* | — | FP8 static\* (untested) |
| `fuse_rope_kvcache` | — | — | — | — | FP16/BF16 | | `fuse_rope_kvcache` | — | — | — | — | FP16/BF16 |
| `enable_qk_norm_rope_fusion` | FP16/BF16 | FP16/BF16 | FP16/BF16† | FP16/BF16† | — | | `enable_qk_norm_rope_fusion` | FP16/BF16 | FP16/BF16 | FP16/BF16† | FP16/BF16† | — |
| `enable_sp` | FP16/BF16, FP8 static† | FP16/BF16, FP8 static | FP16/BF16† | FP16/BF16† | — | | `enable_sp` | FP16/BF16, FP8 static† | FP16/BF16, FP8 static | FP16/BF16† | FP16/BF16† | — |
...@@ -152,7 +152,7 @@ standard `Attention` and `MLAAttention` (used by DeepSeek-V2/V3/R1 models). Patt ...@@ -152,7 +152,7 @@ standard `Attention` and `MLAAttention` (used by DeepSeek-V2/V3/R1 models). Patt
- `FLASHINFER`: CUDA sm100+ with FlashInfer installed - `FLASHINFER`: CUDA sm100+ with FlashInfer installed
`MLAAttention → FP8 static quant` / `MLAAttention → NVFP4 dynamic quant`: `MLAAttention → FP8 static, FP8 per-group, NVFP4 dynamic quant`
The MLA fusion operates at the graph level on the `unified_mla_attention_with_output` op and works The MLA fusion operates at the graph level on the `unified_mla_attention_with_output` op and works
with all MLA decode and prefill backend combinations. Unlike standard `Attention` backends (where with all MLA decode and prefill backend combinations. Unlike standard `Attention` backends (where
......
...@@ -116,6 +116,22 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn): ...@@ -116,6 +116,22 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
model_kwargs["attention_config"] = {"backend": attn_backend.backend.name} model_kwargs["attention_config"] = {"backend": attn_backend.backend.name}
model_kwargs["tensor_parallel_size"] = tp_size model_kwargs["tensor_parallel_size"] = tp_size
# Sparse MLA models (DSv3.2) hit an over-strict inductor assertion in
# decompose_auto_functionalized when +rotary_embedding is forced into
# the compile graph. Disable qk_norm+rope fusion (which auto-enables
# +rotary_embedding) for this combo to avoid the known torch bug.
# TODO: remove once upstream torch fix lands.
if requires_sparse:
if "pass_config" in compilation_config:
compilation_config["pass_config"].enable_qk_norm_rope_fusion = False
matches_check = [m for m in matches_check if m != "norm_rope_fusion"]
# DSv3.2 sparse indexer uses persistent_topk with k=config.index_topk
# (2048 for the default config). max_model_len must be >= index_topk
# or the topk kernel raises "k out of range" at runtime.
model_kwargs["max_model_len"] = max(
model_kwargs.get("max_model_len", 0), 2048
)
# Always compile the full graph instead of piecewise # Always compile the full graph instead of piecewise
if not compilation_config["use_inductor_graph_partition"]: if not compilation_config["use_inductor_graph_partition"]:
compilation_config["splitting_ops"] = [] compilation_config["splitting_ops"] = []
......
...@@ -59,7 +59,10 @@ TRITON_MLA_ATTN = pytest.param( ...@@ -59,7 +59,10 @@ TRITON_MLA_ATTN = pytest.param(
) )
FLASHMLA_SPARSE_ATTN = pytest.param( FLASHMLA_SPARSE_ATTN = pytest.param(
AttentionBackendCase(backend=AttentionBackendEnum.FLASHMLA_SPARSE), AttentionBackendCase(
backend=AttentionBackendEnum.FLASHMLA_SPARSE,
model_kwargs=dict(kv_cache_dtype="fp8_ds_mla"),
),
id="FLASHMLA_SPARSE", id="FLASHMLA_SPARSE",
marks=pytest.mark.skipif( marks=pytest.mark.skipif(
not is_blackwell(), not is_blackwell(),
...@@ -173,9 +176,8 @@ deepseek_v3_fp8 = ModelFusionInfo( ...@@ -173,9 +176,8 @@ deepseek_v3_fp8 = ModelFusionInfo(
rms_quant_fusion=n_layers * 2 + min(3, n_layers), # add for 3 dense layers rms_quant_fusion=n_layers * 2 + min(3, n_layers), # add for 3 dense layers
# silu+block quant # silu+block quant
act_quant_fusion=min(3, n_layers), # dense layers only act_quant_fusion=min(3, n_layers), # dense layers only
# MLA attn + per-group FP8 quant not supported yet: # MLA attn + per-group FP8 quant
# https://github.com/vllm-project/vllm/issues/35792 attn_quant_fusion=n_layers,
attn_quant_fusion=0,
ar_rms_fusion=n_layers * 2 + 1, ar_rms_fusion=n_layers * 2 + 1,
# TODO # TODO
# sequence_parallel= n_layers * 2 + 1, # sequence_parallel= n_layers * 2 + 1,
...@@ -183,11 +185,23 @@ deepseek_v3_fp8 = ModelFusionInfo( ...@@ -183,11 +185,23 @@ deepseek_v3_fp8 = ModelFusionInfo(
), ),
) )
deepseek_r1_fp4 = ModelFusionInfo(
model_name="nvidia/DeepSeek-R1-0528-NVFP4-v2",
matches=lambda n_layers: Matches(
rms_quant_fusion=0,
act_quant_fusion=min(3, n_layers),
attn_quant_fusion=n_layers,
ar_rms_fusion=n_layers * 2 + 1,
),
)
deepseek_v32_fp4 = ModelFusionInfo( deepseek_v32_fp4 = ModelFusionInfo(
model_name="nvidia/DeepSeek-V3.2-NVFP4", model_name="nvidia/DeepSeek-V3.2-NVFP4",
matches=lambda n_layers: Matches( matches=lambda n_layers: Matches(
rms_quant_fusion=0, rms_quant_fusion=0,
act_quant_fusion=0, # silu+quant on dense layers only; MoE hides the act+quant site
act_quant_fusion=min(3, n_layers),
# MLA attn + NVFP4 output quant fuses on sparse MLA output path
attn_quant_fusion=n_layers, attn_quant_fusion=n_layers,
ar_rms_fusion=n_layers * 2 + 1, ar_rms_fusion=n_layers * 2 + 1,
), ),
......
...@@ -24,6 +24,7 @@ from .models import ( ...@@ -24,6 +24,7 @@ from .models import (
TRITON_ATTN, TRITON_ATTN,
TRITON_MLA_ATTN, TRITON_MLA_ATTN,
deepseek_coder_v2_lite_fp8, deepseek_coder_v2_lite_fp8,
deepseek_r1_fp4,
deepseek_v3_fp8, deepseek_v3_fp8,
deepseek_v32_fp4, deepseek_v32_fp4,
llama3_8b_fp4, llama3_8b_fp4,
...@@ -148,11 +149,11 @@ def test_tp1_fp8_fusions( ...@@ -148,11 +149,11 @@ def test_tp1_fp8_fusions(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name, matches_fn, model_kwargs, hf_overrides", "model_name, matches_fn, model_kwargs, hf_overrides",
[llama3_8b_fp4, llama4_scout_fp4, deepseek_v32_fp4], [llama3_8b_fp4, llama4_scout_fp4, deepseek_r1_fp4, deepseek_v32_fp4],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"attn_backend", "attn_backend",
[FLASHINFER_ATTN, FLASHMLA_SPARSE_ATTN], [FLASHINFER_ATTN, FLASHINFER_MLA_ATTN, FLASHMLA_SPARSE_ATTN],
) )
@pytest.mark.parametrize("n_layers", [6]) @pytest.mark.parametrize("n_layers", [6])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm")) @pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm"))
......
...@@ -21,6 +21,7 @@ from .models import ( ...@@ -21,6 +21,7 @@ from .models import (
FLASHMLA_SPARSE_ATTN, FLASHMLA_SPARSE_ATTN,
TRITON_ATTN, TRITON_ATTN,
deepseek_coder_v2_lite_fp8, deepseek_coder_v2_lite_fp8,
deepseek_r1_fp4,
deepseek_v3_fp8, deepseek_v3_fp8,
deepseek_v32_fp4, deepseek_v32_fp4,
gpt_oss_20b, gpt_oss_20b,
...@@ -113,11 +114,11 @@ def test_tp2_ar_rms_fp8_fusions( ...@@ -113,11 +114,11 @@ def test_tp2_ar_rms_fp8_fusions(
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name, matches_fn, model_kwargs, hf_overrides", "model_name, matches_fn, model_kwargs, hf_overrides",
[llama3_8b_fp4, llama4_scout_fp4, deepseek_v32_fp4], [llama3_8b_fp4, llama4_scout_fp4, deepseek_r1_fp4, deepseek_v32_fp4],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"attn_backend", "attn_backend",
[FLASHINFER_ATTN, FLASHMLA_SPARSE_ATTN], [FLASHINFER_ATTN, FLASHINFER_MLA_ATTN, FLASHMLA_SPARSE_ATTN],
) )
@pytest.mark.parametrize("n_layers", [4]) @pytest.mark.parametrize("n_layers", [4])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm")) @pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm"))
......
...@@ -29,12 +29,18 @@ from vllm.config import ( ...@@ -29,12 +29,18 @@ from vllm.config import (
set_current_vllm_config, set_current_vllm_config,
) )
from vllm.forward_context import get_forward_context, set_forward_context from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
CutlassFp8BlockScaledMMKernel,
)
from vllm.model_executor.layers.attention import MLAAttention from vllm.model_executor.layers.attention import MLAAttention
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.modelopt import ModelOptNvFp4Config from vllm.model_executor.layers.quantization.modelopt import ModelOptNvFp4Config
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey, QuantKey,
create_fp8_quant_key,
kFp8Dynamic128Sym,
kFp8StaticTensorSym, kFp8StaticTensorSym,
kNvfp4Dynamic, kNvfp4Dynamic,
) )
...@@ -279,6 +285,67 @@ class TestMLAAttentionNvfp4QuantPatternModel(MLAAttentionQuantPatternModel): ...@@ -279,6 +285,67 @@ class TestMLAAttentionNvfp4QuantPatternModel(MLAAttentionQuantPatternModel):
) )
class TestMLAAttentionFp8GroupQuantPatternModel(MLAAttentionQuantPatternModel):
"""Test model for MLA Attention + per-group FP8 (block quant) fusion."""
quant_key = kFp8Dynamic128Sym
quant_config = Fp8Config(
is_checkpoint_fp8_serialized=True,
weight_block_size=[128, 128],
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
weight_quant_key = create_fp8_quant_key(
static=True, group_shape=GroupShape(128, 128)
)
device = kwargs.get("device", torch.device("cuda:0"))
# Subclass to set weight_block_size before process_weights_after_loading
class _BlockFP8Layer(TestFP8Layer):
def __init__(self, *a, **kw):
self.weight_block_size = [128, 128]
super().__init__(*a, **kw)
# Force CutlassFp8BlockScaledMMKernel to ensure the graph uses
# per_token_group_fp8_quant (not the deepgemm packed variant).
self.block_fp8_linear = _BlockFP8Layer(
weight_shape=(self.output_dim, self.output_dim),
activation_quant_key=self.quant_key,
weight_quant_key=weight_quant_key,
input_dtype=self.dtype,
device=device,
force_kernel=CutlassFp8BlockScaledMMKernel,
)
w = kwargs.get("w")
if w is not None:
self.block_fp8_linear.weight = w["weight"]
# Block-wise uses weight_scale_inv, not weight_scale
self.block_fp8_linear.weight_scale_inv = w["wscale"]
self.w = {
"weight": self.block_fp8_linear.weight,
"wscale": self.block_fp8_linear.weight_scale_inv,
}
def forward(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
):
"""Forward pass: MLA attention -> block FP8 linear (group quant)."""
attn_output = self.mla_attn(
q,
kv_c_normed,
k_pe,
output_shape=(q.shape[0], self.output_dim),
)
return self.block_fp8_linear(attn_output)
def is_nvfp4_supported(): def is_nvfp4_supported():
return current_platform.has_device_capability(100) return current_platform.has_device_capability(100)
...@@ -286,6 +353,7 @@ def is_nvfp4_supported(): ...@@ -286,6 +353,7 @@ def is_nvfp4_supported():
# MLA test configuration # MLA test configuration
MLA_DIMS: list[tuple[int, int, int, int, int]] = [] MLA_DIMS: list[tuple[int, int, int, int, int]] = []
PATTERN_TEST_MODELS_MLA_FP8: list[tuple[str, type]] = [] PATTERN_TEST_MODELS_MLA_FP8: list[tuple[str, type]] = []
PATTERN_TEST_MODELS_MLA_GROUP_FP8: list[tuple[str, type]] = []
PATTERN_TEST_MODELS_MLA_FP4: list[tuple[str, type]] = [] PATTERN_TEST_MODELS_MLA_FP4: list[tuple[str, type]] = []
BACKENDS_MLA_FP8: list[AttentionBackendEnum] = [] BACKENDS_MLA_FP8: list[AttentionBackendEnum] = []
BACKENDS_MLA_FP4: list[AttentionBackendEnum] = [] BACKENDS_MLA_FP4: list[AttentionBackendEnum] = []
...@@ -299,6 +367,12 @@ if current_platform.is_cuda(): ...@@ -299,6 +367,12 @@ if current_platform.is_cuda():
TestMLAAttentionFp8StaticQuantPatternModel, TestMLAAttentionFp8StaticQuantPatternModel,
) )
] ]
PATTERN_TEST_MODELS_MLA_GROUP_FP8 = [
(
"deepseek-ai/DeepSeek-V3",
TestMLAAttentionFp8GroupQuantPatternModel,
)
]
PATTERN_TEST_MODELS_MLA_FP4 = [ PATTERN_TEST_MODELS_MLA_FP4 = [
( (
"deepseek-ai/DeepSeek-V2-Lite", "deepseek-ai/DeepSeek-V2-Lite",
...@@ -324,6 +398,13 @@ if current_platform.is_cuda(): ...@@ -324,6 +398,13 @@ if current_platform.is_cuda():
["+quant_fp8", "-quant_fp8"], ["+quant_fp8", "-quant_fp8"],
) )
) )
+ list(
flat_product(
BACKENDS_MLA_FP8,
PATTERN_TEST_MODELS_MLA_GROUP_FP8,
["+quant_fp8"],
)
)
+ list(flat_product(BACKENDS_MLA_FP4, PATTERN_TEST_MODELS_MLA_FP4, [""])), + list(flat_product(BACKENDS_MLA_FP4, PATTERN_TEST_MODELS_MLA_FP4, [""])),
) )
@pytest.mark.skipif( @pytest.mark.skipif(
...@@ -470,12 +551,13 @@ def test_mla_attention_quant_pattern( ...@@ -470,12 +551,13 @@ def test_mla_attention_quant_pattern(
) )
# Check quantization ops in the graph # Check quantization ops in the graph
is_per_group = quant_key.scale.group_shape.is_per_group()
quant_op = ( quant_op = (
torch.ops.aten.reciprocal torch.ops.aten.reciprocal
if "-quant_fp8" in custom_ops_list if "-quant_fp8" in custom_ops_list
else QUANT_OPS[quant_key] else QUANT_OPS[quant_key]
) )
test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Dynamic) test_backend.check_before_ops([quant_op], fully_replaced=is_per_group)
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported) assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
...@@ -487,25 +569,24 @@ def test_mla_attention_quant_pattern( ...@@ -487,25 +569,24 @@ def test_mla_attention_quant_pattern(
assert len(attn_nodes_pre) == len(attn_nodes_post), ( assert len(attn_nodes_pre) == len(attn_nodes_post), (
"Should have same number of MLA attention nodes before and after fusion" "Should have same number of MLA attention nodes before and after fusion"
) )
assert attn_nodes_pre[0].kwargs.get("output_scale") is None, (
"MLA attention should not have output_scale before fusion"
)
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, (
"MLA attention should have output_scale after fusion"
)
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, ( # Before fusion: neither scale should be set
"MLA attention should not have output_block_scale before fusion" assert attn_nodes_pre[0].kwargs.get("output_scale") is None
) assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None
if quant_key.dtype == FP8_DTYPE: # After fusion: derive expected scale presence from quant_key properties.
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, ( # - output_scale: present for static quant or non-FP8 (NVFP4 carries input_scale)
"MLA attention should not have output_block_scale after FP8 fusion" # - output_block_scale: present when quant uses per-group/block scaling
) has_output_scale = attn_nodes_post[0].kwargs.get("output_scale") is not None
elif quant_key.dtype == FP4_DTYPE: has_block_scale = attn_nodes_post[0].kwargs.get("output_block_scale") is not None
assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, (
"MLA attention should have output_block_scale after FP4 fusion" expects_output_scale = quant_key.scale.static or quant_key.dtype != FP8_DTYPE
) assert has_output_scale == expects_output_scale, (
f"output_scale: expected present={expects_output_scale}, got {has_output_scale}"
)
assert has_block_scale == is_per_group, (
f"output_block_scale: expected present={is_per_group}, got {has_block_scale}"
)
# Check numerical correctness # Check numerical correctness
torch.testing.assert_close(result_unfused, result_fused, atol=1e-2, rtol=1e-2) torch.testing.assert_close(result_unfused, result_fused, atol=1e-2, rtol=1e-2)
...@@ -6,15 +6,18 @@ from collections.abc import Callable ...@@ -6,15 +6,18 @@ from collections.abc import Callable
import torch import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._higher_order_ops.auto_functionalize import auto_functionalized
from vllm._custom_ops import create_fp4_output_tensors
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import MLAAttention from vllm.model_executor.layers.attention.mla_attention import MLAAttention
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic64Sym,
kFp8Dynamic128Sym,
kFp8StaticTensorSym, kFp8StaticTensorSym,
kNvfp4Dynamic, kNvfp4Dynamic,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
from vllm.utils.torch_utils import _USE_LAYERNAME, _encode_layer_name from vllm.utils.torch_utils import _USE_LAYERNAME, _encode_layer_name
from ..vllm_inductor_pass import VllmFusionPatternMatcherPass, VllmPatternReplacement from ..vllm_inductor_pass import VllmFusionPatternMatcherPass, VllmPatternReplacement
...@@ -203,6 +206,8 @@ class MLAAttnNvfp4QuantPattern( ...@@ -203,6 +206,8 @@ class MLAAttnNvfp4QuantPattern(
kv_c_normed, kv_c_normed,
k_pe, k_pe,
output_attn, output_attn,
output_quant,
output_scale,
input_scale, input_scale,
kv_cache_dummy_dep, kv_cache_dummy_dep,
layer_name, layer_name,
...@@ -218,9 +223,6 @@ class MLAAttnNvfp4QuantPattern( ...@@ -218,9 +223,6 @@ class MLAAttnNvfp4QuantPattern(
output_block_scale=None, output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep, kv_cache_dummy_dep=kv_cache_dummy_dep,
) )
output_quant, output_scale = create_fp4_output_tensors(
at1[1].shape[0], at1[1].shape[1], at1[1].device, True
)
at2 = auto_functionalized( at2 = auto_functionalized(
self._QUANT_OP, self._QUANT_OP,
input=at1[1], input=at1[1],
...@@ -235,7 +237,14 @@ class MLAAttnNvfp4QuantPattern( ...@@ -235,7 +237,14 @@ class MLAAttnNvfp4QuantPattern(
return _pattern_with_ln return _pattern_with_ln
def _pattern( def _pattern(
q, kv_c_normed, k_pe, output_attn, input_scale, kv_cache_dummy_dep q,
kv_c_normed,
k_pe,
output_attn,
output_quant,
output_scale,
input_scale,
kv_cache_dummy_dep,
): ):
at1 = auto_functionalized( at1 = auto_functionalized(
MLA_ATTN_OP, MLA_ATTN_OP,
...@@ -248,11 +257,6 @@ class MLAAttnNvfp4QuantPattern( ...@@ -248,11 +257,6 @@ class MLAAttnNvfp4QuantPattern(
output_block_scale=None, output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep, kv_cache_dummy_dep=kv_cache_dummy_dep,
) )
# Replicate what scaled_fp4_quant() does: allocate output
# tensors inline then call the .out variant.
output_quant, output_scale = create_fp4_output_tensors(
at1[1].shape[0], at1[1].shape[1], at1[1].device, True
)
at2 = auto_functionalized( at2 = auto_functionalized(
self._QUANT_OP, self._QUANT_OP,
input=at1[1], input=at1[1],
...@@ -279,6 +283,8 @@ class MLAAttnNvfp4QuantPattern( ...@@ -279,6 +283,8 @@ class MLAAttnNvfp4QuantPattern(
kv_c_normed, kv_c_normed,
k_pe, k_pe,
output_attn, output_attn,
_output_quant,
output_scale,
input_scale, input_scale,
kv_cache_dummy_dep, kv_cache_dummy_dep,
layer_name, layer_name,
...@@ -289,9 +295,6 @@ class MLAAttnNvfp4QuantPattern( ...@@ -289,9 +295,6 @@ class MLAAttnNvfp4QuantPattern(
dtype=FP4_DTYPE, dtype=FP4_DTYPE,
device=q.device, device=q.device,
) )
output_scale = create_fp4_output_tensors(
q.shape[0], self._output_dim, q.device, True
)[1]
output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE) output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
at2 = auto_functionalized( at2 = auto_functionalized(
MLA_ATTN_OP, MLA_ATTN_OP,
...@@ -309,7 +312,14 @@ class MLAAttnNvfp4QuantPattern( ...@@ -309,7 +312,14 @@ class MLAAttnNvfp4QuantPattern(
return _replacement_with_ln return _replacement_with_ln
def _replacement( def _replacement(
q, kv_c_normed, k_pe, output_attn, input_scale, kv_cache_dummy_dep q,
kv_c_normed,
k_pe,
output_attn,
_output_quant,
output_scale,
input_scale,
kv_cache_dummy_dep,
): ):
# MLA output in quant_dtype (FP4 packed as uint8) # MLA output in quant_dtype (FP4 packed as uint8)
output_attn = torch.empty( output_attn = torch.empty(
...@@ -317,10 +327,6 @@ class MLAAttnNvfp4QuantPattern( ...@@ -317,10 +327,6 @@ class MLAAttnNvfp4QuantPattern(
dtype=FP4_DTYPE, dtype=FP4_DTYPE,
device=q.device, device=q.device,
) )
# attention output block scale
output_scale = create_fp4_output_tensors(
q.shape[0], self._output_dim, q.device, True
)[1]
output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE) output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
at2 = auto_functionalized( at2 = auto_functionalized(
MLA_ATTN_OP, MLA_ATTN_OP,
...@@ -343,6 +349,8 @@ class MLAAttnNvfp4QuantPattern( ...@@ -343,6 +349,8 @@ class MLAAttnNvfp4QuantPattern(
self.empty(5, self._kv_lora_rank, dtype=self._dtype), self.empty(5, self._kv_lora_rank, dtype=self._dtype),
self.empty(5, 1, self._qk_rope_head_dim, dtype=self._dtype), self.empty(5, 1, self._qk_rope_head_dim, dtype=self._dtype),
self.empty(5, self._output_dim, dtype=self._dtype), self.empty(5, self._output_dim, dtype=self._dtype),
self.empty(5, self._output_dim // 2, dtype=FP4_DTYPE),
self.empty_i32(128, round_up(self._output_dim // 16, 4)),
self.empty_fp32(1, 1), self.empty_fp32(1, 1),
self.empty(0, dtype=self._dtype), self.empty(0, dtype=self._dtype),
] ]
...@@ -351,6 +359,218 @@ class MLAAttnNvfp4QuantPattern( ...@@ -351,6 +359,218 @@ class MLAAttnNvfp4QuantPattern(
return inputs return inputs
class MLAAttnFp8GroupQuantPattern(
VllmPatternReplacement[..., tuple[torch.Tensor, torch.Tensor]]
):
"""
Fusion for MLA Attention+Fp8GroupQuant (per-group dynamic FP8).
Matches the pattern: MLA attention -> per_token_group_fp8_quant, and
replaces it with MLA attention(output_block_scale=group_scale_buffer).
Used by models with block FP8 quantization (e.g. DeepSeek V3).
"""
def __init__(
self,
layer: MLAAttention,
dtype: torch.dtype,
quant_key: QuantKey,
has_col_major_scales: bool,
is_e8m0: bool,
is_tma_aligned: bool,
) -> None:
self._layer_name = layer.layer_name
self._num_heads = layer.num_heads
self._v_head_dim = layer.v_head_dim
self._kv_lora_rank = layer.kv_lora_rank
self._qk_rope_head_dim = layer.qk_rope_head_dim
self._qk_head_dim = layer.qk_nope_head_dim + layer.qk_rope_head_dim
self._output_dim = layer.num_heads * layer.v_head_dim
self._dtype = dtype
self._layer = layer
self._group_size = quant_key.scale.group_shape[1]
self._has_col_major_scales = has_col_major_scales
self._is_e8m0 = is_e8m0
self._is_tma_aligned = is_tma_aligned
self._quant_matcher = MatcherQuantFP8(
quant_key,
has_col_major_scales=has_col_major_scales,
is_e8m0=is_e8m0,
is_tma_aligned=is_tma_aligned,
)
@property
def pattern(
self,
) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
_ln = _encode_layer_name(self._layer_name)
if _USE_LAYERNAME:
def _pattern_with_ln( # type: ignore[misc]
q,
kv_c_normed,
k_pe,
output_attn,
kv_cache_dummy_dep,
scale,
layer_name,
):
at1 = auto_functionalized(
MLA_ATTN_OP,
q=q,
kv_c_normed=kv_c_normed,
k_pe=k_pe,
output=output_attn,
layer_name=layer_name,
output_scale=None,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
attn_out = at1[1]
result = torch.empty(
attn_out.shape, device=attn_out.device, dtype=FP8_DTYPE
)
finfo = torch.finfo(FP8_DTYPE)
_, result, scale = auto_functionalized(
self._quant_matcher.QUANT_OP,
input=attn_out,
output_q=result,
output_s=scale,
group_size=self._group_size,
eps=1e-10,
fp8_min=finfo.min,
fp8_max=finfo.max,
scale_ue8m0=self._is_e8m0,
dummy_is_scale_transposed=self._has_col_major_scales,
dummy_is_tma_aligned=self._is_tma_aligned,
)
return result, scale
return _pattern_with_ln
def _pattern(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output_attn: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
at1 = auto_functionalized(
MLA_ATTN_OP,
q=q,
kv_c_normed=kv_c_normed,
k_pe=k_pe,
output=output_attn,
layer_name=_ln,
output_scale=None,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
attn_out = at1[1]
result = torch.empty(
attn_out.shape, device=attn_out.device, dtype=FP8_DTYPE
)
finfo = torch.finfo(FP8_DTYPE)
_, result, scale = auto_functionalized(
self._quant_matcher.QUANT_OP,
input=attn_out,
output_q=result,
output_s=scale,
group_size=self._group_size,
eps=1e-10,
fp8_min=finfo.min,
fp8_max=finfo.max,
scale_ue8m0=self._is_e8m0,
dummy_is_scale_transposed=self._has_col_major_scales,
dummy_is_tma_aligned=self._is_tma_aligned,
)
return result, scale
return _pattern
@property
def replacement(
self,
) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
_ln = _encode_layer_name(self._layer_name)
if _USE_LAYERNAME:
def _replacement_with_ln( # type: ignore[misc]
q,
kv_c_normed,
k_pe,
output_attn,
kv_cache_dummy_dep,
scale,
layer_name,
):
output_attn = torch.empty(
[q.shape[0], self._output_dim],
dtype=FP8_DTYPE,
device=q.device,
)
at1 = auto_functionalized(
MLA_ATTN_OP,
q=q,
kv_c_normed=kv_c_normed,
k_pe=k_pe,
output=output_attn,
layer_name=layer_name,
output_scale=None,
output_block_scale=scale,
kv_cache_dummy_dep=kv_cache_dummy_dep,
quant_group_size=self._group_size,
quant_scale_ue8m0=self._is_e8m0,
quant_col_major=self._has_col_major_scales,
quant_tma_aligned=self._is_tma_aligned,
)
return at1[1], at1[2]
return _replacement_with_ln
def _replacement(q, kv_c_normed, k_pe, output_attn, kv_cache_dummy_dep, scale):
output_attn = torch.empty(
[q.shape[0], self._output_dim],
dtype=FP8_DTYPE,
device=q.device,
)
at1 = auto_functionalized(
MLA_ATTN_OP,
q=q,
kv_c_normed=kv_c_normed,
k_pe=k_pe,
output=output_attn,
layer_name=_ln,
output_scale=None,
output_block_scale=scale,
kv_cache_dummy_dep=kv_cache_dummy_dep,
quant_group_size=self._group_size,
quant_scale_ue8m0=self._is_e8m0,
quant_col_major=self._has_col_major_scales,
quant_tma_aligned=self._is_tma_aligned,
)
return at1[1], at1[2]
return _replacement
def get_inputs(self) -> list[torch.Tensor]:
inputs: list = [
self.empty(5, self._num_heads, self._qk_head_dim, dtype=self._dtype),
self.empty(5, self._kv_lora_rank, dtype=self._dtype),
self.empty(5, 1, self._qk_rope_head_dim, dtype=self._dtype),
self.empty(5, self._output_dim, dtype=self._dtype),
self.empty(0, dtype=self._dtype),
self._quant_matcher.empty_f32(1, 1),
]
if _USE_LAYERNAME:
inputs.append(_encode_layer_name(self._layer_name))
return inputs
class MLAAttnQuantFusionPass(VllmFusionPatternMatcherPass): class MLAAttnQuantFusionPass(VllmFusionPatternMatcherPass):
""" """
This pass fuses post-attention quantization onto MLA attention if supported. This pass fuses post-attention quantization onto MLA attention if supported.
...@@ -389,4 +609,25 @@ class MLAAttnQuantFusionPass(VllmFusionPatternMatcherPass): ...@@ -389,4 +609,25 @@ class MLAAttnQuantFusionPass(VllmFusionPatternMatcherPass):
if _USE_LAYERNAME: if _USE_LAYERNAME:
break break
# Per-group FP8 (block quant) — register all flag combinations.
if current_platform.is_cuda():
for quant_key in [kFp8Dynamic128Sym, kFp8Dynamic64Sym]:
for col_major in [True, False]:
for is_e8m0 in [True, False]:
for tma_aligned in [False, True]:
for layer in layers:
if layer.impl.fused_output_quant_supported(quant_key):
self.register(
MLAAttnFp8GroupQuantPattern(
layer,
dtype,
quant_key,
col_major,
is_e8m0,
tma_aligned,
)
)
if _USE_LAYERNAME:
break
self.dump_patterns(config, self.pm_pass) self.dump_patterns(config, self.pm_pass)
...@@ -234,7 +234,12 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -234,7 +234,12 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
QuantKey,
get_and_maybe_dequant_weights, get_and_maybe_dequant_weights,
kFp8Dynamic64Sym,
kFp8Dynamic128Sym,
kFp8StaticTensorSym,
kNvfp4Dynamic,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer, has_nvidia_artifactory from vllm.utils.flashinfer import has_flashinfer, has_nvidia_artifactory
...@@ -276,6 +281,44 @@ from vllm.v1.kv_cache_interface import ( ...@@ -276,6 +281,44 @@ from vllm.v1.kv_cache_interface import (
logger = init_logger(__name__) logger = init_logger(__name__)
_FP8_DTYPE = current_platform.fp8_dtype()
def _detect_output_quant_key(
output: torch.Tensor,
output_scale: torch.Tensor | None,
output_block_scale: torch.Tensor | None,
output_dim: int,
) -> QuantKey | None:
"""Detect the output quantization key from fusion pass parameters.
Returns the appropriate QuantKey, or None if no quantization is needed.
Detection is based on output dtype and which scale tensors are present.
"""
if output_scale is None and output_block_scale is None:
return None
if output_block_scale is not None:
if output.dtype == _FP8_DTYPE:
# Per-group FP8 uses block scales only, not a separate output_scale
assert output_scale is None
# Infer group size from scale shape
num_groups = output_block_scale.shape[-1]
group_size = output_dim // num_groups
if group_size == 128:
return kFp8Dynamic128Sym
elif group_size == 64:
return kFp8Dynamic64Sym
else:
raise ValueError(
f"Unsupported group FP8 group_size={group_size} "
f"(output_dim={output_dim}, num_groups={num_groups}). "
f"Only group_size 128 and 64 are supported."
)
# output_scale None implies MXFP4, not supported
assert output_scale is not None
return kNvfp4Dynamic
return kFp8StaticTensorSym
class MLAAttention(nn.Module, AttentionLayerBase): class MLAAttention(nn.Module, AttentionLayerBase):
"""Multi-Head Latent Attention layer. """Multi-Head Latent Attention layer.
...@@ -549,9 +592,17 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -549,9 +592,17 @@ class MLAAttention(nn.Module, AttentionLayerBase):
output: torch.Tensor, output: torch.Tensor,
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
quant_group_size: int | None = None,
quant_scale_ue8m0: bool | None = None,
quant_col_major: bool | None = None,
quant_tma_aligned: bool | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
use_quant = output_scale is not None or output_block_scale is not None assert output is not None, "Output tensor must be provided."
if use_quant:
quant_key = _detect_output_quant_key(
output, output_scale, output_block_scale, self.num_heads * self.v_head_dim
)
if quant_key is not None:
# The fusion pass has allocated output with quantized dtype # The fusion pass has allocated output with quantized dtype
# (FP8 or uint8 for FP4). We can't write into it directly, # (FP8 or uint8 for FP4). We can't write into it directly,
# so we swap in a temp buffer for computation, then quantize # so we swap in a temp buffer for computation, then quantize
...@@ -582,7 +633,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -582,7 +633,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
# The zero fill is required when used with DP + EP # The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the # to ensure all ranks within a DP group compute the
# same expert outputs. # same expert outputs.
if use_quant: if quant_key is not None:
return quant_output.fill_(0) return quant_output.fill_(0)
return output.fill_(0) return output.fill_(0)
...@@ -724,18 +775,41 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -724,18 +775,41 @@ class MLAAttention(nn.Module, AttentionLayerBase):
# v_up projection # v_up projection
self._v_up_proj(attn_out, out=mqa_output_slice) self._v_up_proj(attn_out, out=mqa_output_slice)
if use_quant: if quant_key is not None:
# Quantize the BF16 computation result into the quantized output # Quantize the BF16 computation result into the quantized output
actual = output[:num_actual_toks] actual = output[:num_actual_toks]
if output_block_scale is not None: if quant_key == kNvfp4Dynamic:
# NVFP4: two FP4 values packed into one uint8 # NVFP4: two FP4 values packed into one uint8
assert output_block_scale is not None
fp4_data, fp4_scales = ops.scaled_fp4_quant(actual, output_scale) fp4_data, fp4_scales = ops.scaled_fp4_quant(actual, output_scale)
quant_output[:num_actual_toks].copy_(fp4_data) quant_output[:num_actual_toks].copy_(fp4_data)
output_block_scale.copy_(fp4_scales) output_block_scale[: fp4_scales.shape[0]].copy_(fp4_scales)
else: elif quant_key in (kFp8Dynamic128Sym, kFp8Dynamic64Sym):
# Per-group FP8
assert output_block_scale is not None
assert quant_group_size is not None, (
"Group FP8 output quant requested but "
"quant_group_size not passed through custom op"
)
finfo = torch.finfo(_FP8_DTYPE)
torch.ops._C.per_token_group_fp8_quant(
actual,
quant_output[:num_actual_toks],
output_block_scale[:num_actual_toks],
quant_group_size,
1e-10, # eps
finfo.min,
finfo.max,
quant_scale_ue8m0,
quant_col_major,
quant_tma_aligned,
)
elif quant_key == kFp8StaticTensorSym:
# Static FP8 quantization # Static FP8 quantization
fp8_data, _ = self._quant_fp8_op(actual, output_scale) fp8_data, _ = self._quant_fp8_op(actual, output_scale)
quant_output[:num_actual_toks].copy_(fp8_data) quant_output[:num_actual_toks].copy_(fp8_data)
else:
raise ValueError(f"Unsupported quant_key: {quant_key}")
return quant_output return quant_output
return output_padded return output_padded
...@@ -980,6 +1054,10 @@ def unified_mla_attention_with_output( ...@@ -980,6 +1054,10 @@ def unified_mla_attention_with_output(
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
kv_cache_dummy_dep: torch.Tensor | None = None, kv_cache_dummy_dep: torch.Tensor | None = None,
quant_group_size: int | None = None,
quant_scale_ue8m0: bool | None = None,
quant_col_major: bool | None = None,
quant_tma_aligned: bool | None = None,
) -> None: ) -> None:
# kv_cache_dummy_dep is not used but accepting it creates a data dependency # kv_cache_dummy_dep is not used but accepting it creates a data dependency
# that ensures torch.compile preserves ordering between KV cache update and # that ensures torch.compile preserves ordering between KV cache update and
...@@ -996,6 +1074,10 @@ def unified_mla_attention_with_output( ...@@ -996,6 +1074,10 @@ def unified_mla_attention_with_output(
output=output, output=output,
output_scale=output_scale, output_scale=output_scale,
output_block_scale=output_block_scale, output_block_scale=output_block_scale,
quant_group_size=quant_group_size,
quant_scale_ue8m0=quant_scale_ue8m0,
quant_col_major=quant_col_major,
quant_tma_aligned=quant_tma_aligned,
) )
...@@ -1008,6 +1090,10 @@ def unified_mla_attention_with_output_fake( ...@@ -1008,6 +1090,10 @@ def unified_mla_attention_with_output_fake(
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
kv_cache_dummy_dep: torch.Tensor | None = None, kv_cache_dummy_dep: torch.Tensor | None = None,
quant_group_size: int | None = None,
quant_scale_ue8m0: bool | None = None,
quant_col_major: bool | None = None,
quant_tma_aligned: bool | None = None,
) -> None: ) -> None:
return return
...@@ -2078,13 +2164,13 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -2078,13 +2164,13 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
""" """
def fused_output_quant_supported(self, quant_key): def fused_output_quant_supported(self, quant_key):
from vllm.model_executor.layers.quantization.utils.quant_utils import ( return quant_key in (
kFp8StaticTensorSym, kFp8StaticTensorSym,
kNvfp4Dynamic, kNvfp4Dynamic,
kFp8Dynamic128Sym,
kFp8Dynamic64Sym,
) )
return quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic)
def __init__( def __init__(
self, self,
num_heads: int, num_heads: int,
......
...@@ -11,6 +11,8 @@ import torch ...@@ -11,6 +11,8 @@ import torch
from typing_extensions import deprecated from typing_extensions import deprecated
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8Dynamic64Sym,
kFp8Dynamic128Sym,
kFp8StaticTensorSym, kFp8StaticTensorSym,
kNvfp4Dynamic, kNvfp4Dynamic,
) )
...@@ -880,7 +882,12 @@ class MLAAttentionImpl(AttentionImplBase[T], Generic[T]): ...@@ -880,7 +882,12 @@ class MLAAttentionImpl(AttentionImplBase[T], Generic[T]):
Since MLA quantization is done manually in forward_impl (common code), Since MLA quantization is done manually in forward_impl (common code),
all MLA backends support it by default. all MLA backends support it by default.
""" """
return quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic) return quant_key in (
kFp8StaticTensorSym,
kNvfp4Dynamic,
kFp8Dynamic128Sym,
kFp8Dynamic64Sym,
)
def do_kv_cache_update( def do_kv_cache_update(
self, self,
...@@ -918,7 +925,12 @@ class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]): ...@@ -918,7 +925,12 @@ class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
Since MLA quantization is done manually in forward_impl (common code), Since MLA quantization is done manually in forward_impl (common code),
all MLA backends support it by default. all MLA backends support it by default.
""" """
return quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic) return quant_key in (
kFp8StaticTensorSym,
kNvfp4Dynamic,
kFp8Dynamic128Sym,
kFp8Dynamic64Sym,
)
@abstractmethod @abstractmethod
def __init__( def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment