Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f2218895
Commit
f2218895
authored
Aug 07, 2025
by
zhuwenwen
Browse files
[feat]支持mtp模型full_cuda_graph
parent
236f11df
Changes
4
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
16 additions
and
601 deletions
+16
-601
vllm/v1/sample/rejection_sampler_mtp.py
vllm/v1/sample/rejection_sampler_mtp.py
+0
-518
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+4
-11
vllm/v1/spec_decode/utils.py
vllm/v1/spec_decode/utils.py
+1
-41
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+11
-31
No files found.
vllm/v1/sample/rejection_sampler_mtp.py
deleted
100644 → 0
View file @
236f11df
This diff is collapsed.
Click to expand it.
vllm/v1/spec_decode/eagle.py
View file @
f2218895
...
...
@@ -98,7 +98,7 @@ class EagleProposer:
next_token_ids
:
torch
.
Tensor
,
common_attn_metadata
:
CommonAttentionMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
num_tokens
=
target_token_ids
.
shape
[
0
]
batch_size
=
next_token_ids
.
shape
[
0
]
last_token_indices
=
common_attn_metadata
.
query_start_loc
[
1
:]
-
1
...
...
@@ -185,16 +185,13 @@ class EagleProposer:
last_hidden_states
,
hidden_states
=
ret_hidden_states
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_prob
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
draft_
probs_list
=
[
draft_prob
]
draft_
token_ids
=
torch
.
argmax
(
logits
,
dim
=-
1
)
# Early exit if there is only one draft token to be generated.
if
self
.
num_speculative_tokens
==
1
:
# [batch_size, 1]
return
draft_token_ids
.
view
(
-
1
,
1
)
,
draft_prob
.
view
(
-
1
,
1
,
draft_prob
.
shape
[
-
1
])
return
draft_token_ids
.
view
(
-
1
,
1
)
# TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once
...
...
@@ -346,14 +343,10 @@ class EagleProposer:
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids_list
.
append
(
draft_token_ids
)
draft_prob
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
draft_probs_list
.
append
(
draft_prob
)
# [batch_size, num_speculative_tokens]
draft_token_ids
=
torch
.
stack
(
draft_token_ids_list
,
dim
=
1
)
draft_probs
=
torch
.
stack
(
draft_probs_list
,
dim
=
1
).
contiguous
()
return
draft_token_ids
,
draft_probs
return
draft_token_ids
def
prepare_inputs
(
self
,
...
...
vllm/v1/spec_decode/utils.py
View file @
f2218895
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
msgspec
from
abc
import
ABC
import
torch
from
vllm.sampling_params
import
SamplingParams
...
...
@@ -17,40 +14,3 @@ def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool:
or
sampling_params
.
min_p
>
_SAMPLING_EPS
or
sampling_params
.
logprobs
is
not
None
)
\ No newline at end of file
class
DraftProbs
(
ABC
):
# type: ignore[call-arg]
"""Draft probs corresponding to in-progress sequences."""
# spec tokens probs.
draft_probs
:
torch
.
Tensor
# The request id list.
_req_ids
:
list
[
str
]
def
__init__
(
self
,
draft_probs
,
req_ids
):
assert
len
(
req_ids
)
==
len
(
draft_probs
)
self
.
draft_probs
=
draft_probs
self
.
_req_ids
=
req_ids
def
update
(
self
,
draft_probs
:
torch
.
Tensor
,
tmp_req_ids
:
list
[
str
]):
diff_req_ids
=
[
item
for
item
in
self
.
_req_ids
if
item
not
in
tmp_req_ids
]
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
diff_req_ids
]
self
.
_req_ids
=
diff_req_ids
self
.
draft_probs
=
self
.
draft_probs
[
index
]
self
.
draft_probs
=
torch
.
cat
([
self
.
draft_probs
,
draft_probs
])
self
.
_req_ids
.
extend
(
tmp_req_ids
)
assert
len
(
self
.
_req_ids
)
==
len
(
self
.
draft_probs
)
def
prune
(
self
,
req_ids
:
list
[
str
]):
new_req_ids
=
[
req_id
for
req_id
in
self
.
_req_ids
if
req_id
not
in
req_ids
]
if
new_req_ids
!=
self
.
_req_ids
:
# Batch contents changed - prune removed sequences.
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
new_req_ids
]
self
.
draft_probs
=
self
.
draft_probs
[
index
]
self
.
_req_ids
=
new_req_ids
def
get_probs
(
self
,
req_ids
:
list
[
str
]):
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
req_ids
]
return
self
.
draft_probs
[
index
]
vllm/v1/worker/gpu_model_runner.py
View file @
f2218895
...
...
@@ -60,13 +60,11 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.rejection_sampler
import
RejectionSampler
from
vllm.v1.sample.rejection_sampler_mtp
import
MtpRejectionSampler
from
vllm.v1.sample.sampler
import
Sampler
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.medusa
import
MedusaProposer
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.spec_decode.utils
import
DraftProbs
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
vllm.platforms
import
current_platform
...
...
@@ -197,12 +195,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
raise
ValueError
(
"Unknown speculative decoding method: "
f
"
{
self
.
speculative_config
.
method
}
"
)
self
.
use_mtp
=
self
.
speculative_config
.
method
==
"deepseek_mtp"
if
not
self
.
use_mtp
:
self
.
rejection_sampler
=
RejectionSampler
()
else
:
self
.
rejection_sampler
=
MtpRejectionSampler
()
# Request states.
self
.
requests
:
dict
[
str
,
CachedRequestState
]
=
{}
...
...
@@ -329,8 +322,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
self
.
shared_kv_cache_layers
:
dict
[
str
,
str
]
=
{}
self
.
draft_probs
:
Optional
[
DraftProbs
]
=
None
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
"""
Update the order of requests in the batch based on the attention
...
...
@@ -399,10 +390,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for
req_id
in
scheduler_output
.
finished_req_ids
:
self
.
input_batch
.
remove_request
(
req_id
)
# prune draft probs of finished requests
if
self
.
use_mtp
and
self
.
draft_probs
is
not
None
and
len
(
scheduler_output
.
finished_req_ids
)
>
0
:
self
.
draft_probs
.
prune
(
list
(
scheduler_output
.
finished_req_ids
))
# Free the cached encoder outputs.
for
req_id
,
input_id
in
scheduler_output
.
free_encoder_input_ids
:
encoder_outputs
=
self
.
encoder_cache
.
get
(
req_id
)
...
...
@@ -1556,8 +1543,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_logits
=
logits
[
spec_decode_metadata
.
target_logits_indices
]
output_token_ids
=
self
.
rejection_sampler
(
spec_decode_metadata
,
self
.
draft_probs
.
get_probs
(
self
.
input_batch
.
req_ids
)
\
if
self
.
draft_probs
is
not
None
else
None
,
# draft_probs
None
,
# draft_probs
target_logits
,
bonus_token_ids
,
sampling_metadata
,
...
...
@@ -1643,7 +1629,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_token_ids
=
None
else
:
assert
spec_decode_common_attn_metadata
is
not
None
spec_token_ids
,
draft_probs
=
self
.
propose_draft_token_ids
(
spec_token_ids
=
self
.
propose_draft_token_ids
(
scheduler_output
,
valid_sampled_token_ids
,
sampling_metadata
,
...
...
@@ -1653,12 +1639,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_decode_metadata
,
spec_decode_common_attn_metadata
,
)
if
self
.
use_mtp
:
if
self
.
draft_probs
is
None
:
self
.
draft_probs
=
DraftProbs
(
draft_probs
,
self
.
input_batch
.
req_ids
)
else
:
self
.
draft_probs
.
update
(
draft_probs
,
self
.
input_batch
.
req_ids
)
spec_token_ids
=
spec_token_ids
.
tolist
()
...
...
@@ -1687,8 +1667,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states
:
Optional
[
torch
.
Tensor
],
spec_decode_metadata
:
Optional
[
SpecDecodeMetadata
],
common_attn_metadata
:
CommonAttentionMetadata
,
)
->
tuple
[
list
[
list
[
int
]],
torch
.
Tensor
]:
draft_probs
=
None
)
->
list
[
list
[
int
]]:
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
if
self
.
speculative_config
.
method
==
"ngram"
:
assert
isinstance
(
self
.
drafter
,
NgramProposer
)
...
...
@@ -1768,7 +1747,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
[
h
[
token_indices
]
for
h
in
aux_hidden_states
],
dim
=-
1
)
else
:
target_hidden_states
=
hidden_states
[
token_indices
]
spec
_token_ids
,
draft_probs
=
self
.
drafter
.
propose
(
draft
_token_ids
=
self
.
drafter
.
propose
(
target_token_ids
=
target_token_ids
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
...
...
@@ -1777,8 +1756,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
common_attn_metadata
=
common_attn_metadata
,
num_rejected_tokens
=
num_rejected_tokens
)
spec_token_ids
=
draft_token_ids
.
tolist
()
return
spec_token_ids
,
draft_probs
return
spec_token_ids
@
staticmethod
def
maybe_setup_kv_connector
(
scheduler_output
:
"SchedulerOutput"
):
...
...
@@ -2292,10 +2272,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
draft_token_ids
,
self
.
device
)
num_tokens
=
sum
(
len
(
ids
)
for
ids
in
draft_token_ids
)
draft_probs
=
torch
.
randn
(
num_tokens
,
logits
.
shape
[
-
1
],
device
=
self
.
device
,
dtype
=
logits
.
dtype
)
#
draft_probs = None
#
draft_probs = torch.randn(
#
num_tokens, logits.shape[-1], device=self.device,
#
dtype=logits.dtype)
draft_probs
=
None
target_logits
=
torch
.
randn
(
num_tokens
,
logits
.
shape
[
-
1
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment