Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ade7db0c
Commit
ade7db0c
authored
Nov 26, 2025
by
zhuwenwen
Browse files
Merge branch 'v0.9.2-dev-wm-1126' into 'v0.9.2-dev'
[feat]支持宽松mtp See merge request dcutoolkit/deeplearing/vllm!269
parents
9aadeed6
b9bc84e2
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
509 additions
and
56 deletions
+509
-56
vllm/envs.py
vllm/envs.py
+7
-1
vllm/v1/sample/rejection_sampler_opt.py
vllm/v1/sample/rejection_sampler_opt.py
+257
-0
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+14
-0
vllm/v1/spec_decode/metadata.py
vllm/v1/spec_decode/metadata.py
+3
-0
vllm/v1/spec_decode/utils.py
vllm/v1/spec_decode/utils.py
+43
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+170
-55
vllm/zero_overhead/v1/eagle.py
vllm/zero_overhead/v1/eagle.py
+15
-0
No files found.
vllm/envs.py
View file @
ade7db0c
...
...
@@ -185,6 +185,7 @@ if TYPE_CHECKING:
VLLM_USE_ZERO_MTP
:
bool
=
False
VLLM_USE_CUDA_GRAPH_SIZES
:
bool
=
False
VLLM_USE_CAT_MLA
:
bool
=
False
VLLM_REJECT_SAMPLE_OPT
:
bool
=
False
def
get_default_cache_root
():
return
os
.
getenv
(
...
...
@@ -1200,6 +1201,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_CAT_MLA"
:
lambda
:
(
os
.
getenv
(
'VLLM_USE_CAT_MLA'
,
'False'
).
lower
()
in
(
"true"
,
"1"
)),
# vllm will use fused cat and mla
"VLLM_REJECT_SAMPLE_OPT"
:
lambda
:
(
os
.
getenv
(
'VLLM_REJECT_SAMPLE_OPT'
,
'False'
).
lower
()
in
(
"true"
,
"1"
)),
}
# --8<-- [end:env-vars-definition]
...
...
vllm/v1/sample/rejection_sampler_opt.py
0 → 100644
View file @
ade7db0c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
vllm.logger
import
init_logger
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.ops.topk_topp_sampler
import
apply_top_k_top_p
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
logger
=
init_logger
(
__name__
)
PLACEHOLDER_TOKEN_ID
:
tl
.
constexpr
=
-
1
GREEDY_TEMPERATURE
:
tl
.
constexpr
=
-
1
# Maximum number of speculative draft tokens allowed per request in a single
# step. This value is chosen to be large enough to handle typical use cases.
MAX_SPEC_LEN
=
32
class
OptRejectionSampler
(
nn
.
Module
):
"""
The implementation strictly follows the algorithm described in
https://arxiv.org/abs/2211.17192.
However, we want to clarify the terminology used in the implementation:
accepted tokens: tokens that are accepted based on the relationship
between the "raw" draft and target probabilities.
recovered tokens: tokens that are sampled based on the adjusted probability
distribution, which is derived from both the draft and target
probabilities.
bonus tokens:
If all proposed tokens are accepted, the bonus token is added to the
end of the sequence. The bonus token is only sampled from the target
probabilities. We pass in the bonus tokens instead of sampling them
in the rejection sampler to allow for more flexibility in the
sampling process. For example, we can use top_p, top_k sampling for
bonus tokens, while spec decode does not support these sampling
strategies.
output tokens:
Tokens are finally generated with the rejection sampler.
output tokens = accepted tokens + recovered tokens + bonus tokens
"""
def
forward
(
self
,
metadata
:
SpecDecodeMetadata
,
# [num_tokens, vocab_size]
draft_probs
:
Optional
[
torch
.
Tensor
],
# [num_tokens, vocab_size]
target_logits
:
torch
.
Tensor
,
# [num_tokens, vocab_size]
target_tokens
:
torch
.
Tensor
,
# [batch_size, 1]
bonus_token_ids
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
'''
Args:
metadata:
Metadata for spec decoding.
draft_probs (Optional[torch.Tensor]):
Probability distribution for the draft tokens. Shape is
[num_tokens, vocab_size]. Can be None if probabilities are
not provided, which is the case for ngram spec decode.
target_logits (torch.Tensor):
Target model's logits probability distribution.
Shape is [num_tokens, vocab_size]. Here, probabilities from
different requests are flattened into a single tensor because
this is the shape of the output logits.
NOTE: `target_logits` can be updated in place to save memory.
bonus_token_ids_tensor (torch.Tensor):
A tensor containing bonus tokens. Shape is [batch_size, 1].
Bonus tokens are added to the end of the sequence if all
proposed tokens are accepted. We generate the bonus tokens
outside of the rejection sampler with the default sampling
strategy. It allows for more flexibility in the sampling
process such as top_p, top_k sampling.
sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata):
Additional metadata needed for sampling, such as temperature,
top-k/top-p parameters, or other relevant information.
Returns:
output_token_ids (torch.Tensor):
A tensor containing the final output token IDs.
'''
assert
metadata
.
max_spec_len
<=
MAX_SPEC_LEN
target_probs
=
target_logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
draft_token_ids
=
metadata
.
draft_token_ids
mask
=
draft_token_ids
.
eq
(
-
1
).
to
(
torch
.
bool
)
draft_token_ids
=
torch
.
where
(
mask
,
0
,
draft_token_ids
).
to
(
torch
.
long
)
# 兼容第一次decode
output_token_ids
=
rejection_sample
(
draft_token_ids
,
metadata
.
num_draft_tokens
,
metadata
.
max_spec_len
,
metadata
.
cu_num_draft_tokens
,
draft_probs
,
target_probs
,
target_tokens
,
bonus_token_ids
,
sampling_metadata
,
)
return
output_token_ids
@
staticmethod
def
parse_output
(
output_token_ids
:
torch
.
Tensor
,
vocab_size
:
int
,
)
->
list
[
list
[
int
]]:
"""Parse the output of the rejection sampler.
Args:
output_token_ids: The sampled token IDs in shape
[batch_size, max_spec_len + 1]. The rejected tokens are
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
and will be filtered out in this function.
vocab_size: The size of the vocabulary.
Returns:
A list of lists of token IDs.
"""
output_token_ids_np
=
output_token_ids
.
cpu
().
numpy
()
# Create mask for valid tokens.
valid_mask
=
((
output_token_ids_np
!=
PLACEHOLDER_TOKEN_ID
)
&
(
output_token_ids_np
<
vocab_size
))
outputs
=
[
row
[
valid_mask
[
i
]].
tolist
()
for
i
,
row
in
enumerate
(
output_token_ids_np
)
]
return
outputs
def
rejection_sample
(
# [num_tokens]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size]
num_draft_tokens
:
list
[
int
],
max_spec_len
:
int
,
# [batch_size]
cu_num_draft_tokens
:
torch
.
Tensor
,
# [num_tokens, vocab_size]
draft_probs
:
Optional
[
torch
.
Tensor
],
# [num_tokens, vocab_size]
target_probs
:
torch
.
Tensor
,
# [num_tokens, vocab_size]
target_tokens
,
# [batch_size, 1]
bonus_token_ids
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
assert
draft_token_ids
.
ndim
==
1
assert
draft_probs
is
None
or
draft_probs
.
ndim
==
3
assert
cu_num_draft_tokens
.
ndim
==
1
assert
target_probs
.
ndim
==
2
batch_size
=
len
(
num_draft_tokens
)
num_tokens
=
draft_token_ids
.
shape
[
0
]
vocab_size
=
target_probs
.
shape
[
-
1
]
device
=
target_probs
.
device
assert
draft_token_ids
.
is_contiguous
()
assert
draft_probs
is
None
or
draft_probs
.
is_contiguous
()
assert
target_probs
.
is_contiguous
()
assert
bonus_token_ids
.
is_contiguous
()
assert
target_probs
.
shape
==
(
num_tokens
,
vocab_size
)
# Create output buffer.
output_token_ids
=
torch
.
full
(
(
batch_size
,
max_spec_len
+
1
),
dtype
=
torch
.
int32
,
# Consistent with SamplerOutput.sampled_token_ids.
fill_value
=
PLACEHOLDER_TOKEN_ID
,
device
=
device
,
)
uniform_probs
=
torch
.
rand
(
(
num_tokens
,
),
dtype
=
torch
.
float32
,
device
=
device
,
)
uniform_probs
=
uniform_probs
*
0.1
+
0.1
# Rejection sampling for random sampling requests.
rejection_random_sample_kernel
[(
batch_size
,
)](
output_token_ids
,
cu_num_draft_tokens
,
draft_token_ids
,
draft_probs
,
target_probs
,
target_tokens
,
bonus_token_ids
,
uniform_probs
,
max_spec_len
,
vocab_size
,
NO_DRAFT_PROBS
=
draft_probs
is
None
,
num_warps
=
1
,
)
return
output_token_ids
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@
triton
.
jit
(
do_not_specialize
=
[
"max_spec_len"
])
def
rejection_random_sample_kernel
(
output_token_ids_ptr
,
# [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr
,
# [batch_size]
draft_token_ids_ptr
,
# [num_tokens]
draft_probs_ptr
,
# [num_tokens, vocab_size] or None
target_probs_ptr
,
# [num_tokens, vocab_size]
target_token_ids_ptr
,
# [num_tokens, vocab_size]
bonus_token_ids_ptr
,
# [batch_size]
uniform_probs_ptr
,
# [num_tokens]
max_spec_len
,
vocab_size
,
NO_DRAFT_PROBS
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
if
req_idx
==
0
:
start_idx
=
0
else
:
start_idx
=
tl
.
load
(
cu_num_draft_tokens_ptr
+
req_idx
-
1
)
end_idx
=
tl
.
load
(
cu_num_draft_tokens_ptr
+
req_idx
)
num_draft_tokens
=
end_idx
-
start_idx
rejected
=
False
for
pos
in
range
(
num_draft_tokens
):
if
not
rejected
:
draft_token_id
=
tl
.
load
(
draft_token_ids_ptr
+
start_idx
+
pos
)
if
NO_DRAFT_PROBS
:
draft_prob
=
1
else
:
draft_prob
=
tl
.
load
(
draft_probs_ptr
+
(
start_idx
+
pos
)
*
vocab_size
+
draft_token_id
)
target_prob
=
tl
.
load
(
target_probs_ptr
+
(
start_idx
+
pos
)
*
vocab_size
+
draft_token_id
)
target_token_id
=
tl
.
load
(
target_token_ids_ptr
+
(
start_idx
+
pos
))
target_token_id
=
target_token_id
.
to
(
tl
.
int64
)
uniform_prob
=
tl
.
load
(
uniform_probs_ptr
+
start_idx
+
pos
)
# NOTE(woosuk): While the draft probability should never be 0,
# we check it to avoid NaNs. If it happens to be 0, we reject.
if
(
draft_token_id
==
target_token_id
)
or
(
target_prob
/
draft_prob
>=
uniform_prob
and
draft_prob
>
0
):
token_id
=
draft_token_id
else
:
rejected
=
True
token_id
=
target_token_id
tl
.
store
(
output_token_ids_ptr
+
req_idx
*
(
max_spec_len
+
1
)
+
pos
,
token_id
)
if
not
rejected
:
# If all tokens are accepted, append the bonus token.
bonus_token_id
=
tl
.
load
(
bonus_token_ids_ptr
+
req_idx
)
tl
.
store
(
output_token_ids_ptr
+
req_idx
*
(
max_spec_len
+
1
)
+
num_draft_tokens
,
bonus_token_id
)
vllm/v1/spec_decode/eagle.py
View file @
ade7db0c
...
...
@@ -5,7 +5,9 @@ import numpy as np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
vllm.envs
as
envs
from
vllm.attention.layer
import
Attention
from
vllm.config
import
(
CompilationLevel
,
VllmConfig
,
get_layers_from_vllm_config
)
...
...
@@ -235,6 +237,10 @@ class EagleProposer:
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
if
envs
.
VLLM_REJECT_SAMPLE_OPT
:
draft_prob
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
draft_probs_list
=
[
draft_prob
]
# Early exit if there is only one draft token to be generated.
if
self
.
num_speculative_tokens
==
1
:
# [batch_size, 1]
...
...
@@ -385,9 +391,17 @@ class EagleProposer:
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids_list
.
append
(
draft_token_ids
)
if
envs
.
VLLM_REJECT_SAMPLE_OPT
:
draft_prob
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
draft_probs_list
.
append
(
draft_prob
)
# [batch_size, num_speculative_tokens]
draft_token_ids
=
torch
.
stack
(
draft_token_ids_list
,
dim
=
1
)
if
envs
.
VLLM_REJECT_SAMPLE_OPT
:
draft_probs
=
torch
.
stack
(
draft_probs_list
,
dim
=
1
).
contiguous
()
return
draft_token_ids
,
draft_probs
return
draft_token_ids
# @staticmethod
...
...
vllm/v1/spec_decode/metadata.py
View file @
ade7db0c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
Optional
import
numpy
as
np
import
torch
...
...
@@ -21,6 +22,8 @@ class SpecDecodeMetadata:
bonus_logits_indices
:
torch
.
Tensor
# [num_tokens + batch_size]
logits_indices
:
torch
.
Tensor
# [batch_size]
spec_decode_ids
:
Optional
[
list
[
str
]]
=
None
def
__post_init__
(
self
):
self
.
max_spec_len
=
max
(
self
.
num_draft_tokens
)
...
...
vllm/v1/spec_decode/utils.py
View file @
ade7db0c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
msgspec
from
abc
import
ABC
import
torch
from
vllm.sampling_params
import
SamplingParams
from
vllm.triton_utils
import
tl
,
triton
...
...
@@ -39,3 +43,42 @@ def prepare_eagle_input_kernel(
index_start
+
offset
,
mask
=
offset
<
num_tokens
,
)
class
DraftProbs
(
ABC
):
# type: ignore[call-arg]
"""Draft probs corresponding to in-progress sequences."""
# spec tokens probs.
draft_probs
:
torch
.
Tensor
# The request id list.
_req_ids
:
list
[
str
]
def
__init__
(
self
,
draft_probs
,
req_ids
):
assert
len
(
req_ids
)
==
len
(
draft_probs
)
self
.
draft_probs
=
draft_probs
self
.
_req_ids
=
req_ids
def
update
(
self
,
draft_probs
:
torch
.
Tensor
,
tmp_req_ids
:
list
[
str
]):
diff_req_ids
=
[
item
for
item
in
self
.
_req_ids
if
item
not
in
tmp_req_ids
]
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
diff_req_ids
]
self
.
_req_ids
=
diff_req_ids
self
.
draft_probs
=
self
.
draft_probs
[
index
]
self
.
draft_probs
=
torch
.
cat
([
self
.
draft_probs
,
draft_probs
])
self
.
_req_ids
.
extend
(
tmp_req_ids
)
assert
len
(
self
.
_req_ids
)
==
len
(
self
.
draft_probs
)
def
prune
(
self
,
req_ids
:
list
[
str
]):
new_req_ids
=
[
req_id
for
req_id
in
self
.
_req_ids
if
req_id
not
in
req_ids
]
if
new_req_ids
!=
self
.
_req_ids
:
# Batch contents changed - prune removed sequences.
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
new_req_ids
]
self
.
draft_probs
=
self
.
draft_probs
[
index
]
self
.
_req_ids
=
new_req_ids
def
get_probs
(
self
,
req_ids
:
list
[
str
]):
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
req_ids
]
return
self
.
draft_probs
[
index
]
vllm/v1/worker/gpu_model_runner.py
View file @
ade7db0c
...
...
@@ -59,6 +59,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.rejection_sampler
import
RejectionSampler
from
vllm.v1.sample.rejection_sampler_opt
import
OptRejectionSampler
from
vllm.v1.sample.sampler
import
Sampler
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.medusa
import
MedusaProposer
...
...
@@ -75,6 +76,7 @@ from ..sample.logits_processor import LogitsProcessorManager
from
.utils
import
(
gather_mm_placeholders
,
initialize_kv_cache_for_kv_sharing
,
sanity_check_mm_encoder_outputs
,
scatter_mm_placeholders
)
from
vllm.zero_overhead.v1.eagle
import
V1ZeroEagleProposer
from
vllm.v1.spec_decode.utils
import
DraftProbs
if
TYPE_CHECKING
:
import
xgrammar
as
xgr
...
...
@@ -197,7 +199,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
else
:
raise
ValueError
(
"Unknown speculative decoding method: "
f
"
{
self
.
speculative_config
.
method
}
"
)
if
not
envs
.
VLLM_REJECT_SAMPLE_OPT
:
self
.
rejection_sampler
=
RejectionSampler
()
else
:
self
.
rejection_sampler
=
OptRejectionSampler
()
# Request states.
self
.
requests
:
dict
[
str
,
CachedRequestState
]
=
{}
...
...
@@ -324,6 +329,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
self
.
shared_kv_cache_layers
:
dict
[
str
,
str
]
=
{}
self
.
draft_probs
:
Optional
[
DraftProbs
]
=
None
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
"""
Update the order of requests in the batch based on the attention
...
...
@@ -383,6 +390,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
for
req_id
in
scheduler_output
.
finished_req_ids
:
self
.
input_batch
.
remove_request
(
req_id
)
# prune draft probs of finished requests
if
envs
.
VLLM_REJECT_SAMPLE_OPT
and
self
.
draft_probs
is
not
None
and
len
(
scheduler_output
.
finished_req_ids
)
>
0
:
self
.
draft_probs
.
prune
(
list
(
scheduler_output
.
finished_req_ids
))
# Free the cached encoder outputs.
for
req_id
,
input_id
in
scheduler_output
.
free_encoder_input_ids
:
encoder_outputs
=
self
.
encoder_cache
.
get
(
req_id
)
...
...
@@ -762,13 +773,18 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Iterate over the dictionary rather than all requests since not all
# requests have draft tokens.
num_draft_tokens
=
np
.
zeros
(
num_reqs
,
dtype
=
np
.
int32
)
for
req_id
,
draft_token_ids
in
(
scheduler_output
.
scheduled_spec_decode_tokens
.
items
()):
req_idx
=
self
.
input_batch
.
req_id_to_index
[
req_id
]
num_draft_tokens
[
req_idx
]
=
len
(
draft_token_ids
)
spec_decode_ids
=
None
if
envs
.
VLLM_REJECT_SAMPLE_OPT
:
spec_decode_ids
=
scheduler_output
.
scheduled_spec_decode_tokens
.
keys
()
spec_decode_metadata
=
self
.
_calc_spec_decode_metadata
(
num_draft_tokens
,
cu_num_tokens
)
num_draft_tokens
,
cu_num_tokens
,
spec_decode_ids
)
logits_indices
=
spec_decode_metadata
.
logits_indices
# Hot-Swap lora model
...
...
@@ -922,6 +938,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self
,
num_draft_tokens
:
np
.
ndarray
,
cu_num_scheduled_tokens
:
np
.
ndarray
,
spec_decode_ids
:
Optional
[
list
[
str
]]
=
None
)
->
SpecDecodeMetadata
:
# Inputs:
# cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
...
...
@@ -993,6 +1010,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
target_logits_indices
=
target_logits_indices
,
bonus_logits_indices
=
bonus_logits_indices
,
logits_indices
=
logits_indices
,
spec_decode_ids
=
spec_decode_ids
,
)
return
metadata
...
...
@@ -1491,6 +1509,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert
logits
is
not
None
if
not
envs
.
VLLM_REJECT_SAMPLE_OPT
:
bonus_logits
=
logits
[
spec_decode_metadata
.
bonus_logits_indices
]
sampler_output
=
self
.
sampler
(
logits
=
bonus_logits
,
...
...
@@ -1510,6 +1530,26 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
sampling_metadata
,
)
sampler_output
.
sampled_token_ids
=
output_token_ids
else
:
sampler_output
=
self
.
sampler
(
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
)
target_token_ids
=
sampler_output
.
sampled_token_ids
[
spec_decode_metadata
.
target_logits_indices
]
target_logits
=
logits
[
spec_decode_metadata
.
target_logits_indices
]
bonus_token_ids
=
sampler_output
.
sampled_token_ids
[
spec_decode_metadata
.
bonus_logits_indices
]
output_token_ids
=
self
.
rejection_sampler
(
spec_decode_metadata
,
self
.
draft_probs
.
get_probs
(
spec_decode_metadata
.
spec_decode_ids
),
target_logits
,
target_token_ids
,
bonus_token_ids
,
sampling_metadata
,
)
sampler_output
.
sampled_token_ids
=
output_token_ids
num_nans_in_logits
=
{}
if
envs
.
VLLM_COMPUTE_NANS_IN_LOGITS
:
...
...
@@ -1590,7 +1630,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Speculative decoding is not enabled.
spec_token_ids
=
None
else
:
spec_
token_ids
=
self
.
propose_draft_token_ids
(
spec_
result
=
self
.
propose_draft_token_ids
(
scheduler_output
,
valid_sampled_token_ids
,
sampling_metadata
,
...
...
@@ -1600,6 +1640,15 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
spec_decode_metadata
,
attn_metadata
,
)
if
not
envs
.
VLLM_REJECT_SAMPLE_OPT
:
spec_token_ids
=
spec_result
else
:
spec_token_ids
,
draft_probs
=
spec_result
if
self
.
draft_probs
is
None
:
self
.
draft_probs
=
DraftProbs
(
draft_probs
,
self
.
input_batch
.
req_ids
)
else
:
self
.
draft_probs
.
update
(
draft_probs
,
self
.
input_batch
.
req_ids
)
# Clear KVConnector state after all KVs are generated.
if
has_kv_transfer_group
():
...
...
@@ -1722,7 +1771,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
target_hidden_states
=
hidden_states
[
token_indices
]
target_slot_mapping
=
eagle_attn_metadata
.
slot_mapping
[
token_indices
]
draft_
token_ids
=
self
.
drafter
.
propose
(
draft_
result
=
self
.
drafter
.
propose
(
target_token_ids
=
target_token_ids
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
...
...
@@ -1733,7 +1782,16 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
sampling_metadata
=
sampling_metadata
,
decoding
=
spec_decode_metadata
is
not
None
)
if
not
envs
.
VLLM_REJECT_SAMPLE_OPT
:
draft_token_ids
=
draft_result
spec_token_ids
=
draft_token_ids
.
tolist
()
return
spec_token_ids
draft_token_ids
,
draft_probs
=
draft_result
spec_token_ids
=
draft_token_ids
.
tolist
()
return
spec_token_ids
,
draft_probs
return
spec_token_ids
def
kv_connector_no_forward
(
...
...
@@ -2190,15 +2248,20 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
else
:
raise
e
if
self
.
speculative_config
:
draft_token_ids
=
[[
0
]
for
_
in
range
(
num_reqs
)]
draft_token_ids
=
[[
0
]
*
self
.
speculative_config
.
num_lookahead_slots
for
_
in
range
(
num_reqs
)]
dummy_spec_decode_metadata
=
SpecDecodeMetadata
.
make_dummy
(
draft_token_ids
,
self
.
device
)
num_tokens
=
sum
(
len
(
ids
)
for
ids
in
draft_token_ids
)
# draft_probs = torch.randn(
# num_tokens, logits.shape[-1], device=self.device,
# dtype=logits.dtype)
if
not
envs
.
VLLM_REJECT_SAMPLE_OPT
:
draft_probs
=
None
else
:
draft_probs
=
torch
.
randn
(
num_reqs
,
self
.
speculative_config
.
num_lookahead_slots
,
logits
.
shape
[
-
1
],
device
=
self
.
device
,
dtype
=
logits
.
dtype
)
target_token_ids
=
torch
.
zeros
(
num_tokens
,
device
=
self
.
device
,
dtype
=
torch
.
int32
)
target_logits
=
torch
.
randn
(
num_tokens
,
logits
.
shape
[
-
1
],
device
=
self
.
device
,
...
...
@@ -2209,10 +2272,20 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
bonus_token_ids
=
torch
.
zeros
(
num_reqs
,
device
=
self
.
device
,
dtype
=
torch
.
int32
)
if
not
envs
.
VLLM_REJECT_SAMPLE_OPT
:
self
.
rejection_sampler
(
dummy_spec_decode_metadata
,
draft_probs
,
target_logits
,
bonus_token_ids
,
dummy_metadata
,
)
else
:
self
.
rejection_sampler
(
dummy_spec_decode_metadata
,
draft_probs
,
target_logits
,
target_token_ids
,
bonus_token_ids
,
dummy_metadata
,
)
...
...
@@ -3050,8 +3123,12 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
req_idx
=
self
.
input_batch
.
req_id_to_index
[
req_id
]
num_draft_tokens
[
req_idx
]
=
len
(
draft_token_ids
)
spec_decode_ids
=
None
if
envs
.
VLLM_REJECT_SAMPLE_OPT
:
spec_decode_ids
=
scheduler_output
.
scheduled_spec_decode_tokens
.
keys
()
spec_decode_metadata
=
self
.
_calc_spec_decode_metadata
(
num_draft_tokens
,
cu_num_tokens
)
num_draft_tokens
,
cu_num_tokens
,
spec_decode_ids
)
logits_indices
=
spec_decode_metadata
.
logits_indices
# Hot-Swap lora model
...
...
@@ -3258,6 +3335,8 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert
logits
is
not
None
if
not
envs
.
VLLM_REJECT_SAMPLE_OPT
:
bonus_logits
=
logits
[
spec_decode_metadata
.
bonus_logits_indices
]
sampler_output
=
self
.
sampler
(
logits
=
bonus_logits
,
...
...
@@ -3277,6 +3356,25 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
sampling_metadata
,
)
sampler_output
.
sampled_token_ids
=
output_token_ids
else
:
sampler_output
=
self
.
sampler
(
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
)
target_token_ids
=
sampler_output
.
sampled_token_ids
[
spec_decode_metadata
.
target_logits_indices
]
target_logits
=
logits
[
spec_decode_metadata
.
target_logits_indices
]
bonus_token_ids
=
sampler_output
.
sampled_token_ids
[
spec_decode_metadata
.
bonus_logits_indices
]
output_token_ids
=
self
.
rejection_sampler
(
spec_decode_metadata
,
self
.
draft_probs
.
get_probs
(
spec_decode_metadata
.
spec_decode_ids
),
target_logits
,
target_token_ids
,
bonus_token_ids
,
sampling_metadata
,
)
sampler_output
.
sampled_token_ids
=
output_token_ids
num_nans_in_logits
=
{}
if
envs
.
VLLM_COMPUTE_NANS_IN_LOGITS
:
...
...
@@ -3325,7 +3423,8 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
mask_int
=
mask
.
int
()
first_neg_one_indices
=
torch
.
argmax
(
mask_int
,
dim
=
1
)
num_accepted_tokens_tensor
=
torch
.
where
(
torch
.
any
(
mask
,
dim
=
1
),
first_neg_one_indices
,
sampled_token_ids
.
size
(
1
))
-
1
spec_token_ids
=
self
.
zero_propose_draft_token_ids
(
spec_result
=
self
.
zero_propose_draft_token_ids
(
scheduler_output
,
num_accepted_tokens_tensor
,
sampled_token_ids
,
...
...
@@ -3336,6 +3435,15 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
spec_decode_metadata
,
attn_metadata
,
)
if
not
envs
.
VLLM_REJECT_SAMPLE_OPT
:
spec_token_ids
=
spec_result
else
:
spec_token_ids
,
draft_probs
=
spec_result
if
self
.
draft_probs
is
None
:
self
.
draft_probs
=
DraftProbs
(
draft_probs
,
self
.
input_batch
.
req_ids
)
else
:
self
.
draft_probs
.
update
(
draft_probs
,
self
.
input_batch
.
req_ids
)
if
max_gen_len
==
1
:
# No spec decode tokens.
valid_sampled_token_ids
=
sampled_token_ids
.
tolist
()
...
...
@@ -3479,7 +3587,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
target_slot_mapping
=
eagle_attn_metadata
.
slot_mapping
[
token_indices
]
self
.
drafter
.
spec_scheduler_max_num_tokens
=
spec_scheduler_max_num_tokens
draft_
token_ids
=
self
.
drafter
.
propose
(
draft_
result
=
self
.
drafter
.
propose
(
target_token_ids
=
target_token_ids
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
...
...
@@ -3494,7 +3602,14 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
# self.last_draft_token_ids = draft_token_ids
# self.last_draft_host_tokens = draft_token_ids.to('cpu', non_blocking=True)
# self.last_draft_event.record()
if
not
envs
.
VLLM_REJECT_SAMPLE_OPT
:
draft_token_ids
=
draft_result
spec_token_ids
=
draft_token_ids
.
tolist
()
return
spec_token_ids
draft_token_ids
,
draft_probs
=
draft_result
spec_token_ids
=
draft_token_ids
.
tolist
()
return
spec_token_ids
,
draft_probs
return
spec_token_ids
#TODO:稳定后使用GPUModelRunnerMTP替换GPUModelRunner
if
envs
.
VLLM_USE_ZERO_MTP
:
...
...
vllm/zero_overhead/v1/eagle.py
View file @
ade7db0c
import
torch
import
torch.nn.functional
as
F
import
vllm.envs
as
envs
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.models.llama_eagle3
import
Eagle3LlamaForCausalLM
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
...
...
@@ -161,6 +164,10 @@ class V1ZeroEagleProposer(EagleProposer):
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
if
envs
.
VLLM_REJECT_SAMPLE_OPT
:
draft_prob
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
draft_probs_list
=
[
draft_prob
]
# Early exit if there is only one draft token to be generated.
if
self
.
num_speculative_tokens
==
1
:
# [batch_size, 1]
...
...
@@ -311,7 +318,15 @@ class V1ZeroEagleProposer(EagleProposer):
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids_list
.
append
(
draft_token_ids
)
if
envs
.
VLLM_REJECT_SAMPLE_OPT
:
draft_prob
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
draft_probs_list
.
append
(
draft_prob
)
# [batch_size, num_speculative_tokens]
draft_token_ids
=
torch
.
stack
(
draft_token_ids_list
,
dim
=
1
)
if
envs
.
VLLM_REJECT_SAMPLE_OPT
:
draft_probs
=
torch
.
stack
(
draft_probs_list
,
dim
=
1
).
contiguous
()
return
draft_token_ids
,
draft_probs
return
draft_token_ids
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment