Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9dbbc59b
Unverified
Commit
9dbbc59b
authored
Dec 16, 2025
by
Pleaplusone
Committed by
GitHub
Dec 16, 2025
Browse files
[ROCm][MTP] Support MTP for AITER MLA backend (#28624)
Signed-off-by:
ganyi
<
ygan@amd.com
>
parent
104003dc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
9 deletions
+15
-9
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
+15
-9
No files found.
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
View file @
9dbbc59b
...
...
@@ -15,6 +15,7 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonImpl
,
MLACommonMetadata
,
MLACommonMetadataBuilder
,
QueryLenSupport
,
)
from
vllm.v1.attention.backends.utils
import
AttentionCGSupport
from
vllm.v1.kv_cache_interface
import
AttentionSpec
...
...
@@ -51,6 +52,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
qo_indptr
:
torch
.
Tensor
|
None
=
None
# The dtype of MLA out tensor
attn_out_dtype
:
torch
.
dtype
=
torch
.
bfloat16
# The max query output length: int
max_qo_len
:
int
|
None
=
None
class
AiterMLAMetadata
(
MLACommonMetadata
[
AiterMLADecodeMetadata
]):
...
...
@@ -60,9 +63,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
class
AiterMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
AiterMLAMetadata
]):
# TODO(luka, lucas): audit this as part of:
# https://github.com/vllm-project/vllm/issues/22945
_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
(
AttentionCGSupport
.
UNIFORM_SINGLE_TOKEN_DECODE
)
_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
AttentionCGSupport
.
UNIFORM_BATCH
query_len_support
:
ClassVar
[
QueryLenSupport
]
=
QueryLenSupport
.
UNIFORM
def
__init__
(
self
,
...
...
@@ -97,8 +99,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
qo_indptr
=
torch
.
arange
(
0
,
max_num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
device
self
.
qo_indptr
=
torch
.
zeros
(
max_num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
def
_build_decode
(
...
...
@@ -128,6 +130,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
seq_lens_device
.
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
),
]
)
qo_len
=
query_start_loc_cpu
[
1
:]
-
query_start_loc_cpu
[:
-
1
]
max_qo_len
=
qo_len
.
max
().
item
()
if
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
():
num_actual_pages
=
paged_kv_indices
.
size
(
0
)
...
...
@@ -150,6 +154,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
self
.
paged_kv_last_page_len
[
num_reqs
:].
fill_
(
1
)
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
[:
num_reqs
]
self
.
qo_indptr
[:
1
+
num_reqs
].
copy_
(
query_start_loc_device
,
non_blocking
=
True
)
self
.
qo_indptr
[
1
+
num_reqs
:]
=
query_start_loc_device
[
-
1
]
qo_indptr
=
self
.
qo_indptr
[:
1
+
num_reqs
]
else
:
...
...
@@ -165,6 +173,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
paged_kv_last_page_len
=
paged_kv_last_page_len
,
qo_indptr
=
qo_indptr
,
dcp_tot_seq_lens
=
dcp_tot_seq_lens_device
,
max_qo_len
=
max_qo_len
,
attn_out_dtype
=
self
.
decode_attn_out_dtype
,
)
...
...
@@ -255,16 +264,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
kv_buffer
=
kv_c_and_k_pe_cache
.
unsqueeze
(
2
)
# max_seqlen_qo must be 1 except for MTP
# TODO: Find the best value for MTP
max_seqlen_qo
=
1
rocm_aiter_ops
.
mla_decode_fwd
(
q
,
kv_buffer
,
o
,
self
.
scale
,
attn_metadata
.
decode
.
qo_indptr
,
max_seqlen_qo
,
attn_metadata
.
decode
.
max_qo_len
,
attn_metadata
.
decode
.
paged_kv_indptr
,
attn_metadata
.
decode
.
paged_kv_indices
,
attn_metadata
.
decode
.
paged_kv_last_page_len
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment