Unverified Commit 6d53efd2 authored by haosdent's avatar haosdent Committed by GitHub
Browse files

[Bugfix] Fix MLA attention crash with AWQ/GPTQ quantized models (#34695)


Signed-off-by: default avatarhaosdent <haosdent@gmail.com>
parent 8b346309
...@@ -442,6 +442,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -442,6 +442,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
# If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
self.is_aiter_triton_fp4_bmm_enabled = ( self.is_aiter_triton_fp4_bmm_enabled = (
rocm_aiter_ops.is_fp4bmm_enabled() rocm_aiter_ops.is_fp4bmm_enabled()
and hasattr(self.kv_b_proj, "weight")
and self.kv_b_proj.weight.dtype == torch.bfloat16 and self.kv_b_proj.weight.dtype == torch.bfloat16
) )
...@@ -2492,11 +2493,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -2492,11 +2493,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_c_normed = workspace[:toks][..., : self.kv_lora_rank] kv_c_normed = workspace[:toks][..., : self.kv_lora_rank]
# When FP8 weights are used without FP8 prefill, kv_b_proj expects # When FP8 weights are used without FP8 prefill, kv_b_proj expects
# model dtype input and will quantize internally. # model dtype input and will quantize internally.
if ( # For quantized layers (AWQ/GPTQ) that lack a .weight attribute,
use_fp8_prefill # use params_dtype which is the expected input dtype.
or self.kv_b_proj.weight.dtype != current_platform.fp8_dtype() _kv_b_proj_w_dtype = (
): self.kv_b_proj.weight.dtype
kv_c_normed = kv_c_normed.to(self.kv_b_proj.weight.dtype) if hasattr(self.kv_b_proj, "weight")
else self.kv_b_proj.params_dtype
)
if use_fp8_prefill or _kv_b_proj_w_dtype != current_platform.fp8_dtype():
kv_c_normed = kv_c_normed.to(_kv_b_proj_w_dtype)
k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1) k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1)
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment