Unverified Commit 21b086d0 authored by Rita Brugarolas's avatar Rita Brugarolas Committed by GitHub
Browse files

[ROCm] Hotfix: guard MLA dual RMS norm fusion against older AITer versions (#40386)


Signed-off-by: default avatarRita Brugarolas Brufau <rita.brugarolasbrufau@amd.com>
parent 3173441b
......@@ -376,6 +376,22 @@ def _rocm_aiter_fused_topk_fake(
# Cache whether aiter supports FP8 MLA parameters
_AITER_MLA_SUPPORTS_FP8: bool | None = None
_AITER_HAS_FUSED_QK_RMSNORM: bool | None = None
def check_aiter_fused_qk_rmsnorm() -> bool:
"""Check if aiter provides fused_qk_rmsnorm (requires AITer >= PR #2442)."""
global _AITER_HAS_FUSED_QK_RMSNORM
if _AITER_HAS_FUSED_QK_RMSNORM is None:
try:
from aiter.ops.fused_qk_norm_rope_cache_quant import ( # noqa: F401
fused_qk_rmsnorm,
)
_AITER_HAS_FUSED_QK_RMSNORM = True
except (ImportError, ModuleNotFoundError, AttributeError):
_AITER_HAS_FUSED_QK_RMSNORM = False
return _AITER_HAS_FUSED_QK_RMSNORM
def _check_aiter_mla_fp8_support() -> bool:
......@@ -970,7 +986,14 @@ def _fused_mla_dual_rms_norm_impl(
x1_epsilon: float,
x2_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.fused_qk_norm_rope_cache_quant import fused_qk_rmsnorm
try:
from aiter.ops.fused_qk_norm_rope_cache_quant import fused_qk_rmsnorm
except (ImportError, ModuleNotFoundError) as exc:
raise ImportError(
"fused_qk_rmsnorm requires a newer AITer version "
"(>= PR #2442). Please upgrade aiter or disable the "
"fuse_mla_dual_rms_norm pass."
) from exc
return fused_qk_rmsnorm(
q=x1,
......
......@@ -166,10 +166,10 @@ def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool:
def enable_mla_dual_rms_norm_fusion(cfg: "VllmConfig") -> bool:
"""Enable MLA dual RMS norm fusion when AITer is available."""
from vllm._aiter_ops import rocm_aiter_ops
"""Enable MLA dual RMS norm fusion when AITer has fused_qk_rmsnorm."""
from vllm._aiter_ops import check_aiter_fused_qk_rmsnorm, rocm_aiter_ops
return rocm_aiter_ops.is_enabled()
return rocm_aiter_ops.is_enabled() and check_aiter_fused_qk_rmsnorm()
OPTIMIZATION_LEVEL_00 = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment