Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c188749b
Unverified
Commit
c188749b
authored
Mar 06, 2026
by
Chuan (Richard) Li
Committed by
GitHub
Mar 06, 2026
Browse files
[ROCm] Support MLA with nhead<16 and FP8 KV cache for TP=8 (Kimi K2.5/Linear) (#35850)
Signed-off-by:
Li
<
chuali@amd.com
>
parent
225d1090
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
3 deletions
+19
-3
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
+19
-3
No files found.
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
View file @
c188749b
...
@@ -221,11 +221,17 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
...
@@ -221,11 +221,17 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
kv_sharing_target_layer_name
,
kv_sharing_target_layer_name
,
**
mla_args
,
**
mla_args
,
)
)
assert
num_heads
==
16
or
num_heads
==
128
,
(
_valid_heads
=
num_heads
in
(
4
,
8
)
or
(
f
"Aiter MLA only supports 16 or 128 number of heads.
\n
"
num_heads
%
16
==
0
and
16
<=
num_heads
<=
128
)
assert
_valid_heads
,
(
f
"Aiter MLA supports num_heads of 4, 8, or multiples of 16 "
f
"in [16, 128].
\n
"
f
"Provided
{
num_heads
}
number of heads.
\n
"
f
"Provided
{
num_heads
}
number of heads.
\n
"
"Try adjusting tensor_parallel_size value."
"Try adjusting tensor_parallel_size value."
)
)
self
.
_needs_head_repeat
=
num_heads
<
16
self
.
_head_repeat_factor
=
16
//
num_heads
if
num_heads
<
16
else
1
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
logits_soft_cap
]
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
logits_soft_cap
]
if
any
(
unsupported_features
):
if
any
(
unsupported_features
):
raise
NotImplementedError
(
raise
NotImplementedError
(
...
@@ -267,9 +273,16 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
...
@@ -267,9 +273,16 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
assert
isinstance
(
q
,
torch
.
Tensor
)
assert
isinstance
(
q
,
torch
.
Tensor
)
B
=
q
.
shape
[
0
]
B
=
q
.
shape
[
0
]
if
self
.
_needs_head_repeat
:
q
=
q
.
repeat_interleave
(
self
.
_head_repeat_factor
,
dim
=
1
)
kernel_num_heads
=
16
else
:
kernel_num_heads
=
self
.
num_heads
o
=
torch
.
zeros
(
o
=
torch
.
zeros
(
B
,
B
,
self
.
num_heads
,
kernel_
num_heads
,
self
.
kv_lora_rank
,
self
.
kv_lora_rank
,
dtype
=
attn_metadata
.
decode
.
attn_out_dtype
,
dtype
=
attn_metadata
.
decode
.
attn_out_dtype
,
device
=
q
.
device
,
device
=
q
.
device
,
...
@@ -291,4 +304,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
...
@@ -291,4 +304,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
kv_scale
=
layer
.
_k_scale
,
kv_scale
=
layer
.
_k_scale
,
)
)
if
self
.
_needs_head_repeat
:
o
=
o
[:,
::
self
.
_head_repeat_factor
,
:]
return
o
,
None
return
o
,
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment