Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
eac2dc2b
Unverified
Commit
eac2dc2b
authored
Mar 11, 2026
by
pschlan-amd
Committed by
GitHub
Mar 11, 2026
Browse files
AITER MLA backend: Avoid CPU sync in _build_decode (#35765)
Signed-off-by:
Patrick Schlangen
<
pschlan@amd.com
>
parent
d5080aea
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
15 deletions
+46
-15
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
+46
-15
No files found.
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
View file @
eac2dc2b
...
...
@@ -17,6 +17,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
MLACommonMetadataBuilder
,
QueryLenSupport
,
)
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.attention.backend
import
AttentionCGSupport
,
AttentionLayer
,
MultipleOf
from
vllm.v1.kv_cache_interface
import
AttentionSpec
...
...
@@ -108,13 +109,16 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
# Persistent buffer for paged_kv_indices to avoid blocking boolean mask
# indexing (block_table_tensor[mask]) which has data-dependent output size.
self
.
paged_kv_indices
=
torch
.
zeros
(
max_num_pages
,
dtype
=
torch
.
int32
,
device
=
device
)
if
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
():
self
.
paged_kv_indptr
=
torch
.
zeros
(
max_num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
paged_kv_indices
=
torch
.
zeros
(
max_num_pages
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
qo_indptr
=
torch
.
zeros
(
max_num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
device
...
...
@@ -134,11 +138,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
device
=
self
.
device
num_reqs
=
seq_lens_device
.
size
(
0
)
mask
=
torch
.
arange
(
block_table_tensor
.
size
(
1
),
dtype
=
block_table_tensor
.
dtype
,
device
=
device
).
unsqueeze
(
0
)
<
seq_lens_device
.
unsqueeze
(
1
)
paged_kv_indices
=
block_table_tensor
[
mask
]
# kernel block size is always 1, so each page has exactly 1 token.
# last_page_len is always 1 - just slice the pre-initialized buffer.
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
[:
num_reqs
]
...
...
@@ -153,14 +152,17 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
max_qo_len
=
qo_len
.
max
().
item
()
if
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
():
num_actual_pages
=
paged_kv_indices
.
size
(
0
)
self
.
paged_kv_indices
[:
num_actual_pages
].
copy_
(
paged_kv_indices
,
non_blocking
=
True
)
self
.
paged_kv_indices
[
num_actual_pages
:].
fill_
(
-
1
)
paged_kv_indices
=
self
.
paged_kv_indices
[:
num_actual_pages
]
self
.
paged_kv_indices
.
fill_
(
-
1
)
_copy_page_indices_kernel
[(
num_reqs
,)](
self
.
paged_kv_indices
,
block_table_tensor
,
block_table_tensor
.
stride
(
0
),
paged_kv_indptr
,
BLOCK_SIZE
=
1024
,
)
paged_kv_indices
=
self
.
paged_kv_indices
if
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
():
self
.
paged_kv_indptr
[:
1
+
num_reqs
].
copy_
(
paged_kv_indptr
,
non_blocking
=
True
)
...
...
@@ -196,6 +198,35 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
return
attn_metadata
@
triton
.
jit
def
_copy_page_indices_kernel
(
page_indices
,
block_table
,
block_table_stride
,
cu_num_blocks
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
"""Copy block table rows into a flat page_indices buffer using indptr.
Avoids blocking boolean mask indexing (tensor[mask]) which has
data-dependent output size and forces sync.
This is the same kernel as introduced in backends/flashinfer.py.
"""
req_idx
=
tl
.
program_id
(
0
)
row_ptr
=
block_table
+
req_idx
*
block_table_stride
start_idx
=
tl
.
load
(
cu_num_blocks
+
req_idx
)
end_idx
=
tl
.
load
(
cu_num_blocks
+
req_idx
+
1
)
num_blocks
=
end_idx
-
start_idx
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
for
i
in
tl
.
range
(
0
,
num_blocks
,
BLOCK_SIZE
):
block_ids
=
tl
.
load
(
row_ptr
+
i
+
offset
,
mask
=
i
+
offset
<
num_blocks
)
tl
.
store
(
page_indices
+
start_idx
+
i
+
offset
,
block_ids
,
mask
=
i
+
offset
<
num_blocks
,
)
class
AiterMLAImpl
(
MLACommonImpl
[
AiterMLAMetadata
]):
def
__init__
(
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment