Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1e50f1be
Unverified
Commit
1e50f1be
authored
Oct 02, 2025
by
Chen Zhang
Committed by
GitHub
Oct 02, 2025
Browse files
[Deepseek v3.2] Support indexer prefill chunking (#25999)
Signed-off-by:
Chen Zhang
<
zhangch99@outlook.com
>
parent
ad87ba92
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
149 additions
and
79 deletions
+149
-79
tests/v1/attention/test_sparse_mla_backends.py
tests/v1/attention/test_sparse_mla_backends.py
+22
-0
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+37
-38
vllm/v1/attention/backends/mla/indexer.py
vllm/v1/attention/backends/mla/indexer.py
+90
-41
No files found.
tests/v1/attention/test_sparse_mla_backends.py
View file @
1e50f1be
...
@@ -22,6 +22,7 @@ from vllm.utils import cdiv
...
@@ -22,6 +22,7 @@ from vllm.utils import cdiv
from
vllm.v1.attention.backends.mla.flashmla_sparse
import
(
from
vllm.v1.attention.backends.mla.flashmla_sparse
import
(
FlashMLASparseBackend
,
FlashMLASparseDecodeAndContextMetadata
,
FlashMLASparseBackend
,
FlashMLASparseDecodeAndContextMetadata
,
FlashMLASparseImpl
,
FlashMLASparseMetadata
)
FlashMLASparseImpl
,
FlashMLASparseMetadata
)
from
vllm.v1.attention.backends.mla.indexer
import
split_prefill_chunks
SPARSE_BACKEND_BATCH_SPECS
=
{
SPARSE_BACKEND_BATCH_SPECS
=
{
name
:
BATCH_SPECS
[
name
]
name
:
BATCH_SPECS
[
name
]
...
@@ -424,3 +425,24 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
...
@@ -424,3 +425,24 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
sdpa_reference
,
sdpa_reference
,
rtol
=
0.5
,
rtol
=
0.5
,
atol
=
0.5
)
atol
=
0.5
)
@
pytest
.
mark
.
parametrize
(
"seq_lens,max_buf,start,expected"
,
[
# Basic split: totals per chunk ≤ max_buf
(
torch
.
tensor
([
2
,
3
,
4
,
2
]),
5
,
0
,
[(
0
,
2
),
(
2
,
3
),
(
3
,
4
)]),
# Non-zero start index
(
torch
.
tensor
([
2
,
3
,
4
,
2
]),
5
,
1
,
[(
1
,
2
),
(
2
,
3
),
(
3
,
4
)]),
# Exact fits should split between items when adding the next would
# overflow
(
torch
.
tensor
([
5
,
5
,
5
]),
5
,
0
,
[(
0
,
1
),
(
1
,
2
),
(
2
,
3
)]),
# All requests fit in a single chunk
(
torch
.
tensor
([
1
,
1
,
1
]),
10
,
0
,
[(
0
,
3
)]),
# Large buffer with non-zero start
(
torch
.
tensor
([
4
,
4
,
4
]),
100
,
1
,
[(
1
,
3
)]),
],
)
def
test_split_prefill_chunks
(
seq_lens
,
max_buf
,
start
,
expected
):
out
=
split_prefill_chunks
(
seq_lens
,
max_buf
,
start
)
assert
out
==
expected
vllm/model_executor/models/deepseek_v2.py
View file @
1e50f1be
...
@@ -583,44 +583,43 @@ def sparse_attn_indexer(
...
@@ -583,44 +583,43 @@ def sparse_attn_indexer(
topk_indices_buffer
[:
hidden_states
.
shape
[
0
]]
=
-
1
topk_indices_buffer
[:
hidden_states
.
shape
[
0
]]
=
-
1
if
has_prefill
:
if
has_prefill
:
prefill_metadata
=
attn_metadata
.
prefill
prefill_metadata
=
attn_metadata
.
prefill
num_prefills
=
attn_metadata
.
num_prefills
for
chunk
in
prefill_metadata
.
chunks
:
k_fp8
=
torch
.
empty
([
prefill_metadata
.
total_seq_lens
,
head_dim
],
k_fp8
=
torch
.
empty
([
chunk
.
total_seq_lens
,
head_dim
],
device
=
k
.
device
,
device
=
k
.
device
,
dtype
=
torch
.
float8_e4m3fn
)
dtype
=
torch
.
float8_e4m3fn
)
k_scale
=
torch
.
empty
([
prefill_metadata
.
total_seq_lens
,
1
],
k_scale
=
torch
.
empty
([
chunk
.
total_seq_lens
,
1
],
device
=
k
.
device
,
device
=
k
.
device
,
dtype
=
torch
.
float32
)
dtype
=
torch
.
float32
)
cp_gather_indexer_k_quant_cache
(
cp_gather_indexer_k_quant_cache
(
kv_cache
,
kv_cache
,
k_fp8
,
k_fp8
,
k_scale
,
k_scale
,
prefill_metadata
.
block_table
,
chunk
.
block_table
,
prefill_metadata
.
cu_seq_lens
,
chunk
.
cu_seq_lens
,
num_prefills
,
chunk
.
num_reqs
,
)
)
cu_seqlen_ks
=
prefill_metadata
.
cu_seqlen_ks
logits
=
fp8_mqa_logits
(
cu_seqlen_ke
=
prefill_metadata
.
cu_seqlen_ke
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
],
num_tokens
=
attn_metadata
.
num_actual_tokens
(
k_fp8
,
k_scale
),
logits
=
fp8_mqa_logits
(
weights
[
chunk
.
token_start
:
chunk
.
token_end
],
q_fp8
[
num_decode_tokens
:
num_tokens
],
chunk
.
cu_seqlen_ks
,
(
k_fp8
,
k_scale
),
chunk
.
cu_seqlen_ke
,
weights
[
num_decode_tokens
:
num_tokens
],
)
cu_seqlen_ks
,
topk_indices
=
logits
.
topk
(
min
(
topk_tokens
,
logits
.
shape
[
-
1
]),
cu_seqlen_ke
,
dim
=-
1
)[
1
]
)
topk_indices
-=
chunk
.
cu_seqlen_ks
[:,
None
]
topk_indices
=
logits
.
topk
(
min
(
topk_tokens
,
logits
.
shape
[
-
1
]),
mask_lo
=
topk_indices
>=
0
dim
=-
1
)[
1
]
mask_hi
=
topk_indices
-
(
chunk
.
cu_seqlen_ke
-
topk_indices
-=
cu_seqlen_ks
[:,
None
]
chunk
.
cu_seqlen_ks
)[:,
None
]
<
0
mask_lo
=
topk_indices
>=
0
mask
=
torch
.
full_like
(
topk_indices
,
mask_hi
=
topk_indices
-
(
cu_seqlen_ke
-
cu_seqlen_ks
)[:,
None
]
<
0
False
,
mask
=
torch
.
full_like
(
topk_indices
,
dtype
=
torch
.
bool
,
False
,
device
=
topk_indices
.
device
)
dtype
=
torch
.
bool
,
mask
=
mask_lo
&
mask_hi
device
=
topk_indices
.
device
)
topk_indices
=
topk_indices
.
masked_fill
(
~
mask
,
-
1
)
mask
=
mask_lo
&
mask_hi
topk_indices_buffer
[
topk_indices
=
topk_indices
.
masked_fill
(
~
mask
,
-
1
)
chunk
.
token_start
:
chunk
.
token_end
,
:
topk_indices
.
topk_indices_buffer
[
num_decode_tokens
:
num_tokens
,
:
topk_indices
.
shape
[
-
1
]]
=
topk_indices
.
to
(
dtype
=
torch
.
int32
)
shape
[
-
1
]]
=
topk_indices
.
to
(
dtype
=
torch
.
int32
)
if
has_decode
:
if
has_decode
:
decode_metadata
=
attn_metadata
.
decode
decode_metadata
=
attn_metadata
.
decode
...
...
vllm/v1/attention/backends/mla/indexer.py
View file @
1e50f1be
...
@@ -49,14 +49,20 @@ class DeepseekV32IndexerBackend(AttentionBackend):
...
@@ -49,14 +49,20 @@ class DeepseekV32IndexerBackend(AttentionBackend):
@
dataclass
@
dataclass
class
DeepseekV32IndexerPrefillMetadata
:
class
DeepseekV32IndexerPrefill
Chunk
Metadata
:
block_table
:
torch
.
Tensor
block_table
:
torch
.
Tensor
query_start_loc
:
torch
.
Tensor
max_query_len
:
int
cu_seqlen_ks
:
torch
.
Tensor
cu_seqlen_ks
:
torch
.
Tensor
cu_seqlen_ke
:
torch
.
Tensor
cu_seqlen_ke
:
torch
.
Tensor
cu_seq_lens
:
torch
.
Tensor
cu_seq_lens
:
torch
.
Tensor
total_seq_lens
:
int
total_seq_lens
:
int
token_start
:
int
token_end
:
int
num_reqs
:
int
@
dataclass
class
DeepseekV32IndexerPrefillMetadata
:
chunks
:
list
[
DeepseekV32IndexerPrefillChunkMetadata
]
@
dataclass
@
dataclass
...
@@ -98,8 +104,8 @@ class DeepseekV32IndexerMetadata:
...
@@ -98,8 +104,8 @@ class DeepseekV32IndexerMetadata:
# TODO (zyongye) optimize this, this is now vibe coded
# TODO (zyongye) optimize this, this is now vibe coded
def
kv_spans_from_batches
(
def
kv_spans_from_batches
(
start_seq_loc
:
torch
.
Tensor
,
start_seq_loc
:
torch
.
Tensor
,
seq_len_per_batch
:
torch
.
Tensor
,
seq_len_per_batch
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
device
:
torch
.
device
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
Args:
Args:
start_seq_loc: 1D long tensor [B+1], cumulative counts of
start_seq_loc: 1D long tensor [B+1], cumulative counts of
...
@@ -122,7 +128,7 @@ def kv_spans_from_batches(
...
@@ -122,7 +128,7 @@ def kv_spans_from_batches(
are the **last** `counts[i]` positions of that sequence.
are the **last** `counts[i]` positions of that sequence.
"""
"""
q
=
start_seq_loc
.
to
(
dtype
=
torch
.
long
)
q
=
start_seq_loc
.
to
(
dtype
=
torch
.
long
)
L
=
seq_len_per_batch
.
to
(
dtype
=
torch
.
long
,
device
=
q
.
device
)
L
=
seq_len_per_batch
.
to
(
dtype
=
torch
.
long
)
assert
q
.
dim
()
==
1
and
L
.
dim
()
==
1
assert
q
.
dim
()
==
1
and
L
.
dim
()
==
1
assert
q
.
numel
()
==
L
.
numel
()
+
1
,
"start_seq_loc must have length B+1"
assert
q
.
numel
()
==
L
.
numel
()
+
1
,
"start_seq_loc must have length B+1"
...
@@ -130,7 +136,6 @@ def kv_spans_from_batches(
...
@@ -130,7 +136,6 @@ def kv_spans_from_batches(
counts
=
q
[
1
:]
-
q
[:
-
1
]
# [B]
counts
=
q
[
1
:]
-
q
[:
-
1
]
# [B]
N
=
int
(
q
[
-
1
].
item
())
# total selected tokens
N
=
int
(
q
[
-
1
].
item
())
# total selected tokens
B
=
L
.
numel
()
B
=
L
.
numel
()
device
=
L
.
device
if
N
==
0
:
if
N
==
0
:
return
(
torch
.
empty
(
0
,
dtype
=
torch
.
long
,
device
=
device
),
return
(
torch
.
empty
(
0
,
dtype
=
torch
.
long
,
device
=
device
),
...
@@ -140,8 +145,7 @@ def kv_spans_from_batches(
...
@@ -140,8 +145,7 @@ def kv_spans_from_batches(
kv_starts_per_batch
=
torch
.
cumsum
(
L
,
dim
=
0
)
-
L
# [B]
kv_starts_per_batch
=
torch
.
cumsum
(
L
,
dim
=
0
)
-
L
# [B]
# For each selected token, which batch does it belong to?
# For each selected token, which batch does it belong to?
batch_id
=
torch
.
repeat_interleave
(
torch
.
arange
(
B
,
device
=
device
),
batch_id
=
torch
.
repeat_interleave
(
torch
.
arange
(
B
),
counts
)
# [N]
counts
)
# [N]
# Map batch KV start to each token
# Map batch KV start to each token
start_tensor
=
kv_starts_per_batch
[
batch_id
]
# [N]
start_tensor
=
kv_starts_per_batch
[
batch_id
]
# [N]
...
@@ -151,22 +155,51 @@ def kv_spans_from_batches(
...
@@ -151,22 +155,51 @@ def kv_spans_from_batches(
L_expand
=
torch
.
repeat_interleave
(
L
,
counts
)
# [N]
L_expand
=
torch
.
repeat_interleave
(
L
,
counts
)
# [N]
m_expand
=
torch
.
repeat_interleave
(
counts
,
counts
)
# [N]
m_expand
=
torch
.
repeat_interleave
(
counts
,
counts
)
# [N]
# position within the selected block: 1..counts[b]
# position within the selected block: 1..counts[b]
pos_within
=
(
torch
.
arange
(
N
,
device
=
device
,
dtype
=
torch
.
long
)
-
pos_within
=
(
torch
.
arange
(
N
,
dtype
=
torch
.
long
)
-
torch
.
repeat_interleave
(
q
[:
-
1
],
counts
)
+
1
)
torch
.
repeat_interleave
(
q
[:
-
1
],
counts
)
+
1
)
local_pos
=
L_expand
-
m_expand
+
pos_within
# [N], 1-based
local_pos
=
L_expand
-
m_expand
+
pos_within
# [N], 1-based
end_location
=
start_tensor
+
local_pos
# exclusive end
end_location
=
start_tensor
+
local_pos
# exclusive end
return
start_tensor
.
int
(),
end_location
.
int
()
return
start_tensor
.
int
()
.
to
(
device
)
,
end_location
.
int
()
.
to
(
device
)
def
get_max_prefill_buffer_size
(
vllm_config
:
VllmConfig
):
def
get_max_prefill_buffer_size
(
vllm_config
:
VllmConfig
):
max_model_len
=
vllm_config
.
model_config
.
max_model_len
max_model_len
=
vllm_config
.
model_config
.
max_model_len
# max_num_batched_tokens = \
# NOTE(Chen): 2 is a magic number for controlling the prefill buffer size.
# vllm_config.scheduler_config.max_num_batched_tokens
# May be tuned later.
max_num_seq
=
vllm_config
.
scheduler_config
.
max_num_seqs
return
max_model_len
*
2
# NOTE(Chen): an estimated max size of flattened_kv. Need to double check.
return
max_model_len
*
max_num_seq
def
split_prefill_chunks
(
seq_lens_cpu
:
torch
.
Tensor
,
max_prefill_buffer_size
:
int
,
reqs_start
:
int
)
->
list
[
tuple
[
int
,
int
]]:
"""
Split the prefill chunks into a list of tuples of (reqs_start, reqs_end)
such that the total sequence length of each chunk is less than the
maximum prefill buffer size.
Args:
seq_lens_cpu: The sequence lengths of the prefill requests.
max_prefill_buffer_size: The maximum prefill buffer size.
reqs_start: The start index of the prefill requests.
Returns:
A list of tuples of (reqs_start, reqs_end).
"""
chunk_seq_ids
=
[]
total_seq_lens
=
0
for
i
in
range
(
reqs_start
,
len
(
seq_lens_cpu
)):
cur_seq_len
=
seq_lens_cpu
[
i
].
item
()
assert
cur_seq_len
<=
max_prefill_buffer_size
total_seq_lens
+=
cur_seq_len
if
total_seq_lens
>
max_prefill_buffer_size
:
chunk_seq_ids
.
append
((
reqs_start
,
i
))
reqs_start
=
i
total_seq_lens
=
cur_seq_len
if
total_seq_lens
>
0
:
chunk_seq_ids
.
append
((
reqs_start
,
len
(
seq_lens_cpu
)))
return
chunk_seq_ids
class
DeepseekV32IndexerMetadataBuilder
(
AttentionMetadataBuilder
):
class
DeepseekV32IndexerMetadataBuilder
(
AttentionMetadataBuilder
):
...
@@ -201,6 +234,33 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
...
@@ -201,6 +234,33 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
self
.
device
)
def
build_one_prefill_chunk
(
self
,
reqs_start
,
reqs_end
,
query_start_loc_cpu
,
seq_lens_cpu
,
block_table
):
prefill_query_start_loc
=
query_start_loc_cpu
[
reqs_start
:
reqs_end
+
1
]
-
query_start_loc_cpu
[
reqs_start
]
cu_seqlen_ks
,
cu_seqlen_ke
=
kv_spans_from_batches
(
prefill_query_start_loc
,
seq_lens_cpu
[
reqs_start
:
reqs_end
],
self
.
device
)
token_start
=
query_start_loc_cpu
[
reqs_start
].
item
()
token_end
=
query_start_loc_cpu
[
reqs_end
].
item
()
total_seq_lens
=
seq_lens_cpu
[
reqs_start
:
reqs_end
].
sum
()
assert
total_seq_lens
<=
self
.
max_prefill_buffer_size
cu_seq_lens
=
torch
.
cat
([
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
),
seq_lens_cpu
[
reqs_start
:
reqs_end
].
cumsum
(
dim
=
0
)
]).
to
(
torch
.
int32
).
to
(
self
.
device
)
return
DeepseekV32IndexerPrefillChunkMetadata
(
cu_seqlen_ks
=
cu_seqlen_ks
,
cu_seqlen_ke
=
cu_seqlen_ke
,
cu_seq_lens
=
cu_seq_lens
,
total_seq_lens
=
total_seq_lens
,
block_table
=
block_table
[
reqs_start
:
reqs_end
],
token_start
=
token_start
,
token_end
=
token_end
,
num_reqs
=
reqs_end
-
reqs_start
,
)
def
build
(
self
,
def
build
(
self
,
common_prefix_len
:
int
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
common_attn_metadata
:
CommonAttentionMetadata
,
...
@@ -209,11 +269,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
...
@@ -209,11 +269,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
num_reqs
=
common_attn_metadata
.
num_reqs
num_reqs
=
common_attn_metadata
.
num_reqs
num_tokens
=
common_attn_metadata
.
num_actual_tokens
num_tokens
=
common_attn_metadata
.
num_actual_tokens
device
=
self
.
device
query_start_loc_cpu
=
common_attn_metadata
.
query_start_loc_cpu
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
query_start_loc
=
common_attn_metadata
.
query_start_loc
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
\
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
\
split_decodes_and_prefills
(
split_decodes_and_prefills
(
common_attn_metadata
,
common_attn_metadata
,
...
@@ -224,27 +280,20 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
...
@@ -224,27 +280,20 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
prefill_metadata
=
None
prefill_metadata
=
None
if
num_prefills
>
0
:
if
num_prefills
>
0
:
reqs_start
=
num_decodes
chunk_seq_ids
=
split_prefill_chunks
(
prefill_query_start_loc
=
query_start_loc
[
common_attn_metadata
.
seq_lens_cpu
,
reqs_start
:]
-
query_start_loc
[
reqs_start
]
self
.
max_prefill_buffer_size
,
cu_seqlen_ks
,
cu_seqlen_ke
=
kv_spans_from_batches
(
num_decodes
,
prefill_query_start_loc
,
common_attn_metadata
.
seq_lens
[
reqs_start
:])
total_seq_lens
=
common_attn_metadata
.
seq_lens
[
reqs_start
:].
sum
()
assert
total_seq_lens
<
self
.
max_prefill_buffer_size
cu_seq_lens
=
torch
.
cat
([
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
,
device
=
device
),
common_attn_metadata
.
seq_lens
[
reqs_start
:].
cumsum
(
dim
=
0
)
]).
to
(
torch
.
int32
).
cuda
()
prefill_metadata
=
DeepseekV32IndexerPrefillMetadata
(
block_table
=
block_table_tensor
[
reqs_start
:,
...],
query_start_loc
=
prefill_query_start_loc
,
max_query_len
=
common_attn_metadata
.
max_query_len
,
cu_seqlen_ks
=
cu_seqlen_ks
,
cu_seqlen_ke
=
cu_seqlen_ke
,
cu_seq_lens
=
cu_seq_lens
,
total_seq_lens
=
total_seq_lens
,
)
)
chunks
=
[
self
.
build_one_prefill_chunk
(
reqs_start
,
reqs_end
,
query_start_loc_cpu
,
common_attn_metadata
.
seq_lens_cpu
,
common_attn_metadata
.
block_table_tensor
)
for
reqs_start
,
reqs_end
in
chunk_seq_ids
]
prefill_metadata
=
DeepseekV32IndexerPrefillMetadata
(
chunks
=
chunks
,
)
decode_metadata
=
None
decode_metadata
=
None
if
num_decodes
>
0
:
if
num_decodes
>
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment