Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6d53efd2
Unverified
Commit
6d53efd2
authored
Mar 14, 2026
by
haosdent
Committed by
GitHub
Mar 13, 2026
Browse files
[Bugfix] Fix MLA attention crash with AWQ/GPTQ quantized models (#34695)
Signed-off-by:
haosdent
<
haosdent@gmail.com
>
parent
8b346309
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
5 deletions
+10
-5
vllm/model_executor/layers/attention/mla_attention.py
vllm/model_executor/layers/attention/mla_attention.py
+10
-5
No files found.
vllm/model_executor/layers/attention/mla_attention.py
View file @
6d53efd2
...
@@ -442,6 +442,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
...
@@ -442,6 +442,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
# If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
# If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
self
.
is_aiter_triton_fp4_bmm_enabled
=
(
self
.
is_aiter_triton_fp4_bmm_enabled
=
(
rocm_aiter_ops
.
is_fp4bmm_enabled
()
rocm_aiter_ops
.
is_fp4bmm_enabled
()
and
hasattr
(
self
.
kv_b_proj
,
"weight"
)
and
self
.
kv_b_proj
.
weight
.
dtype
==
torch
.
bfloat16
and
self
.
kv_b_proj
.
weight
.
dtype
==
torch
.
bfloat16
)
)
...
@@ -2492,11 +2493,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -2492,11 +2493,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_c_normed
=
workspace
[:
toks
][...,
:
self
.
kv_lora_rank
]
kv_c_normed
=
workspace
[:
toks
][...,
:
self
.
kv_lora_rank
]
# When FP8 weights are used without FP8 prefill, kv_b_proj expects
# When FP8 weights are used without FP8 prefill, kv_b_proj expects
# model dtype input and will quantize internally.
# model dtype input and will quantize internally.
if
(
# For quantized layers (AWQ/GPTQ) that lack a .weight attribute,
use_fp8_prefill
# use params_dtype which is the expected input dtype.
or
self
.
kv_b_proj
.
weight
.
dtype
!=
current_platform
.
fp8_dtype
()
_kv_b_proj_w_dtype
=
(
):
self
.
kv_b_proj
.
weight
.
dtype
kv_c_normed
=
kv_c_normed
.
to
(
self
.
kv_b_proj
.
weight
.
dtype
)
if
hasattr
(
self
.
kv_b_proj
,
"weight"
)
else
self
.
kv_b_proj
.
params_dtype
)
if
use_fp8_prefill
or
_kv_b_proj_w_dtype
!=
current_platform
.
fp8_dtype
():
kv_c_normed
=
kv_c_normed
.
to
(
_kv_b_proj_w_dtype
)
k_pe
=
workspace
[:
toks
][...,
self
.
kv_lora_rank
:].
unsqueeze
(
1
)
k_pe
=
workspace
[:
toks
][...,
self
.
kv_lora_rank
:].
unsqueeze
(
1
)
kv_nope
=
self
.
kv_b_proj
(
kv_c_normed
)[
0
].
view
(
kv_nope
=
self
.
kv_b_proj
(
kv_c_normed
)[
0
].
view
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment