Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3ff124a2
Commit
3ff124a2
authored
Aug 05, 2025
by
zhuwenwen
Browse files
Merge remote-tracking branch 'origin/v0.9.2-dev-wm' into v0.9.2-dev
parents
2c8026d1
7e71c143
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
739 additions
and
61 deletions
+739
-61
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+1
-1
vllm/v1/sample/rejection_sampler_mtp.py
vllm/v1/sample/rejection_sampler_mtp.py
+519
-0
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+139
-47
vllm/v1/spec_decode/utils.py
vllm/v1/spec_decode/utils.py
+42
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+38
-13
No files found.
vllm/model_executor/models/deepseek_mtp.py
View file @
3ff124a2
...
@@ -150,7 +150,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
...
@@ -150,7 +150,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
sampling_metadata
)
sampling_metadata
)
return
logits
return
logits
#
@support_torch_compile
@
support_torch_compile
class
DeepSeekMTP
(
nn
.
Module
,
SupportsPP
):
class
DeepSeekMTP
(
nn
.
Module
,
SupportsPP
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
vllm/v1/sample/rejection_sampler_mtp.py
0 → 100644
View file @
3ff124a2
This diff is collapsed.
Click to expand it.
vllm/v1/spec_decode/eagle.py
View file @
3ff124a2
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
,
Optional
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -29,10 +32,10 @@ PADDING_SLOT_ID = -1
...
@@ -29,10 +32,10 @@ PADDING_SLOT_ID = -1
class
EagleProposer
:
class
EagleProposer
:
def
__init__
(
def
__init__
(
self
,
self
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
device
:
torch
.
device
,
runner
=
None
,
runner
=
None
,
):
):
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
self
.
speculative_config
=
vllm_config
.
speculative_config
self
.
speculative_config
=
vllm_config
.
speculative_config
...
@@ -56,6 +59,9 @@ class EagleProposer:
...
@@ -56,6 +59,9 @@ class EagleProposer:
self
.
use_cuda_graph
=
(
self
.
vllm_config
.
compilation_config
.
level
self
.
use_cuda_graph
=
(
self
.
vllm_config
.
compilation_config
.
level
==
CompilationLevel
.
PIECEWISE
and
==
CompilationLevel
.
PIECEWISE
and
not
self
.
vllm_config
.
model_config
.
enforce_eager
)
not
self
.
vllm_config
.
model_config
.
enforce_eager
)
self
.
use_full_cuda_graph
=
(
self
.
use_cuda_graph
and
vllm_config
.
compilation_config
.
full_cuda_graph
)
self
.
cudagraph_batch_sizes
=
list
(
self
.
cudagraph_batch_sizes
=
list
(
reversed
(
reversed
(
self
.
vllm_config
.
compilation_config
.
cudagraph_capture_sizes
))
self
.
vllm_config
.
compilation_config
.
cudagraph_capture_sizes
))
...
@@ -71,6 +77,9 @@ class EagleProposer:
...
@@ -71,6 +77,9 @@ class EagleProposer:
(
self
.
max_num_tokens
,
self
.
hidden_size
),
(
self
.
max_num_tokens
,
self
.
hidden_size
),
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
device
=
device
)
device
=
device
)
# attention metadata captured in full cudagraph mode
self
.
attn_metadata_cudagraph
=
None
# We need +1 here because the arange is used to set query_start_loc,
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
# which has one more element than batch_size.
self
.
arange
=
torch
.
arange
(
vllm_config
.
scheduler_config
.
max_num_seqs
+
self
.
arange
=
torch
.
arange
(
vllm_config
.
scheduler_config
.
max_num_seqs
+
...
@@ -79,26 +88,26 @@ class EagleProposer:
...
@@ -79,26 +88,26 @@ class EagleProposer:
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
def
propose
(
def
propose
(
self
,
self
,
# [num_tokens]
# [num_tokens]
target_token_ids
:
torch
.
Tensor
,
target_token_ids
:
torch
.
Tensor
,
# [num_tokens]
# [num_tokens]
target_positions
:
torch
.
Tensor
,
target_positions
:
torch
.
Tensor
,
# [num_tokens, hidden_size]
# [num_tokens, hidden_size]
target_hidden_states
:
torch
.
Tensor
,
target_hidden_states
:
torch
.
Tensor
,
# [num_tokens]
# [num_tokens]
target_slot_mapping
:
torch
.
Tensor
,
target_slot_mapping
:
torch
.
Tensor
,
# [batch_size]
# [batch_size]
next_token_ids
:
torch
.
Tensor
,
next_token_ids
:
torch
.
Tensor
,
# [batch_size + 1] starting with 0
# [batch_size + 1] starting with 0
cu_num_tokens
:
torch
.
Tensor
,
cu_num_tokens
:
torch
.
Tensor
,
# [batch_size, max_num_blocks_per_req]
# [batch_size, max_num_blocks_per_req]
block_table
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
# [batch_size]
# [batch_size]
num_rejected_tokens
:
list
[
int
],
num_rejected_tokens
:
list
[
int
],
# [batch_size]
# [batch_size]
sampling_metadata
:
SamplingMetadata
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
num_tokens
=
target_token_ids
.
shape
[
0
]
num_tokens
=
target_token_ids
.
shape
[
0
]
batch_size
=
next_token_ids
.
shape
[
0
]
batch_size
=
next_token_ids
.
shape
[
0
]
last_token_indices
=
cu_num_tokens
[
1
:]
-
1
last_token_indices
=
cu_num_tokens
[
1
:]
-
1
...
@@ -157,7 +166,7 @@ class EagleProposer:
...
@@ -157,7 +166,7 @@ class EagleProposer:
# FIXME: need to consider multiple kv_cache_groups
# FIXME: need to consider multiple kv_cache_groups
attn_metadata
=
self
.
runner
.
attn_metadata_builders
[
0
].
build
(
attn_metadata
=
self
.
runner
.
attn_metadata_builders
[
0
].
build
(
common_prefix_len
=
0
,
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
,
common_attn_metadata
=
common_attn_metadata
)
)
else
:
else
:
raise
ValueError
(
f
"Unsupported method:
{
self
.
method
}
"
)
raise
ValueError
(
f
"Unsupported method:
{
self
.
method
}
"
)
...
@@ -168,7 +177,7 @@ class EagleProposer:
...
@@ -168,7 +177,7 @@ class EagleProposer:
for
layer_name
in
self
.
attn_layer_names
:
for
layer_name
in
self
.
attn_layer_names
:
per_layer_attn_metadata
[
layer_name
]
=
attn_metadata
per_layer_attn_metadata
[
layer_name
]
=
attn_metadata
if
self
.
use_cuda_graph
and
\
if
self
.
use_cuda_graph
and
\
num_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]:
num_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]:
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_tokens
)
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_tokens
)
else
:
else
:
num_input_tokens
=
num_tokens
num_input_tokens
=
num_tokens
...
@@ -176,6 +185,38 @@ class EagleProposer:
...
@@ -176,6 +185,38 @@ class EagleProposer:
self
.
positions
[:
num_tokens
]
=
target_positions
self
.
positions
[:
num_tokens
]
=
target_positions
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
if
(
self
.
use_full_cuda_graph
and
num_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
assert
self
.
attn_metadata_cudagraph
if
self
.
method
in
[
"eagle"
,
"eagle3"
]:
self
.
attn_metadata_cudagraph
.
seq_lens
[:
batch_size
]
=
(
attn_metadata
.
seq_lens
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
num_tokens
]
=
(
attn_metadata
.
slot_mapping
)
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
block_table
[:
batch_size
]
=
(
attn_metadata
.
block_table
)
elif
self
.
method
==
"deepseek_mtp"
:
self
.
attn_metadata_cudagraph
.
num_actual_tokens
=
(
attn_metadata
.
num_actual_tokens
)
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
num_tokens
]
=
(
attn_metadata
.
slot_mapping
)
self
.
attn_metadata_cudagraph
.
num_decodes
=
(
attn_metadata
.
num_decodes
)
self
.
attn_metadata_cudagraph
.
num_decode_tokens
=
(
attn_metadata
.
num_decode_tokens
)
self
.
attn_metadata_cudagraph
.
num_prefills
=
(
attn_metadata
.
num_prefills
)
if
attn_metadata
.
decode
is
not
None
:
self
.
attn_metadata_cudagraph
.
decode
.
block_table
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
block_table
)
self
.
attn_metadata_cudagraph
.
decode
.
seq_lens
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
seq_lens
)
with
set_forward_context
(
per_layer_attn_metadata
,
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
):
num_tokens
=
num_input_tokens
):
...
@@ -192,10 +233,14 @@ class EagleProposer:
...
@@ -192,10 +233,14 @@ class EagleProposer:
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_prob
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
draft_probs_list
=
[
draft_prob
]
# Early exit if there is only one draft token to be generated.
# Early exit if there is only one draft token to be generated.
if
self
.
num_speculative_tokens
==
1
:
if
self
.
num_speculative_tokens
==
1
:
# [batch_size, 1]
# [batch_size, 1]
return
draft_token_ids
.
view
(
-
1
,
1
)
return
draft_token_ids
.
view
(
-
1
,
1
)
,
draft_prob
.
view
(
-
1
,
1
,
draft_prob
.
shape
[
-
1
])
# TODO: Currently, MTP module released by deepseek only has
# TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once
# one layer. Adapt this code to support multiple layers once
...
@@ -212,7 +257,7 @@ class EagleProposer:
...
@@ -212,7 +257,7 @@ class EagleProposer:
hidden_states
=
hidden_states
[
last_token_indices
]
hidden_states
=
hidden_states
[
last_token_indices
]
if
self
.
use_cuda_graph
and
\
if
self
.
use_cuda_graph
and
\
batch_size
<=
self
.
cudagraph_batch_sizes
[
-
1
]:
batch_size
<=
self
.
cudagraph_batch_sizes
[
-
1
]:
input_batch_size
=
self
.
vllm_config
.
pad_for_cudagraph
(
batch_size
)
input_batch_size
=
self
.
vllm_config
.
pad_for_cudagraph
(
batch_size
)
else
:
else
:
input_batch_size
=
batch_size
input_batch_size
=
batch_size
...
@@ -230,7 +275,7 @@ class EagleProposer:
...
@@ -230,7 +275,7 @@ class EagleProposer:
seq_lens
=
(
seq_lens
+
1
),
seq_lens
=
(
seq_lens
+
1
),
)
)
for
_
in
range
(
self
.
num_speculative_tokens
-
1
):
for
i
in
range
(
self
.
num_speculative_tokens
-
1
):
# Update the inputs.
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
# tensor.argmax() returns int64 by default.
...
@@ -267,10 +312,10 @@ class EagleProposer:
...
@@ -267,10 +312,10 @@ class EagleProposer:
# Compute the slot mapping.
# Compute the slot mapping.
block_numbers
=
clamped_positions
//
self
.
block_size
block_numbers
=
clamped_positions
//
self
.
block_size
block_ids
=
block_table
.
gather
(
dim
=
1
,
block_ids
=
block_table
.
gather
(
dim
=
1
,
index
=
block_numbers
.
view
(
-
1
,
1
))
index
=
block_numbers
.
view
(
-
1
,
1
))
block_ids
=
block_ids
.
view
(
-
1
)
block_ids
=
block_ids
.
view
(
-
1
)
attn_metadata
.
slot_mapping
=
(
block_ids
*
self
.
block_size
+
attn_metadata
.
slot_mapping
=
(
block_ids
*
self
.
block_size
+
clamped_positions
%
self
.
block_size
)
clamped_positions
%
self
.
block_size
)
# Mask out the slot mappings that exceed the max model length.
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
# padding tokens.
...
@@ -282,6 +327,43 @@ class EagleProposer:
...
@@ -282,6 +327,43 @@ class EagleProposer:
self
.
positions
[:
batch_size
]
=
clamped_positions
self
.
positions
[:
batch_size
]
=
clamped_positions
self
.
hidden_states
[:
batch_size
]
=
hidden_states
self
.
hidden_states
[:
batch_size
]
=
hidden_states
if
(
self
.
use_full_cuda_graph
and
batch_size
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
assert
self
.
attn_metadata_cudagraph
if
self
.
method
in
[
"eagle"
,
"eagle3"
]:
self
.
attn_metadata_cudagraph
.
seq_lens
[:
batch_size
]
=
(
attn_metadata
.
seq_lens
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
batch_size
]
=
(
attn_metadata
.
slot_mapping
)
if
i
==
0
:
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
block_table
[:
batch_size
]
=
(
attn_metadata
.
block_table
)
elif
self
.
method
==
"deepseek_mtp"
:
self
.
attn_metadata_cudagraph
.
num_actual_tokens
=
(
attn_metadata
.
num_actual_tokens
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
slot_mapping
)
self
.
attn_metadata_cudagraph
.
num_decodes
=
(
attn_metadata
.
num_decodes
)
self
.
attn_metadata_cudagraph
.
num_decode_tokens
=
(
attn_metadata
.
num_decode_tokens
)
self
.
attn_metadata_cudagraph
.
num_prefills
=
(
attn_metadata
.
num_prefills
)
self
.
attn_metadata_cudagraph
.
decode
.
seq_lens
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
seq_lens
)
if
i
==
0
:
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
decode
.
block_table
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
block_table
)
# Run the model.
# Run the model.
with
set_forward_context
(
per_layer_attn_metadata
,
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
self
.
vllm_config
,
...
@@ -305,17 +387,22 @@ class EagleProposer:
...
@@ -305,17 +387,22 @@ class EagleProposer:
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids_list
.
append
(
draft_token_ids
)
draft_token_ids_list
.
append
(
draft_token_ids
)
draft_prob
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
draft_probs_list
.
append
(
draft_prob
)
# [batch_size, num_speculative_tokens]
# [batch_size, num_speculative_tokens]
draft_token_ids
=
torch
.
stack
(
draft_token_ids_list
,
dim
=
1
)
draft_token_ids
=
torch
.
stack
(
draft_token_ids_list
,
dim
=
1
)
return
draft_token_ids
draft_probs
=
torch
.
stack
(
draft_probs_list
,
dim
=
1
).
contiguous
()
return
draft_token_ids
,
draft_probs
@
staticmethod
@
staticmethod
def
prepare_inputs
(
def
prepare_inputs
(
# [batch_size + 1]
# [batch_size + 1]
cu_target_query_lens
:
torch
.
Tensor
,
cu_target_query_lens
:
torch
.
Tensor
,
# [batch_size]
# [batch_size]
num_rejected_tokens
:
torch
.
Tensor
,
num_rejected_tokens
:
torch
.
Tensor
,
num_tokens
:
int
,
num_tokens
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# cu_target_query_lens: [0, a, a + b, a + b + c]
# cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3]
# num_rejected_tokens: [n1, n2, n3]
...
@@ -342,7 +429,7 @@ class EagleProposer:
...
@@ -342,7 +429,7 @@ class EagleProposer:
)
)
batch_size
=
num_rejected_tokens
.
shape
[
0
]
batch_size
=
num_rejected_tokens
.
shape
[
0
]
BLOCK_SIZE
=
1024
BLOCK_SIZE
=
1024
prepare_eagle_input_kernel
[(
batch_size
,)](
prepare_eagle_input_kernel
[(
batch_size
,
)](
token_indices
,
token_indices
,
cu_target_query_lens
,
cu_target_query_lens
,
cu_num_tokens
,
cu_num_tokens
,
...
@@ -362,8 +449,8 @@ class EagleProposer:
...
@@ -362,8 +449,8 @@ class EagleProposer:
model_config
=
draft_model_config
)
model_config
=
draft_model_config
)
draft_attn_layer_names
=
(
draft_attn_layer_names
=
(
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
).
keys
()
-
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
).
keys
()
-
target_attn_layer_names
)
target_attn_layer_names
)
self
.
attn_layer_names
=
list
(
draft_attn_layer_names
)
self
.
attn_layer_names
=
list
(
draft_attn_layer_names
)
...
@@ -376,8 +463,8 @@ class EagleProposer:
...
@@ -376,8 +463,8 @@ class EagleProposer:
target_language_model
=
target_model
target_language_model
=
target_model
# share embed_tokens with the target model if needed
# share embed_tokens with the target model if needed
if
get_pp_group
().
world_size
==
1
\
if
get_pp_group
().
world_size
==
1
\
and
self
.
method
!=
"deepseek_mtp"
\
and
self
.
method
!=
"deepseek_mtp"
\
and
self
.
model
.
model
.
embed_tokens
.
weight
.
shape
\
and
self
.
model
.
model
.
embed_tokens
.
weight
.
shape
\
==
target_language_model
.
model
.
embed_tokens
.
weight
.
shape
:
==
target_language_model
.
model
.
embed_tokens
.
weight
.
shape
:
logger
.
info
(
logger
.
info
(
"Assuming the EAGLE head shares the same vocab embedding"
\
"Assuming the EAGLE head shares the same vocab embedding"
\
...
@@ -402,10 +489,15 @@ class EagleProposer:
...
@@ -402,10 +489,15 @@ class EagleProposer:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
dummy_run
(
def
dummy_run
(
self
,
self
,
num_tokens
:
int
,
num_tokens
:
int
,
attn_metadata
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
None
:
)
->
None
:
with
set_forward_context
(
None
,
self
.
vllm_config
,
if
attn_metadata
is
not
None
and
self
.
attn_metadata_cudagraph
is
None
:
self
.
attn_metadata_cudagraph
=
attn_metadata
[
self
.
attn_layer_names
[
0
]]
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_tokens
):
num_tokens
=
num_tokens
):
self
.
model
(
self
.
model
(
self
.
input_ids
[:
num_tokens
],
self
.
input_ids
[:
num_tokens
],
...
@@ -440,8 +532,8 @@ class EagleProposer:
...
@@ -440,8 +532,8 @@ class EagleProposer:
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
# We should refactor this to reuse the same sampling implementation.
def
compute_probs_and_sample_next_token
(
def
compute_probs_and_sample_next_token
(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
sampling_metadata
.
all_greedy
:
if
sampling_metadata
.
all_greedy
:
# For greedy requests, draft_probs is not used in rejection sampling.
# For greedy requests, draft_probs is not used in rejection sampling.
...
...
vllm/v1/spec_decode/utils.py
View file @
3ff124a2
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
msgspec
from
abc
import
ABC
import
torch
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
...
@@ -39,3 +43,41 @@ def prepare_eagle_input_kernel(
...
@@ -39,3 +43,41 @@ def prepare_eagle_input_kernel(
index_start
+
offset
,
index_start
+
offset
,
mask
=
offset
<
num_tokens
,
mask
=
offset
<
num_tokens
,
)
)
class
DraftProbs
(
ABC
):
# type: ignore[call-arg]
"""Draft probs corresponding to in-progress sequences."""
# spec tokens probs.
draft_probs
:
torch
.
Tensor
# The request id list.
_req_ids
:
list
[
str
]
def
__init__
(
self
,
draft_probs
,
req_ids
):
assert
len
(
req_ids
)
==
len
(
draft_probs
)
self
.
draft_probs
=
draft_probs
self
.
_req_ids
=
req_ids
def
update
(
self
,
draft_probs
:
torch
.
Tensor
,
tmp_req_ids
:
list
[
str
]):
diff_req_ids
=
[
item
for
item
in
self
.
_req_ids
if
item
not
in
tmp_req_ids
]
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
diff_req_ids
]
self
.
_req_ids
=
diff_req_ids
self
.
draft_probs
=
self
.
draft_probs
[
index
]
self
.
draft_probs
=
torch
.
cat
([
self
.
draft_probs
,
draft_probs
])
self
.
_req_ids
.
extend
(
tmp_req_ids
)
assert
len
(
self
.
_req_ids
)
==
len
(
self
.
draft_probs
)
def
prune
(
self
,
req_ids
:
list
[
str
]):
new_req_ids
=
[
req_id
for
req_id
in
self
.
_req_ids
if
req_id
not
in
req_ids
]
if
new_req_ids
!=
self
.
_req_ids
:
# Batch contents changed - prune removed sequences.
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
new_req_ids
]
self
.
draft_probs
=
self
.
draft_probs
[
index
]
self
.
_req_ids
=
new_req_ids
def
get_probs
(
self
,
req_ids
:
list
[
str
]):
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
req_ids
]
return
self
.
draft_probs
[
index
]
vllm/v1/worker/gpu_model_runner.py
View file @
3ff124a2
...
@@ -58,11 +58,13 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
...
@@ -58,11 +58,13 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.rejection_sampler
import
RejectionSampler
from
vllm.v1.sample.rejection_sampler
import
RejectionSampler
from
vllm.v1.sample.rejection_sampler_mtp
import
MtpRejectionSampler
from
vllm.v1.sample.sampler
import
Sampler
from
vllm.v1.sample.sampler
import
Sampler
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.medusa
import
MedusaProposer
from
vllm.v1.spec_decode.medusa
import
MedusaProposer
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.spec_decode.utils
import
DraftProbs
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
...
@@ -192,7 +194,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -192,7 +194,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else
:
else
:
raise
ValueError
(
"Unknown speculative decoding method: "
raise
ValueError
(
"Unknown speculative decoding method: "
f
"
{
self
.
speculative_config
.
method
}
"
)
f
"
{
self
.
speculative_config
.
method
}
"
)
self
.
rejection_sampler
=
RejectionSampler
()
self
.
use_mtp
=
self
.
speculative_config
.
method
==
"deepseek_mtp"
if
not
self
.
use_mtp
:
self
.
rejection_sampler
=
RejectionSampler
()
else
:
self
.
rejection_sampler
=
MtpRejectionSampler
()
# Request states.
# Request states.
self
.
requests
:
dict
[
str
,
CachedRequestState
]
=
{}
self
.
requests
:
dict
[
str
,
CachedRequestState
]
=
{}
...
@@ -319,6 +326,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -319,6 +326,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
self
.
shared_kv_cache_layers
:
dict
[
str
,
str
]
=
{}
self
.
shared_kv_cache_layers
:
dict
[
str
,
str
]
=
{}
self
.
draft_probs
:
Optional
[
DraftProbs
]
=
None
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
"""
"""
Update the order of requests in the batch based on the attention
Update the order of requests in the batch based on the attention
...
@@ -378,6 +387,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -378,6 +387,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for
req_id
in
scheduler_output
.
finished_req_ids
:
for
req_id
in
scheduler_output
.
finished_req_ids
:
self
.
input_batch
.
remove_request
(
req_id
)
self
.
input_batch
.
remove_request
(
req_id
)
# prune draft probs of finished requests
if
self
.
use_mtp
and
self
.
draft_probs
is
not
None
and
len
(
scheduler_output
.
finished_req_ids
)
>
0
:
self
.
draft_probs
.
prune
(
list
(
scheduler_output
.
finished_req_ids
))
# Free the cached encoder outputs.
# Free the cached encoder outputs.
for
req_id
,
input_id
in
scheduler_output
.
free_encoder_input_ids
:
for
req_id
,
input_id
in
scheduler_output
.
free_encoder_input_ids
:
encoder_outputs
=
self
.
encoder_cache
.
get
(
req_id
)
encoder_outputs
=
self
.
encoder_cache
.
get
(
req_id
)
...
@@ -535,6 +548,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -535,6 +548,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Add spec_token_ids to token_ids_cpu.
# Add spec_token_ids to token_ids_cpu.
spec_token_ids
=
(
spec_token_ids
=
(
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
,
()))
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
,
()))
if
spec_token_ids
:
if
spec_token_ids
:
num_spec_tokens
=
len
(
spec_token_ids
)
num_spec_tokens
=
len
(
spec_token_ids
)
start_index
=
self
.
input_batch
.
num_tokens_no_spec
[
req_index
]
start_index
=
self
.
input_batch
.
num_tokens_no_spec
[
req_index
]
...
@@ -1458,7 +1472,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1458,7 +1472,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_logits
=
logits
[
spec_decode_metadata
.
target_logits_indices
]
target_logits
=
logits
[
spec_decode_metadata
.
target_logits_indices
]
output_token_ids
=
self
.
rejection_sampler
(
output_token_ids
=
self
.
rejection_sampler
(
spec_decode_metadata
,
spec_decode_metadata
,
None
,
# draft_probs
self
.
draft_probs
.
get_probs
(
self
.
input_batch
.
req_ids
)
\
if
self
.
draft_probs
is
not
None
else
None
,
# draft_probs
target_logits
,
target_logits
,
bonus_token_ids
,
bonus_token_ids
,
sampling_metadata
,
sampling_metadata
,
...
@@ -1543,7 +1558,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1543,7 +1558,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Speculative decoding is not enabled.
# Speculative decoding is not enabled.
spec_token_ids
=
None
spec_token_ids
=
None
else
:
else
:
spec_token_ids
=
self
.
propose_draft_token_ids
(
spec_token_ids
,
draft_probs
=
self
.
propose_draft_token_ids
(
scheduler_output
,
scheduler_output
,
valid_sampled_token_ids
,
valid_sampled_token_ids
,
sampling_metadata
,
sampling_metadata
,
...
@@ -1554,6 +1569,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1554,6 +1569,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
attn_metadata
,
attn_metadata
,
)
)
if
self
.
use_mtp
:
if
self
.
draft_probs
is
None
:
self
.
draft_probs
=
DraftProbs
(
draft_probs
,
self
.
input_batch
.
req_ids
)
else
:
self
.
draft_probs
.
update
(
draft_probs
,
self
.
input_batch
.
req_ids
)
spec_token_ids
=
spec_token_ids
.
tolist
()
# Clear KVConnector state after all KVs are generated.
# Clear KVConnector state after all KVs are generated.
if
has_kv_transfer_group
():
if
has_kv_transfer_group
():
get_kv_transfer_group
().
clear_connector_metadata
()
get_kv_transfer_group
().
clear_connector_metadata
()
...
@@ -1570,7 +1594,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1570,7 +1594,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
pooler_output
=
[],
pooler_output
=
[],
finished_sending
=
finished_sending
,
finished_sending
=
finished_sending
,
finished_recving
=
finished_recving
,
finished_recving
=
finished_recving
,
num_nans_in_logits
=
num_nans_in_logits
,
num_nans_in_logits
=
num_nans_in_logits
)
)
def
propose_draft_token_ids
(
def
propose_draft_token_ids
(
...
@@ -1583,7 +1607,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1583,7 +1607,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states
:
Optional
[
torch
.
Tensor
],
aux_hidden_states
:
Optional
[
torch
.
Tensor
],
spec_decode_metadata
:
Optional
[
SpecDecodeMetadata
],
spec_decode_metadata
:
Optional
[
SpecDecodeMetadata
],
attn_metadata
:
dict
[
str
,
Any
],
attn_metadata
:
dict
[
str
,
Any
],
)
->
list
[
list
[
int
]]:
)
->
tuple
[
list
[
list
[
int
]],
torch
.
Tensor
]:
draft_probs
=
None
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
if
self
.
speculative_config
.
method
==
"ngram"
:
if
self
.
speculative_config
.
method
==
"ngram"
:
assert
isinstance
(
self
.
drafter
,
NgramProposer
)
assert
isinstance
(
self
.
drafter
,
NgramProposer
)
...
@@ -1682,7 +1707,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1682,7 +1707,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_hidden_states
=
hidden_states
[
token_indices
]
target_hidden_states
=
hidden_states
[
token_indices
]
target_slot_mapping
=
eagle_attn_metadata
.
slot_mapping
[
target_slot_mapping
=
eagle_attn_metadata
.
slot_mapping
[
token_indices
]
token_indices
]
draft
_token_ids
=
self
.
drafter
.
propose
(
spec
_token_ids
,
draft_probs
=
self
.
drafter
.
propose
(
target_token_ids
=
target_token_ids
,
target_token_ids
=
target_token_ids
,
target_positions
=
target_positions
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
target_hidden_states
=
target_hidden_states
,
...
@@ -1693,8 +1718,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1693,8 +1718,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
num_rejected_tokens
=
num_rejected_tokens
num_rejected_tokens
=
num_rejected_tokens
)
)
spec_token_ids
=
draft_token_ids
.
tolist
()
return
spec_token_ids
return
spec_token_ids
,
draft_probs
def
kv_connector_no_forward
(
def
kv_connector_no_forward
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
ModelRunnerOutput
:
self
,
scheduler_output
:
"SchedulerOutput"
)
->
ModelRunnerOutput
:
...
@@ -2083,7 +2108,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2083,7 +2108,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
self
.
speculative_config
and
self
.
speculative_config
.
use_eagle
():
if
self
.
speculative_config
and
self
.
speculative_config
.
use_eagle
():
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
self
.
drafter
.
dummy_run
(
num_tokens
)
self
.
drafter
.
dummy_run
(
num_tokens
,
attn_metadata
)
# This is necessary to avoid blocking DP.
# This is necessary to avoid blocking DP.
# For dummy runs, we typically skip EPLB since we don't have any real
# For dummy runs, we typically skip EPLB since we don't have any real
...
@@ -2150,10 +2175,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2150,10 +2175,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
draft_token_ids
,
self
.
device
)
draft_token_ids
,
self
.
device
)
num_tokens
=
sum
(
len
(
ids
)
for
ids
in
draft_token_ids
)
num_tokens
=
sum
(
len
(
ids
)
for
ids
in
draft_token_ids
)
#
draft_probs = torch.randn(
draft_probs
=
torch
.
randn
(
#
num_tokens, logits.shape[-1], device=self.device,
num_tokens
,
logits
.
shape
[
-
1
],
device
=
self
.
device
,
#
dtype=logits.dtype)
dtype
=
logits
.
dtype
)
draft_probs
=
None
#
draft_probs = None
target_logits
=
torch
.
randn
(
num_tokens
,
target_logits
=
torch
.
randn
(
num_tokens
,
logits
.
shape
[
-
1
],
logits
.
shape
[
-
1
],
device
=
self
.
device
,
device
=
self
.
device
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment