Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ce1d4073
Commit
ce1d4073
authored
Feb 06, 2026
by
王敏
Browse files
[feat]支持宽松mtp
parent
6af85e40
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
595 additions
and
6 deletions
+595
-6
vllm/envs.py
vllm/envs.py
+6
-0
vllm/v1/sample/rejection_sampler_opt.py
vllm/v1/sample/rejection_sampler_opt.py
+442
-0
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+20
-0
vllm/v1/spec_decode/metadata.py
vllm/v1/spec_decode/metadata.py
+3
-0
vllm/v1/spec_decode/utils.py
vllm/v1/spec_decode/utils.py
+75
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+49
-6
No files found.
vllm/envs.py
View file @
ce1d4073
...
@@ -291,6 +291,7 @@ if TYPE_CHECKING:
...
@@ -291,6 +291,7 @@ if TYPE_CHECKING:
VLLM_USE_FUSED_RMS_ROPE
:
bool
=
False
VLLM_USE_FUSED_RMS_ROPE
:
bool
=
False
VLLM_USE_FUSED_FILL_RMS_CAT
:
bool
=
False
VLLM_USE_FUSED_FILL_RMS_CAT
:
bool
=
False
VLLM_W8A8_BACKEND
:
int
=
3
VLLM_W8A8_BACKEND
:
int
=
3
VLLM_REJECT_SAMPLE_OPT
:
bool
=
False
def
get_default_cache_root
():
def
get_default_cache_root
():
...
@@ -1836,6 +1837,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1836,6 +1837,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# blaslt: 3 (default)
# blaslt: 3 (default)
# rocblas: others
# rocblas: others
"VLLM_W8A8_BACKEND"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_W8A8_BACKEND"
,
"3"
)),
"VLLM_W8A8_BACKEND"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_W8A8_BACKEND"
,
"3"
)),
# vllm will use optimized reject sample
"VLLM_REJECT_SAMPLE_OPT"
:
lambda
:
(
os
.
getenv
(
'VLLM_REJECT_SAMPLE_OPT'
,
'True'
).
lower
()
in
(
"true"
,
"1"
)),
}
}
...
...
vllm/v1/sample/rejection_sampler_opt.py
0 → 100644
View file @
ce1d4073
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
from
collections.abc
import
Sequence
from
dataclasses
import
replace
import
torch
import
torch.nn
as
nn
from
vllm.logger
import
init_logger
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.outputs
import
LogprobsLists
,
LogprobsTensors
,
SamplerOutput
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.ops.bad_words
import
apply_bad_words_with_drafts
from
vllm.v1.sample.ops.penalties
import
apply_all_penalties
from
vllm.v1.sample.ops.topk_topp_sampler
import
apply_top_k_top_p
from
vllm.v1.sample.sampler
import
Sampler
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
logger
=
init_logger
(
__name__
)
PLACEHOLDER_TOKEN_ID
:
tl
.
constexpr
=
-
1
GREEDY_TEMPERATURE
:
tl
.
constexpr
=
0
# Maximum number of speculative draft tokens allowed per request in a single
# step. This value is chosen to be large enough to handle typical use cases.
MAX_SPEC_LEN
=
128
class
OptRejectionSampler
(
nn
.
Module
):
"""
The implementation strictly follows the algorithm described in
https://arxiv.org/abs/2211.17192.
However, we want to clarify the terminology used in the implementation:
accepted tokens: tokens that are accepted based on the relationship
between the "raw" draft and target probabilities.
recovered tokens: tokens that are sampled based on the adjusted probability
distribution, which is derived from both the draft and target
probabilities.
bonus tokens:
If all proposed tokens are accepted, the bonus token is added to the
end of the sequence. The bonus token is only sampled from the target
probabilities. We pass in the bonus tokens instead of sampling them
in the rejection sampler to allow for more flexibility in the
sampling process. For example, we can use top_p, top_k sampling for
bonus tokens, while spec decode does not support these sampling
strategies.
output tokens:
Tokens are finally generated with the rejection sampler.
output tokens = accepted tokens + recovered tokens + bonus tokens
"""
def
__init__
(
self
,
sampler
:
Sampler
):
super
().
__init__
()
self
.
sampler
=
sampler
logprobs_mode
=
self
.
sampler
.
logprobs_mode
self
.
is_processed_logprobs_mode
=
logprobs_mode
.
startswith
(
"processed"
)
self
.
is_logits_logprobs_mode
=
logprobs_mode
.
endswith
(
"logits"
)
def
forward
(
self
,
metadata
:
SpecDecodeMetadata
,
# [num_tokens, vocab_size]
draft_probs
:
torch
.
Tensor
|
None
,
# [num_tokens + batch_size, vocab_size]
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
"""
Args:
metadata:
Metadata for spec decoding.
draft_probs (Optional[torch.Tensor]):
Probability distribution for the draft tokens. Shape is
[num_tokens, vocab_size]. Can be None if probabilities are
not provided, which is the case for ngram spec decode.
logits (torch.Tensor):
Target model's logits probability distribution.
Shape is [num_tokens + batch_size, vocab_size]. Here,
probabilities from different requests are flattened into a
single tensor because this is the shape of the output logits.
NOTE: `logits` can be updated in place to save memory.
sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata):
Additional metadata needed for sampling, such as temperature,
top-k/top-p parameters, or other relevant information.
Returns:
SamplerOutput:
Contains the final output token IDs and their logprobs if
requested.
"""
assert
metadata
.
max_spec_len
<=
MAX_SPEC_LEN
bonus_logits_indices
=
metadata
.
bonus_logits_indices
target_logits_indices
=
metadata
.
target_logits_indices
# When indexing with a tensor (bonus_logits_indices), PyTorch
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert
logits
is
not
None
sampling_metadata
.
all_greedy
=
True
sampling_metadata
.
all_random
=
False
sampler_output
=
self
.
sampler
(
logits
=
logits
,
sampling_metadata
=
replace
(
sampling_metadata
,
max_num_logprobs
=-
1
,
),
predict_bonus_token
=
True
,
# Override the logprobs mode to return logits because they are
# needed later to compute the accepted token logprobs.
logprobs_mode_override
=
"processed_logits"
if
self
.
is_processed_logprobs_mode
else
"raw_logits"
,
)
target_logits
=
logits
[
target_logits_indices
]
target_tokens
=
sampler_output
.
sampled_token_ids
[
target_logits_indices
]
bonus_token_ids
=
sampler_output
.
sampled_token_ids
[
bonus_logits_indices
]
# Compute probability distribution from target logits.
target_probs
=
target_logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
output_token_ids
=
rejection_sample
(
metadata
.
draft_token_ids
,
metadata
.
num_draft_tokens
,
metadata
.
max_spec_len
,
metadata
.
cu_num_draft_tokens
,
draft_probs
,
target_probs
,
target_tokens
,
bonus_token_ids
,
sampling_metadata
,
)
logprobs_tensors
=
None
if
sampling_metadata
.
max_num_logprobs
is
not
None
:
logprobs_tensors
=
self
.
_get_logprobs_tensors
(
sampling_metadata
.
max_num_logprobs
,
metadata
,
sampler_output
.
logprobs_tensors
.
logprobs
,
output_token_ids
,
)
return
SamplerOutput
(
sampled_token_ids
=
output_token_ids
,
logprobs_tensors
=
logprobs_tensors
,
)
def
_get_logprobs_tensors
(
self
,
max_num_logprobs
:
int
,
metadata
:
SpecDecodeMetadata
,
logits
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
)
->
LogprobsTensors
:
cu_num_sampled_tokens
=
torch
.
zeros_like
(
metadata
.
cu_num_sampled_tokens
)
cu_num_sampled_tokens
[
1
:]
=
metadata
.
cu_num_sampled_tokens
[:
-
1
]
final_logits
=
logits
.
to
(
torch
.
float32
)
# NOTE: To avoid cpu-gpu synchronization, we now simply compute indices for
# all draft tokens, including the rejected ones. The rejected tokens will
# be filtered out in the `parse_output`.
logit_start_indices
=
cu_num_sampled_tokens
offsets
=
torch
.
arange
(
sampled_token_ids
.
shape
[
-
1
],
device
=
logit_start_indices
.
device
,
dtype
=
logit_start_indices
.
dtype
,
)
accepted_logit_indices
=
(
logit_start_indices
.
unsqueeze
(
1
)
+
offsets
.
unsqueeze
(
0
)
).
flatten
()
accepted_logit_indices
.
clamp_
(
max
=
final_logits
.
shape
[
0
]
-
1
)
accepted_tokens
=
sampled_token_ids
.
clone
().
flatten
()
# we replace rejected token ids with 0 to avoid gather_logprobs error
accepted_tokens
[
accepted_tokens
==
PLACEHOLDER_TOKEN_ID
]
=
0
# Compute logprobs for accepted tokens.
accepted_logits
=
final_logits
[
accepted_logit_indices
]
accepted_logprobs
=
(
accepted_logits
if
self
.
is_logits_logprobs_mode
else
self
.
sampler
.
compute_logprobs
(
accepted_logits
)
)
return
self
.
sampler
.
gather_logprobs
(
accepted_logprobs
,
max_num_logprobs
,
accepted_tokens
.
to
(
torch
.
int64
),
)
@
staticmethod
def
parse_output
(
output_token_ids
:
torch
.
Tensor
,
vocab_size
:
int
,
discard_req_indices
:
Sequence
[
int
]
=
(),
logprobs_tensors
:
LogprobsTensors
|
None
=
None
,
)
->
tuple
[
list
[
list
[
int
]],
LogprobsLists
|
None
]:
"""Parse the output of the rejection sampler.
Args:
output_token_ids: The sampled token IDs in shape
[batch_size, max_spec_len + 1]. The rejected tokens are
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
and will be filtered out in this function.
vocab_size: The size of the vocabulary.
discard_req_indices: Optional row indices to discard tokens in.
logprobs_tensors: Optional logprobs tensors to filter.
Returns:
A list of lists of token IDs.
"""
output_token_ids_np
=
output_token_ids
.
cpu
().
numpy
()
# Create mask for valid tokens.
valid_mask
=
(
output_token_ids_np
!=
PLACEHOLDER_TOKEN_ID
)
&
(
output_token_ids_np
<
vocab_size
)
output_logprobs
=
None
if
logprobs_tensors
is
not
None
:
cu_num_tokens
=
[
0
]
+
valid_mask
.
sum
(
axis
=
1
).
cumsum
().
tolist
()
filtered_tensors
=
logprobs_tensors
.
filter
(
valid_mask
.
flatten
())
output_logprobs
=
filtered_tensors
.
tolists
(
cu_num_tokens
)
if
len
(
discard_req_indices
)
>
0
:
valid_mask
[
discard_req_indices
]
=
False
outputs
=
[
row
[
valid_mask
[
i
]].
tolist
()
for
i
,
row
in
enumerate
(
output_token_ids_np
)
]
return
outputs
,
output_logprobs
def
apply_logits_processors
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
metadata
:
SpecDecodeMetadata
,
)
->
torch
.
Tensor
:
has_penalties
=
not
sampling_metadata
.
no_penalties
any_penalties_or_bad_words
=
(
sampling_metadata
.
bad_words_token_ids
or
has_penalties
)
output_token_ids
=
sampling_metadata
.
output_token_ids
if
any_penalties_or_bad_words
:
output_token_ids
=
self
.
_combine_outputs_with_spec_tokens
(
output_token_ids
,
sampling_metadata
.
spec_token_ids
,
)
# Calculate indices of target logits.
if
sampling_metadata
.
allowed_token_ids_mask
is
not
None
or
has_penalties
:
num_requests
=
len
(
sampling_metadata
.
output_token_ids
)
num_draft_tokens
=
torch
.
tensor
(
metadata
.
num_draft_tokens
,
device
=
"cpu"
)
original_indices
=
torch
.
arange
(
num_requests
,
device
=
"cpu"
)
repeat_indices_cpu
=
original_indices
.
repeat_interleave
(
num_draft_tokens
)
repeat_indices
=
repeat_indices_cpu
.
to
(
device
=
logits
.
device
,
non_blocking
=
True
)
logits
=
self
.
apply_penalties
(
logits
,
sampling_metadata
,
metadata
,
repeat_indices
,
output_token_ids
)
# Apply allowed token ids.
if
sampling_metadata
.
allowed_token_ids_mask
is
not
None
:
token_mask
=
sampling_metadata
.
allowed_token_ids_mask
[
repeat_indices
]
logits
.
masked_fill_
(
token_mask
,
float
(
"-inf"
))
# Apply bad words exclusion.
if
bad_words_token_ids
:
=
sampling_metadata
.
bad_words_token_ids
:
apply_bad_words_with_drafts
(
logits
,
bad_words_token_ids
,
output_token_ids
,
metadata
.
num_draft_tokens
)
return
logits
@
staticmethod
def
apply_penalties
(
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
metadata
:
SpecDecodeMetadata
,
repeat_indices
:
torch
.
Tensor
,
output_token_ids
:
list
[
list
[
int
]],
)
->
torch
.
Tensor
:
if
sampling_metadata
.
no_penalties
:
return
logits
assert
sampling_metadata
.
prompt_token_ids
is
not
None
prompt_token_ids
=
sampling_metadata
.
prompt_token_ids
[
repeat_indices
]
presence_penalties
=
sampling_metadata
.
presence_penalties
[
repeat_indices
]
frequency_penalties
=
sampling_metadata
.
frequency_penalties
[
repeat_indices
]
repetition_penalties
=
sampling_metadata
.
repetition_penalties
[
repeat_indices
]
logits
=
apply_all_penalties
(
logits
,
prompt_token_ids
,
presence_penalties
,
frequency_penalties
,
repetition_penalties
,
output_token_ids
,
)
return
logits
@
staticmethod
def
_combine_outputs_with_spec_tokens
(
output_token_ids
:
list
[
list
[
int
]],
spec_token_ids
:
list
[
list
[
int
]]
|
None
=
None
,
)
->
list
[
list
[
int
]]:
if
spec_token_ids
is
None
:
return
output_token_ids
result
=
[]
for
out
,
spec
in
zip
(
output_token_ids
,
spec_token_ids
):
if
len
(
spec
)
==
0
:
continue
result
.
append
(
out
)
for
i
in
range
(
len
(
spec
)
-
1
):
result
.
append
([
*
result
[
-
1
],
spec
[
i
]])
return
result
def
rejection_sample
(
# [num_tokens]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size]
num_draft_tokens
:
list
[
int
],
max_spec_len
:
int
,
# [batch_size]
cu_num_draft_tokens
:
torch
.
Tensor
,
# [num_tokens, vocab_size]
draft_probs
:
Optional
[
torch
.
Tensor
],
# [num_tokens, vocab_size]
target_probs
:
torch
.
Tensor
,
# [num_tokens, vocab_size]
target_tokens
,
# [batch_size, 1]
bonus_token_ids
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
assert
draft_token_ids
.
ndim
==
1
assert
draft_probs
is
None
or
draft_probs
.
ndim
==
3
assert
cu_num_draft_tokens
.
ndim
==
1
assert
target_probs
.
ndim
==
2
batch_size
=
len
(
num_draft_tokens
)
num_tokens
=
draft_token_ids
.
shape
[
0
]
vocab_size
=
target_probs
.
shape
[
-
1
]
device
=
target_probs
.
device
assert
draft_token_ids
.
is_contiguous
()
assert
draft_probs
is
None
or
draft_probs
.
is_contiguous
()
assert
target_probs
.
is_contiguous
()
assert
bonus_token_ids
.
is_contiguous
()
assert
target_probs
.
shape
==
(
num_tokens
,
vocab_size
)
# Create output buffer.
output_token_ids
=
torch
.
full
(
(
batch_size
,
max_spec_len
+
1
),
dtype
=
torch
.
int32
,
# Consistent with SamplerOutput.sampled_token_ids.
fill_value
=
PLACEHOLDER_TOKEN_ID
,
device
=
device
,
)
uniform_probs
=
torch
.
rand
(
(
num_tokens
,
),
dtype
=
torch
.
float32
,
device
=
device
,
)
uniform_probs
=
uniform_probs
*
0.1
+
0.1
# Rejection sampling for random sampling requests.
rejection_random_sample_kernel
[(
batch_size
,
)](
output_token_ids
,
cu_num_draft_tokens
,
draft_token_ids
,
draft_probs
,
target_probs
,
target_tokens
,
bonus_token_ids
,
uniform_probs
,
max_spec_len
,
vocab_size
,
NO_DRAFT_PROBS
=
draft_probs
is
None
,
num_warps
=
1
,
)
return
output_token_ids
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@
triton
.
jit
(
do_not_specialize
=
[
"max_spec_len"
])
def
rejection_random_sample_kernel
(
output_token_ids_ptr
,
# [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr
,
# [batch_size]
draft_token_ids_ptr
,
# [num_tokens]
draft_probs_ptr
,
# [num_tokens, vocab_size] or None
target_probs_ptr
,
# [num_tokens, vocab_size]
target_token_ids_ptr
,
# [num_tokens, vocab_size]
bonus_token_ids_ptr
,
# [batch_size]
uniform_probs_ptr
,
# [num_tokens]
max_spec_len
,
vocab_size
,
NO_DRAFT_PROBS
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
if
req_idx
==
0
:
start_idx
=
0
else
:
start_idx
=
tl
.
load
(
cu_num_draft_tokens_ptr
+
req_idx
-
1
)
end_idx
=
tl
.
load
(
cu_num_draft_tokens_ptr
+
req_idx
)
num_draft_tokens
=
end_idx
-
start_idx
rejected
=
False
for
pos
in
range
(
num_draft_tokens
):
if
not
rejected
:
draft_token_id
=
tl
.
load
(
draft_token_ids_ptr
+
start_idx
+
pos
)
if
draft_token_id
<
0
:
draft_token_id
=
0
if
NO_DRAFT_PROBS
:
draft_prob
=
1
else
:
draft_prob
=
tl
.
load
(
draft_probs_ptr
+
(
start_idx
+
pos
)
*
vocab_size
+
draft_token_id
)
target_prob
=
tl
.
load
(
target_probs_ptr
+
(
start_idx
+
pos
)
*
vocab_size
+
draft_token_id
)
draft_token_id
=
draft_token_id
.
to
(
tl
.
int64
)
target_token_id
=
tl
.
load
(
target_token_ids_ptr
+
(
start_idx
+
pos
))
target_token_id
=
target_token_id
.
to
(
tl
.
int64
)
uniform_prob
=
tl
.
load
(
uniform_probs_ptr
+
start_idx
+
pos
)
# NOTE(woosuk): While the draft probability should never be 0,
# we check it to avoid NaNs. If it happens to be 0, we reject.
if
(
draft_token_id
==
target_token_id
)
or
(
target_prob
/
draft_prob
>=
uniform_prob
and
draft_prob
>
0
):
token_id
=
draft_token_id
else
:
rejected
=
True
token_id
=
target_token_id
tl
.
store
(
output_token_ids_ptr
+
req_idx
*
(
max_spec_len
+
1
)
+
pos
,
token_id
)
if
not
rejected
:
# If all tokens are accepted, append the bonus token.
bonus_token_id
=
tl
.
load
(
bonus_token_ids_ptr
+
req_idx
)
tl
.
store
(
output_token_ids_ptr
+
req_idx
*
(
max_spec_len
+
1
)
+
num_draft_tokens
,
bonus_token_id
)
vllm/v1/spec_decode/eagle.py
View file @
ce1d4073
...
@@ -8,6 +8,7 @@ import numpy as np
...
@@ -8,6 +8,7 @@ import numpy as np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
vllm.envs
as
envs
from
vllm.config
import
(
from
vllm.config
import
(
CUDAGraphMode
,
CUDAGraphMode
,
VllmConfig
,
VllmConfig
,
...
@@ -397,9 +398,16 @@ class SpecDecodeBaseProposer:
...
@@ -397,9 +398,16 @@ class SpecDecodeBaseProposer:
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
if
envs
.
VLLM_REJECT_SAMPLE_OPT
:
draft_prob
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
# Early exit if there is only one draft token to be generated.
# Early exit if there is only one draft token to be generated.
if
self
.
num_speculative_tokens
==
1
:
if
self
.
num_speculative_tokens
==
1
:
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
if
envs
.
VLLM_REJECT_SAMPLE_OPT
:
return
draft_token_ids
.
view
(
-
1
,
1
),
draft_prob
.
view
(
-
1
,
1
,
logits
.
shape
[
-
1
])
return
draft_token_ids
.
view
(
-
1
,
1
)
return
draft_token_ids
.
view
(
-
1
,
1
)
if
self
.
uses_mrope
:
if
self
.
uses_mrope
:
...
@@ -472,6 +480,9 @@ class SpecDecodeBaseProposer:
...
@@ -472,6 +480,9 @@ class SpecDecodeBaseProposer:
common_attn_metadata
.
_seq_lens_cpu
=
None
common_attn_metadata
.
_seq_lens_cpu
=
None
common_attn_metadata
.
_num_computed_tokens_cpu
=
None
common_attn_metadata
.
_num_computed_tokens_cpu
=
None
if
envs
.
VLLM_REJECT_SAMPLE_OPT
:
draft_probs_list
=
[
draft_prob
]
for
token_index
in
range
(
self
.
num_speculative_tokens
-
1
):
for
token_index
in
range
(
self
.
num_speculative_tokens
-
1
):
# Update the inputs.
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# cast to int32 is crucial when eagle model is compiled.
...
@@ -598,8 +609,17 @@ class SpecDecodeBaseProposer:
...
@@ -598,8 +609,17 @@ class SpecDecodeBaseProposer:
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids_list
.
append
(
draft_token_ids
)
draft_token_ids_list
.
append
(
draft_token_ids
)
if
envs
.
VLLM_REJECT_SAMPLE_OPT
:
draft_prob
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
draft_probs_list
.
append
(
draft_prob
)
# [batch_size, num_speculative_tokens]
# [batch_size, num_speculative_tokens]
draft_token_ids
=
torch
.
stack
(
draft_token_ids_list
,
dim
=
1
)
draft_token_ids
=
torch
.
stack
(
draft_token_ids_list
,
dim
=
1
)
if
envs
.
VLLM_REJECT_SAMPLE_OPT
:
draft_probs
=
torch
.
stack
(
draft_probs_list
,
dim
=
1
).
contiguous
()
return
draft_token_ids
,
draft_probs
return
draft_token_ids
return
draft_token_ids
def
set_inputs_first_pass
(
def
set_inputs_first_pass
(
...
...
vllm/v1/spec_decode/metadata.py
View file @
ce1d4073
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -22,6 +23,8 @@ class SpecDecodeMetadata:
...
@@ -22,6 +23,8 @@ class SpecDecodeMetadata:
bonus_logits_indices
:
torch
.
Tensor
bonus_logits_indices
:
torch
.
Tensor
# [num_tokens + batch_size]
# [num_tokens + batch_size]
logits_indices
:
torch
.
Tensor
logits_indices
:
torch
.
Tensor
# [batch_size]
spec_decode_ids
:
Optional
[
list
[
str
]]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
self
.
max_spec_len
=
max
(
self
.
num_draft_tokens
)
self
.
max_spec_len
=
max
(
self
.
num_draft_tokens
)
...
...
vllm/v1/spec_decode/utils.py
View file @
ce1d4073
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.torch_utils
import
async_tensor_h2d
@
triton
.
jit
@
triton
.
jit
...
@@ -107,3 +111,74 @@ def eagle_prepare_next_token_padded_kernel(
...
@@ -107,3 +111,74 @@ def eagle_prepare_next_token_padded_kernel(
tl
.
store
(
next_token_ids_ptr
+
req_idx
,
backup_token
)
tl
.
store
(
next_token_ids_ptr
+
req_idx
,
backup_token
)
tl
.
store
(
valid_sampled_tokens_count_ptr
+
req_idx
,
valid_count
)
tl
.
store
(
valid_sampled_tokens_count_ptr
+
req_idx
,
valid_count
)
class
DraftProbs
(
ABC
):
# type: ignore[call-arg]
"""Draft probs corresponding to in-progress sequences."""
# spec tokens probs.
draft_probs
:
torch
.
Tensor
# The request id list.
_req_ids
:
list
[
str
]
=
[]
count
=
0
req_id_to_count
:
dict
[
str
,
int
]
=
{}
prune_threshould
=
100
def
__init__
(
self
,
draft_probs
,
req_ids
):
assert
len
(
req_ids
)
==
len
(
draft_probs
)
self
.
draft_probs
=
draft_probs
self
.
_req_ids
=
req_ids
for
req_id
in
req_ids
:
self
.
req_id_to_count
[
req_id
]
=
self
.
count
def
update
(
self
,
draft_probs
:
torch
.
Tensor
,
tmp_req_ids
:
list
[
str
]):
self
.
count
+=
1
diff_req_ids
=
[
item
for
item
in
self
.
_req_ids
if
item
not
in
tmp_req_ids
]
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
diff_req_ids
]
index_tensor
=
async_tensor_h2d
(
index
,
dtype
=
torch
.
int32
,
target_device
=
self
.
draft_probs
.
device
,
pin_memory
=
True
)
self
.
draft_probs
=
self
.
draft_probs
[
index_tensor
]
self
.
draft_probs
=
torch
.
cat
([
self
.
draft_probs
,
draft_probs
])
self
.
_req_ids
=
diff_req_ids
self
.
_req_ids
.
extend
(
tmp_req_ids
)
for
req_id
in
tmp_req_ids
:
self
.
req_id_to_count
[
req_id
]
=
self
.
count
assert
len
(
self
.
_req_ids
)
==
len
(
self
.
draft_probs
)
def
prune
(
self
,
req_ids
:
list
[
str
]):
if
self
.
count
%
self
.
prune_threshould
==
0
:
for
req_id
,
last_count
in
self
.
req_id_to_count
.
items
():
if
self
.
count
-
last_count
>=
self
.
prune_threshould
:
req_ids
.
append
(
req_id
)
self
.
req_id_to_count
=
{
k
:
v
for
k
,
v
in
self
.
req_id_to_count
.
items
()
if
k
not
in
req_ids
}
new_req_ids
=
[
req_id
for
req_id
in
self
.
_req_ids
if
req_id
not
in
req_ids
]
if
new_req_ids
!=
self
.
_req_ids
:
# Batch contents changed - prune removed sequences.
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
new_req_ids
]
index_tensor
=
async_tensor_h2d
(
index
,
dtype
=
torch
.
int32
,
target_device
=
self
.
draft_probs
.
device
,
pin_memory
=
True
)
self
.
draft_probs
=
self
.
draft_probs
[
index_tensor
]
self
.
_req_ids
=
new_req_ids
def
get_probs
(
self
,
req_ids
:
list
[
str
]):
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
req_ids
]
index_tensor
=
async_tensor_h2d
(
index
,
dtype
=
torch
.
int32
,
target_device
=
self
.
draft_probs
.
device
,
pin_memory
=
True
)
return
self
.
draft_probs
[
index_tensor
]
vllm/v1/worker/gpu_model_runner.py
View file @
ce1d4073
...
@@ -12,7 +12,7 @@ from contextlib import contextmanager
...
@@ -12,7 +12,7 @@ from contextlib import contextmanager
from
copy
import
copy
,
deepcopy
from
copy
import
copy
,
deepcopy
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
reduce
from
functools
import
reduce
from
typing
import
TYPE_CHECKING
,
Any
,
NamedTuple
,
TypeAlias
,
cast
from
typing
import
TYPE_CHECKING
,
Any
,
NamedTuple
,
TypeAlias
,
cast
,
Optional
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -145,6 +145,7 @@ from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
...
@@ -145,6 +145,7 @@ from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
from
vllm.v1.sample.logits_processor.interface
import
LogitsProcessor
from
vllm.v1.sample.logits_processor.interface
import
LogitsProcessor
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.rejection_sampler
import
RejectionSampler
from
vllm.v1.sample.rejection_sampler
import
RejectionSampler
from
vllm.v1.sample.rejection_sampler_opt
import
OptRejectionSampler
from
vllm.v1.sample.sampler
import
Sampler
from
vllm.v1.sample.sampler
import
Sampler
from
vllm.v1.spec_decode.draft_model
import
DraftModelProposer
from
vllm.v1.spec_decode.draft_model
import
DraftModelProposer
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.eagle
import
EagleProposer
...
@@ -181,6 +182,7 @@ from .utils import (
...
@@ -181,6 +182,7 @@ from .utils import (
bind_kv_cache
,
bind_kv_cache
,
sanity_check_mm_encoder_outputs
,
sanity_check_mm_encoder_outputs
,
)
)
from
vllm.v1.spec_decode.utils
import
DraftProbs
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
...
@@ -470,7 +472,11 @@ class GPUModelRunner(
...
@@ -470,7 +472,11 @@ class GPUModelRunner(
"Unknown speculative decoding method: "
"Unknown speculative decoding method: "
f
"
{
self
.
speculative_config
.
method
}
"
f
"
{
self
.
speculative_config
.
method
}
"
)
)
self
.
rejection_sampler
=
RejectionSampler
(
self
.
sampler
)
if
not
envs
.
VLLM_REJECT_SAMPLE_OPT
:
self
.
rejection_sampler
=
RejectionSampler
(
self
.
sampler
)
else
:
self
.
rejection_sampler
=
OptRejectionSampler
(
self
.
sampler
)
self
.
num_spec_tokens
=
0
self
.
num_spec_tokens
=
0
if
self
.
speculative_config
:
if
self
.
speculative_config
:
...
@@ -702,6 +708,8 @@ class GPUModelRunner(
...
@@ -702,6 +708,8 @@ class GPUModelRunner(
self
.
mamba_state_idx
:
dict
[
str
,
int
]
=
{}
self
.
mamba_state_idx
:
dict
[
str
,
int
]
=
{}
self
.
layerwise_nvtx_hooks_registered
=
False
self
.
layerwise_nvtx_hooks_registered
=
False
self
.
draft_probs
:
Optional
[
DraftProbs
]
=
None
def
update_max_model_len
(
self
,
max_model_len
:
int
)
->
None
:
def
update_max_model_len
(
self
,
max_model_len
:
int
)
->
None
:
self
.
max_model_len
=
max_model_len
self
.
max_model_len
=
max_model_len
if
self
.
speculative_config
:
if
self
.
speculative_config
:
...
@@ -874,6 +882,10 @@ class GPUModelRunner(
...
@@ -874,6 +882,10 @@ class GPUModelRunner(
for
req_id
in
scheduler_output
.
finished_req_ids
:
for
req_id
in
scheduler_output
.
finished_req_ids
:
self
.
input_batch
.
remove_request
(
req_id
)
self
.
input_batch
.
remove_request
(
req_id
)
# prune draft probs of finished requests
if
envs
.
VLLM_REJECT_SAMPLE_OPT
and
self
.
draft_probs
is
not
None
and
len
(
scheduler_output
.
finished_req_ids
)
>
0
:
self
.
draft_probs
.
prune
(
list
(
scheduler_output
.
finished_req_ids
))
# Free the cached encoder outputs.
# Free the cached encoder outputs.
for
mm_hash
in
scheduler_output
.
free_encoder_mm_hashes
:
for
mm_hash
in
scheduler_output
.
free_encoder_mm_hashes
:
self
.
encoder_cache
.
pop
(
mm_hash
,
None
)
self
.
encoder_cache
.
pop
(
mm_hash
,
None
)
...
@@ -1616,8 +1628,13 @@ class GPUModelRunner(
...
@@ -1616,8 +1628,13 @@ class GPUModelRunner(
>=
self
.
input_batch
.
num_prompt_tokens
[
req_idx
]
>=
self
.
input_batch
.
num_prompt_tokens
[
req_idx
]
):
):
num_decode_draft_tokens
[
req_idx
]
=
len
(
draft_token_ids
)
num_decode_draft_tokens
[
req_idx
]
=
len
(
draft_token_ids
)
spec_decode_ids
=
None
if
envs
.
VLLM_REJECT_SAMPLE_OPT
:
spec_decode_ids
=
scheduler_output
.
scheduled_spec_decode_tokens
.
keys
()
spec_decode_metadata
=
self
.
_calc_spec_decode_metadata
(
spec_decode_metadata
=
self
.
_calc_spec_decode_metadata
(
num_draft_tokens
,
cu_num_tokens
num_draft_tokens
,
cu_num_tokens
,
spec_decode_ids
)
)
logits_indices
=
spec_decode_metadata
.
logits_indices
logits_indices
=
spec_decode_metadata
.
logits_indices
num_sampled_tokens
=
num_draft_tokens
+
1
num_sampled_tokens
=
num_draft_tokens
+
1
...
@@ -2118,6 +2135,7 @@ class GPUModelRunner(
...
@@ -2118,6 +2135,7 @@ class GPUModelRunner(
self
,
self
,
num_draft_tokens
:
np
.
ndarray
,
num_draft_tokens
:
np
.
ndarray
,
cu_num_scheduled_tokens
:
np
.
ndarray
,
cu_num_scheduled_tokens
:
np
.
ndarray
,
spec_decode_ids
:
Optional
[
list
[
str
]]
=
None
)
->
SpecDecodeMetadata
:
)
->
SpecDecodeMetadata
:
# Inputs:
# Inputs:
# cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
# cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
...
@@ -2191,6 +2209,7 @@ class GPUModelRunner(
...
@@ -2191,6 +2209,7 @@ class GPUModelRunner(
target_logits_indices
=
target_logits_indices
,
target_logits_indices
=
target_logits_indices
,
bonus_logits_indices
=
bonus_logits_indices
,
bonus_logits_indices
=
bonus_logits_indices
,
logits_indices
=
logits_indices
,
logits_indices
=
logits_indices
,
spec_decode_ids
=
spec_decode_ids
,
)
)
def
_prepare_kv_sharing_fast_prefill
(
def
_prepare_kv_sharing_fast_prefill
(
...
@@ -2838,7 +2857,8 @@ class GPUModelRunner(
...
@@ -2838,7 +2857,8 @@ class GPUModelRunner(
sampler_output
=
self
.
rejection_sampler
(
sampler_output
=
self
.
rejection_sampler
(
spec_decode_metadata
,
spec_decode_metadata
,
None
,
# draft_probs
None
if
self
.
draft_probs
is
None
else
\
self
.
draft_probs
.
get_probs
(
spec_decode_metadata
.
spec_decode_ids
),
# draft_probs
logits
,
logits
,
sampling_metadata
,
sampling_metadata
,
)
)
...
@@ -3999,7 +4019,7 @@ class GPUModelRunner(
...
@@ -3999,7 +4019,7 @@ class GPUModelRunner(
else
:
else
:
mm_embed_inputs
=
None
mm_embed_inputs
=
None
draft_
token_ids
=
self
.
drafter
.
propose
(
draft_
result
=
self
.
drafter
.
propose
(
target_token_ids
=
target_token_ids
,
target_token_ids
=
target_token_ids
,
target_positions
=
target_positions
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
target_hidden_states
=
target_hidden_states
,
...
@@ -4012,6 +4032,19 @@ class GPUModelRunner(
...
@@ -4012,6 +4032,19 @@ class GPUModelRunner(
slot_mappings
=
slot_mappings
,
slot_mappings
=
slot_mappings
,
)
)
if
not
envs
.
VLLM_REJECT_SAMPLE_OPT
:
draft_token_ids
=
draft_result
else
:
draft_token_ids
,
draft_probs
=
draft_result
if
envs
.
VLLM_REJECT_SAMPLE_OPT
:
draft_req_ids
=
list
(
scheduler_output
.
num_scheduled_tokens
.
keys
())
if
self
.
draft_probs
is
None
:
self
.
draft_probs
=
DraftProbs
(
draft_probs
,
draft_req_ids
)
else
:
self
.
draft_probs
.
update
(
draft_probs
,
draft_req_ids
)
return
draft_token_ids
return
draft_token_ids
def
update_config
(
self
,
overrides
:
dict
[
str
,
Any
])
->
None
:
def
update_config
(
self
,
overrides
:
dict
[
str
,
Any
])
->
None
:
...
@@ -4651,6 +4684,9 @@ class GPUModelRunner(
...
@@ -4651,6 +4684,9 @@ class GPUModelRunner(
inputs_embeds
=
self
.
inputs_embeds
.
gpu
[:
num_tokens_padded
]
inputs_embeds
=
self
.
inputs_embeds
.
gpu
[:
num_tokens_padded
]
model_kwargs
=
self
.
_init_model_kwargs
()
model_kwargs
=
self
.
_init_model_kwargs
()
else
:
else
:
self
.
input_ids
.
gpu
[:
num_tokens_padded
]
=
torch
.
randint
(
0
,
self
.
model_config
.
get_vocab_size
(),
(
num_tokens_padded
,),
dtype
=
torch
.
int32
)
input_ids
=
self
.
input_ids
.
gpu
[:
num_tokens_padded
]
input_ids
=
self
.
input_ids
.
gpu
[:
num_tokens_padded
]
inputs_embeds
=
None
inputs_embeds
=
None
...
@@ -4836,7 +4872,14 @@ class GPUModelRunner(
...
@@ -4836,7 +4872,14 @@ class GPUModelRunner(
# draft_probs = torch.randn(
# draft_probs = torch.randn(
# num_tokens, logits.shape[-1], device=self.device,
# num_tokens, logits.shape[-1], device=self.device,
# dtype=logits.dtype)
# dtype=logits.dtype)
draft_probs
=
None
if
not
envs
.
VLLM_REJECT_SAMPLE_OPT
:
draft_probs
=
None
else
:
draft_probs
=
torch
.
randn
(
num_reqs
,
self
.
speculative_config
.
num_speculative_tokens
,
logits
.
shape
[
-
1
],
device
=
self
.
device
,
dtype
=
logits
.
dtype
)
logits
=
torch
.
randn
(
logits
=
torch
.
randn
(
num_tokens
+
num_reqs
,
num_tokens
+
num_reqs
,
logits
.
shape
[
-
1
],
logits
.
shape
[
-
1
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment