Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fe393be8
Commit
fe393be8
authored
Jul 31, 2025
by
王敏
Browse files
[feat]支持v1 engine mtp cudagraph
parent
741dbbbb
Changes
6
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
521 additions
and
52 deletions
+521
-52
vllm/config.py
vllm/config.py
+5
-0
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+42
-9
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+2
-0
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+424
-1
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+42
-42
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+6
-0
No files found.
vllm/config.py
View file @
fe393be8
...
...
@@ -4776,6 +4776,11 @@ class VllmConfig:
if
size
<=
max_num_tokens
]
# add for spec decode
if
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
num_lookahead_slots
>
0
:
batch_size_capture_list
=
list
(
map
(
lambda
x
:
x
*
(
1
+
self
.
speculative_config
.
num_lookahead_slots
),
batch_size_capture_list
))
self
.
compilation_config
.
init_with_cudagraph_sizes
(
batch_size_capture_list
)
...
...
vllm/v1/attention/backends/mla/common.py
View file @
fe393be8
...
...
@@ -400,6 +400,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self
.
use_spec_decode
=
False
self
.
num_scheduled_tokens_np
=
np
.
zeros
(
scheduler_config
.
max_num_seqs
,
dtype
=
np
.
int32
)
# support for cudagraph spec docoding
self
.
spec_decode_block_table_tensor
=
None
self
.
spec_decode_seq_lens
=
None
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
...
...
@@ -496,17 +500,36 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
Currently, only decode is supported for full cudagraphs with MLA.
"""
m
=
common_attn_metadata
assert
m
.
num_reqs
==
m
.
num_actual_tokens
,
\
"MLA only supports decode-only full CUDAGraph capture. "
\
"Make sure all cudagraph capture sizes <= max_num_seq."
#
assert m.num_reqs == m.num_actual_tokens, \
#
"MLA only supports decode-only full CUDAGraph capture. " \
#
"Make sure all cudagraph capture sizes <= max_num_seq."
m
.
max_query_len
=
1
# decode-only
#
m.max_query_len = 1 # decode-only
# Update state usually set in reorder_batch.
self
.
_num_decodes
=
m
.
num_reqs
self
.
_num_decode_tokens
=
m
.
num_actual_tokens
self
.
_num_prefills
=
0
self
.
_num_prefill_tokens
=
0
self
.
use_spec_decode
=
m
.
num_speculative_tokens
>
0
# support for cudagraph spec docoding
if
self
.
use_spec_decode
:
for
i
in
range
(
m
.
num_reqs
):
self
.
num_scheduled_tokens_np
[
i
]
=
m
.
num_actual_tokens
//
m
.
num_reqs
if
self
.
spec_decode_block_table_tensor
is
None
:
max_num_reqs
=
m
.
seq_lens
.
shape
[
0
]
block_table_tensor
=
self
.
block_table
.
get_device_tensor
()
tokens_per_seq
=
1
+
m
.
num_speculative_tokens
self
.
spec_decode_block_table_tensor
=
torch
.
zeros
((
block_table_tensor
.
shape
[
0
]
*
tokens_per_seq
,
block_table_tensor
.
shape
[
1
]),
dtype
=
block_table_tensor
.
dtype
,
device
=
m
.
seq_lens
.
device
)
self
.
spec_decode_seq_lens
=
torch
.
zeros
(
max_num_reqs
*
tokens_per_seq
,
dtype
=
m
.
seq_lens
.
dtype
,
device
=
m
.
seq_lens
.
device
)
return
self
.
build
(
0
,
m
)
def
build
(
self
,
common_prefix_len
:
int
,
...
...
@@ -633,6 +656,15 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_lens
.
device
,
non_blocking
=
True
)
decode_seq_lens
=
decode_seq_lens
-
seq_lens_minus
if
self
.
spec_decode_block_table_tensor
is
not
None
:
self
.
spec_decode_block_table_tensor
[:
self
.
_num_decode_tokens
].
copy_
(
decode_block_table_tensor
)
self
.
spec_decode_seq_lens
[:
self
.
_num_decode_tokens
].
copy_
(
decode_seq_lens
)
decode_metadata
=
self
.
_build_decode
(
block_table_tensor
=
self
.
spec_decode_block_table_tensor
[:
self
.
_num_decode_tokens
,
...],
seq_lens
=
self
.
spec_decode_seq_lens
[:
self
.
_num_decode_tokens
],
)
else
:
decode_metadata
=
self
.
_build_decode
(
block_table_tensor
=
decode_block_table_tensor
,
seq_lens
=
decode_seq_lens
,
...
...
@@ -658,7 +690,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
def
can_run_in_cudagraph
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
bool
:
return
common_attn_metadata
.
max_query_len
==
1
#return common_attn_metadata.max_query_len == 1
return
self
.
_num_prefills
==
0
class
MLACommonImpl
(
MLAAttentionImpl
[
M
],
Generic
[
M
]):
...
...
vllm/v1/attention/backends/utils.py
View file @
fe393be8
...
...
@@ -43,6 +43,8 @@ class CommonAttentionMetadata:
"""Longest query in batch"""
num_rejected_tokens
:
list
[
int
]
=
None
"""(batch_size,), record the rejected tokens number in cpu and gpu"""
num_speculative_tokens
:
int
=
0
"""Number of speculative tokens"""
M
=
TypeVar
(
"M"
)
...
...
vllm/v1/core/sched/scheduler.py
View file @
fe393be8
This diff is collapsed.
Click to expand it.
vllm/v1/spec_decode/eagle.py
View file @
fe393be8
...
...
@@ -342,7 +342,7 @@ class EagleProposer:
)
batch_size
=
num_rejected_tokens
.
shape
[
0
]
BLOCK_SIZE
=
1024
prepare_eagle_input_kernel
[(
batch_size
,
)](
prepare_eagle_input_kernel
[(
batch_size
,)](
token_indices
,
cu_target_query_lens
,
cu_num_tokens
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
fe393be8
...
...
@@ -1989,6 +1989,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_num_reqs
=
self
.
scheduler_config
.
max_num_seqs
num_reqs
=
min
(
num_tokens
,
max_num_reqs
)
min_tokens_per_req
=
num_tokens
//
num_reqs
if
not
is_profile
and
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
num_lookahead_slots
>
0
:
min_tokens_per_req
=
(
1
+
self
.
speculative_config
.
num_lookahead_slots
)
num_reqs
=
num_tokens
//
min_tokens_per_req
num_scheduled_tokens_list
=
[
min_tokens_per_req
]
*
num_reqs
num_scheduled_tokens_list
[
-
1
]
+=
num_tokens
%
num_reqs
assert
sum
(
num_scheduled_tokens_list
)
==
num_tokens
...
...
@@ -2008,12 +2012,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
non_blocking
=
True
)
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
num_speculative_tokens
=
0
if
self
.
speculative_config
is
None
else
self
.
speculative_config
.
num_lookahead_slots
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
num_tokens
,
max_query_len
=
num_tokens
,
num_speculative_tokens
=
num_speculative_tokens
,
)
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment