Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2e984060
Unverified
Commit
2e984060
authored
Apr 08, 2026
by
Yongye Zhu
Committed by
GitHub
Apr 08, 2026
Browse files
[Refactor] Improve indexer decode path metadata preparation (#38865)
parent
ef5a2268
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
162 additions
and
102 deletions
+162
-102
csrc/sampler.cu
csrc/sampler.cu
+24
-9
csrc/topk.cu
csrc/topk.cu
+4
-2
vllm/model_executor/layers/sparse_attn_indexer.py
vllm/model_executor/layers/sparse_attn_indexer.py
+7
-16
vllm/v1/attention/backends/mla/indexer.py
vllm/v1/attention/backends/mla/indexer.py
+127
-75
No files found.
csrc/sampler.cu
View file @
2e984060
...
...
@@ -564,8 +564,9 @@ template <int kNumThreadsPerBlock, bool useRadixSort,
bool
multipleBlocksPerRow
=
false
,
bool
mergeBlocks
=
false
>
static
__global__
__launch_bounds__
(
kNumThreadsPerBlock
)
void
topKPerRowDecode
(
const
float
*
logits
,
const
int
*
seqLens
,
int
*
outIndices
,
int
stride0
,
int
stride1
,
const
int
topK
,
int
next_n
,
float
*
outLogits
=
nullptr
,
const
int
numBlocksToMerge
=
0
,
const
int
*
indices
=
nullptr
)
{
int
stride1
,
const
int
topK
,
int
next_n
,
int
seqLensIs2D
=
0
,
float
*
outLogits
=
nullptr
,
const
int
numBlocksToMerge
=
0
,
const
int
*
indices
=
nullptr
)
{
// The number of bins in the histogram.
static
constexpr
int
kNumBins
=
2048
;
...
...
@@ -574,8 +575,16 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(
// The range of logits within the row.
int
rowStart
=
0
;
int
seq_len
=
seqLens
[
rowIdx
/
next_n
];
int
rowEnd
=
max
(
0
,
seq_len
-
next_n
+
(
rowIdx
%
next_n
)
+
1
);
int
batch_idx
=
rowIdx
/
next_n
;
int
next_n_idx
=
rowIdx
%
next_n
;
// seqLensIs2D=0: 1D seqLens — all rows in a batch share the same seq_len;
// kernel computes per-row effective length via offset.
// seqLensIs2D=1: 2D seqLens — each logit row has its own pre-computed
// effective length (flat index rowIdx = b*next_n + j maps
// directly to seqLens[b, j] in C-contiguous layout).
int
seq_len
=
seqLensIs2D
?
seqLens
[
rowIdx
]
:
seqLens
[
batch_idx
];
int
rowEnd
=
seqLensIs2D
?
max
(
0
,
seq_len
)
:
max
(
0
,
seq_len
-
next_n
+
next_n_idx
+
1
);
// Local pointers to this block
if
constexpr
(
!
multipleBlocksPerRow
&&
!
mergeBlocks
)
{
...
...
@@ -653,6 +662,11 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
auto
numColumns
=
logits
.
size
(
1
);
// True if seqLens is 2D (B, next_n): each logit row has its own pre-computed
// effective seq_len. False if seqLens is 1D (B,): all rows in a batch share
// the same seq_len and the kernel computes the per-row offset itself.
int
seqLensIs2D
=
seqLens
.
dim
()
==
2
?
1
:
0
;
if
(
numColumns
<
kSortingAlgorithmThreshold
)
{
// Use insertion sort
vllm
::
topKPerRowDecode
<
kNumThreadsPerBlock
,
false
>
...
...
@@ -660,7 +674,7 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
logits
.
data_ptr
<
float
>
(),
seqLens
.
data_ptr
<
int
>
(),
indices
.
data_ptr
<
int
>
(),
static_cast
<
int
>
(
stride0
),
static_cast
<
int
>
(
stride1
),
static_cast
<
int
>
(
topK
),
static_cast
<
int
>
(
next_n
));
static_cast
<
int
>
(
next_n
)
,
seqLensIs2D
);
}
else
if
(
numColumns
<
kSplitWorkThreshold
)
{
// From this threshold, use radix sort instead
vllm
::
topKPerRowDecode
<
kNumThreadsPerBlock
,
true
>
...
...
@@ -668,7 +682,7 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
logits
.
data_ptr
<
float
>
(),
seqLens
.
data_ptr
<
int
>
(),
indices
.
data_ptr
<
int
>
(),
static_cast
<
int
>
(
stride0
),
static_cast
<
int
>
(
stride1
),
static_cast
<
int
>
(
topK
),
static_cast
<
int
>
(
next_n
));
static_cast
<
int
>
(
next_n
)
,
seqLensIs2D
);
}
else
{
// Long sequences are run in two steps
constexpr
auto
multipleBlocksPerRowConfig
=
10
;
...
...
@@ -686,15 +700,16 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
logits
.
data_ptr
<
float
>
(),
seqLens
.
data_ptr
<
int
>
(),
outIndicesAux
.
data_ptr
<
int
>
(),
static_cast
<
int
>
(
stride0
),
static_cast
<
int
>
(
stride1
),
static_cast
<
int
>
(
topK
),
static_cast
<
int
>
(
next_n
),
outLogitsAux
.
data_ptr
<
float
>
());
static_cast
<
int
>
(
next_n
),
seqLensIs2D
,
outLogitsAux
.
data_ptr
<
float
>
());
constexpr
int
kNumThreadsPerBlockMerge
=
1024
;
vllm
::
topKPerRowDecode
<
kNumThreadsPerBlockMerge
,
true
,
false
,
true
>
<<<
numRows
,
kNumThreadsPerBlockMerge
,
topK
*
sizeof
(
int32_t
),
stream
>>>
(
outLogitsAux
.
data_ptr
<
float
>
(),
seqLens
.
data_ptr
<
int
>
(),
indices
.
data_ptr
<
int
>
(),
multipleBlocksPerRowConfig
*
topK
,
1
,
static_cast
<
int
>
(
topK
),
static_cast
<
int
>
(
next_n
),
nullptr
,
multipleBlocksPerRowConfig
,
outIndicesAux
.
data_ptr
<
int
>
());
static_cast
<
int
>
(
topK
),
static_cast
<
int
>
(
next_n
),
seqLensIs2D
,
nullptr
,
multipleBlocksPerRowConfig
,
outIndicesAux
.
data_ptr
<
int
>
());
}
}
...
...
csrc/topk.cu
View file @
2e984060
...
...
@@ -21,13 +21,15 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
TORCH_CHECK
(
lengths
.
dtype
()
==
torch
::
kInt32
,
"lengths must be int32"
);
TORCH_CHECK
(
output
.
dtype
()
==
torch
::
kInt32
,
"output must be int32"
);
TORCH_CHECK
(
logits
.
dim
()
==
2
,
"logits must be 2D"
);
TORCH_CHECK
(
lengths
.
dim
()
==
1
,
"lengths must be 1D"
);
TORCH_CHECK
(
lengths
.
dim
()
==
1
||
lengths
.
dim
()
==
2
,
"lengths must be 1D or 2D"
);
TORCH_CHECK
(
lengths
.
is_contiguous
(),
"lengths must be contiguous"
);
TORCH_CHECK
(
output
.
dim
()
==
2
,
"output must be 2D"
);
const
int64_t
num_rows
=
logits
.
size
(
0
);
const
int64_t
stride
=
logits
.
size
(
1
);
TORCH_CHECK
(
lengths
.
size
(
0
)
==
num_rows
,
"lengths size mismatch"
);
TORCH_CHECK
(
lengths
.
numel
(
)
==
num_rows
,
"lengths size mismatch"
);
TORCH_CHECK
(
output
.
size
(
0
)
==
num_rows
&&
output
.
size
(
1
)
==
k
,
"output size mismatch"
);
namespace
P
=
vllm
::
persistent
;
...
...
vllm/model_executor/layers/sparse_attn_indexer.py
View file @
2e984060
...
...
@@ -183,13 +183,15 @@ def sparse_attn_indexer(
# TODO: move and optimize below logic with triton kernels
batch_size
=
padded_q_fp8_decode_tokens
.
shape
[
0
]
next_n
=
padded_q_fp8_decode_tokens
.
shape
[
1
]
assert
batch_size
==
decode_metadata
.
seq_lens
.
shape
[
0
]
num_padded_tokens
=
batch_size
*
next_n
seq_lens
=
decode_metadata
.
seq_lens
[:
batch_size
]
# seq_lens is (B, next_n) for native spec decode, (B,) otherwise.
# fp8_paged_mqa_logits and all topk kernels accept both shapes.
logits
=
fp8_paged_mqa_logits
(
padded_q_fp8_decode_tokens
,
kv_cache
,
weights
[:
num_padded_tokens
],
decode_metadata
.
seq_lens
,
seq_lens
,
decode_metadata
.
block_table
,
decode_metadata
.
schedule_metadata
,
max_model_len
=
max_model_len
,
...
...
@@ -198,17 +200,6 @@ def sparse_attn_indexer(
num_rows
=
logits
.
shape
[
0
]
topk_indices
=
topk_indices_buffer
[:
num_padded_tokens
,
:
topk_tokens
]
if
next_n
==
1
:
lengths
=
decode_metadata
.
seq_lens
else
:
# (bs,) -> (bs, 1) + (next_n,) -> (bs, next_n) -> (bs * next_n,)
lengths
=
(
decode_metadata
.
seq_lens
.
unsqueeze
(
1
)
-
next_n
+
1
+
decode_metadata
.
offsets
).
flatten
()
if
current_platform
.
is_cuda
():
workspace_manager
=
current_workspace_manager
()
(
topk_workspace
,)
=
workspace_manager
.
get_simultaneous
(
...
...
@@ -216,7 +207,7 @@ def sparse_attn_indexer(
)
torch
.
ops
.
_C
.
persistent_topk
(
logits
,
len
gth
s
,
decode_metadata
.
seq_
lens
,
topk_indices
,
topk_workspace
,
topk_tokens
,
...
...
@@ -227,7 +218,7 @@ def sparse_attn_indexer(
ops
.
top_k_per_row_decode
(
logits
,
next_n
,
decode_metadata
.
seq_lens
,
seq_lens
,
topk_indices
,
num_rows
,
logits
.
stride
(
0
),
...
...
@@ -238,7 +229,7 @@ def sparse_attn_indexer(
torch
.
ops
.
_C
.
top_k_per_row_decode
(
logits
,
next_n
,
decode_metadata
.
seq_lens
,
seq_lens
,
topk_indices
,
num_rows
,
logits
.
stride
(
0
),
...
...
vllm/v1/attention/backends/mla/indexer.py
View file @
2e984060
...
...
@@ -141,11 +141,14 @@ class DeepseekV32IndexerPrefillMetadata:
@
dataclass
class
DeepSeekV32IndexerDecodeMetadata
:
block_table
:
torch
.
Tensor
# seq_lens: per-token effective context lengths.
# - flatten path / plain decode: 1D (batch_size,)
# - native MTP path: 2D (B, next_n) where [b,j] = L_b - next_n + j + 1
# Both fp8_paged_mqa_logits and the topk kernels accept both shapes.
seq_lens
:
torch
.
Tensor
decode_lens
:
torch
.
Tensor
requires_padding
:
bool
schedule_metadata
:
torch
.
Tensor
offsets
:
torch
.
Tensor
|
None
# Precomputed offsets for speculative decoding
@
dataclass
...
...
@@ -283,24 +286,34 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
sm_count
=
num_compute_units
(
self
.
device
.
index
)
self
.
num_sms
=
sm_count
self
.
decode_lens_buffer
=
torch
.
empty
(
self
.
offsets_buffer
=
torch
.
arange
(
next_n
,
device
=
self
.
device
,
dtype
=
torch
.
int32
)
self
.
decode_lens_buffer
=
torch
.
zeros
(
(
scheduler_config
.
max_num_batched_tokens
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
self
.
offsets_buffer
=
torch
.
arange
(
next_n
,
device
=
self
.
device
,
dtype
=
torch
.
int32
)
self
.
arange
_buffer
=
torch
.
arange
(
scheduler_config
.
max_num_seqs
*
next_n
,
if
not
self
.
use_flattening
and
next_n
>
1
:
# Native MTP: 2D buffer for per-token seq_lens.
# Flattening path is never used, so no expanded_seq_lens_buffer.
self
.
decode_seq_lens
_buffer
=
torch
.
zeros
(
(
scheduler_config
.
max_num_seqs
,
next_n
)
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
self
.
expanded_seq_lens_buffer
=
torch
.
zeros
(
else
:
# Flattening or no MTP: 1D buffer for expanded per-token seq_lens.
self
.
decode_seq_lens_buffer
=
torch
.
zeros
(
(
scheduler_config
.
max_num_batched_tokens
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
self
.
arange_buffer
=
torch
.
arange
(
scheduler_config
.
max_num_seqs
*
next_n
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
max_num_blocks_per_req
=
cdiv
(
self
.
vllm_config
.
model_config
.
max_model_len
,
self
.
kv_cache_spec
.
block_size
*
get_total_cp_world_size
(),
...
...
@@ -367,6 +380,96 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
skip_kv_gather
=
skip_kv_gather
,
)
def
_prepare_decode_tensors
(
self
,
seq_lens
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
decode_lens
:
torch
.
Tensor
,
decode_lens_cpu
:
torch
.
Tensor
,
query_start_loc
:
torch
.
Tensor
,
num_decodes
:
int
,
num_decode_tokens
:
int
,
use_native
:
bool
,
next_n
:
int
,
max_decode_len
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
int
,
bool
]:
"""Expand seq_lens/block_table/decode_lens for the decode kernels.
Flatten path (not use_native, max_decode_len > 1):
Each multi-token decode request is expanded into individual
single-token entries so the kernel always sees next_n=1.
Native path (use_native or max_decode_len == 1):
Plain decode or spec-decode with 2D per-token context lengths.
Returns (seq_lens, block_table, decode_lens, batch_size, requires_padding).
seq_lens is 1D (batch_size,) for flatten/plain, 2D (B, next_n) for native MTP.
"""
if
not
use_native
and
max_decode_len
>
1
:
assert
self
.
decode_seq_lens_buffer
.
dim
()
==
1
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
# The context lengths are therefore
# [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0].
# 3 + 1 + 4 + 0 = 8
actual_expanded
=
int
(
decode_lens_cpu
.
sum
().
item
())
# Fuse expanded_base and expanded_starts into a single repeat_interleave:
# seq_len_i = (context_start[b] - query_start_loc[b]) + arange[i] + 1
# where context_start[b] = seq_lens[b] - decode_lens[b].
# Example: offsets = [7-0, 6-3, 8-4, 0-8] = [7, 3, 4, -8]
# expanded_offsets = [7, 7, 7, 3, 4, 4, 4, 4]
# result = [8, 9, 10, 7, 9, 10, 11, 12]
expanded_offsets
=
torch
.
repeat_interleave
(
seq_lens
-
decode_lens
-
query_start_loc
,
decode_lens
,
output_size
=
actual_expanded
,
)
# [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space
self
.
decode_seq_lens_buffer
[:
actual_expanded
]
=
(
expanded_offsets
+
self
.
arange_buffer
[:
actual_expanded
]
+
1
)
self
.
decode_seq_lens_buffer
[
actual_expanded
:]
=
0
seq_lens
=
self
.
decode_seq_lens_buffer
[:
num_decode_tokens
]
# Give each of the flattened entries the same block table row as the
# original request.
self
.
expanded_block_table_buffer
[:
actual_expanded
]
=
(
torch
.
repeat_interleave
(
block_table
,
decode_lens
,
dim
=
0
,
output_size
=
actual_expanded
)
)
if
actual_expanded
<
num_decode_tokens
:
self
.
expanded_block_table_buffer
[
actual_expanded
:
num_decode_tokens
,
0
]
=
0
block_table
=
self
.
expanded_block_table_buffer
[:
num_decode_tokens
]
# All reqs now have decode_len=1
self
.
decode_lens_buffer
[:
num_decode_tokens
]
=
1
decode_lens
=
self
.
decode_lens_buffer
[:
num_decode_tokens
]
return
seq_lens
,
block_table
,
decode_lens
,
num_decode_tokens
,
False
else
:
# Native path: plain decode (next_n==1) or spec decode
# with 2D per-token context lengths (next_n > 1).
#
# When decode_lens are not truly uniform (e.g. some requests have
# decode_len < next_n due to padding or short prefills), the simple
# reshape in sparse_attn_indexer won't work. Use pack_seq_triton
# (requires_padding) instead.
min_decode_len
=
int
(
decode_lens_cpu
.
min
().
item
())
requires_padding
=
min_decode_len
!=
max_decode_len
if
use_native
and
next_n
>
1
:
assert
self
.
decode_seq_lens_buffer
.
dim
()
==
2
# (B, next_n): token j attends to L - next_n + j + 1 KV tokens
self
.
decode_seq_lens_buffer
[:
num_decodes
]
=
(
seq_lens
.
unsqueeze
(
1
)
-
next_n
+
1
+
self
.
offsets_buffer
)
seq_lens
=
self
.
decode_seq_lens_buffer
[:
num_decodes
]
return
seq_lens
,
block_table
,
decode_lens
,
num_decodes
,
requires_padding
def
build
(
self
,
common_prefix_len
:
int
,
...
...
@@ -434,68 +537,20 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
next_n
=
1
+
self
.
num_speculative_tokens
use_native
=
not
self
.
use_flattening
and
max_decode_len
==
next_n
if
use_native
and
next_n
>
1
:
offsets
=
self
.
offsets_buffer
elif
max_decode_len
>
1
:
# Flatten multi-token decode requests into single-token
# batch entries, expanding seq_lens and block tables so
# the kernel always sees next_n=1.
# Also handles the edge case where use_flattening=False
# but max_decode_len != next_n (e.g. a batch containing some
# short prefills (q_len < next_n) and no true decodes).
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
# The context lengths are therefore
# [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0].
# 3 + 1 + 4 + 0 = 8
actual_expanded
=
int
(
decode_lens_cpu
.
sum
().
item
())
# [7, 6, 8, 0] -> [7, 7, 7, 6, 8, 8, 8, 8]
expanded_base
=
torch
.
repeat_interleave
(
seq_lens
-
decode_lens
,
decode_lens
,
output_size
=
actual_expanded
)
# [0, 3, 4, 8] -> [0, 0, 0, 3, 4, 4, 4, 4]
expanded_starts
=
torch
.
repeat_interleave
(
common_attn_metadata
.
query_start_loc
[:
num_decodes
],
decode_lens
,
output_size
=
actual_expanded
,
)
# [0, 1, 2, 0, 0, 1, 2, 3]
positions_within
=
(
self
.
arange_buffer
[:
actual_expanded
]
-
expanded_starts
)
# [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space
self
.
expanded_seq_lens_buffer
[:
actual_expanded
]
=
(
expanded_base
+
positions_within
+
1
)
self
.
expanded_seq_lens_buffer
[
actual_expanded
:]
=
0
seq_lens
=
self
.
expanded_seq_lens_buffer
[:
num_decode_tokens
]
# Give each of the flattened entries the same block table row as the
# original request.
self
.
expanded_block_table_buffer
[:
actual_expanded
]
=
(
torch
.
repeat_interleave
(
block_table
,
decode_lens
,
dim
=
0
,
output_size
=
actual_expanded
seq_lens
,
block_table
,
decode_lens
,
batch_size
,
requires_padding
=
(
self
.
_prepare_decode_tensors
(
seq_lens
=
seq_lens
,
block_table
=
block_table
,
decode_lens
=
decode_lens
,
decode_lens_cpu
=
decode_lens_cpu
,
query_start_loc
=
common_attn_metadata
.
query_start_loc
[:
num_decodes
],
num_decodes
=
num_decodes
,
num_decode_tokens
=
num_decode_tokens
,
use_native
=
use_native
,
next_n
=
next_n
,
max_decode_len
=
max_decode_len
,
)
)
if
actual_expanded
<
num_decode_tokens
:
self
.
expanded_block_table_buffer
[
actual_expanded
:
num_decode_tokens
,
0
]
=
0
block_table
=
self
.
expanded_block_table_buffer
[:
num_decode_tokens
]
# All reqs now have decode_len=1
self
.
decode_lens_buffer
[:
num_decode_tokens
]
=
1
decode_lens
=
self
.
decode_lens_buffer
[:
num_decode_tokens
]
offsets
=
None
else
:
offsets
=
None
# DeepGEMM is required for the paged MQA logits on CUDA devices
if
current_platform
.
is_cuda
()
and
has_deep_gemm
():
...
...
@@ -509,9 +564,8 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
block_table
=
block_table
,
seq_lens
=
seq_lens
,
decode_lens
=
decode_lens
,
requires_padding
=
False
,
requires_padding
=
requires_padding
,
schedule_metadata
=
self
.
scheduler_metadata_buffer
,
offsets
=
offsets
,
)
attn_metadata
=
DeepseekV32IndexerMetadata
(
...
...
@@ -531,6 +585,4 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
decode
=
decode_metadata
,
)
# if get_tensor_model_parallel_rank() == 0:
# logger.info(f"attn_metadata: {attn_metadata}")
return
attn_metadata
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment