Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
467490e6
Commit
467490e6
authored
Aug 06, 2025
by
zhuwenwen
Browse files
Revert "Revert "[feat]1.支持mtp模型 full_cuda_graph; 2.优化mtp拒绝采样""
This reverts commit
33485749
.
parent
33485749
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
691 additions
and
16 deletions
+691
-16
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+1
-1
vllm/v1/sample/rejection_sampler_mtp.py
vllm/v1/sample/rejection_sampler_mtp.py
+518
-0
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+94
-4
vllm/v1/spec_decode/utils.py
vllm/v1/spec_decode/utils.py
+42
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+36
-11
No files found.
vllm/model_executor/models/deepseek_mtp.py
View file @
467490e6
...
...
@@ -152,7 +152,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
return
logits
#
@support_torch_compile
@
support_torch_compile
class
DeepSeekMTP
(
nn
.
Module
,
SupportsPP
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
vllm/v1/sample/rejection_sampler_mtp.py
0 → 100644
View file @
467490e6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
from
functools
import
cached_property
import
torch
import
torch.nn
as
nn
from
vllm.logger
import
init_logger
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.rejection_sampler
import
RejectionSampler
from
vllm.v1.sample.ops.topk_topp_sampler
import
apply_top_k_top_p
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
# Maximum number of speculative draft tokens allowed per request in a single
# step. This value is chosen to be large enough to handle typical use cases.
MAX_SPEC_LEN
=
32
PLACEHOLDER_TOKEN_ID
=
-
1
class
MtpRejectionSampler
(
RejectionSampler
):
"""
The implementation strictly follows the algorithm described in
https://arxiv.org/abs/2211.17192.
However, we want to clarify the terminology used in the implementation:
accepted tokens: tokens that are accepted based on the relationship
between the "raw" draft and target probabilities.
recovered tokens: tokens that are sampled based on the adjusted probability
distribution, which is derived from both the draft and target
probabilities.
bonus tokens:
If all proposed tokens are accepted, the bonus token is added to the
end of the sequence. The bonus token is only sampled from the target
probabilities. We pass in the bonus tokens instead of sampling them
in the rejection sampler to allow for more flexibility in the
sampling process. For example, we can use top_p, top_k sampling for
bonus tokens, while spec decode does not support these sampling
strategies.
output tokens:
Tokens are finally generated with the rejection sampler.
output tokens = accepted tokens + recovered tokens + bonus tokens
"""
def
__init__
(
self
):
super
().
__init__
()
# NOTE: A "bonus token" is accepted iff all proposal tokens are
# accepted. There is always only one possible bonus token. We store this
# value in a variable for readability.
self
.
_num_bonus_tokens
=
1
def
forward
(
self
,
metadata
:
SpecDecodeMetadata
,
# [num_tokens, vocab_size]
draft_probs
:
Optional
[
torch
.
Tensor
],
# [num_tokens, vocab_size]
target_logits
:
torch
.
Tensor
,
# [batch_size, 1]
bonus_token_ids
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
'''
Args:
metadata:
Metadata for spec decoding.
draft_probs (Optional[torch.Tensor]):
Probability distribution for the draft tokens. Shape is
[num_tokens, vocab_size]. Can be None if probabilities are
not provided, which is the case for ngram spec decode.
target_logits (torch.Tensor):
Target model's logits probability distribution.
Shape is [num_tokens, vocab_size]. Here, probabilities from
different requests are flattened into a single tensor because
this is the shape of the output logits.
NOTE: `target_logits` can be updated in place to save memory.
bonus_token_ids_tensor (torch.Tensor):
A tensor containing bonus tokens. Shape is [batch_size, 1].
Bonus tokens are added to the end of the sequence if all
proposed tokens are accepted. We generate the bonus tokens
outside of the rejection sampler with the default sampling
strategy. It allows for more flexibility in the sampling
process such as top_p, top_k sampling.
sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata):
Additional metadata needed for sampling, such as temperature,
top-k/top-p parameters, or other relevant information.
Returns:
output_token_ids (torch.Tensor):
A tensor containing the final output token IDs.
'''
assert
metadata
.
max_spec_len
<=
MAX_SPEC_LEN
assert
draft_probs
is
not
None
# [num_tokens, vocab_size]
# NOTE(woosuk): `target_logits` can be updated in place inside the
# `compute_probs` function.
num_draft_tokens
=
metadata
.
num_draft_tokens
[
0
]
target_probs
=
compute_probs
(
target_logits
,
metadata
.
cu_num_draft_tokens
,
sampling_metadata
,
num_draft_tokens
)
target_probs
=
target_probs
.
view
(
-
1
,
num_draft_tokens
,
target_probs
.
shape
[
-
1
])
draft_probs
=
draft_probs
.
view
(
-
1
,
num_draft_tokens
,
draft_probs
.
shape
[
-
1
])
draft_token_ids
=
metadata
.
draft_token_ids
.
view
(
-
1
,
num_draft_tokens
)
accepted
,
recovered_token_ids
=
(
self
.
_batch_modified_rejection_sampling
(
target_probs
,
draft_probs
,
draft_token_ids
,
None
,
))
output_token_ids
=
self
.
_create_output
(
accepted
,
recovered_token_ids
,
draft_token_ids
,
bonus_token_ids
,
)
return
output_token_ids
def
_create_uniform_samples
(
self
,
seeded_seqs
:
Optional
[
dict
[
int
,
torch
.
Generator
]],
batch_size
:
int
,
k
:
int
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
"""
Generates a batch of uniform random samples, with optional seeding
for specific sequences.
This method creates a tensor of shape `(batch_size, k + 1)` filled
with uniform random values in the range [0, 1). If `seeded_seqs`
is provided, the sequences corresponding to specific indices
will be generated using the provided `torch.Generator` for
reproducibility. The other sequences will be generated without
a seed.
Args:
seeded_seqs : Optional[dict[int, torch.Generator]]
A dictionary mapping indices in the batch to
`torch.Generator` objects. If `None`, all samples are
generated without a seed.
batch_size : int
The number of sequences to generate.
k : int
The number of random samples per sequence.
device : torch.device
The device on which to allocate the tensor.
Returns:
uniform_rand : torch.Tensor
A tensor of shape `(batch_size, k + 1)` containing uniform
random values in the range [0, 1).
"""
if
not
seeded_seqs
:
return
torch
.
rand
(
batch_size
,
k
+
1
,
device
=
device
)
uniform_rand
=
torch
.
empty
(
batch_size
,
k
+
1
,
device
=
device
)
non_seeded_indices
=
[]
for
idx
in
range
(
batch_size
):
generator
=
seeded_seqs
.
get
(
idx
)
if
generator
is
None
:
non_seeded_indices
.
append
(
idx
)
else
:
uniform_rand
[
idx
,
:]
=
torch
.
rand
(
1
,
k
+
1
,
dtype
=
self
.
probs_dtype
,
device
=
device
,
generator
=
generator
)
if
non_seeded_indices
:
uniform_rand
[
non_seeded_indices
,
:]
=
torch
.
rand
(
len
(
non_seeded_indices
),
k
+
1
,
dtype
=
self
.
probs_dtype
,
device
=
device
)
return
uniform_rand
def
_get_accepted
(
self
,
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
seeded_seqs
:
Optional
[
dict
[
int
,
torch
.
Generator
]],
)
->
torch
.
Tensor
:
r
"""Create bool matrix over the proposed draft tokens. If
True, then a token can be accepted, else it should be
rejected.
Given $q(\hat{x}_{n+1}|x_1, \dots, x_n)$, the probability of
$\hat{x}_{n+1}$ given context $x_1, \dots, x_n$ according
to the target model, and $p(\hat{x}_{n+1}|x_1, \dots, x_n)$, the
same conditional probability according to the draft model, the token
is accepted with probability:
$$
\min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
{p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)
$$
This implementation does not apply causality. When using the output,
if a token is rejected, subsequent tokens should not be used.
Returns a bool tensor of shape [batch_size, k] specifying which tokens
are accepted.
"""
batch_size
,
k
,
_
=
draft_probs
.
shape
batch_indices
=
torch
.
arange
(
batch_size
,
device
=
target_probs
.
device
)[:,
None
]
probs_indices
=
torch
.
arange
(
k
,
device
=
target_probs
.
device
)
# shape [batch_size, k]
selected_draft_probs
=
draft_probs
[
batch_indices
,
probs_indices
,
draft_token_ids
]
# shape [batch_size, k]
selected_target_probs
=
target_probs
[
batch_indices
,
probs_indices
,
draft_token_ids
]
uniform_rand
=
self
.
_create_uniform_samples
(
seeded_seqs
,
batch_size
,
k
-
1
,
target_probs
.
device
)
capped_ratio
=
torch
.
minimum
(
selected_target_probs
/
selected_draft_probs
,
torch
.
full
((
1
,
),
1
,
device
=
target_probs
.
device
))
accepted
=
uniform_rand
<
capped_ratio
return
accepted
def
_get_recovered_probs
(
self
,
target_probs
:
torch
.
Tensor
,
# [k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [k, vocab_size]
)
->
torch
.
Tensor
:
r
"""Create a probability distribution for each proposed token which can
be sampled if the proposed token is rejected.
When this routine is applied sequentially, the true distribution of the
target model is recovered (within hardware numerics).
The probability distribution used in this rejection case is constructed
as follows. Given $q(x|x_1, \dots, x_n)$, the probability of
$x$ given context $x_1, \dots, x_n$ according to the target
model and $p(x|x_1, \dots, x_n)$, the same conditional probability
according to the draft model:
$$
x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+
$$
where $(f(x))_+$ is defined as:
$$
(f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}
$$
See https://github.com/vllm-project/vllm/pull/2336 for a visualization
of the draft, target, and recovered probability distributions.
Returns a tensor of shape [batch_size, k, vocab_size].
Note:
This batches operations on GPU and thus constructs the recovered
distribution for all tokens, even if they are accepted. This causes
division-by-zero errors, so we use self._smallest_positive_value to
avoid that. This introduces some drift to the distribution.
"""
_
,
k
,
_
=
draft_probs
.
shape
# shape [batch_size, k, vocab_size]
difference
=
target_probs
-
draft_probs
# TODO(cade): Can we use logprobs instead of probs, and avoid the
# division-by-zero errors without introducing distribution drift?
# shape [batch_size, k, vocab_size]
f
=
torch
.
clamp
(
difference
,
min
=
self
.
_smallest_positive_value
)
# shape [batch_size, k, vocab_size]
recovered_probs
=
f
/
torch
.
sum
(
f
,
dim
=-
1
).
reshape
(
-
1
,
k
,
1
)
return
recovered_probs
def
_batch_modified_rejection_sampling
(
self
,
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
seeded_seqs
:
Optional
[
dict
[
int
,
torch
.
Generator
]],
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Perform modified rejection sampling on each sequence.
Returns:
A tuple of two tensors:
0: A bool tensor of which tokens in each sequence is accepted.
shape = [batch_size, k]
1: Token ids sampled from a recovered distribution, to be used
when a token is rejected.
shape = [batch_size, k]
"""
batch_size
,
k
,
vocab_size
=
target_probs
.
shape
# shape [batch_size, k]
accepted
=
self
.
_get_accepted
(
target_probs
,
draft_probs
,
draft_token_ids
,
seeded_seqs
)
recovered_probs
=
self
.
_get_recovered_probs
(
target_probs
,
draft_probs
).
reshape
(
batch_size
*
k
,
vocab_size
)
# NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids
=
_multinomial
(
recovered_probs
,
num_samples
=
1
,
k
=
k
,
seeded_seqs
=
seeded_seqs
or
{},
).
reshape
(
batch_size
,
k
)
return
accepted
,
recovered_token_ids
def
_create_output
(
self
,
accepted
:
torch
.
Tensor
,
# [batch_size, k]
substitute_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
bonus_token_ids
:
torch
.
Tensor
,
# [batch_size]
)
->
torch
.
Tensor
:
"""Format output. Returns a matrix of token ids. When
a token is rejected via sampling, all subsequent token ids are
set to -1 for the sequence.
Args:
accepted: A boolean tensor indicating if the corresponding
draft token in draft_token_ids should be accepted or not.
substitute_token_ids: A tensor of token_ids that can be used
as substitutes for the draft token ids if the proposed token
is rejected.
draft_token_ids: A tensor of token ids speculated by the
draft model.
bonus_token_ids: Token ids to use as the bonus token if
all the draft tokens are accepted.
Returns:
A tensor containing the accepted token ids. The shape of the
tensor is [batch_size, k + num_bonus_tokens]
"""
batch_size
,
k
=
substitute_token_ids
.
shape
bonus_token_ids
=
bonus_token_ids
.
squeeze
(
-
1
)
# Determine the index of the first False value for each row.
limits
=
(
accepted
==
0
).
max
(
1
).
indices
limits
[
~
(
accepted
==
0
).
any
(
1
)]
=
k
# Create masks using the indices.
indices
=
torch
.
arange
(
k
,
device
=
accepted
.
device
).
unsqueeze
(
0
)
accepted_mask
=
indices
<
limits
.
unsqueeze
(
1
)
after_false_mask
=
indices
==
limits
.
unsqueeze
(
1
)
# Create an extended output tensor
output_with_bonus_tokens
=
-
torch
.
ones
(
(
batch_size
,
k
+
self
.
_num_bonus_tokens
),
dtype
=
self
.
token_id_dtype
,
device
=
accepted
.
device
)
output
=
output_with_bonus_tokens
[:,
:
k
]
# Fill in the first k columns of the output tensor using masks and data
# tensors.
output
[:,
:
k
]
=
torch
.
where
(
accepted_mask
,
draft_token_ids
,
-
torch
.
ones_like
(
draft_token_ids
))
# Fill the last column.
# We check output directly as accepted may have True values inconsistent
# with causal acceptance.
output_with_bonus_tokens
[:,
-
1
]
=
torch
.
where
(
output
[:,
-
1
]
!=
-
1
,
bonus_token_ids
,
-
1
)
# Fill the recovered token ids.
output
.
mul_
(
~
after_false_mask
).
add_
(
substitute_token_ids
.
mul
(
after_false_mask
))
return
output_with_bonus_tokens
@
staticmethod
def
parse_output
(
output_token_ids
:
torch
.
Tensor
,
vocab_size
:
int
,
)
->
list
[
list
[
int
]]:
"""Parse the output of the rejection sampler.
Args:
output_token_ids: The sampled token IDs in shape
[batch_size, max_spec_len + 1]. The rejected tokens are
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
and will be filtered out in this function.
vocab_size: The size of the vocabulary.
Returns:
A list of lists of token IDs.
"""
output_token_ids_np
=
output_token_ids
.
cpu
().
numpy
()
# Create mask for valid tokens.
valid_mask
=
((
output_token_ids_np
!=
PLACEHOLDER_TOKEN_ID
)
&
(
output_token_ids_np
<
vocab_size
))
outputs
=
[
row
[
valid_mask
[
i
]].
tolist
()
for
i
,
row
in
enumerate
(
output_token_ids_np
)
]
return
outputs
@
cached_property
def
_smallest_positive_value
(
self
)
->
float
:
"""Return the smallest positive value representable by the probs dtype.
This value is used when constructing a distribution from which to sample
recovered tokens in the first rejection case.
See _get_recovered_probs for more details
Note that this isn't actually the smallest positive value representable
by float32, but the smallest positive normal value.
See https://en.wikipedia.org/wiki/Subnormal_number for more information.
"""
return
torch
.
finfo
(
self
.
probs_dtype
).
tiny
@
property
def
probs_dtype
(
self
):
return
torch
.
float32
@
property
def
token_id_dtype
(
self
):
return
torch
.
int64
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
@
torch
.
compile
(
dynamic
=
True
,
backend
=
current_platform
.
simple_compile_backend
)
def
_multinomial
(
probs
:
torch
.
Tensor
,
num_samples
:
int
,
k
:
int
,
seeded_seqs
:
dict
[
int
,
torch
.
Generator
],
)
->
torch
.
Tensor
:
if
num_samples
>
1
:
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
probs
=
probs
[:,
None
,
:].
expand
(
probs
.
shape
[
0
],
num_samples
,
probs
.
shape
[
1
]).
contiguous
().
view
(
-
1
,
probs
.
shape
[
1
])
q
=
torch
.
empty_like
(
probs
)
if
not
seeded_seqs
:
q
.
exponential_
(
1.0
)
else
:
start
=
0
for
idx
in
range
(
len
(
q
)
//
k
):
end
=
start
+
k
generator
=
seeded_seqs
.
get
(
idx
)
# Note: generator might be None for non seeded
q
[
start
:
end
].
exponential_
(
1.0
,
generator
=
generator
)
start
=
end
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
def
compute_probs
(
logits
:
torch
.
Tensor
,
# [num_tokens, vocab_size]
cu_num_draft_tokens
:
torch
.
Tensor
,
# [batch_size]
sampling_metadata
:
SamplingMetadata
,
spec_len
:
int
)
->
torch
.
Tensor
:
"""Compute probability distribution from logits based on sampling metadata.
This function applies temperature scaling to the logits and converts
them to probabilities using softmax. For greedy decoding, it returns
the original logits.
Args:
logits: Input logits tensor to be converted to probabilities.
cu_num_draft_tokens: Cumulative number of draft tokens.
sampling_metadata: Metadata containing sampling parameters such as
temperature and whether greedy sampling is used.
Returns:
torch.Tensor: Probability distribution (softmax of scaled logits)
if non-greedy sampling is used, otherwise returns the
original logits.
"""
assert
logits
.
ndim
==
2
assert
cu_num_draft_tokens
.
ndim
==
1
if
sampling_metadata
.
all_greedy
:
return
logits
# num_tokens = logits.shape[0]
temperature
=
sampling_metadata
.
temperature
.
view
(
-
1
,
1
).
repeat
(
1
,
spec_len
).
view
(
-
1
)
temperature
=
torch
.
where
(
temperature
>
0
,
temperature
,
1
)
# NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor.
logits
.
div_
(
temperature
.
unsqueeze
(
-
1
))
# Get expanded top_k and top_p tensors.
top_k
=
None
if
sampling_metadata
.
top_k
is
not
None
:
top_k
=
sampling_metadata
.
top_k
.
view
(
-
1
,
1
).
repeat
(
1
,
spec_len
).
view
(
-
1
)
top_p
=
None
if
sampling_metadata
.
top_p
is
not
None
:
top_p
=
sampling_metadata
.
top_p
.
view
(
-
1
,
1
).
repeat
(
1
,
spec_len
).
view
(
-
1
)
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
# which is slow for large vocab sizes. This may cause performance issues.
logits
=
apply_top_k_top_p
(
logits
,
top_k
,
top_p
)
output_prob
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
return
output_prob
\ No newline at end of file
vllm/v1/spec_decode/eagle.py
View file @
467490e6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
,
Optional
import
numpy
as
np
import
torch
import
torch.nn
as
nn
...
...
@@ -57,6 +59,9 @@ class EagleProposer:
self
.
use_cuda_graph
=
(
self
.
vllm_config
.
compilation_config
.
level
==
CompilationLevel
.
PIECEWISE
and
not
self
.
vllm_config
.
model_config
.
enforce_eager
)
self
.
use_full_cuda_graph
=
(
self
.
use_cuda_graph
and
vllm_config
.
compilation_config
.
full_cuda_graph
)
self
.
cudagraph_batch_sizes
=
list
(
reversed
(
self
.
vllm_config
.
compilation_config
.
cudagraph_capture_sizes
))
...
...
@@ -72,6 +77,8 @@ class EagleProposer:
(
self
.
max_num_tokens
,
self
.
hidden_size
),
dtype
=
self
.
dtype
,
device
=
device
)
# attention metadata captured in full cudagraph mode
self
.
attn_metadata_cudagraph
=
None
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
self
.
arange
=
torch
.
arange
(
vllm_config
.
scheduler_config
.
max_num_seqs
+
...
...
@@ -131,6 +138,38 @@ class EagleProposer:
# copy inputs to buffer for cudagraph
self
.
positions
[:
num_tokens
]
=
target_positions
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
if
(
self
.
use_full_cuda_graph
and
num_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
assert
self
.
attn_metadata_cudagraph
if
self
.
method
in
[
"eagle"
,
"eagle3"
]:
self
.
attn_metadata_cudagraph
.
seq_lens
[:
batch_size
]
=
(
attn_metadata
.
seq_lens
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
num_tokens
]
=
(
attn_metadata
.
slot_mapping
)
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
block_table
[:
batch_size
]
=
(
attn_metadata
.
block_table
)
elif
self
.
method
==
"deepseek_mtp"
:
self
.
attn_metadata_cudagraph
.
num_actual_tokens
=
(
attn_metadata
.
num_actual_tokens
)
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
num_tokens
]
=
(
attn_metadata
.
slot_mapping
)
self
.
attn_metadata_cudagraph
.
num_decodes
=
(
attn_metadata
.
num_decodes
)
self
.
attn_metadata_cudagraph
.
num_decode_tokens
=
(
attn_metadata
.
num_decode_tokens
)
self
.
attn_metadata_cudagraph
.
num_prefills
=
(
attn_metadata
.
num_prefills
)
if
attn_metadata
.
decode
is
not
None
:
self
.
attn_metadata_cudagraph
.
decode
.
block_table
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
block_table
)
self
.
attn_metadata_cudagraph
.
decode
.
seq_lens
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
seq_lens
)
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
...
...
@@ -147,11 +186,15 @@ class EagleProposer:
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_prob
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
draft_probs_list
=
[
draft_prob
]
# Early exit if there is only one draft token to be generated.
if
self
.
num_speculative_tokens
==
1
:
# [batch_size, 1]
return
draft_token_ids
.
view
(
-
1
,
1
)
return
draft_token_ids
.
view
(
-
1
,
1
)
,
draft_probs_list
# TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once
...
...
@@ -191,7 +234,7 @@ class EagleProposer:
seq_lens
=
(
seq_lens
+
1
),
)
for
_
in
range
(
self
.
num_speculative_tokens
-
1
):
for
i
in
range
(
self
.
num_speculative_tokens
-
1
):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
...
...
@@ -242,6 +285,43 @@ class EagleProposer:
self
.
input_ids
[:
batch_size
]
=
input_ids
self
.
positions
[:
batch_size
]
=
clamped_positions
self
.
hidden_states
[:
batch_size
]
=
hidden_states
if
(
self
.
use_full_cuda_graph
and
batch_size
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
assert
self
.
attn_metadata_cudagraph
if
self
.
method
in
[
"eagle"
,
"eagle3"
]:
self
.
attn_metadata_cudagraph
.
seq_lens
[:
batch_size
]
=
(
attn_metadata
.
seq_lens
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
batch_size
]
=
(
attn_metadata
.
slot_mapping
)
if
i
==
0
:
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
block_table
[:
batch_size
]
=
(
attn_metadata
.
block_table
)
elif
self
.
method
==
"deepseek_mtp"
:
self
.
attn_metadata_cudagraph
.
num_actual_tokens
=
(
attn_metadata
.
num_actual_tokens
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
slot_mapping
)
self
.
attn_metadata_cudagraph
.
num_decodes
=
(
attn_metadata
.
num_decodes
)
self
.
attn_metadata_cudagraph
.
num_decode_tokens
=
(
attn_metadata
.
num_decode_tokens
)
self
.
attn_metadata_cudagraph
.
num_prefills
=
(
attn_metadata
.
num_prefills
)
self
.
attn_metadata_cudagraph
.
decode
.
seq_lens
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
seq_lens
)
if
i
==
0
:
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
decode
.
block_table
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
block_table
)
# Run the model.
with
set_forward_context
(
per_layer_attn_metadata
,
...
...
@@ -265,10 +345,15 @@ class EagleProposer:
# TODO(wenlong): get more than one token for tree attention
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids_list
.
append
(
draft_token_ids
)
draft_prob
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
draft_probs_list
.
append
(
draft_prob
)
# [batch_size, num_speculative_tokens]
draft_token_ids
=
torch
.
stack
(
draft_token_ids_list
,
dim
=
1
)
return
draft_token_ids
draft_probs
=
torch
.
stack
(
draft_probs_list
,
dim
=
1
).
contiguous
()
return
draft_token_ids
,
draft_probs
def
prepare_inputs
(
self
,
...
...
@@ -418,8 +503,13 @@ class EagleProposer:
def
dummy_run
(
self
,
num_tokens
:
int
,
attn_metadata
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
None
:
with
set_forward_context
(
None
,
self
.
vllm_config
,
if
attn_metadata
is
not
None
and
self
.
attn_metadata_cudagraph
is
None
:
self
.
attn_metadata_cudagraph
=
attn_metadata
[
self
.
attn_layer_names
[
0
]]
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_tokens
):
self
.
model
(
self
.
input_ids
[:
num_tokens
],
...
...
vllm/v1/spec_decode/utils.py
View file @
467490e6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
msgspec
from
abc
import
ABC
import
torch
from
vllm.sampling_params
import
SamplingParams
_SAMPLING_EPS
=
1e-5
...
...
@@ -12,3 +16,41 @@ def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool:
or
sampling_params
.
repetition_penalty
!=
1.0
or
sampling_params
.
min_p
>
_SAMPLING_EPS
or
sampling_params
.
logprobs
is
not
None
)
class
DraftProbs
(
ABC
):
# type: ignore[call-arg]
"""Draft probs corresponding to in-progress sequences."""
# spec tokens probs.
draft_probs
:
torch
.
Tensor
# The request id list.
_req_ids
:
list
[
str
]
def
__init__
(
self
,
draft_probs
,
req_ids
):
assert
len
(
req_ids
)
==
len
(
draft_probs
)
self
.
draft_probs
=
draft_probs
self
.
_req_ids
=
req_ids
def
update
(
self
,
draft_probs
:
torch
.
Tensor
,
tmp_req_ids
:
list
[
str
]):
diff_req_ids
=
[
item
for
item
in
self
.
_req_ids
if
item
not
in
tmp_req_ids
]
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
diff_req_ids
]
self
.
_req_ids
=
diff_req_ids
self
.
draft_probs
=
self
.
draft_probs
[
index
]
self
.
draft_probs
=
torch
.
cat
([
self
.
draft_probs
,
draft_probs
])
self
.
_req_ids
.
extend
(
tmp_req_ids
)
assert
len
(
self
.
_req_ids
)
==
len
(
self
.
draft_probs
)
def
prune
(
self
,
req_ids
:
list
[
str
]):
new_req_ids
=
[
req_id
for
req_id
in
self
.
_req_ids
if
req_id
not
in
req_ids
]
if
new_req_ids
!=
self
.
_req_ids
:
# Batch contents changed - prune removed sequences.
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
new_req_ids
]
self
.
draft_probs
=
self
.
draft_probs
[
index
]
self
.
_req_ids
=
new_req_ids
def
get_probs
(
self
,
req_ids
:
list
[
str
]):
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
req_ids
]
return
self
.
draft_probs
[
index
]
vllm/v1/worker/gpu_model_runner.py
View file @
467490e6
...
...
@@ -60,11 +60,13 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.rejection_sampler
import
RejectionSampler
from
vllm.v1.sample.rejection_sampler_mtp
import
MtpRejectionSampler
from
vllm.v1.sample.sampler
import
Sampler
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.medusa
import
MedusaProposer
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.spec_decode.utils
import
DraftProbs
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
vllm.platforms
import
current_platform
...
...
@@ -194,7 +196,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else
:
raise
ValueError
(
"Unknown speculative decoding method: "
f
"
{
self
.
speculative_config
.
method
}
"
)
self
.
rejection_sampler
=
RejectionSampler
()
self
.
use_mtp
=
self
.
speculative_config
.
method
==
"deepseek_mtp"
if
not
self
.
use_mtp
:
self
.
rejection_sampler
=
RejectionSampler
()
else
:
self
.
rejection_sampler
=
MtpRejectionSampler
()
# Request states.
self
.
requests
:
dict
[
str
,
CachedRequestState
]
=
{}
...
...
@@ -320,6 +328,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# means this layer will perform attention using the keys and values
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
self
.
shared_kv_cache_layers
:
dict
[
str
,
str
]
=
{}
self
.
draft_probs
:
Optional
[
DraftProbs
]
=
None
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
"""
...
...
@@ -379,6 +389,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for
req_id
in
scheduler_output
.
finished_req_ids
:
self
.
requests
.
pop
(
req_id
,
None
)
self
.
encoder_cache
.
pop
(
req_id
,
None
)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
# scheduled_req_ids overlap. This happens when a request is aborted and
...
...
@@ -387,6 +398,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# and handling the second as a new request.
for
req_id
in
scheduler_output
.
finished_req_ids
:
self
.
input_batch
.
remove_request
(
req_id
)
# prune draft probs of finished requests
if
self
.
use_mtp
and
self
.
draft_probs
is
not
None
and
len
(
scheduler_output
.
finished_req_ids
)
>
0
:
self
.
draft_probs
.
prune
(
list
(
scheduler_output
.
finished_req_ids
))
# Free the cached encoder outputs.
for
req_id
,
input_id
in
scheduler_output
.
free_encoder_input_ids
:
...
...
@@ -1541,7 +1556,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_logits
=
logits
[
spec_decode_metadata
.
target_logits_indices
]
output_token_ids
=
self
.
rejection_sampler
(
spec_decode_metadata
,
None
,
# draft_probs
self
.
draft_probs
.
get_probs
(
self
.
input_batch
.
req_ids
)
\
if
self
.
draft_probs
is
not
None
else
None
,
# draft_probs
target_logits
,
bonus_token_ids
,
sampling_metadata
,
...
...
@@ -1627,7 +1643,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_token_ids
=
None
else
:
assert
spec_decode_common_attn_metadata
is
not
None
spec_token_ids
=
self
.
propose_draft_token_ids
(
spec_token_ids
,
draft_probs
=
self
.
propose_draft_token_ids
(
scheduler_output
,
valid_sampled_token_ids
,
sampling_metadata
,
...
...
@@ -1637,6 +1653,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_decode_metadata
,
spec_decode_common_attn_metadata
,
)
if
self
.
use_mtp
:
if
self
.
draft_probs
is
None
:
self
.
draft_probs
=
DraftProbs
(
draft_probs
,
self
.
input_batch
.
req_ids
)
else
:
self
.
draft_probs
.
update
(
draft_probs
,
self
.
input_batch
.
req_ids
)
spec_token_ids
=
spec_token_ids
.
tolist
()
self
.
eplb_step
()
...
...
@@ -1743,7 +1767,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
[
h
[
token_indices
]
for
h
in
aux_hidden_states
],
dim
=-
1
)
else
:
target_hidden_states
=
hidden_states
[
token_indices
]
draft
_token_ids
=
self
.
drafter
.
propose
(
spec
_token_ids
,
draft_probs
=
self
.
drafter
.
propose
(
target_token_ids
=
target_token_ids
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
...
...
@@ -1752,8 +1776,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
common_attn_metadata
=
common_attn_metadata
,
num_rejected_tokens
=
num_rejected_tokens
)
spec_token_ids
=
draft_token_ids
.
tolist
()
return
spec_token_ids
return
spec_token_ids
,
draft_probs
@
staticmethod
def
maybe_setup_kv_connector
(
scheduler_output
:
"SchedulerOutput"
):
...
...
@@ -2200,7 +2224,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
self
.
speculative_config
and
self
.
speculative_config
.
use_eagle
():
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
self
.
drafter
.
dummy_run
(
num_tokens
)
self
.
drafter
.
dummy_run
(
num_tokens
,
attn_metadata
)
# This is necessary to avoid blocking DP.
# For dummy runs, we typically skip EPLB since we don't have any real
...
...
@@ -2267,10 +2291,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
draft_token_ids
,
self
.
device
)
num_tokens
=
sum
(
len
(
ids
)
for
ids
in
draft_token_ids
)
# draft_probs = torch.randn(
# num_tokens, logits.shape[-1], device=self.device,
# dtype=logits.dtype)
draft_probs
=
None
draft_probs
=
torch
.
randn
(
num_tokens
,
logits
.
shape
[
-
1
],
device
=
self
.
device
,
dtype
=
logits
.
dtype
)
# draft_probs = None
target_logits
=
torch
.
randn
(
num_tokens
,
logits
.
shape
[
-
1
],
device
=
self
.
device
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment