Unverified Commit fb5635d3 authored by Rita Brugarolas's avatar Rita Brugarolas Committed by GitHub
Browse files

[ROCm] Add MLA dual RMS norm fusion (Q, KV) pass for DeepSeek/Kimi-K2 (#39242)


Signed-off-by: default avatarRita Brugarolas Brufau <rita.brugarolasbrufau@amd.com>
parent b42e878e
...@@ -31,6 +31,7 @@ or just on the low or high end. ...@@ -31,6 +31,7 @@ or just on the low or high end.
| [RMSNorm + Quant](#rmsnorm--quantization-fuse_norm_quant) | `fuse_norm_quant` | RMSNorm (+residual add) → FP8/FP4 quant | O1 (conditional) | 1-4% | No | Always | | [RMSNorm + Quant](#rmsnorm--quantization-fuse_norm_quant) | `fuse_norm_quant` | RMSNorm (+residual add) → FP8/FP4 quant | O1 (conditional) | 1-4% | No | Always |
| [SiLU+Mul + Quant](#silumul--quantization-fuse_act_quant) | `fuse_act_quant` | SiLU+Mul activation → FP8/FP4 quant | O1 (conditional) | 1-4% | No | Always | | [SiLU+Mul + Quant](#silumul--quantization-fuse_act_quant) | `fuse_act_quant` | SiLU+Mul activation → FP8/FP4 quant | O1 (conditional) | 1-4% | No | Always |
| [RMSNorm + Padding](#rmsnorm--padding-fuse_act_padding) | `fuse_act_padding` | Residual add + RMSNorm → padding | O1 (ROCm/AITER only) | TBD | No | Always | | [RMSNorm + Padding](#rmsnorm--padding-fuse_act_padding) | `fuse_act_padding` | Residual add + RMSNorm → padding | O1 (ROCm/AITER only) | TBD | No | Always |
| [MLA Dual RMSNorm](#mla-dual-rmsnorm-fuse_mla_dual_rms_norm) | `fuse_mla_dual_rms_norm` | Paired Q + KV RMSNorm → single kernel | O1 (ROCm/AITER only) | ~2% | No | Always |
## Support Matrix ## Support Matrix
...@@ -51,6 +52,7 @@ The table below lists the quantization schemes supported by each fusion on each ...@@ -51,6 +52,7 @@ The table below lists the quantization schemes supported by each fusion on each
| `fuse_norm_quant` | FP8 static, FP8 per-token, FP8 per-group | FP8 static, FP8 per-token, FP8 per-group | FP8 static, FP8 per-token, FP8 per-group | — | FP8 static, FP8 per-token, FP8 per-group | | `fuse_norm_quant` | FP8 static, FP8 per-token, FP8 per-group | FP8 static, FP8 per-token, FP8 per-group | FP8 static, FP8 per-token, FP8 per-group | — | FP8 static, FP8 per-token, FP8 per-group |
| `fuse_act_quant` | FP8 static, NVFP4 | FP8 static, FP8 per-group (128/64) | FP8 static, FP8 per-group (128/64) | — | FP8 per-group | | `fuse_act_quant` | FP8 static, NVFP4 | FP8 static, FP8 per-group (128/64) | FP8 static, FP8 per-group (128/64) | — | FP8 per-group |
| `fuse_act_padding` | — | — | — | — | FP16/BF16 | | `fuse_act_padding` | — | — | — | — | FP16/BF16 |
| `fuse_mla_dual_rms_norm` | — | — | — | — | BF16 |
\* `fuse_attn_quant` support depends on the attention backend in use; not all backends support \* `fuse_attn_quant` support depends on the attention backend in use; not all backends support
fused quantization output. See the [`fuse_attn_quant` section](#attention--quantization-fuse_attn_quant) fused quantization output. See the [`fuse_attn_quant` section](#attention--quantization-fuse_attn_quant)
...@@ -381,6 +383,44 @@ when the hidden size is 2880 and AITER Triton GEMMs *not* enabled. ...@@ -381,6 +383,44 @@ when the hidden size is 2880 and AITER Triton GEMMs *not* enabled.
- Pass: [`vllm/compilation/passes/fusion/rocm_aiter_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/rocm_aiter_fusion.py) (`RocmAiterTritonAddRMSNormPadFusionPass`) - Pass: [`vllm/compilation/passes/fusion/rocm_aiter_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/rocm_aiter_fusion.py) (`RocmAiterTritonAddRMSNormPadFusionPass`)
### MLA Dual RMSNorm (`fuse_mla_dual_rms_norm`)
!!! info
ROCm/AITER-only. Targeted at DeepSeek-V3 / Kimi-K2 MLA attention.
!!! note
When the native implementation of `rms_norm` is used (the default on CUDA and
ROCm for now), Inductor's built-in fusion already handles merging these norms
automatically. This explicit pass targets the case where AITER's custom
`rms_norm` op is active, which Inductor cannot fuse on its own.
**What it fuses.** Fuses the paired `q_a_layernorm` and `kv_a_layernorm` RMS norm
operations in MLA attention into a single `fused_qk_rmsnorm` HIP kernel call via AITER,
reducing kernel launch overhead from 2 launches to 1 per MLA layer.
```text
# Unfused:
q_c, kv_lora = split(projected, [q_dim, kv_dim])
kv_c, k_pe = split(kv_lora, [kv_c_dim, k_pe_dim])
q_c = rms_norm(q_c, q_weight, eps)
kv_c = rms_norm(kv_c, kv_weight, eps)
# Fused:
q_c, kv_lora = split(projected, [q_dim, kv_dim])
kv_c, k_pe = split(kv_lora, [kv_c_dim, k_pe_dim])
q_normed, kv_normed = fused_mla_dual_rms_norm(
q_c, q_weight, kv_c, kv_weight, eps1, eps2)
```
Requires: AMD ROCm with AITER enabled. Enabled by default at optimization level O1 and above
when AITER is available.
**Code locations.**
- Pass: [`vllm/compilation/passes/fusion/rocm_aiter_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/rocm_aiter_fusion.py) (`MLADualRMSNormFusionPass`)
- Custom op: [`vllm/_aiter_ops.py`](https://github.com/vllm-project/vllm/blob/main/vllm/_aiter_ops.py) (`fused_mla_dual_rms_norm`)
- AITER kernel: [`fused_qk_rmsnorm`](https://github.com/ROCm/aiter/pull/2442)
## See Also ## See Also
- [Optimization Levels](optimization_levels.md) — high-level presets that set - [Optimization Levels](optimization_levels.md) — high-level presets that set
......
...@@ -56,6 +56,7 @@ Fusions: ...@@ -56,6 +56,7 @@ Fusions:
- `-cc.pass_config.fuse_norm_quant=True`* - `-cc.pass_config.fuse_norm_quant=True`*
- `-cc.pass_config.fuse_act_quant=True`* - `-cc.pass_config.fuse_act_quant=True`*
- `-cc.pass_config.fuse_act_padding=True` - `-cc.pass_config.fuse_act_padding=True`
- `-cc.pass_config.fuse_mla_dual_rms_norm=True`
\* These fusions are only enabled when either op is using a custom kernel, otherwise Inductor fusion is better.</br> \* These fusions are only enabled when either op is using a custom kernel, otherwise Inductor fusion is better.</br>
† These fusions are ROCm-only and require AITER. † These fusions are ROCm-only and require AITER.
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit test for the MLADualRMSNormFusionPass.
The pass fuses paired q/kv RMS norms in MLA attention into a single
fused_mla_dual_rms_norm op backed by AITER's fused_qk_rmsnorm kernel.
"""
import pytest
import torch
import vllm.config
from tests.compile.backend import TestBackend
from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
from vllm.config import (
CompilationConfig,
CompilationMode,
ModelConfig,
PassConfig,
VllmConfig,
)
from vllm.model_executor.layers.layernorm import RMSNorm
# MLA attention geometry for DeepSeek-V3 / Kimi-K2
Q_DIM = 1536
KV_C_DIM = 512
K_PE_DIM = 64
EPS = 1e-6
class MLADualRMSNormTestModel(torch.nn.Module):
"""
Minimal model reproducing the MLA dual RMS norm pattern:
linear -> split([q_dim, kv_dim])
+-- q_c (getitem 0) -> rms_norm(q_w, eps) -> linear
+-- kv_lora (getitem 1) -> split([kv_c_dim, k_pe_dim])
+-- kv_c (getitem 0) -> rms_norm(kv_w, eps)
+-- k_pe
"""
def __init__(
self,
hidden_size: int,
q_dim: int = Q_DIM,
kv_c_dim: int = KV_C_DIM,
k_pe_dim: int = K_PE_DIM,
eps: float = EPS,
):
super().__init__()
self.q_dim = q_dim
self.kv_dim = kv_c_dim + k_pe_dim
self.kv_c_dim = kv_c_dim
self.k_pe_dim = k_pe_dim
self.proj = torch.nn.Linear(hidden_size, q_dim + self.kv_dim, bias=False)
self.q_norm = RMSNorm(q_dim, eps=eps)
self.kv_norm = RMSNorm(kv_c_dim, eps=eps)
self.q_b_proj = torch.nn.Linear(q_dim, hidden_size, bias=False)
def forward(
self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Avoid graph input being a direct arg to a matched pattern node
x = torch.relu(x)
projected = self.proj(x)
q_c, kv_lora = projected.split([self.q_dim, self.kv_dim], dim=-1)
kv_c, k_pe = kv_lora.split([self.kv_c_dim, self.k_pe_dim], dim=-1)
q_normed = self.q_norm(q_c)
kv_normed = self.kv_norm(kv_c)
q_out = self.q_b_proj(q_normed)
return q_out, kv_normed, k_pe
def ops_in_model_before(self):
return [torch.ops.vllm_ir.rms_norm.default]
def ops_in_model_after(self):
return [torch.ops.vllm.fused_mla_dual_rms_norm.default]
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [7168])
@pytest.mark.skipif(
not is_aiter_found_and_supported(),
reason="Only test on ROCm with AITER installed and supported",
)
def test_fuse_mla_dual_rms_norm(
dtype: torch.dtype,
hidden_size: int,
monkeypatch: pytest.MonkeyPatch,
):
torch._dynamo.reset()
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=["+rms_norm"],
pass_config=PassConfig(
fuse_mla_dual_rms_norm=True,
eliminate_noops=True,
),
),
)
with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
MLADualRMSNormFusionPass,
)
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(42)
m.setenv("VLLM_ROCM_USE_AITER", "1")
rocm_aiter_ops.refresh_env_variables()
fusion_pass = MLADualRMSNormFusionPass(vllm_config)
passes = [
NoOpEliminationPass(vllm_config),
fusion_pass,
PostCleanupPass(vllm_config),
]
backend = TestBackend(*passes)
model = MLADualRMSNormTestModel(hidden_size)
x = torch.randn(1, hidden_size)
torch._dynamo.mark_dynamic(x, 0)
outputs_unfused = model(x)
model_fused = torch.compile(model, backend=backend)
outputs_fused = model_fused(x)
torch.testing.assert_close(outputs_unfused, outputs_fused, atol=1e-2, rtol=1e-2)
assert fusion_pass.matched_count == 1, (
f"Expected 1 fused pair, got {fusion_pass.matched_count}"
)
backend.check_before_ops(model.ops_in_model_before())
backend.check_after_ops(model.ops_in_model_after())
...@@ -962,6 +962,37 @@ def _rocm_aiter_triton_add_rmsnorm_pad_fake( ...@@ -962,6 +962,37 @@ def _rocm_aiter_triton_add_rmsnorm_pad_fake(
return out, residual_out return out, residual_out
def _fused_mla_dual_rms_norm_impl(
x1: torch.Tensor,
x1_weight: torch.Tensor,
x2: torch.Tensor,
x2_weight: torch.Tensor,
x1_epsilon: float,
x2_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.fused_qk_norm_rope_cache_quant import fused_qk_rmsnorm
return fused_qk_rmsnorm(
q=x1,
q_weight=x1_weight,
q_eps=x1_epsilon,
k=x2,
k_weight=x2_weight,
k_eps=x2_epsilon,
)
def _fused_mla_dual_rms_norm_fake(
x1: torch.Tensor,
x1_weight: torch.Tensor,
x2: torch.Tensor,
x2_weight: torch.Tensor,
x1_epsilon: float,
x2_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
return (torch.empty_like(x1), torch.empty_like(x2))
def _rocm_aiter_gemm_a8wfp4_impl( def _rocm_aiter_gemm_a8wfp4_impl(
x: torch.Tensor, x: torch.Tensor,
w: torch.Tensor, w: torch.Tensor,
...@@ -1491,6 +1522,13 @@ class rocm_aiter_ops: ...@@ -1491,6 +1522,13 @@ class rocm_aiter_ops:
fake_impl=_triton_rotary_embedding_fake, fake_impl=_triton_rotary_embedding_fake,
) )
direct_register_custom_op(
op_name="fused_mla_dual_rms_norm",
op_func=_fused_mla_dual_rms_norm_impl,
mutates_args=[],
fake_impl=_fused_mla_dual_rms_norm_fake,
)
_OPS_REGISTERED = True _OPS_REGISTERED = True
@staticmethod @staticmethod
...@@ -1537,6 +1575,10 @@ class rocm_aiter_ops: ...@@ -1537,6 +1575,10 @@ class rocm_aiter_ops:
def get_triton_rotary_embedding_op() -> OpOverload: def get_triton_rotary_embedding_op() -> OpOverload:
return torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default return torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default
@staticmethod
def get_fused_mla_dual_rms_norm_op() -> OpOverload:
return torch.ops.vllm.fused_mla_dual_rms_norm.default
@staticmethod @staticmethod
def rms_norm( def rms_norm(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any from typing import Any
import torch import torch
...@@ -7,6 +8,7 @@ import torch._inductor.pattern_matcher as pm ...@@ -7,6 +8,7 @@ import torch._inductor.pattern_matcher as pm
from torch import fx from torch import fx
from torch._inductor.pattern_matcher import PatternMatcherPass from torch._inductor.pattern_matcher import PatternMatcherPass
import vllm.ir.ops
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401 import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -20,7 +22,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -20,7 +22,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..inductor_pass import enable_fake_mode from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass from ..vllm_inductor_pass import (
VllmFusionPatternMatcherPass,
VllmInductorPass,
VllmPatternMatcherPass,
VllmPatternReplacement,
)
from .act_quant_fusion import ActivationQuantPattern from .act_quant_fusion import ActivationQuantPattern
from .matcher_utils import ( from .matcher_utils import (
MatcherFusedAddRMSNorm, MatcherFusedAddRMSNorm,
...@@ -512,3 +519,101 @@ class RocmAiterTritonAddRMSNormPadFusionPass(VllmPatternMatcherPass): ...@@ -512,3 +519,101 @@ class RocmAiterTritonAddRMSNormPadFusionPass(VllmPatternMatcherPass):
def uuid(self) -> str: def uuid(self) -> str:
return VllmInductorPass.hash_source(self, AddAiterRMSNormPadPattern) return VllmInductorPass.hash_source(self, AddAiterRMSNormPadPattern)
class MLADualRMSNormPattern(
VllmPatternReplacement[..., tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
):
"""
Fuse paired q_a_layernorm + kv_a_layernorm in MLA attention into
AITER's ``fused_qk_rmsnorm`` HIP kernel.
Target FX-graph pattern (unfused, ``vllm_ir`` stage)::
gemm -> split_with_sizes([q_dim, kv_dim])
+-- q_c -> vllm_ir.rms_norm(q_c, q_w, eps)
+-- kv_lora -> split_with_sizes([kv_c_dim, k_pe_dim])
+-- kv_c -> vllm_ir.rms_norm(kv_c, kv_w, eps)
+-- k_pe
The pattern covers the connected subgraph rooted at the first
``split_with_sizes`` (which produces ``q_c`` and ``kv_lora``),
through the two ``rms_norm`` calls, and the ``k_pe`` passthrough.
"""
def __init__(self, epsilon: float) -> None:
self._epsilon = epsilon
def get_inputs(self) -> list[torch.Tensor]:
q_dim, kv_c_dim, k_pe_dim = 8, 4, 2
return [
self.empty_bf16(5, q_dim + kv_c_dim + k_pe_dim),
self.empty_bf16(q_dim),
self.empty_bf16(kv_c_dim),
]
@property
def pattern(
self,
) -> Callable[..., tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
eps = self._epsilon
def _pattern(
projected: torch.Tensor,
q_weight: torch.Tensor,
kv_weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q_dim = q_weight.shape[0]
kv_dim = projected.shape[-1] - q_dim
kv_c_dim = kv_weight.shape[0]
k_pe_dim = kv_dim - kv_c_dim
q_c, kv_lora = projected.split([q_dim, kv_dim], dim=-1)
kv_c, k_pe = kv_lora.split([kv_c_dim, k_pe_dim], dim=-1)
q_normed = vllm.ir.ops.rms_norm(q_c, q_weight, eps)
kv_normed = vllm.ir.ops.rms_norm(kv_c, kv_weight, eps)
return q_normed, kv_normed, k_pe
return _pattern
@property
def replacement(
self,
) -> Callable[..., tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
eps = self._epsilon
def _replacement(
projected: torch.Tensor,
q_weight: torch.Tensor,
kv_weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q_dim = q_weight.shape[0]
kv_dim = projected.shape[-1] - q_dim
kv_c_dim = kv_weight.shape[0]
k_pe_dim = kv_dim - kv_c_dim
q_c, kv_lora = projected.split([q_dim, kv_dim], dim=-1)
kv_c, k_pe = kv_lora.split([kv_c_dim, k_pe_dim], dim=-1)
q_normed, kv_normed = torch.ops.vllm.fused_mla_dual_rms_norm(
q_c,
q_weight,
kv_c,
kv_weight,
eps,
eps,
)
return q_normed, kv_normed, k_pe
return _replacement
class MLADualRMSNormFusionPass(VllmFusionPatternMatcherPass):
"""
Post-grad PatternMatcher pass that fuses paired q / kv RMS norms in
MLA attention into ``fused_mla_dual_rms_norm`` backed by aiter's
``fused_qk_rmsnorm`` HIP kernel.
"""
def __init__(self, config: VllmConfig) -> None:
super().__init__(config, "mla_dual_rms_norm_fusion_pass")
for epsilon in [1e-5, 1e-6]:
self.register(MLADualRMSNormPattern(epsilon))
...@@ -19,6 +19,7 @@ from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass ...@@ -19,6 +19,7 @@ from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
if rocm_aiter_ops.is_enabled(): if rocm_aiter_ops.is_enabled():
from .fusion.rocm_aiter_fusion import ( from .fusion.rocm_aiter_fusion import (
MLADualRMSNormFusionPass,
RocmAiterRMSNormQuantFusionPass, RocmAiterRMSNormQuantFusionPass,
RocmAiterSiluMulFp8GroupQuantFusionPass, RocmAiterSiluMulFp8GroupQuantFusionPass,
RocmAiterTritonAddRMSNormPadFusionPass, RocmAiterTritonAddRMSNormPadFusionPass,
...@@ -155,6 +156,9 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc] ...@@ -155,6 +156,9 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
if self.pass_config.fuse_act_padding and rocm_aiter_ops.is_enabled(): if self.pass_config.fuse_act_padding and rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterTritonAddRMSNormPadFusionPass(config)] self.passes += [RocmAiterTritonAddRMSNormPadFusionPass(config)]
if self.pass_config.fuse_mla_dual_rms_norm and rocm_aiter_ops.is_enabled():
self.passes += [MLADualRMSNormFusionPass(config)]
if self.pass_config.fuse_rope_kvcache: if self.pass_config.fuse_rope_kvcache:
self.passes += [SplitCoalescingPass(config)] self.passes += [SplitCoalescingPass(config)]
self.passes += [ScatterSplitReplacementPass(config)] self.passes += [ScatterSplitReplacementPass(config)]
......
...@@ -142,6 +142,8 @@ class PassConfig: ...@@ -142,6 +142,8 @@ class PassConfig:
# ROCm/AITER specific fusions # ROCm/AITER specific fusions
fuse_act_padding: bool = None # type: ignore[assignment] fuse_act_padding: bool = None # type: ignore[assignment]
"""Fuse the custom RMSNorm + padding ops.""" """Fuse the custom RMSNorm + padding ops."""
fuse_mla_dual_rms_norm: bool = None # type: ignore[assignment]
"""Fuse paired q/kv RMS norms in MLA attention."""
fuse_rope_kvcache: bool = None # type: ignore[assignment] fuse_rope_kvcache: bool = None # type: ignore[assignment]
"""Fuse the QK rope + KV cache ops.""" """Fuse the QK rope + KV cache ops."""
...@@ -224,6 +226,7 @@ class PassConfig: ...@@ -224,6 +226,7 @@ class PassConfig:
"fuse_gemm_comms", "fuse_gemm_comms",
"fuse_allreduce_rms", "fuse_allreduce_rms",
"fuse_act_padding", "fuse_act_padding",
"fuse_mla_dual_rms_norm",
"fuse_rope_kvcache", "fuse_rope_kvcache",
mode="wrap", mode="wrap",
) )
...@@ -270,6 +273,12 @@ class PassConfig: ...@@ -270,6 +273,12 @@ class PassConfig:
"The fusion will be disabled." "The fusion will be disabled."
) )
self.fuse_act_padding = False self.fuse_act_padding = False
if self.fuse_mla_dual_rms_norm and not current_platform.is_rocm():
logger.warning_once(
"MLA dual RMS norm fusion requires ROCm/AITER. "
"The fusion will be disabled."
)
self.fuse_mla_dual_rms_norm = False
if self.fuse_rope_kvcache and not current_platform.is_rocm(): if self.fuse_rope_kvcache and not current_platform.is_rocm():
logger.warning_once( logger.warning_once(
"KV cache fusion currently only enabled on ROCm. " "KV cache fusion currently only enabled on ROCm. "
......
...@@ -165,6 +165,13 @@ def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool: ...@@ -165,6 +165,13 @@ def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool:
) )
def enable_mla_dual_rms_norm_fusion(cfg: "VllmConfig") -> bool:
"""Enable MLA dual RMS norm fusion when AITer is available."""
from vllm._aiter_ops import rocm_aiter_ops
return rocm_aiter_ops.is_enabled()
OPTIMIZATION_LEVEL_00 = { OPTIMIZATION_LEVEL_00 = {
"compilation_config": { "compilation_config": {
"pass_config": { "pass_config": {
...@@ -175,6 +182,7 @@ OPTIMIZATION_LEVEL_00 = { ...@@ -175,6 +182,7 @@ OPTIMIZATION_LEVEL_00 = {
"enable_sp": False, "enable_sp": False,
"fuse_gemm_comms": False, "fuse_gemm_comms": False,
"fuse_act_padding": False, "fuse_act_padding": False,
"fuse_mla_dual_rms_norm": False,
"fuse_rope_kvcache": False, "fuse_rope_kvcache": False,
}, },
"cudagraph_mode": CUDAGraphMode.NONE, "cudagraph_mode": CUDAGraphMode.NONE,
...@@ -194,6 +202,7 @@ OPTIMIZATION_LEVEL_01 = { ...@@ -194,6 +202,7 @@ OPTIMIZATION_LEVEL_01 = {
"enable_sp": False, "enable_sp": False,
"fuse_gemm_comms": False, "fuse_gemm_comms": False,
"fuse_act_padding": enable_norm_pad_fusion, "fuse_act_padding": enable_norm_pad_fusion,
"fuse_mla_dual_rms_norm": enable_mla_dual_rms_norm_fusion,
"fuse_rope_kvcache": False, "fuse_rope_kvcache": False,
}, },
"cudagraph_mode": CUDAGraphMode.PIECEWISE, "cudagraph_mode": CUDAGraphMode.PIECEWISE,
...@@ -213,6 +222,7 @@ OPTIMIZATION_LEVEL_02 = { ...@@ -213,6 +222,7 @@ OPTIMIZATION_LEVEL_02 = {
"enable_sp": IS_DENSE, "enable_sp": IS_DENSE,
"fuse_gemm_comms": IS_DENSE, "fuse_gemm_comms": IS_DENSE,
"fuse_act_padding": enable_norm_pad_fusion, "fuse_act_padding": enable_norm_pad_fusion,
"fuse_mla_dual_rms_norm": enable_mla_dual_rms_norm_fusion,
"fuse_rope_kvcache": enable_rope_kvcache_fusion, "fuse_rope_kvcache": enable_rope_kvcache_fusion,
}, },
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE, "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
...@@ -232,6 +242,7 @@ OPTIMIZATION_LEVEL_03 = { ...@@ -232,6 +242,7 @@ OPTIMIZATION_LEVEL_03 = {
"enable_sp": IS_DENSE, "enable_sp": IS_DENSE,
"fuse_gemm_comms": IS_DENSE, "fuse_gemm_comms": IS_DENSE,
"fuse_act_padding": enable_norm_pad_fusion, "fuse_act_padding": enable_norm_pad_fusion,
"fuse_mla_dual_rms_norm": enable_mla_dual_rms_norm_fusion,
"fuse_rope_kvcache": enable_rope_kvcache_fusion, "fuse_rope_kvcache": enable_rope_kvcache_fusion,
}, },
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE, "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment