Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7eb2446c
Commit
7eb2446c
authored
Mar 21, 2026
by
王敏
Browse files
[perf]DSA架构模型支持mtp>1
parent
84b9fe55
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
122 additions
and
12 deletions
+122
-12
vllm/model_executor/layers/sparse_attn_indexer.py
vllm/model_executor/layers/sparse_attn_indexer.py
+8
-2
vllm/v1/attention/backends/mla/indexer.py
vllm/v1/attention/backends/mla/indexer.py
+100
-10
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+14
-0
No files found.
vllm/model_executor/layers/sparse_attn_indexer.py
View file @
7eb2446c
...
...
@@ -74,6 +74,12 @@ def sparse_attn_indexer(
has_prefill
=
attn_metadata
.
num_prefills
>
0
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
# During speculative decoding, k may be padded to the CUDA graph batch
# size while slot_mapping only covers actual tokens. Truncate k to avoid
# out-of-bounds reads in the kernel.
num_tokens
=
slot_mapping
.
shape
[
0
]
k
=
k
[:
num_tokens
]
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
ops
.
indexer_k_quant_and_cache
(
k
,
...
...
@@ -135,10 +141,10 @@ def sparse_attn_indexer(
k_scale
.
view
(
torch
.
float32
).
flatten
(),
True
)
else
:
else
:
logits
=
op
.
mqa_logits
(
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
],
k
,
k
,
weights
[
chunk
.
token_start
:
chunk
.
token_end
].
to
(
torch
.
float32
),
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
...
...
vllm/v1/attention/backends/mla/indexer.py
View file @
7eb2446c
...
...
@@ -8,6 +8,7 @@ import torch
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.deep_gemm
import
get_paged_mqa_logits_metadata
,
is_deep_gemm_supported
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
...
...
@@ -21,8 +22,10 @@ from vllm.v1.attention.backends.utils import (
split_prefill_chunks
,
)
from
vllm.platforms
import
current_platform
from
vllm.v1.worker.cp_utils
import
get_total_cp_world_size
from
lightop
import
gemmopt
logger
=
init_logger
(
__name__
)
...
...
@@ -214,14 +217,44 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
else
0
)
# Now deepgemm fp8_paged_mqa_logits does not support next_n > 2
self
.
reorder_batch_threshold
+=
min
(
self
.
num_speculative_tokens
,
1
)
#self.reorder_batch_threshold += min(self.num_speculative_tokens, 1)
self
.
reorder_batch_threshold
+=
self
.
num_speculative_tokens
props
=
torch
.
cuda
.
get_device_properties
(
self
.
device
)
sm_count
=
props
.
multi_processor_count
self
.
num_sms
=
sm_count
# self.decode_lens_buffer = torch.empty(
# (scheduler_config.max_num_seqs,), dtype=torch.int32, device=self.device
# )
self
.
decode_lens_buffer
=
torch
.
empty
(
(
scheduler_config
.
max_num_seqs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
(
scheduler_config
.
max_num_batched_tokens
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
# Pre-allocated buffers for flattening (spec decode).
self
.
arange_buffer
=
torch
.
arange
(
scheduler_config
.
max_num_seqs
*
(
1
+
self
.
num_speculative_tokens
),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
self
.
expanded_seq_lens_buffer
=
torch
.
zeros
(
(
scheduler_config
.
max_num_batched_tokens
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
max_num_blocks_per_req
=
cdiv
(
self
.
vllm_config
.
model_config
.
max_model_len
,
self
.
kv_cache_spec
.
block_size
*
get_total_cp_world_size
(),
)
self
.
expanded_block_table_buffer
=
torch
.
zeros
(
(
scheduler_config
.
max_num_batched_tokens
,
max_num_blocks_per_req
,
),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
# See: DeepGMM/csrc/apis/attention.hpp
...
...
@@ -320,24 +353,81 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
common_attn_metadata
.
query_start_loc_cpu
[:
num_decodes
+
1
]
)
# Use CPU to avoid GPU sync; breaking async scheduling
requires_padding
=
(
decode_lens_cpu
.
max
()
>
decode_lens_cpu
.
min
()).
item
()
seq_lens
=
common_attn_metadata
.
seq_lens
[:
num_decodes
]
# if is_deep_gemm_supported():
block_table
=
common_attn_metadata
.
block_table_tensor
[:
num_decodes
,
...]
# Padded CUDA graph requests have block_table entries of -1.
# Clamp to 0 to prevent OOB access in the DeepGEMM kernel.
# This is safe because padded requests have seq_lens=0, so the
# kernel produces no meaningful output for those rows.
block_table
.
clamp_
(
min
=
0
)
max_decode_len
=
int
(
decode_lens_cpu
.
max
().
item
())
if
max_decode_len
>
1
:
# Flatten multi-token decode requests into single-token
# batch entries, expanding seq_lens and block tables so
# the kernel always sees next_n=1.
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
# The context lengths are therefore
# [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0].
# 3 + 1 + 4 + 0 = 8
actual_expanded
=
int
(
decode_lens_cpu
.
sum
().
item
())
# [7, 6, 8, 0] -> [7, 7, 7, 6, 8, 8, 8, 8]
expanded_base
=
torch
.
repeat_interleave
(
seq_lens
-
decode_lens
,
decode_lens
)
# [0, 3, 4, 8] -> [0, 0, 0, 3, 4, 4, 4, 4]
expanded_starts
=
torch
.
repeat_interleave
(
common_attn_metadata
.
query_start_loc
[:
num_decodes
],
decode_lens
)
# [0, 1, 2, 0, 0, 1, 2, 3]
positions_within
=
(
self
.
arange_buffer
[:
actual_expanded
]
-
expanded_starts
)
# [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space
self
.
expanded_seq_lens_buffer
[:
actual_expanded
]
=
(
expanded_base
+
positions_within
+
1
)
self
.
expanded_seq_lens_buffer
[
actual_expanded
:]
=
0
seq_lens
=
self
.
expanded_seq_lens_buffer
[:
num_decode_tokens
]
# Give each of the flattened entries the same block table row as the
# original request.
self
.
expanded_block_table_buffer
[:
actual_expanded
]
=
(
torch
.
repeat_interleave
(
block_table
,
decode_lens
,
dim
=
0
)
)
if
actual_expanded
<
num_decode_tokens
:
self
.
expanded_block_table_buffer
[
actual_expanded
:
num_decode_tokens
,
0
]
=
0
block_table
=
self
.
expanded_block_table_buffer
[:
num_decode_tokens
]
# All reqs now have decode_len=1
self
.
decode_lens_buffer
[:
num_decode_tokens
]
=
1
decode_lens
=
self
.
decode_lens_buffer
[:
num_decode_tokens
]
# DeepGEMM is required for the paged MQA logits on CUDA devices
if
current_platform
.
is_rocm
():
self
.
scheduler_metadata_buffer
=
gemmopt
.
get_paged_mqa_logits_metadata
(
self
.
scheduler_metadata_buffer
=
gemmopt
.
get_paged_mqa_logits_metadata
(
seq_lens
,
self
.
kv_cache_spec
.
block_size
,
self
.
num_sms
)
else
:
self
.
scheduler_metadata_buffer
[:]
=
get_paged_mqa_logits_metadata
(
seq_lens
,
self
.
kv_cache_spec
.
block_size
,
self
.
num_sms
)
decode_metadata
=
DeepSeekV32IndexerDecodeMetadata
(
block_table
=
common_attn_metadata
.
block_table_tensor
[:
num_decodes
,
...]
,
seq_lens
=
common_attn_metadata
.
seq_lens
[:
num_decodes
]
,
block_table
=
block_table
,
seq_lens
=
seq_lens
,
decode_lens
=
decode_lens
,
requires_padding
=
requires_padding
,
requires_padding
=
False
,
schedule_metadata
=
self
.
scheduler_metadata_buffer
,
)
...
...
vllm/v1/spec_decode/eagle.py
View file @
7eb2446c
...
...
@@ -562,9 +562,23 @@ class SpecDecodeBaseProposer:
attn_metadata
=
attn_metadata_builder
.
build_for_drafting
(
# type: ignore
common_attn_metadata
=
common_attn_metadata
,
draft_index
=
token_index
+
1
)
if
self
.
draft_indexer_metadata_builder
:
draft_indexer_metadata
=
(
self
.
draft_indexer_metadata_builder
.
build_for_drafting
(
common_attn_metadata
=
common_attn_metadata
,
draft_index
=
token_index
+
1
,
)
)
else
:
draft_indexer_metadata
=
None
for
layer_name
in
self
.
attn_layer_names
:
per_layer_attn_metadata
[
layer_name
]
=
attn_metadata
for
layer_name
in
self
.
indexer_layer_names
:
per_layer_attn_metadata
[
layer_name
]
=
draft_indexer_metadata
# copy inputs to buffer for cudagraph
self
.
input_ids
[:
batch_size
]
=
input_ids
self
.
_set_positions
(
batch_size
,
clamped_positions
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment