Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bf9a5ddb
Unverified
Commit
bf9a5ddb
authored
Apr 16, 2026
by
Giancarlo Delfin
Committed by
GitHub
Apr 16, 2026
Browse files
[MLA] Optimize mla indexer prepare uniform decode for MTP > 1 (#39458)
Signed-off-by:
Giancarlo Delfin
<
gdelfin@inferact.ai
>
parent
79e799eb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
100 additions
and
43 deletions
+100
-43
vllm/v1/attention/backends/mla/indexer.py
vllm/v1/attention/backends/mla/indexer.py
+100
-43
No files found.
vllm/v1/attention/backends/mla/indexer.py
View file @
bf9a5ddb
...
...
@@ -8,6 +8,7 @@ import vllm.envs as envs
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.deep_gemm
import
(
get_paged_mqa_logits_metadata
,
has_deep_gemm
,
...
...
@@ -30,6 +31,40 @@ from vllm.v1.worker.cp_utils import get_total_cp_world_size
logger
=
init_logger
(
__name__
)
@
triton
.
jit
def
_prepare_uniform_decode_kernel
(
seq_lens_ptr
,
decode_seq_lens_ptr
,
block_table_ptr
,
block_table_stride
,
expanded_block_table_ptr
,
expanded_bt_stride
,
decode_lens_ptr
,
max_decode_len
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
idx
=
tl
.
program_id
(
0
)
req_id
=
idx
//
max_decode_len
local_idx
=
idx
%
max_decode_len
# Compute number of KVs attended to by this token.
seq_len
=
tl
.
load
(
seq_lens_ptr
+
req_id
)
per_token_seq_len
=
seq_len
-
max_decode_len
+
local_idx
+
1
tl
.
store
(
decode_seq_lens_ptr
+
idx
,
per_token_seq_len
)
# Copy block table row.
src
=
block_table_ptr
+
req_id
*
block_table_stride
dst
=
expanded_block_table_ptr
+
idx
*
expanded_bt_stride
for
i
in
tl
.
range
(
0
,
expanded_bt_stride
,
BLOCK_SIZE
):
off
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
off
<
expanded_bt_stride
src_block
=
tl
.
load
(
src
+
off
,
mask
=
mask
)
tl
.
store
(
dst
+
off
,
src_block
,
mask
=
mask
)
# All reqs now have decode_len = 1.
tl
.
store
(
decode_lens_ptr
+
idx
,
1
)
def
split_indexer_prefill_chunks
(
seq_lens_cpu
:
torch
.
Tensor
,
query_lens_cpu
:
torch
.
Tensor
,
...
...
@@ -405,52 +440,75 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
Returns (seq_lens, block_table, decode_lens, batch_size, requires_padding).
seq_lens is 1D (batch_size,) for flatten/plain, 2D (B, next_n) for native MTP.
"""
min_decode_len
=
int
(
decode_lens_cpu
.
min
().
item
())
if
not
use_native
and
max_decode_len
>
1
:
assert
self
.
decode_seq_lens_buffer
.
dim
()
==
1
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
# The context lengths are therefore
# [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0].
# 3 + 1 + 4 + 0 = 8
actual_expanded
=
int
(
decode_lens_cpu
.
sum
().
item
())
# Fuse expanded_base and expanded_starts into a single repeat_interleave:
# seq_len_i = (context_start[b] - query_start_loc[b]) + arange[i] + 1
# where context_start[b] = seq_lens[b] - decode_lens[b].
# Example: offsets = [7-0, 6-3, 8-4, 0-8] = [7, 3, 4, -8]
# expanded_offsets = [7, 7, 7, 3, 4, 4, 4, 4]
# result = [8, 9, 10, 7, 9, 10, 11, 12]
expanded_offsets
=
torch
.
repeat_interleave
(
seq_lens
-
decode_lens
-
query_start_loc
,
decode_lens
,
output_size
=
actual_expanded
,
)
if
min_decode_len
==
max_decode_len
:
# Uniform decode lengths.
num_decode_tokens
=
num_decodes
*
max_decode_len
_prepare_uniform_decode_kernel
[(
num_decode_tokens
,)](
seq_lens
,
self
.
decode_seq_lens_buffer
,
block_table
,
block_table
.
stride
(
0
),
self
.
expanded_block_table_buffer
,
self
.
expanded_block_table_buffer
.
stride
(
0
),
self
.
decode_lens_buffer
,
max_decode_len
,
BLOCK_SIZE
=
1024
,
)
self
.
decode_seq_lens_buffer
[
num_decode_tokens
:]
=
0
seq_lens
=
self
.
decode_seq_lens_buffer
[:
num_decode_tokens
]
block_table
=
self
.
expanded_block_table_buffer
[:
num_decode_tokens
]
decode_lens
=
self
.
decode_lens_buffer
[:
num_decode_tokens
]
return
seq_lens
,
block_table
,
decode_lens
,
num_decode_tokens
,
False
else
:
# Variable decode lengths.
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
# The context lengths are therefore
# [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0].
# 3 + 1 + 4 + 0 = 8
actual_expanded
=
int
(
decode_lens_cpu
.
sum
().
item
())
# Fuse expanded_base and expanded_starts into a single
# repeat_interleave:
# seq_len_i = (context_start[b] - query_start_loc[b]) + arange[i] + 1
# where context_start[b] = seq_lens[b] - decode_lens[b].
# Example: offsets = [7-0, 6-3, 8-4, 0-8] = [7, 3, 4, -8]
# expanded_offsets = [7, 7, 7, 3, 4, 4, 4, 4]
# result = [8, 9, 10, 7, 9, 10, 11, 12]
expanded_offsets
=
torch
.
repeat_interleave
(
seq_lens
-
decode_lens
-
query_start_loc
,
decode_lens
,
output_size
=
actual_expanded
,
)
# [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space
self
.
decode_seq_lens_buffer
[:
actual_expanded
]
=
(
expanded_offsets
+
self
.
arange_buffer
[:
actual_expanded
]
+
1
)
self
.
decode_seq_lens_buffer
[
actual_expanded
:]
=
0
seq_lens
=
self
.
decode_seq_lens_buffer
[:
num_decode_tokens
]
# Give each of the flattened entries the same block table row as the
# original request.
self
.
expanded_block_table_buffer
[:
actual_expanded
]
=
(
torch
.
repeat_interleave
(
block_table
,
decode_lens
,
dim
=
0
,
output_size
=
actual_expanded
# [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space
self
.
decode_seq_lens_buffer
[:
actual_expanded
]
=
(
expanded_offsets
+
self
.
arange_buffer
[:
actual_expanded
]
+
1
)
)
if
actual_expanded
<
num_decode_tokens
:
self
.
expanded_block_table_buffer
[
actual_expanded
:
num_decode_tokens
,
0
]
=
0
block_table
=
self
.
expanded_block_table_buffer
[:
num_decode_tokens
]
# All reqs now have decode_len=1
self
.
decode_lens_buffer
[:
num_decode_tokens
]
=
1
decode_lens
=
self
.
decode_lens_buffer
[:
num_decode_tokens
]
return
seq_lens
,
block_table
,
decode_lens
,
num_decode_tokens
,
False
self
.
decode_seq_lens_buffer
[
actual_expanded
:]
=
0
seq_lens
=
self
.
decode_seq_lens_buffer
[:
num_decode_tokens
]
# Give each of the flattened entries the same block table row as the
# original request.
self
.
expanded_block_table_buffer
[:
actual_expanded
]
=
(
torch
.
repeat_interleave
(
block_table
,
decode_lens
,
dim
=
0
,
output_size
=
actual_expanded
)
)
if
actual_expanded
<
num_decode_tokens
:
self
.
expanded_block_table_buffer
[
actual_expanded
:
num_decode_tokens
,
0
]
=
0
block_table
=
self
.
expanded_block_table_buffer
[:
num_decode_tokens
]
# All reqs now have decode_len=1
self
.
decode_lens_buffer
[:
num_decode_tokens
]
=
1
decode_lens
=
self
.
decode_lens_buffer
[:
num_decode_tokens
]
return
seq_lens
,
block_table
,
decode_lens
,
num_decode_tokens
,
False
else
:
# Native path: plain decode (next_n==1) or spec decode
# with 2D per-token context lengths (next_n > 1).
...
...
@@ -459,7 +517,6 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
# decode_len < next_n due to padding or short prefills), the simple
# reshape in sparse_attn_indexer won't work. Use pack_seq_triton
# (requires_padding) instead.
min_decode_len
=
int
(
decode_lens_cpu
.
min
().
item
())
requires_padding
=
min_decode_len
!=
max_decode_len
if
use_native
and
next_n
>
1
:
assert
self
.
decode_seq_lens_buffer
.
dim
()
==
2
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment