Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
33485749
Commit
33485749
authored
Aug 06, 2025
by
zhuwenwen
Browse files
Revert "[feat]1.支持mtp模型 full_cuda_graph; 2.优化mtp拒绝采样"
This reverts commit
93fae6b1
.
parent
c34fa0bf
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
16 additions
and
691 deletions
+16
-691
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+1
-1
vllm/v1/sample/rejection_sampler_mtp.py
vllm/v1/sample/rejection_sampler_mtp.py
+0
-518
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+4
-94
vllm/v1/spec_decode/utils.py
vllm/v1/spec_decode/utils.py
+0
-42
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+11
-36
No files found.
vllm/model_executor/models/deepseek_mtp.py
View file @
33485749
...
...
@@ -152,7 +152,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
return
logits
@
support_torch_compile
#
@support_torch_compile
class
DeepSeekMTP
(
nn
.
Module
,
SupportsPP
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
vllm/v1/sample/rejection_sampler_mtp.py
deleted
100644 → 0
View file @
c34fa0bf
This diff is collapsed.
Click to expand it.
vllm/v1/spec_decode/eagle.py
View file @
33485749
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
,
Optional
import
numpy
as
np
import
torch
import
torch.nn
as
nn
...
...
@@ -59,9 +57,6 @@ class EagleProposer:
self
.
use_cuda_graph
=
(
self
.
vllm_config
.
compilation_config
.
level
==
CompilationLevel
.
PIECEWISE
and
not
self
.
vllm_config
.
model_config
.
enforce_eager
)
self
.
use_full_cuda_graph
=
(
self
.
use_cuda_graph
and
vllm_config
.
compilation_config
.
full_cuda_graph
)
self
.
cudagraph_batch_sizes
=
list
(
reversed
(
self
.
vllm_config
.
compilation_config
.
cudagraph_capture_sizes
))
...
...
@@ -77,8 +72,6 @@ class EagleProposer:
(
self
.
max_num_tokens
,
self
.
hidden_size
),
dtype
=
self
.
dtype
,
device
=
device
)
# attention metadata captured in full cudagraph mode
self
.
attn_metadata_cudagraph
=
None
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
self
.
arange
=
torch
.
arange
(
vllm_config
.
scheduler_config
.
max_num_seqs
+
...
...
@@ -138,38 +131,6 @@ class EagleProposer:
# copy inputs to buffer for cudagraph
self
.
positions
[:
num_tokens
]
=
target_positions
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
if
(
self
.
use_full_cuda_graph
and
num_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
assert
self
.
attn_metadata_cudagraph
if
self
.
method
in
[
"eagle"
,
"eagle3"
]:
self
.
attn_metadata_cudagraph
.
seq_lens
[:
batch_size
]
=
(
attn_metadata
.
seq_lens
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
num_tokens
]
=
(
attn_metadata
.
slot_mapping
)
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
block_table
[:
batch_size
]
=
(
attn_metadata
.
block_table
)
elif
self
.
method
==
"deepseek_mtp"
:
self
.
attn_metadata_cudagraph
.
num_actual_tokens
=
(
attn_metadata
.
num_actual_tokens
)
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
num_tokens
]
=
(
attn_metadata
.
slot_mapping
)
self
.
attn_metadata_cudagraph
.
num_decodes
=
(
attn_metadata
.
num_decodes
)
self
.
attn_metadata_cudagraph
.
num_decode_tokens
=
(
attn_metadata
.
num_decode_tokens
)
self
.
attn_metadata_cudagraph
.
num_prefills
=
(
attn_metadata
.
num_prefills
)
if
attn_metadata
.
decode
is
not
None
:
self
.
attn_metadata_cudagraph
.
decode
.
block_table
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
block_table
)
self
.
attn_metadata_cudagraph
.
decode
.
seq_lens
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
seq_lens
)
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
...
...
@@ -186,15 +147,11 @@ class EagleProposer:
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_prob
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
draft_probs_list
=
[
draft_prob
]
# Early exit if there is only one draft token to be generated.
if
self
.
num_speculative_tokens
==
1
:
# [batch_size, 1]
return
draft_token_ids
.
view
(
-
1
,
1
)
,
draft_probs_list
return
draft_token_ids
.
view
(
-
1
,
1
)
# TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once
...
...
@@ -234,7 +191,7 @@ class EagleProposer:
seq_lens
=
(
seq_lens
+
1
),
)
for
i
in
range
(
self
.
num_speculative_tokens
-
1
):
for
_
in
range
(
self
.
num_speculative_tokens
-
1
):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
...
...
@@ -285,43 +242,6 @@ class EagleProposer:
self
.
input_ids
[:
batch_size
]
=
input_ids
self
.
positions
[:
batch_size
]
=
clamped_positions
self
.
hidden_states
[:
batch_size
]
=
hidden_states
if
(
self
.
use_full_cuda_graph
and
batch_size
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
assert
self
.
attn_metadata_cudagraph
if
self
.
method
in
[
"eagle"
,
"eagle3"
]:
self
.
attn_metadata_cudagraph
.
seq_lens
[:
batch_size
]
=
(
attn_metadata
.
seq_lens
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
batch_size
]
=
(
attn_metadata
.
slot_mapping
)
if
i
==
0
:
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
block_table
[:
batch_size
]
=
(
attn_metadata
.
block_table
)
elif
self
.
method
==
"deepseek_mtp"
:
self
.
attn_metadata_cudagraph
.
num_actual_tokens
=
(
attn_metadata
.
num_actual_tokens
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
slot_mapping
)
self
.
attn_metadata_cudagraph
.
num_decodes
=
(
attn_metadata
.
num_decodes
)
self
.
attn_metadata_cudagraph
.
num_decode_tokens
=
(
attn_metadata
.
num_decode_tokens
)
self
.
attn_metadata_cudagraph
.
num_prefills
=
(
attn_metadata
.
num_prefills
)
self
.
attn_metadata_cudagraph
.
decode
.
seq_lens
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
seq_lens
)
if
i
==
0
:
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
decode
.
block_table
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
block_table
)
# Run the model.
with
set_forward_context
(
per_layer_attn_metadata
,
...
...
@@ -345,15 +265,10 @@ class EagleProposer:
# TODO(wenlong): get more than one token for tree attention
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids_list
.
append
(
draft_token_ids
)
draft_prob
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
draft_probs_list
.
append
(
draft_prob
)
# [batch_size, num_speculative_tokens]
draft_token_ids
=
torch
.
stack
(
draft_token_ids_list
,
dim
=
1
)
draft_probs
=
torch
.
stack
(
draft_probs_list
,
dim
=
1
).
contiguous
()
return
draft_token_ids
,
draft_probs
return
draft_token_ids
def
prepare_inputs
(
self
,
...
...
@@ -503,13 +418,8 @@ class EagleProposer:
def
dummy_run
(
self
,
num_tokens
:
int
,
attn_metadata
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
None
:
if
attn_metadata
is
not
None
and
self
.
attn_metadata_cudagraph
is
None
:
self
.
attn_metadata_cudagraph
=
attn_metadata
[
self
.
attn_layer_names
[
0
]]
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
with
set_forward_context
(
None
,
self
.
vllm_config
,
num_tokens
=
num_tokens
):
self
.
model
(
self
.
input_ids
[:
num_tokens
],
...
...
vllm/v1/spec_decode/utils.py
View file @
33485749
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
msgspec
from
abc
import
ABC
import
torch
from
vllm.sampling_params
import
SamplingParams
_SAMPLING_EPS
=
1e-5
...
...
@@ -16,41 +12,3 @@ def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool:
or
sampling_params
.
repetition_penalty
!=
1.0
or
sampling_params
.
min_p
>
_SAMPLING_EPS
or
sampling_params
.
logprobs
is
not
None
)
class
DraftProbs
(
ABC
):
# type: ignore[call-arg]
"""Draft probs corresponding to in-progress sequences."""
# spec tokens probs.
draft_probs
:
torch
.
Tensor
# The request id list.
_req_ids
:
list
[
str
]
def
__init__
(
self
,
draft_probs
,
req_ids
):
assert
len
(
req_ids
)
==
len
(
draft_probs
)
self
.
draft_probs
=
draft_probs
self
.
_req_ids
=
req_ids
def
update
(
self
,
draft_probs
:
torch
.
Tensor
,
tmp_req_ids
:
list
[
str
]):
diff_req_ids
=
[
item
for
item
in
self
.
_req_ids
if
item
not
in
tmp_req_ids
]
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
diff_req_ids
]
self
.
_req_ids
=
diff_req_ids
self
.
draft_probs
=
self
.
draft_probs
[
index
]
self
.
draft_probs
=
torch
.
cat
([
self
.
draft_probs
,
draft_probs
])
self
.
_req_ids
.
extend
(
tmp_req_ids
)
assert
len
(
self
.
_req_ids
)
==
len
(
self
.
draft_probs
)
def
prune
(
self
,
req_ids
:
list
[
str
]):
new_req_ids
=
[
req_id
for
req_id
in
self
.
_req_ids
if
req_id
not
in
req_ids
]
if
new_req_ids
!=
self
.
_req_ids
:
# Batch contents changed - prune removed sequences.
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
new_req_ids
]
self
.
draft_probs
=
self
.
draft_probs
[
index
]
self
.
_req_ids
=
new_req_ids
def
get_probs
(
self
,
req_ids
:
list
[
str
]):
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
req_ids
]
return
self
.
draft_probs
[
index
]
vllm/v1/worker/gpu_model_runner.py
View file @
33485749
...
...
@@ -60,13 +60,11 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.rejection_sampler
import
RejectionSampler
from
vllm.v1.sample.rejection_sampler_mtp
import
MtpRejectionSampler
from
vllm.v1.sample.sampler
import
Sampler
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.medusa
import
MedusaProposer
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.spec_decode.utils
import
DraftProbs
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
vllm.platforms
import
current_platform
...
...
@@ -196,13 +194,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else
:
raise
ValueError
(
"Unknown speculative decoding method: "
f
"
{
self
.
speculative_config
.
method
}
"
)
self
.
use_mtp
=
self
.
speculative_config
.
method
==
"deepseek_mtp"
if
not
self
.
use_mtp
:
self
.
rejection_sampler
=
RejectionSampler
()
else
:
self
.
rejection_sampler
=
MtpRejectionSampler
()
self
.
rejection_sampler
=
RejectionSampler
()
# Request states.
self
.
requests
:
dict
[
str
,
CachedRequestState
]
=
{}
...
...
@@ -328,8 +320,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# means this layer will perform attention using the keys and values
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
self
.
shared_kv_cache_layers
:
dict
[
str
,
str
]
=
{}
self
.
draft_probs
:
Optional
[
DraftProbs
]
=
None
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
"""
...
...
@@ -389,7 +379,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for
req_id
in
scheduler_output
.
finished_req_ids
:
self
.
requests
.
pop
(
req_id
,
None
)
self
.
encoder_cache
.
pop
(
req_id
,
None
)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
# scheduled_req_ids overlap. This happens when a request is aborted and
...
...
@@ -398,10 +387,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# and handling the second as a new request.
for
req_id
in
scheduler_output
.
finished_req_ids
:
self
.
input_batch
.
remove_request
(
req_id
)
# prune draft probs of finished requests
if
self
.
use_mtp
and
self
.
draft_probs
is
not
None
and
len
(
scheduler_output
.
finished_req_ids
)
>
0
:
self
.
draft_probs
.
prune
(
list
(
scheduler_output
.
finished_req_ids
))
# Free the cached encoder outputs.
for
req_id
,
input_id
in
scheduler_output
.
free_encoder_input_ids
:
...
...
@@ -1556,8 +1541,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_logits
=
logits
[
spec_decode_metadata
.
target_logits_indices
]
output_token_ids
=
self
.
rejection_sampler
(
spec_decode_metadata
,
self
.
draft_probs
.
get_probs
(
self
.
input_batch
.
req_ids
)
\
if
self
.
draft_probs
is
not
None
else
None
,
# draft_probs
None
,
# draft_probs
target_logits
,
bonus_token_ids
,
sampling_metadata
,
...
...
@@ -1643,7 +1627,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_token_ids
=
None
else
:
assert
spec_decode_common_attn_metadata
is
not
None
spec_token_ids
,
draft_probs
=
self
.
propose_draft_token_ids
(
spec_token_ids
=
self
.
propose_draft_token_ids
(
scheduler_output
,
valid_sampled_token_ids
,
sampling_metadata
,
...
...
@@ -1653,14 +1637,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_decode_metadata
,
spec_decode_common_attn_metadata
,
)
if
self
.
use_mtp
:
if
self
.
draft_probs
is
None
:
self
.
draft_probs
=
DraftProbs
(
draft_probs
,
self
.
input_batch
.
req_ids
)
else
:
self
.
draft_probs
.
update
(
draft_probs
,
self
.
input_batch
.
req_ids
)
spec_token_ids
=
spec_token_ids
.
tolist
()
self
.
eplb_step
()
...
...
@@ -1767,7 +1743,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
[
h
[
token_indices
]
for
h
in
aux_hidden_states
],
dim
=-
1
)
else
:
target_hidden_states
=
hidden_states
[
token_indices
]
spec
_token_ids
,
draft_probs
=
self
.
drafter
.
propose
(
draft
_token_ids
=
self
.
drafter
.
propose
(
target_token_ids
=
target_token_ids
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
...
...
@@ -1776,8 +1752,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
common_attn_metadata
=
common_attn_metadata
,
num_rejected_tokens
=
num_rejected_tokens
)
return
spec_token_ids
,
draft_probs
spec_token_ids
=
draft_token_ids
.
tolist
()
return
spec_token_ids
@
staticmethod
def
maybe_setup_kv_connector
(
scheduler_output
:
"SchedulerOutput"
):
...
...
@@ -2224,7 +2200,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
self
.
speculative_config
and
self
.
speculative_config
.
use_eagle
():
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
self
.
drafter
.
dummy_run
(
num_tokens
,
attn_metadata
)
self
.
drafter
.
dummy_run
(
num_tokens
)
# This is necessary to avoid blocking DP.
# For dummy runs, we typically skip EPLB since we don't have any real
...
...
@@ -2291,11 +2267,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
draft_token_ids
,
self
.
device
)
num_tokens
=
sum
(
len
(
ids
)
for
ids
in
draft_token_ids
)
draft_probs
=
torch
.
randn
(
num_tokens
,
logits
.
shape
[
-
1
],
device
=
self
.
device
,
dtype
=
logits
.
dtype
)
# draft_probs = None
# draft_probs = torch.randn(
# num_tokens, logits.shape[-1], device=self.device,
# dtype=logits.dtype)
draft_probs
=
None
target_logits
=
torch
.
randn
(
num_tokens
,
logits
.
shape
[
-
1
],
device
=
self
.
device
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment