Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
095d2f87
Unverified
Commit
095d2f87
authored
Apr 24, 2026
by
qli88
Committed by
GitHub
Apr 24, 2026
Browse files
[Bug] Fix GLM-5.1 running error on ROCm platform (#40763)
Signed-off-by:
Qiang Li
<
qiang.li2@amd.com
>
parent
21792520
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
68 additions
and
27 deletions
+68
-27
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
+55
-24
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
+10
-3
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
+3
-0
No files found.
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
View file @
095d2f87
...
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
ClassVar
from
typing
import
ClassVar
,
Final
import
torch
...
...
@@ -389,6 +389,53 @@ def _expand_page_indices_kernel(
)
class
AiterMLAHelper
:
"""
AITER MLA implementation requires num_heads >= 16. If num_heads < 16 and
16 % num_heads == 0, we can pad q to 16 heads; otherwise AITER has to fail.
"""
_AITER_MIN_MLA_HEADS
:
Final
=
16
@
staticmethod
def
check_num_heads_validity
(
num_heads
:
int
):
assert
AiterMLAHelper
.
is_valid_num_heads
(
num_heads
),
(
f
"Aiter MLA requires that num_heads be multiples or divisors of 16, "
f
"but provided
{
num_heads
}
number of heads.
\n
"
f
"Try adjusting tensor_parallel_size value."
)
@
staticmethod
def
is_valid_num_heads
(
num_heads
:
int
)
->
bool
:
return
(
num_heads
%
AiterMLAHelper
.
_AITER_MIN_MLA_HEADS
==
0
if
num_heads
>=
AiterMLAHelper
.
_AITER_MIN_MLA_HEADS
else
AiterMLAHelper
.
_AITER_MIN_MLA_HEADS
%
num_heads
==
0
)
@
staticmethod
def
get_actual_mla_num_heads
(
num_heads
:
int
)
->
int
:
return
max
(
num_heads
,
AiterMLAHelper
.
_AITER_MIN_MLA_HEADS
)
@
staticmethod
def
get_mla_padded_q
(
num_heads
:
int
,
q
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
(
q
if
num_heads
>=
AiterMLAHelper
.
_AITER_MIN_MLA_HEADS
else
q
.
repeat_interleave
(
AiterMLAHelper
.
_AITER_MIN_MLA_HEADS
//
num_heads
,
dim
=
1
)
)
@
staticmethod
def
get_mla_unpadded_o
(
num_heads
:
int
,
o
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
(
o
if
num_heads
>=
AiterMLAHelper
.
_AITER_MIN_MLA_HEADS
else
o
[:,
::
AiterMLAHelper
.
_AITER_MIN_MLA_HEADS
//
num_heads
,
:]
)
class
AiterMLAImpl
(
MLACommonImpl
[
AiterMLAMetadata
]):
def
__init__
(
self
,
...
...
@@ -418,17 +465,8 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
kv_sharing_target_layer_name
,
**
mla_args
,
)
_valid_heads
=
num_heads
in
(
4
,
8
)
or
(
num_heads
%
16
==
0
and
16
<=
num_heads
<=
128
)
assert
_valid_heads
,
(
f
"Aiter MLA supports num_heads of 4, 8, or multiples of 16 "
f
"in [16, 128].
\n
"
f
"Provided
{
num_heads
}
number of heads.
\n
"
"Try adjusting tensor_parallel_size value."
)
self
.
_needs_head_repeat
=
num_heads
<
16
self
.
_head_repeat_factor
=
16
//
num_heads
if
num_heads
<
16
else
1
AiterMLAHelper
.
check_num_heads_validity
(
num_heads
)
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
logits_soft_cap
]
if
any
(
unsupported_features
):
raise
NotImplementedError
(
...
...
@@ -471,15 +509,11 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
assert
isinstance
(
q
,
torch
.
Tensor
)
B
=
q
.
shape
[
0
]
if
self
.
_needs_head_repeat
:
q
=
q
.
repeat_interleave
(
self
.
_head_repeat_factor
,
dim
=
1
)
kernel_num_heads
=
16
else
:
kernel_num_heads
=
self
.
num_heads
mla_padded_q
=
AiterMLAHelper
.
get_mla_padded_q
(
self
.
num_heads
,
q
)
mla_num_heads
=
AiterMLAHelper
.
get_actual_mla_num_heads
(
self
.
num_heads
)
o
=
torch
.
empty
(
B
,
kernel
_num_heads
,
mla
_num_heads
,
self
.
kv_lora_rank
,
dtype
=
attn_metadata
.
decode
.
attn_out_dtype
,
device
=
q
.
device
,
...
...
@@ -506,7 +540,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
)
rocm_aiter_ops
.
mla_decode_fwd
(
q
,
mla_padded_
q
,
kv_buffer
,
o
,
self
.
scale
,
...
...
@@ -518,7 +552,4 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
**
mla_kwargs
,
)
if
self
.
_needs_head_repeat
:
o
=
o
[:,
::
self
.
_head_repeat_factor
,
:]
return
o
,
None
return
AiterMLAHelper
.
get_mla_unpadded_o
(
self
.
num_heads
,
o
),
None
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
View file @
095d2f87
...
...
@@ -28,6 +28,9 @@ from vllm.v1.attention.backend import (
from
vllm.v1.attention.backends.mla.flashmla_sparse
import
(
triton_convert_req_index_to_global_index
,
)
from
vllm.v1.attention.backends.mla.rocm_aiter_mla
import
(
AiterMLAHelper
,
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
if
TYPE_CHECKING
:
...
...
@@ -299,6 +302,8 @@ class ROCMAiterMLASparseImpl(SparseMLAAttentionImpl[ROCMAiterMLASparseMetadata])
indexer
:
"Indexer | None"
=
None
,
**
mla_args
,
)
->
None
:
AiterMLAHelper
.
check_num_heads_validity
(
num_heads
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
...
...
@@ -317,8 +322,9 @@ class ROCMAiterMLASparseImpl(SparseMLAAttentionImpl[ROCMAiterMLASparseMetadata])
attn_metadata
:
ROCMAiterMLASparseMetadata
,
)
->
torch
.
Tensor
:
num_tokens
=
q
.
shape
[
0
]
mla_num_heads
=
AiterMLAHelper
.
get_actual_mla_num_heads
(
self
.
num_heads
)
output
=
torch
.
empty
(
[
num_tokens
,
self
.
num_heads
,
self
.
kv_lora_rank
],
[
num_tokens
,
mla_
num_heads
,
self
.
kv_lora_rank
],
dtype
=
q
.
dtype
,
device
=
q
.
device
,
)
...
...
@@ -344,7 +350,7 @@ class ROCMAiterMLASparseImpl(SparseMLAAttentionImpl[ROCMAiterMLASparseMetadata])
attn_metadata
.
paged_kv_last_page_len
,
)
return
output
[:,
:
self
.
num_heads
,
:]
return
AiterMLAHelper
.
get_mla_unpadded_o
(
self
.
num_heads
,
output
)
def
forward_mqa
(
self
,
...
...
@@ -374,8 +380,9 @@ class ROCMAiterMLASparseImpl(SparseMLAAttentionImpl[ROCMAiterMLASparseMetadata])
NUM_TOPK_TOKENS
=
attn_metadata
.
topk_tokens
,
)
mla_padded_q
=
AiterMLAHelper
.
get_mla_padded_q
(
self
.
num_heads
,
q
)
attn_out
=
self
.
_forward_bf16_kv
(
q
,
kv_c_and_k_pe_cache
,
topk_indices_global
,
attn_metadata
mla_padded_
q
,
kv_c_and_k_pe_cache
,
topk_indices_global
,
attn_metadata
)
return
attn_out
,
None
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
View file @
095d2f87
...
...
@@ -339,6 +339,8 @@ def rocm_fp8_paged_mqa_logits(
device
=
"cuda"
,
dtype
=
torch
.
float32
,
)
# TODO: 1. Replace _stage1 and out_qk.sum with another fused variant;
# 2. Remove ChunkQ when AITER PR #2891 merged
deepgemm_fp8_paged_mqa_logits_stage1
(
q_fp8
,
kv_cache_fp8
,
...
...
@@ -347,6 +349,7 @@ def rocm_fp8_paged_mqa_logits(
context_lens
,
block_tables
,
max_model_len
,
ChunkQ
=
heads
,
)
return
out_qk
.
sum
(
dim
=
0
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment