Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fe393be8
Commit
fe393be8
authored
Jul 31, 2025
by
王敏
Browse files
[feat]支持v1 engine mtp cudagraph
parent
741dbbbb
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
521 additions
and
52 deletions
+521
-52
vllm/config.py
vllm/config.py
+5
-0
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+42
-9
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+2
-0
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+424
-1
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+42
-42
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+6
-0
No files found.
vllm/config.py
View file @
fe393be8
...
@@ -4776,6 +4776,11 @@ class VllmConfig:
...
@@ -4776,6 +4776,11 @@ class VllmConfig:
if
size
<=
max_num_tokens
if
size
<=
max_num_tokens
]
]
# add for spec decode
if
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
num_lookahead_slots
>
0
:
batch_size_capture_list
=
list
(
map
(
lambda
x
:
x
*
(
1
+
self
.
speculative_config
.
num_lookahead_slots
),
batch_size_capture_list
))
self
.
compilation_config
.
init_with_cudagraph_sizes
(
self
.
compilation_config
.
init_with_cudagraph_sizes
(
batch_size_capture_list
)
batch_size_capture_list
)
...
...
vllm/v1/attention/backends/mla/common.py
View file @
fe393be8
...
@@ -400,6 +400,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -400,6 +400,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self
.
use_spec_decode
=
False
self
.
use_spec_decode
=
False
self
.
num_scheduled_tokens_np
=
np
.
zeros
(
scheduler_config
.
max_num_seqs
,
dtype
=
np
.
int32
)
self
.
num_scheduled_tokens_np
=
np
.
zeros
(
scheduler_config
.
max_num_seqs
,
dtype
=
np
.
int32
)
# support for cudagraph spec docoding
self
.
spec_decode_block_table_tensor
=
None
self
.
spec_decode_seq_lens
=
None
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
...
@@ -496,17 +500,36 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -496,17 +500,36 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
Currently, only decode is supported for full cudagraphs with MLA.
Currently, only decode is supported for full cudagraphs with MLA.
"""
"""
m
=
common_attn_metadata
m
=
common_attn_metadata
assert
m
.
num_reqs
==
m
.
num_actual_tokens
,
\
#
assert m.num_reqs == m.num_actual_tokens, \
"MLA only supports decode-only full CUDAGraph capture. "
\
#
"MLA only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
#
"Make sure all cudagraph capture sizes <= max_num_seq."
m
.
max_query_len
=
1
# decode-only
#
m.max_query_len = 1 # decode-only
# Update state usually set in reorder_batch.
# Update state usually set in reorder_batch.
self
.
_num_decodes
=
m
.
num_reqs
self
.
_num_decodes
=
m
.
num_reqs
self
.
_num_decode_tokens
=
m
.
num_actual_tokens
self
.
_num_decode_tokens
=
m
.
num_actual_tokens
self
.
_num_prefills
=
0
self
.
_num_prefills
=
0
self
.
_num_prefill_tokens
=
0
self
.
_num_prefill_tokens
=
0
self
.
use_spec_decode
=
m
.
num_speculative_tokens
>
0
# support for cudagraph spec docoding
if
self
.
use_spec_decode
:
for
i
in
range
(
m
.
num_reqs
):
self
.
num_scheduled_tokens_np
[
i
]
=
m
.
num_actual_tokens
//
m
.
num_reqs
if
self
.
spec_decode_block_table_tensor
is
None
:
max_num_reqs
=
m
.
seq_lens
.
shape
[
0
]
block_table_tensor
=
self
.
block_table
.
get_device_tensor
()
tokens_per_seq
=
1
+
m
.
num_speculative_tokens
self
.
spec_decode_block_table_tensor
=
torch
.
zeros
((
block_table_tensor
.
shape
[
0
]
*
tokens_per_seq
,
block_table_tensor
.
shape
[
1
]),
dtype
=
block_table_tensor
.
dtype
,
device
=
m
.
seq_lens
.
device
)
self
.
spec_decode_seq_lens
=
torch
.
zeros
(
max_num_reqs
*
tokens_per_seq
,
dtype
=
m
.
seq_lens
.
dtype
,
device
=
m
.
seq_lens
.
device
)
return
self
.
build
(
0
,
m
)
return
self
.
build
(
0
,
m
)
def
build
(
self
,
common_prefix_len
:
int
,
def
build
(
self
,
common_prefix_len
:
int
,
...
@@ -633,10 +656,19 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -633,10 +656,19 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_lens
.
device
,
non_blocking
=
True
)
seq_lens
.
device
,
non_blocking
=
True
)
decode_seq_lens
=
decode_seq_lens
-
seq_lens_minus
decode_seq_lens
=
decode_seq_lens
-
seq_lens_minus
decode_metadata
=
self
.
_build_decode
(
if
self
.
spec_decode_block_table_tensor
is
not
None
:
block_table_tensor
=
decode_block_table_tensor
,
self
.
spec_decode_block_table_tensor
[:
self
.
_num_decode_tokens
].
copy_
(
decode_block_table_tensor
)
seq_lens
=
decode_seq_lens
,
self
.
spec_decode_seq_lens
[:
self
.
_num_decode_tokens
].
copy_
(
decode_seq_lens
)
)
decode_metadata
=
self
.
_build_decode
(
block_table_tensor
=
self
.
spec_decode_block_table_tensor
[:
self
.
_num_decode_tokens
,
...],
seq_lens
=
self
.
spec_decode_seq_lens
[:
self
.
_num_decode_tokens
],
)
else
:
decode_metadata
=
self
.
_build_decode
(
block_table_tensor
=
decode_block_table_tensor
,
seq_lens
=
decode_seq_lens
,
)
else
:
else
:
decode_metadata
=
self
.
_build_decode
(
decode_metadata
=
self
.
_build_decode
(
block_table_tensor
=
block_table_tensor
[:
self
.
_num_decodes
,
...],
block_table_tensor
=
block_table_tensor
[:
self
.
_num_decodes
,
...],
...
@@ -658,7 +690,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -658,7 +690,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
def
can_run_in_cudagraph
(
def
can_run_in_cudagraph
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
bool
:
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
bool
:
return
common_attn_metadata
.
max_query_len
==
1
#return common_attn_metadata.max_query_len == 1
return
self
.
_num_prefills
==
0
class
MLACommonImpl
(
MLAAttentionImpl
[
M
],
Generic
[
M
]):
class
MLACommonImpl
(
MLAAttentionImpl
[
M
],
Generic
[
M
]):
...
...
vllm/v1/attention/backends/utils.py
View file @
fe393be8
...
@@ -43,6 +43,8 @@ class CommonAttentionMetadata:
...
@@ -43,6 +43,8 @@ class CommonAttentionMetadata:
"""Longest query in batch"""
"""Longest query in batch"""
num_rejected_tokens
:
list
[
int
]
=
None
num_rejected_tokens
:
list
[
int
]
=
None
"""(batch_size,), record the rejected tokens number in cpu and gpu"""
"""(batch_size,), record the rejected tokens number in cpu and gpu"""
num_speculative_tokens
:
int
=
0
"""Number of speculative tokens"""
M
=
TypeVar
(
"M"
)
M
=
TypeVar
(
"M"
)
...
...
vllm/v1/core/sched/scheduler.py
View file @
fe393be8
This diff is collapsed.
Click to expand it.
vllm/v1/spec_decode/eagle.py
View file @
fe393be8
...
@@ -29,10 +29,10 @@ PADDING_SLOT_ID = -1
...
@@ -29,10 +29,10 @@ PADDING_SLOT_ID = -1
class
EagleProposer
:
class
EagleProposer
:
def
__init__
(
def
__init__
(
self
,
self
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
device
:
torch
.
device
,
runner
=
None
,
runner
=
None
,
):
):
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
self
.
speculative_config
=
vllm_config
.
speculative_config
self
.
speculative_config
=
vllm_config
.
speculative_config
...
@@ -79,25 +79,25 @@ class EagleProposer:
...
@@ -79,25 +79,25 @@ class EagleProposer:
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
def
propose
(
def
propose
(
self
,
self
,
# [num_tokens]
# [num_tokens]
target_token_ids
:
torch
.
Tensor
,
target_token_ids
:
torch
.
Tensor
,
# [num_tokens]
# [num_tokens]
target_positions
:
torch
.
Tensor
,
target_positions
:
torch
.
Tensor
,
# [num_tokens, hidden_size]
# [num_tokens, hidden_size]
target_hidden_states
:
torch
.
Tensor
,
target_hidden_states
:
torch
.
Tensor
,
# [num_tokens]
# [num_tokens]
target_slot_mapping
:
torch
.
Tensor
,
target_slot_mapping
:
torch
.
Tensor
,
# [batch_size]
# [batch_size]
next_token_ids
:
torch
.
Tensor
,
next_token_ids
:
torch
.
Tensor
,
# [batch_size + 1] starting with 0
# [batch_size + 1] starting with 0
cu_num_tokens
:
torch
.
Tensor
,
cu_num_tokens
:
torch
.
Tensor
,
# [batch_size, max_num_blocks_per_req]
# [batch_size, max_num_blocks_per_req]
block_table
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
# [batch_size]
# [batch_size]
num_rejected_tokens
:
list
[
int
],
num_rejected_tokens
:
list
[
int
],
# [batch_size]
# [batch_size]
sampling_metadata
:
SamplingMetadata
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_tokens
=
target_token_ids
.
shape
[
0
]
num_tokens
=
target_token_ids
.
shape
[
0
]
batch_size
=
next_token_ids
.
shape
[
0
]
batch_size
=
next_token_ids
.
shape
[
0
]
...
@@ -168,7 +168,7 @@ class EagleProposer:
...
@@ -168,7 +168,7 @@ class EagleProposer:
for
layer_name
in
self
.
attn_layer_names
:
for
layer_name
in
self
.
attn_layer_names
:
per_layer_attn_metadata
[
layer_name
]
=
attn_metadata
per_layer_attn_metadata
[
layer_name
]
=
attn_metadata
if
self
.
use_cuda_graph
and
\
if
self
.
use_cuda_graph
and
\
num_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]:
num_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]:
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_tokens
)
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_tokens
)
else
:
else
:
num_input_tokens
=
num_tokens
num_input_tokens
=
num_tokens
...
@@ -212,7 +212,7 @@ class EagleProposer:
...
@@ -212,7 +212,7 @@ class EagleProposer:
hidden_states
=
hidden_states
[
last_token_indices
]
hidden_states
=
hidden_states
[
last_token_indices
]
if
self
.
use_cuda_graph
and
\
if
self
.
use_cuda_graph
and
\
batch_size
<=
self
.
cudagraph_batch_sizes
[
-
1
]:
batch_size
<=
self
.
cudagraph_batch_sizes
[
-
1
]:
input_batch_size
=
self
.
vllm_config
.
pad_for_cudagraph
(
batch_size
)
input_batch_size
=
self
.
vllm_config
.
pad_for_cudagraph
(
batch_size
)
else
:
else
:
input_batch_size
=
batch_size
input_batch_size
=
batch_size
...
@@ -259,7 +259,7 @@ class EagleProposer:
...
@@ -259,7 +259,7 @@ class EagleProposer:
# Consider max model length.
# Consider max model length.
attn_metadata
.
max_seq_len
=
min
(
attn_metadata
.
max_seq_len
,
attn_metadata
.
max_seq_len
=
min
(
attn_metadata
.
max_seq_len
,
self
.
max_model_len
)
self
.
max_model_len
)
# For the requests that exceed the max model length, we set the
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
# sequence length to 1 to minimize their overheads in attention.
attn_metadata
.
seq_lens
.
masked_fill_
(
exceeds_max_model_len
,
1
)
attn_metadata
.
seq_lens
.
masked_fill_
(
exceeds_max_model_len
,
1
)
...
@@ -267,10 +267,10 @@ class EagleProposer:
...
@@ -267,10 +267,10 @@ class EagleProposer:
# Compute the slot mapping.
# Compute the slot mapping.
block_numbers
=
clamped_positions
//
self
.
block_size
block_numbers
=
clamped_positions
//
self
.
block_size
block_ids
=
block_table
.
gather
(
dim
=
1
,
block_ids
=
block_table
.
gather
(
dim
=
1
,
index
=
block_numbers
.
view
(
-
1
,
1
))
index
=
block_numbers
.
view
(
-
1
,
1
))
block_ids
=
block_ids
.
view
(
-
1
)
block_ids
=
block_ids
.
view
(
-
1
)
attn_metadata
.
slot_mapping
=
(
block_ids
*
self
.
block_size
+
attn_metadata
.
slot_mapping
=
(
block_ids
*
self
.
block_size
+
clamped_positions
%
self
.
block_size
)
clamped_positions
%
self
.
block_size
)
# Mask out the slot mappings that exceed the max model length.
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
# padding tokens.
...
@@ -311,11 +311,11 @@ class EagleProposer:
...
@@ -311,11 +311,11 @@ class EagleProposer:
@
staticmethod
@
staticmethod
def
prepare_inputs
(
def
prepare_inputs
(
# [batch_size + 1]
# [batch_size + 1]
cu_target_query_lens
:
torch
.
Tensor
,
cu_target_query_lens
:
torch
.
Tensor
,
# [batch_size]
# [batch_size]
num_rejected_tokens
:
torch
.
Tensor
,
num_rejected_tokens
:
torch
.
Tensor
,
num_tokens
:
int
,
num_tokens
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# cu_target_query_lens: [0, a, a + b, a + b + c]
# cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3]
# num_rejected_tokens: [n1, n2, n3]
...
@@ -342,7 +342,7 @@ class EagleProposer:
...
@@ -342,7 +342,7 @@ class EagleProposer:
)
)
batch_size
=
num_rejected_tokens
.
shape
[
0
]
batch_size
=
num_rejected_tokens
.
shape
[
0
]
BLOCK_SIZE
=
1024
BLOCK_SIZE
=
1024
prepare_eagle_input_kernel
[(
batch_size
,
)](
prepare_eagle_input_kernel
[(
batch_size
,)](
token_indices
,
token_indices
,
cu_target_query_lens
,
cu_target_query_lens
,
cu_num_tokens
,
cu_num_tokens
,
...
@@ -362,8 +362,8 @@ class EagleProposer:
...
@@ -362,8 +362,8 @@ class EagleProposer:
model_config
=
draft_model_config
)
model_config
=
draft_model_config
)
draft_attn_layer_names
=
(
draft_attn_layer_names
=
(
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
).
keys
()
-
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
).
keys
()
-
target_attn_layer_names
)
target_attn_layer_names
)
self
.
attn_layer_names
=
list
(
draft_attn_layer_names
)
self
.
attn_layer_names
=
list
(
draft_attn_layer_names
)
...
@@ -376,8 +376,8 @@ class EagleProposer:
...
@@ -376,8 +376,8 @@ class EagleProposer:
target_language_model
=
target_model
target_language_model
=
target_model
# share embed_tokens with the target model if needed
# share embed_tokens with the target model if needed
if
get_pp_group
().
world_size
==
1
\
if
get_pp_group
().
world_size
==
1
\
and
self
.
method
!=
"deepseek_mtp"
\
and
self
.
method
!=
"deepseek_mtp"
\
and
self
.
model
.
model
.
embed_tokens
.
weight
.
shape
\
and
self
.
model
.
model
.
embed_tokens
.
weight
.
shape
\
==
target_language_model
.
model
.
embed_tokens
.
weight
.
shape
:
==
target_language_model
.
model
.
embed_tokens
.
weight
.
shape
:
logger
.
info
(
logger
.
info
(
"Assuming the EAGLE head shares the same vocab embedding"
\
"Assuming the EAGLE head shares the same vocab embedding"
\
...
@@ -402,8 +402,8 @@ class EagleProposer:
...
@@ -402,8 +402,8 @@ class EagleProposer:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
dummy_run
(
def
dummy_run
(
self
,
self
,
num_tokens
:
int
,
num_tokens
:
int
,
)
->
None
:
)
->
None
:
with
set_forward_context
(
None
,
self
.
vllm_config
,
with
set_forward_context
(
None
,
self
.
vllm_config
,
num_tokens
=
num_tokens
):
num_tokens
=
num_tokens
):
...
@@ -440,8 +440,8 @@ class EagleProposer:
...
@@ -440,8 +440,8 @@ class EagleProposer:
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
# We should refactor this to reuse the same sampling implementation.
def
compute_probs_and_sample_next_token
(
def
compute_probs_and_sample_next_token
(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
sampling_metadata
.
all_greedy
:
if
sampling_metadata
.
all_greedy
:
# For greedy requests, draft_probs is not used in rejection sampling.
# For greedy requests, draft_probs is not used in rejection sampling.
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
fe393be8
...
@@ -1989,6 +1989,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1989,6 +1989,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_num_reqs
=
self
.
scheduler_config
.
max_num_seqs
max_num_reqs
=
self
.
scheduler_config
.
max_num_seqs
num_reqs
=
min
(
num_tokens
,
max_num_reqs
)
num_reqs
=
min
(
num_tokens
,
max_num_reqs
)
min_tokens_per_req
=
num_tokens
//
num_reqs
min_tokens_per_req
=
num_tokens
//
num_reqs
if
not
is_profile
and
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
num_lookahead_slots
>
0
:
min_tokens_per_req
=
(
1
+
self
.
speculative_config
.
num_lookahead_slots
)
num_reqs
=
num_tokens
//
min_tokens_per_req
num_scheduled_tokens_list
=
[
min_tokens_per_req
]
*
num_reqs
num_scheduled_tokens_list
=
[
min_tokens_per_req
]
*
num_reqs
num_scheduled_tokens_list
[
-
1
]
+=
num_tokens
%
num_reqs
num_scheduled_tokens_list
[
-
1
]
+=
num_tokens
%
num_reqs
assert
sum
(
num_scheduled_tokens_list
)
==
num_tokens
assert
sum
(
num_scheduled_tokens_list
)
==
num_tokens
...
@@ -2008,12 +2012,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2008,12 +2012,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
non_blocking
=
True
)
non_blocking
=
True
)
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
num_speculative_tokens
=
0
if
self
.
speculative_config
is
None
else
self
.
speculative_config
.
num_lookahead_slots
common_attn_metadata
=
CommonAttentionMetadata
(
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
num_tokens
,
num_actual_tokens
=
num_tokens
,
max_query_len
=
num_tokens
,
max_query_len
=
num_tokens
,
num_speculative_tokens
=
num_speculative_tokens
,
)
)
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment