Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
00e13357
Commit
00e13357
authored
Aug 01, 2025
by
zhuwenwen
Browse files
[feat]支持v1 engine mtp cudagraph
parent
3de379de
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
482 additions
and
10 deletions
+482
-10
vllm/config.py
vllm/config.py
+5
-0
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+44
-9
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+2
-0
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+424
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+7
-0
No files found.
vllm/config.py
View file @
00e13357
...
...
@@ -4802,6 +4802,11 @@ class VllmConfig:
if
size
<=
max_num_tokens
]
# add for spec decode
if
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
num_lookahead_slots
>
0
:
batch_size_capture_list
=
list
(
map
(
lambda
x
:
x
*
(
1
+
self
.
speculative_config
.
num_lookahead_slots
),
batch_size_capture_list
))
self
.
compilation_config
.
init_with_cudagraph_sizes
(
batch_size_capture_list
)
...
...
vllm/v1/attention/backends/mla/common.py
View file @
00e13357
...
...
@@ -488,6 +488,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self
.
use_spec_decode
=
False
self
.
num_scheduled_tokens_np
=
np
.
zeros
(
scheduler_config
.
max_num_seqs
,
dtype
=
np
.
int32
)
# support for cudagraph spec docoding
self
.
spec_decode_block_table_tensor
=
None
self
.
spec_decode_seq_lens
=
None
def
_build_fi_prefill_wrappers
(
self
,
prefill
:
FlashInferPrefillMetadata
):
qo_indptr
=
prefill
.
query_start_loc
...
...
@@ -589,11 +593,30 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
Currently, only decode is supported for full cudagraphs with MLA.
"""
m
=
common_attn_metadata
assert
m
.
num_reqs
==
m
.
num_actual_tokens
,
\
"MLA only supports decode-only full CUDAGraph capture. "
\
"Make sure all cudagraph capture sizes <= max_num_seq."
# assert m.num_reqs == m.num_actual_tokens, \
# "MLA only supports decode-only full CUDAGraph capture. " \
# "Make sure all cudagraph capture sizes <= max_num_seq."
# m.max_query_len = 1 # decode-only
m
.
max_query_len
=
1
# decode-only
self
.
use_spec_decode
=
m
.
num_speculative_tokens
>
0
# support for cudagraph spec docoding
if
self
.
use_spec_decode
:
for
i
in
range
(
m
.
num_reqs
):
self
.
num_scheduled_tokens_np
[
i
]
=
m
.
num_actual_tokens
//
m
.
num_reqs
if
self
.
spec_decode_block_table_tensor
is
None
:
max_num_reqs
=
m
.
seq_lens
.
shape
[
0
]
block_table_tensor
=
self
.
block_table
.
get_device_tensor
()
tokens_per_seq
=
1
+
m
.
num_speculative_tokens
self
.
spec_decode_block_table_tensor
=
torch
.
zeros
((
block_table_tensor
.
shape
[
0
]
*
tokens_per_seq
,
block_table_tensor
.
shape
[
1
]),
dtype
=
block_table_tensor
.
dtype
,
device
=
m
.
seq_lens
.
device
)
self
.
spec_decode_seq_lens
=
torch
.
zeros
(
max_num_reqs
*
tokens_per_seq
,
dtype
=
m
.
seq_lens
.
dtype
,
device
=
m
.
seq_lens
.
device
)
return
self
.
build
(
0
,
m
)
...
...
@@ -742,6 +765,15 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_lens
.
device
,
non_blocking
=
True
)
decode_seq_lens
=
decode_seq_lens
-
seq_lens_minus
if
self
.
spec_decode_block_table_tensor
is
not
None
:
self
.
spec_decode_block_table_tensor
[:
self
.
_num_decode_tokens
].
copy_
(
decode_block_table_tensor
)
self
.
spec_decode_seq_lens
[:
self
.
_num_decode_tokens
].
copy_
(
decode_seq_lens
)
decode_metadata
=
self
.
_build_decode
(
block_table_tensor
=
self
.
spec_decode_block_table_tensor
[:
self
.
_num_decode_tokens
,
...],
seq_lens
=
self
.
spec_decode_seq_lens
[:
self
.
_num_decode_tokens
],
)
else
:
decode_metadata
=
self
.
_build_decode
(
block_table_tensor
=
decode_block_table_tensor
,
seq_lens
=
decode_seq_lens
,
...
...
@@ -775,7 +807,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
def
can_run_in_cudagraph
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
bool
:
# return common_attn_metadata.max_query_len == 1
if
not
self
.
use_spec_decode
:
return
common_attn_metadata
.
max_query_len
==
1
return
self
.
_num_prefills
==
0
class
MLACommonImpl
(
MLAAttentionImpl
[
M
],
Generic
[
M
]):
...
...
vllm/v1/attention/backends/utils.py
View file @
00e13357
...
...
@@ -55,6 +55,8 @@ class CommonAttentionMetadata:
"""Longest query in batch"""
num_rejected_tokens
:
list
[
int
]
"""(batch_size,), record the rejected tokens number in cpu and gpu"""
num_speculative_tokens
:
int
=
0
"""Number of speculative tokens"""
block_table_tensor
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
...
...
vllm/v1/core/sched/scheduler.py
View file @
00e13357
This diff is collapsed.
Click to expand it.
vllm/v1/worker/gpu_model_runner.py
View file @
00e13357
...
...
@@ -2091,6 +2091,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_num_reqs
=
self
.
scheduler_config
.
max_num_seqs
num_reqs
=
min
(
num_tokens
,
max_num_reqs
)
min_tokens_per_req
=
num_tokens
//
num_reqs
if
not
is_profile
and
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
num_lookahead_slots
>
0
:
min_tokens_per_req
=
(
1
+
self
.
speculative_config
.
num_lookahead_slots
)
num_reqs
=
num_tokens
//
min_tokens_per_req
num_scheduled_tokens_list
=
[
min_tokens_per_req
]
*
num_reqs
num_scheduled_tokens_list
[
-
1
]
+=
num_tokens
%
num_reqs
assert
sum
(
num_scheduled_tokens_list
)
==
num_tokens
...
...
@@ -2108,6 +2112,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
seq_lens
[:
num_reqs
].
copy_
(
self
.
seq_lens_cpu
[:
num_reqs
],
non_blocking
=
True
)
num_speculative_tokens
=
0
if
self
.
speculative_config
is
None
else
self
.
speculative_config
.
num_lookahead_slots
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
self
.
kv_cache_config
.
kv_cache_groups
):
common_attn_metadata
=
CommonAttentionMetadata
(
...
...
@@ -2121,6 +2127,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_reqs
=
num_reqs
,
num_actual_tokens
=
num_tokens
,
max_query_len
=
num_tokens
,
num_speculative_tokens
=
num_speculative_tokens
,
block_table_tensor
=
self
.
input_batch
.
block_table
[
kv_cache_group_id
].
get_device_tensor
()[:
num_reqs
],
slot_mapping
=
self
.
input_batch
.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment