Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
825c2dc1
Unverified
Commit
825c2dc1
authored
Jan 01, 2026
by
Kevin McKay
Committed by
GitHub
Jan 01, 2026
Browse files
[Bugfix][Hardware][AMD] Fix last_page_len calculation in AITER MLA decode (#31282)
Signed-off-by:
c0de128
<
kevin.mckay@outlook.com
>
parent
1f43c121
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
9 deletions
+12
-9
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
+12
-9
No files found.
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
View file @
825c2dc1
...
...
@@ -88,6 +88,13 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
# TODO: we can disambiguate between decode and mixed-prefill decode here
# so we can only use the persistent buffer if a cudagraph is actually
# being used.
# paged_kv_last_page_len is always 1s (kernel block size is always 1),
# so we create it once and reuse slices in both eager and cudagraph modes.
self
.
paged_kv_last_page_len
=
torch
.
ones
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
if
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
():
self
.
paged_kv_indptr
=
torch
.
zeros
(
max_num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
device
...
...
@@ -95,9 +102,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
self
.
paged_kv_indices
=
torch
.
zeros
(
max_num_pages
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
paged_kv_last_page_len
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
qo_indptr
=
torch
.
zeros
(
max_num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
device
...
...
@@ -122,7 +126,9 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
).
unsqueeze
(
0
)
<
seq_lens_device
.
unsqueeze
(
1
)
paged_kv_indices
=
block_table_tensor
[
mask
]
paged_kv_last_page_len
=
torch
.
where
(
seq_lens_device
==
0
,
1
,
seq_lens_device
)
# kernel block size is always 1, so each page has exactly 1 token.
# last_page_len is always 1 - just slice the pre-initialized buffer.
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
[:
num_reqs
]
paged_kv_indptr
=
torch
.
cat
(
[
...
...
@@ -148,11 +154,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
self
.
paged_kv_indptr
[
1
+
num_reqs
:].
fill_
(
paged_kv_indptr
[
-
1
])
paged_kv_indptr
=
self
.
paged_kv_indptr
[:
1
+
num_reqs
]
self
.
paged_kv_last_page_len
[:
num_reqs
].
copy_
(
paged_kv_last_page_len
,
non_blocking
=
True
)
self
.
paged_kv_last_page_len
[
num_reqs
:].
fill_
(
1
)
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
[:
num_reqs
]
# paged_kv_last_page_len already uses the pre-initialized buffer slice
# (set above), so no copy needed - buffer is always 1s.
self
.
qo_indptr
[:
1
+
num_reqs
].
copy_
(
query_start_loc_device
,
non_blocking
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment