Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3ff124a2
Commit
3ff124a2
authored
Aug 05, 2025
by
zhuwenwen
Browse files
Merge remote-tracking branch 'origin/v0.9.2-dev-wm' into v0.9.2-dev
parents
2c8026d1
7e71c143
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
739 additions
and
61 deletions
+739
-61
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+1
-1
vllm/v1/sample/rejection_sampler_mtp.py
vllm/v1/sample/rejection_sampler_mtp.py
+519
-0
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+139
-47
vllm/v1/spec_decode/utils.py
vllm/v1/spec_decode/utils.py
+42
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+38
-13
No files found.
vllm/model_executor/models/deepseek_mtp.py
View file @
3ff124a2
...
...
@@ -150,7 +150,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
sampling_metadata
)
return
logits
#
@support_torch_compile
@
support_torch_compile
class
DeepSeekMTP
(
nn
.
Module
,
SupportsPP
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
vllm/v1/sample/rejection_sampler_mtp.py
0 → 100644
View file @
3ff124a2
This diff is collapsed.
Click to expand it.
vllm/v1/spec_decode/eagle.py
View file @
3ff124a2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
,
Optional
import
numpy
as
np
import
torch
import
torch.nn
as
nn
...
...
@@ -56,6 +59,9 @@ class EagleProposer:
self
.
use_cuda_graph
=
(
self
.
vllm_config
.
compilation_config
.
level
==
CompilationLevel
.
PIECEWISE
and
not
self
.
vllm_config
.
model_config
.
enforce_eager
)
self
.
use_full_cuda_graph
=
(
self
.
use_cuda_graph
and
vllm_config
.
compilation_config
.
full_cuda_graph
)
self
.
cudagraph_batch_sizes
=
list
(
reversed
(
self
.
vllm_config
.
compilation_config
.
cudagraph_capture_sizes
))
...
...
@@ -71,6 +77,9 @@ class EagleProposer:
(
self
.
max_num_tokens
,
self
.
hidden_size
),
dtype
=
self
.
dtype
,
device
=
device
)
# attention metadata captured in full cudagraph mode
self
.
attn_metadata_cudagraph
=
None
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
self
.
arange
=
torch
.
arange
(
vllm_config
.
scheduler_config
.
max_num_seqs
+
...
...
@@ -98,7 +107,7 @@ class EagleProposer:
num_rejected_tokens
:
list
[
int
],
# [batch_size]
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
num_tokens
=
target_token_ids
.
shape
[
0
]
batch_size
=
next_token_ids
.
shape
[
0
]
last_token_indices
=
cu_num_tokens
[
1
:]
-
1
...
...
@@ -157,7 +166,7 @@ class EagleProposer:
# FIXME: need to consider multiple kv_cache_groups
attn_metadata
=
self
.
runner
.
attn_metadata_builders
[
0
].
build
(
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
,
common_attn_metadata
=
common_attn_metadata
)
else
:
raise
ValueError
(
f
"Unsupported method:
{
self
.
method
}
"
)
...
...
@@ -176,6 +185,38 @@ class EagleProposer:
self
.
positions
[:
num_tokens
]
=
target_positions
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
if
(
self
.
use_full_cuda_graph
and
num_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
assert
self
.
attn_metadata_cudagraph
if
self
.
method
in
[
"eagle"
,
"eagle3"
]:
self
.
attn_metadata_cudagraph
.
seq_lens
[:
batch_size
]
=
(
attn_metadata
.
seq_lens
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
num_tokens
]
=
(
attn_metadata
.
slot_mapping
)
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
block_table
[:
batch_size
]
=
(
attn_metadata
.
block_table
)
elif
self
.
method
==
"deepseek_mtp"
:
self
.
attn_metadata_cudagraph
.
num_actual_tokens
=
(
attn_metadata
.
num_actual_tokens
)
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
num_tokens
]
=
(
attn_metadata
.
slot_mapping
)
self
.
attn_metadata_cudagraph
.
num_decodes
=
(
attn_metadata
.
num_decodes
)
self
.
attn_metadata_cudagraph
.
num_decode_tokens
=
(
attn_metadata
.
num_decode_tokens
)
self
.
attn_metadata_cudagraph
.
num_prefills
=
(
attn_metadata
.
num_prefills
)
if
attn_metadata
.
decode
is
not
None
:
self
.
attn_metadata_cudagraph
.
decode
.
block_table
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
block_table
)
self
.
attn_metadata_cudagraph
.
decode
.
seq_lens
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
seq_lens
)
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
):
...
...
@@ -192,10 +233,14 @@ class EagleProposer:
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_prob
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
draft_probs_list
=
[
draft_prob
]
# Early exit if there is only one draft token to be generated.
if
self
.
num_speculative_tokens
==
1
:
# [batch_size, 1]
return
draft_token_ids
.
view
(
-
1
,
1
)
return
draft_token_ids
.
view
(
-
1
,
1
)
,
draft_prob
.
view
(
-
1
,
1
,
draft_prob
.
shape
[
-
1
])
# TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once
...
...
@@ -230,7 +275,7 @@ class EagleProposer:
seq_lens
=
(
seq_lens
+
1
),
)
for
_
in
range
(
self
.
num_speculative_tokens
-
1
):
for
i
in
range
(
self
.
num_speculative_tokens
-
1
):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
...
...
@@ -282,6 +327,43 @@ class EagleProposer:
self
.
positions
[:
batch_size
]
=
clamped_positions
self
.
hidden_states
[:
batch_size
]
=
hidden_states
if
(
self
.
use_full_cuda_graph
and
batch_size
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
assert
self
.
attn_metadata_cudagraph
if
self
.
method
in
[
"eagle"
,
"eagle3"
]:
self
.
attn_metadata_cudagraph
.
seq_lens
[:
batch_size
]
=
(
attn_metadata
.
seq_lens
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
batch_size
]
=
(
attn_metadata
.
slot_mapping
)
if
i
==
0
:
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
block_table
[:
batch_size
]
=
(
attn_metadata
.
block_table
)
elif
self
.
method
==
"deepseek_mtp"
:
self
.
attn_metadata_cudagraph
.
num_actual_tokens
=
(
attn_metadata
.
num_actual_tokens
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
slot_mapping
)
self
.
attn_metadata_cudagraph
.
num_decodes
=
(
attn_metadata
.
num_decodes
)
self
.
attn_metadata_cudagraph
.
num_decode_tokens
=
(
attn_metadata
.
num_decode_tokens
)
self
.
attn_metadata_cudagraph
.
num_prefills
=
(
attn_metadata
.
num_prefills
)
self
.
attn_metadata_cudagraph
.
decode
.
seq_lens
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
seq_lens
)
if
i
==
0
:
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
decode
.
block_table
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
block_table
)
# Run the model.
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
...
...
@@ -305,9 +387,14 @@ class EagleProposer:
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids_list
.
append
(
draft_token_ids
)
draft_prob
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
draft_probs_list
.
append
(
draft_prob
)
# [batch_size, num_speculative_tokens]
draft_token_ids
=
torch
.
stack
(
draft_token_ids_list
,
dim
=
1
)
return
draft_token_ids
draft_probs
=
torch
.
stack
(
draft_probs_list
,
dim
=
1
).
contiguous
()
return
draft_token_ids
,
draft_probs
@
staticmethod
def
prepare_inputs
(
...
...
@@ -342,7 +429,7 @@ class EagleProposer:
)
batch_size
=
num_rejected_tokens
.
shape
[
0
]
BLOCK_SIZE
=
1024
prepare_eagle_input_kernel
[(
batch_size
,)](
prepare_eagle_input_kernel
[(
batch_size
,
)](
token_indices
,
cu_target_query_lens
,
cu_num_tokens
,
...
...
@@ -404,8 +491,13 @@ class EagleProposer:
def
dummy_run
(
self
,
num_tokens
:
int
,
attn_metadata
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
None
:
with
set_forward_context
(
None
,
self
.
vllm_config
,
if
attn_metadata
is
not
None
and
self
.
attn_metadata_cudagraph
is
None
:
self
.
attn_metadata_cudagraph
=
attn_metadata
[
self
.
attn_layer_names
[
0
]]
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_tokens
):
self
.
model
(
self
.
input_ids
[:
num_tokens
],
...
...
vllm/v1/spec_decode/utils.py
View file @
3ff124a2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
msgspec
from
abc
import
ABC
import
torch
from
vllm.sampling_params
import
SamplingParams
from
vllm.triton_utils
import
tl
,
triton
...
...
@@ -39,3 +43,41 @@ def prepare_eagle_input_kernel(
index_start
+
offset
,
mask
=
offset
<
num_tokens
,
)
class
DraftProbs
(
ABC
):
# type: ignore[call-arg]
"""Draft probs corresponding to in-progress sequences."""
# spec tokens probs.
draft_probs
:
torch
.
Tensor
# The request id list.
_req_ids
:
list
[
str
]
def
__init__
(
self
,
draft_probs
,
req_ids
):
assert
len
(
req_ids
)
==
len
(
draft_probs
)
self
.
draft_probs
=
draft_probs
self
.
_req_ids
=
req_ids
def
update
(
self
,
draft_probs
:
torch
.
Tensor
,
tmp_req_ids
:
list
[
str
]):
diff_req_ids
=
[
item
for
item
in
self
.
_req_ids
if
item
not
in
tmp_req_ids
]
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
diff_req_ids
]
self
.
_req_ids
=
diff_req_ids
self
.
draft_probs
=
self
.
draft_probs
[
index
]
self
.
draft_probs
=
torch
.
cat
([
self
.
draft_probs
,
draft_probs
])
self
.
_req_ids
.
extend
(
tmp_req_ids
)
assert
len
(
self
.
_req_ids
)
==
len
(
self
.
draft_probs
)
def
prune
(
self
,
req_ids
:
list
[
str
]):
new_req_ids
=
[
req_id
for
req_id
in
self
.
_req_ids
if
req_id
not
in
req_ids
]
if
new_req_ids
!=
self
.
_req_ids
:
# Batch contents changed - prune removed sequences.
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
new_req_ids
]
self
.
draft_probs
=
self
.
draft_probs
[
index
]
self
.
_req_ids
=
new_req_ids
def
get_probs
(
self
,
req_ids
:
list
[
str
]):
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
req_ids
]
return
self
.
draft_probs
[
index
]
vllm/v1/worker/gpu_model_runner.py
View file @
3ff124a2
...
...
@@ -58,11 +58,13 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.rejection_sampler
import
RejectionSampler
from
vllm.v1.sample.rejection_sampler_mtp
import
MtpRejectionSampler
from
vllm.v1.sample.sampler
import
Sampler
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.medusa
import
MedusaProposer
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.spec_decode.utils
import
DraftProbs
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
...
...
@@ -192,7 +194,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else
:
raise
ValueError
(
"Unknown speculative decoding method: "
f
"
{
self
.
speculative_config
.
method
}
"
)
self
.
use_mtp
=
self
.
speculative_config
.
method
==
"deepseek_mtp"
if
not
self
.
use_mtp
:
self
.
rejection_sampler
=
RejectionSampler
()
else
:
self
.
rejection_sampler
=
MtpRejectionSampler
()
# Request states.
self
.
requests
:
dict
[
str
,
CachedRequestState
]
=
{}
...
...
@@ -319,6 +326,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
self
.
shared_kv_cache_layers
:
dict
[
str
,
str
]
=
{}
self
.
draft_probs
:
Optional
[
DraftProbs
]
=
None
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
"""
Update the order of requests in the batch based on the attention
...
...
@@ -378,6 +387,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for
req_id
in
scheduler_output
.
finished_req_ids
:
self
.
input_batch
.
remove_request
(
req_id
)
# prune draft probs of finished requests
if
self
.
use_mtp
and
self
.
draft_probs
is
not
None
and
len
(
scheduler_output
.
finished_req_ids
)
>
0
:
self
.
draft_probs
.
prune
(
list
(
scheduler_output
.
finished_req_ids
))
# Free the cached encoder outputs.
for
req_id
,
input_id
in
scheduler_output
.
free_encoder_input_ids
:
encoder_outputs
=
self
.
encoder_cache
.
get
(
req_id
)
...
...
@@ -535,6 +548,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Add spec_token_ids to token_ids_cpu.
spec_token_ids
=
(
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
,
()))
if
spec_token_ids
:
num_spec_tokens
=
len
(
spec_token_ids
)
start_index
=
self
.
input_batch
.
num_tokens_no_spec
[
req_index
]
...
...
@@ -1458,7 +1472,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_logits
=
logits
[
spec_decode_metadata
.
target_logits_indices
]
output_token_ids
=
self
.
rejection_sampler
(
spec_decode_metadata
,
None
,
# draft_probs
self
.
draft_probs
.
get_probs
(
self
.
input_batch
.
req_ids
)
\
if
self
.
draft_probs
is
not
None
else
None
,
# draft_probs
target_logits
,
bonus_token_ids
,
sampling_metadata
,
...
...
@@ -1543,7 +1558,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Speculative decoding is not enabled.
spec_token_ids
=
None
else
:
spec_token_ids
=
self
.
propose_draft_token_ids
(
spec_token_ids
,
draft_probs
=
self
.
propose_draft_token_ids
(
scheduler_output
,
valid_sampled_token_ids
,
sampling_metadata
,
...
...
@@ -1554,6 +1569,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
attn_metadata
,
)
if
self
.
use_mtp
:
if
self
.
draft_probs
is
None
:
self
.
draft_probs
=
DraftProbs
(
draft_probs
,
self
.
input_batch
.
req_ids
)
else
:
self
.
draft_probs
.
update
(
draft_probs
,
self
.
input_batch
.
req_ids
)
spec_token_ids
=
spec_token_ids
.
tolist
()
# Clear KVConnector state after all KVs are generated.
if
has_kv_transfer_group
():
get_kv_transfer_group
().
clear_connector_metadata
()
...
...
@@ -1570,7 +1594,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
pooler_output
=
[],
finished_sending
=
finished_sending
,
finished_recving
=
finished_recving
,
num_nans_in_logits
=
num_nans_in_logits
,
num_nans_in_logits
=
num_nans_in_logits
)
def
propose_draft_token_ids
(
...
...
@@ -1583,7 +1607,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states
:
Optional
[
torch
.
Tensor
],
spec_decode_metadata
:
Optional
[
SpecDecodeMetadata
],
attn_metadata
:
dict
[
str
,
Any
],
)
->
list
[
list
[
int
]]:
)
->
tuple
[
list
[
list
[
int
]],
torch
.
Tensor
]:
draft_probs
=
None
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
if
self
.
speculative_config
.
method
==
"ngram"
:
assert
isinstance
(
self
.
drafter
,
NgramProposer
)
...
...
@@ -1682,7 +1707,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_hidden_states
=
hidden_states
[
token_indices
]
target_slot_mapping
=
eagle_attn_metadata
.
slot_mapping
[
token_indices
]
draft
_token_ids
=
self
.
drafter
.
propose
(
spec
_token_ids
,
draft_probs
=
self
.
drafter
.
propose
(
target_token_ids
=
target_token_ids
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
...
...
@@ -1693,8 +1718,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sampling_metadata
=
sampling_metadata
,
num_rejected_tokens
=
num_rejected_tokens
)
spec_token_ids
=
draft_token_ids
.
tolist
()
return
spec_token_ids
return
spec_token_ids
,
draft_probs
def
kv_connector_no_forward
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
ModelRunnerOutput
:
...
...
@@ -2083,7 +2108,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
self
.
speculative_config
and
self
.
speculative_config
.
use_eagle
():
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
self
.
drafter
.
dummy_run
(
num_tokens
)
self
.
drafter
.
dummy_run
(
num_tokens
,
attn_metadata
)
# This is necessary to avoid blocking DP.
# For dummy runs, we typically skip EPLB since we don't have any real
...
...
@@ -2150,10 +2175,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
draft_token_ids
,
self
.
device
)
num_tokens
=
sum
(
len
(
ids
)
for
ids
in
draft_token_ids
)
#
draft_probs = torch.randn(
#
num_tokens, logits.shape[-1], device=self.device,
#
dtype=logits.dtype)
draft_probs
=
None
draft_probs
=
torch
.
randn
(
num_tokens
,
logits
.
shape
[
-
1
],
device
=
self
.
device
,
dtype
=
logits
.
dtype
)
#
draft_probs = None
target_logits
=
torch
.
randn
(
num_tokens
,
logits
.
shape
[
-
1
],
device
=
self
.
device
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment