Unverified Commit b8a6ae41 authored by Ye (Charlotte) Qi's avatar Ye (Charlotte) Qi Committed by GitHub
Browse files

[ROCm] add fallback for aiter fp8 decode mla (#30005)


Signed-off-by: default avatarYe (Charlotte) Qi <yeq@meta.com>
parent 899e2ef5
......@@ -283,6 +283,28 @@ def _rocm_aiter_grouped_topk_fake(
pass
# Cache whether aiter supports FP8 MLA parameters
_AITER_MLA_SUPPORTS_FP8: bool | None = None
def _check_aiter_mla_fp8_support() -> bool:
"""Check if aiter.mla.mla_decode_fwd supports q_scale and kv_scale parameters."""
global _AITER_MLA_SUPPORTS_FP8
if _AITER_MLA_SUPPORTS_FP8 is None:
try:
import inspect
from aiter.mla import mla_decode_fwd
sig = inspect.signature(mla_decode_fwd)
_AITER_MLA_SUPPORTS_FP8 = (
"q_scale" in sig.parameters and "kv_scale" in sig.parameters
)
except Exception:
_AITER_MLA_SUPPORTS_FP8 = False
return _AITER_MLA_SUPPORTS_FP8
def _rocm_aiter_mla_decode_fwd_impl(
q: torch.Tensor,
kv_buffer: torch.Tensor,
......@@ -299,6 +321,16 @@ def _rocm_aiter_mla_decode_fwd_impl(
) -> None:
from aiter.mla import mla_decode_fwd
kwargs = {
"sm_scale": sm_scale,
"logit_cap": logit_cap,
}
# Only pass q_scale and kv_scale if the aiter library supports them
if _check_aiter_mla_fp8_support():
kwargs["q_scale"] = q_scale
kwargs["kv_scale"] = kv_scale
mla_decode_fwd(
q,
kv_buffer.view(-1, 1, 1, q.shape[-1]),
......@@ -308,10 +340,7 @@ def _rocm_aiter_mla_decode_fwd_impl(
kv_indices,
kv_last_page_lens,
max_seqlen_qo,
sm_scale=sm_scale,
logit_cap=logit_cap,
q_scale=q_scale,
kv_scale=kv_scale,
**kwargs,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment