Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dd572c0a
Unverified
Commit
dd572c0a
authored
Jul 18, 2025
by
Woosuk Kwon
Committed by
GitHub
Jul 18, 2025
Browse files
[V0 Deprecation] Remove V0 Spec Decode workers (#21152)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
9ffe905a
Changes
73
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
18 additions
and
3288 deletions
+18
-3288
vllm/spec_decode/metrics.py
vllm/spec_decode/metrics.py
+0
-213
vllm/spec_decode/mlp_speculator_worker.py
vllm/spec_decode/mlp_speculator_worker.py
+0
-94
vllm/spec_decode/mqa_scorer.py
vllm/spec_decode/mqa_scorer.py
+0
-160
vllm/spec_decode/multi_step_worker.py
vllm/spec_decode/multi_step_worker.py
+0
-423
vllm/spec_decode/ngram_worker.py
vllm/spec_decode/ngram_worker.py
+0
-196
vllm/spec_decode/proposer_worker_base.py
vllm/spec_decode/proposer_worker_base.py
+0
-59
vllm/spec_decode/smaller_tp_proposer_worker.py
vllm/spec_decode/smaller_tp_proposer_worker.py
+0
-196
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+0
-1326
vllm/spec_decode/target_model_runner.py
vllm/spec_decode/target_model_runner.py
+0
-45
vllm/spec_decode/top1_proposer.py
vllm/spec_decode/top1_proposer.py
+0
-275
vllm/spec_decode/util.py
vllm/spec_decode/util.py
+0
-277
vllm/transformers_utils/configs/eagle.py
vllm/transformers_utils/configs/eagle.py
+18
-22
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+0
-2
No files found.
vllm/spec_decode/metrics.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
time
from
typing
import
Callable
,
Optional
,
Union
import
msgspec
import
torch
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
SpecDecodeBaseSampler
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_pin_memory_available
class
SpecDecodeWorkerMetrics
(
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
array_like
=
True
):
# type: ignore[call-arg]
"""Dataclass holding metrics emitted from the spec decode worker.
"""
# The empirical acceptance rate of the proposal method on a per-token basis.
# This is useful for evaluating how well the proposal method aligns with the
# scoring method.
draft_acceptance_rate
:
float
# The empirical efficiency, measured as the number of tokens emitted by the
# system divided by the number of tokens that could be emitted by the system
# if the proposal method were perfect.
system_efficiency
:
float
# The number of speculative tokens produced by the proposal method.
draft_tokens
:
int
# The number of tokens emitted by the entire system.
emitted_tokens
:
int
# The number of tokens accepted by the scoring model and verification
# routine, e.g. Llama2-70B and lossless rejection sampling.
#
# NOTE: Any token accepted by the verification routine is considered
# accepted (regardless of if the speculative prefix is also accepted). The
# user will usually see less accepted tokens. This metric is helpful when
# evaluating alignment of the proposal method with the scoring model.
accepted_tokens
:
int
# The number of speculative tokens per sequence.
num_spec_tokens
:
int
Timer
=
Callable
[[],
float
]
class
AsyncMetricsCollector
:
"""Class which copies rejection/typical-acceptance sampler metrics
from the device to CPU on a non-default Torch stream.
"""
def
__init__
(
self
,
spec_decode_sampler
:
SpecDecodeBaseSampler
,
timer
:
Optional
[
Timer
]
=
None
,
collect_interval_s
:
float
=
5.0
):
self
.
spec_decode_sampler
=
spec_decode_sampler
self
.
_timer
=
time
.
time
if
timer
is
None
else
timer
self
.
_rank
:
Optional
[
int
]
=
None
# We don't have a device set yet.
self
.
_copy_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
self
.
_in_flight_copy
:
Optional
[
torch
.
cuda
.
Event
]
=
None
pin_memory
=
is_pin_memory_available
()
self
.
_aggregate_num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
_aggregate_num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
_aggregate_num_draft_tokens
=
0
self
.
_rejsample_metrics_collect_interval_s
=
collect_interval_s
self
.
_last_metrics_collect_time
=
self
.
_timer
()
def
init_gpu_tensors
(
self
,
rank
:
int
)
->
None
:
self
.
_rank
=
rank
self
.
_copy_stream
=
torch
.
cuda
.
Stream
()
def
init_tensors
(
self
,
rank
:
int
,
device_type
:
Union
[
torch
.
device
,
str
]
=
'cuda'
)
->
None
:
self
.
_rank
=
rank
if
isinstance
(
device_type
,
torch
.
device
):
device_type
=
device_type
.
type
stream
=
current_platform
.
Stream
if
stream
is
not
None
:
self
.
_copy_stream
=
stream
()
def
maybe_collect_rejsample_metrics
(
self
,
k
:
int
)
->
Optional
[
SpecDecodeWorkerMetrics
]:
# Skip for any platform that doesn't have device Event
if
current_platform
.
Event
is
None
:
return
None
# If a copy was initiated in the previous call, collect and return.
if
self
.
_in_flight_copy
is
not
None
:
ready_event
=
self
.
_in_flight_copy
self
.
_in_flight_copy
=
None
return
self
.
_collect_rejsample_metrics
(
k
,
ready_event
)
# Otherwise, check if we should start a new copy.
if
self
.
_should_collect_rejsample_metrics
(
self
.
_timer
()):
assert
self
.
_in_flight_copy
is
None
self
.
_in_flight_copy
=
self
.
_copy_rejsample_metrics_async
()
return
None
def
_should_collect_rejsample_metrics
(
self
,
now
:
float
)
->
bool
:
"""Return whether or not this iteration should print sampling
metrics.
"""
if
self
.
_rank
!=
0
:
return
False
return
now
-
self
.
_last_metrics_collect_time
>=
self
.
_rejsample_metrics_collect_interval_s
# noqa: E501
def
_copy_rejsample_metrics_async
(
self
)
->
torch
.
cuda
.
Event
:
"""Copy rejection/typical-acceptance sampling metrics
(number of accepted tokens, etc) to CPU asynchronously.
Returns a device event recording when the copy is complete.
"""
assert
self
.
_copy_stream
is
not
None
self
.
_copy_stream
.
wait_stream
(
current_platform
.
current_stream
())
with
current_platform
.
stream
(
self
.
_copy_stream
):
self
.
_aggregate_num_accepted_tokens
.
copy_
(
self
.
spec_decode_sampler
.
num_accepted_tokens
,
non_blocking
=
True
)
self
.
_aggregate_num_emitted_tokens
.
copy_
(
self
.
spec_decode_sampler
.
num_emitted_tokens
,
non_blocking
=
True
)
# Number of draft tokens is calculated on CPU, so no copy is
# required.
self
.
_aggregate_num_draft_tokens
=
(
self
.
spec_decode_sampler
.
num_draft_tokens
)
aggregate_metrics_ready
=
current_platform
.
Event
()
aggregate_metrics_ready
.
record
(
self
.
_copy_stream
)
return
aggregate_metrics_ready
def
_collect_rejsample_metrics
(
self
,
k
:
int
,
ready_event
:
torch
.
cuda
.
Event
)
->
SpecDecodeWorkerMetrics
:
"""Create metrics object from statistics copied asynchronously.
Args:
k: int. The number of speculative tokens; used to determine system
efficiency.
ready_event: torch.cuda.Event. The CUDA event recording when the
async GPU->CPU copy is complete.
"""
ready_event
.
synchronize
()
# update time of last collection
self
.
_last_metrics_collect_time
=
self
.
_timer
()
accepted_tokens
=
self
.
_aggregate_num_accepted_tokens
.
item
()
emitted_tokens
=
self
.
_aggregate_num_emitted_tokens
.
item
()
draft_tokens
=
self
.
_aggregate_num_draft_tokens
max_num_emitted_tokens
=
self
.
get_max_num_emitted_tokens
(
draft_tokens
,
k
)
if
draft_tokens
>
0
:
draft_acceptance_rate
=
accepted_tokens
/
draft_tokens
else
:
draft_acceptance_rate
=
float
(
"nan"
)
if
max_num_emitted_tokens
>
0
:
system_efficiency
=
emitted_tokens
/
max_num_emitted_tokens
else
:
system_efficiency
=
float
(
"nan"
)
return
SpecDecodeWorkerMetrics
(
num_spec_tokens
=
k
,
draft_acceptance_rate
=
draft_acceptance_rate
,
system_efficiency
=
system_efficiency
,
accepted_tokens
=
accepted_tokens
,
draft_tokens
=
draft_tokens
,
emitted_tokens
=
emitted_tokens
,
)
@
staticmethod
def
get_max_num_emitted_tokens
(
draft_tokens
:
int
,
k
:
int
)
->
int
:
"""Calculate the number of emitted tokens, assuming all tokens are
accepted.
This is equal to the number of sequences that have been speculated on,
times (speculation len + 1). The +1 comes from the bonus token.
"""
# Determine the number of sequences that have been speculated on. Since
# the batch size can be variable, we divide by k.
assert
draft_tokens
%
k
==
0
total_num_spec_seqs
=
draft_tokens
//
k
# A single sequence may emit k accepted tokens and one bonus token in
# the best case.
num_emitted_per_seq_if_all_accepted
=
k
+
1
# The max num of emitted tokens is the number of speculated sequences
# times the max emitted per seq.
return
total_num_spec_seqs
*
num_emitted_per_seq_if_all_accepted
vllm/spec_decode/mlp_speculator_worker.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
List
,
Optional
,
Set
,
Tuple
import
torch
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SequenceGroupMetadata
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
class
MLPSpeculatorWorker
(
NonLLMProposerWorkerBase
,
MultiStepWorker
):
"""Worker for MLPSpeculator models.
Not currently compatible with LoRA or chunked prefill.
"""
@
torch
.
inference_mode
()
def
sampler_output
(
self
,
execute_model_req
:
ExecuteModelRequest
,
sample_len
:
int
,
# Unused parameter. MLPSpeculatorWorker does not use the KV Cache and
# therefore does not need this parameter.
seq_ids_with_bonus_token_in_last_step
:
Set
[
int
],
)
->
Tuple
[
List
[
SamplerOutput
],
bool
]:
"""Run the model forward pass to generate sample_len future tokens.
Returns the list of sampler output, one per layer, along with indicator
of whether torch tensor in sampler output need to be transposed in
latter sampler_output_to_torch logic.
For mlp spec worker, this indicator shall be True.
"""
self
.
_raise_if_unsupported
(
execute_model_req
)
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
(
input_tokens
,
seq_lens
,
query_lens
)
=
self
.
_prepare_input_tensors
(
seq_group_metadata_list
)
generators
=
self
.
model_runner
.
get_generators
(
execute_model_req
.
finished_requests_ids
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
self
.
device
,
self
.
model_runner
.
pin_memory
,
generators
)
model_outputs
=
self
.
model_runner
.
model
.
generate_proposals
(
input_ids
=
input_tokens
,
previous_hidden_states
=
execute_model_req
.
previous_hidden_states
.
hidden_states
,
num_predict_tokens
=
sample_len
,
sampling_metadata
=
sampling_metadata
)
assert
len
(
model_outputs
)
==
sample_len
return
model_outputs
,
True
def
_prepare_input_tensors
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
)
->
Tuple
[
torch
.
Tensor
,
List
[
int
],
List
[
int
]]:
if
not
seq_group_metadata_list
:
return
torch
.
empty
(
0
,
device
=
self
.
device
),
[],
[]
input_tokens
:
List
[
int
]
=
[]
seq_lens
:
List
[
int
]
=
[]
query_lens
:
List
[
int
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
is_prompt
=
seq_group_metadata
.
is_prompt
for
seq_data
in
seq_group_metadata
.
seq_data
.
values
():
seq_data_len
=
seq_data
.
get_len
()
if
is_prompt
:
context_len
=
seq_data
.
get_num_computed_tokens
()
seq_len
=
min
(
seq_data_len
,
context_len
+
seq_group_metadata
.
token_chunk_size
)
tokens
=
seq_data
.
get_token_ids
()[
context_len
:
seq_len
]
seq_lens
.
append
(
seq_len
)
input_tokens
.
extend
(
tokens
)
query_lens
.
append
(
seq_len
-
context_len
)
else
:
seq_lens
.
append
(
seq_data_len
)
input_tokens
.
append
(
seq_data
.
get_last_token_id
())
query_lens
.
append
(
1
)
input_tokens_tensor
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
return
input_tokens_tensor
,
seq_lens
,
query_lens
vllm/spec_decode/mqa_scorer.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.sequence
import
(
ExecuteModelRequest
,
SequenceData
,
SequenceGroupMetadata
,
get_all_seq_ids
)
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
SeqId
=
int
TargetSeqId
=
int
class
MQAScorer
(
SpeculativeScorer
):
def
score_proposals
(
self
,
execute_model_req
:
ExecuteModelRequest
,
proposals
:
SpeculativeProposals
,
)
->
SpeculativeScores
:
target_seq_group_metadata_list
=
[]
target_seq_id_start
=
max
(
get_all_seq_ids
(
execute_model_req
.
seq_group_metadata_list
))
+
1
all_proposal_tokens
=
proposals
.
proposal_token_ids
.
tolist
()
all_proposal_lengths
=
proposals
.
proposal_lens
.
tolist
()
for
i
,
seq_group_metadata
in
enumerate
(
execute_model_req
.
seq_group_metadata_list
):
if
all_proposal_lengths
[
i
]
==
0
:
# Keep prompt seqs untouched (keep computed_tokens for chunks).
target_seq_group_metadata_list
.
append
(
seq_group_metadata
)
continue
seq_data_dict
=
seq_group_metadata
.
seq_data
assert
len
(
seq_data_dict
)
==
1
seq_id
=
next
(
iter
(
seq_data_dict
.
keys
()))
seq_data
:
SequenceData
=
seq_data_dict
[
seq_id
]
prompt_token_ids
=
seq_data
.
get_prompt_token_ids
()
output_token_ids
=
seq_data
.
get_output_token_ids
()
proposal_token_ids
=
all_proposal_tokens
[
i
][:
all_proposal_lengths
[
i
]]
new_output_token_ids
=
[
*
output_token_ids
,
*
proposal_token_ids
]
target_seq_id
=
target_seq_id_start
+
i
new_seq_data
=
SequenceData
.
from_seqs
(
prompt_token_ids
=
prompt_token_ids
,
output_token_ids
=
new_output_token_ids
,
)
new_seq_data
.
update_num_computed_tokens
(
len
(
prompt_token_ids
)
+
len
(
output_token_ids
)
-
1
)
# Ensure that the new decode sequence has at least one token.
assert
len
(
output_token_ids
)
>=
1
new_seq_data_dict
=
{
target_seq_id
:
new_seq_data
}
new_seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
seq_group_metadata
.
request_id
,
is_prompt
=
seq_group_metadata
.
is_prompt
,
seq_data
=
new_seq_data_dict
,
sampling_params
=
seq_group_metadata
.
sampling_params
,
block_tables
=
{
target_seq_id
:
seq_group_metadata
.
block_tables
[
seq_id
],
},
lora_request
=
None
,
)
target_seq_group_metadata_list
.
append
(
new_seq_group_metadata
)
target_sampler_output
=
self
.
_scorer_worker
.
execute_model
(
execute_model_req
=
execute_model_req
.
clone
(
seq_group_metadata_list
=
target_seq_group_metadata_list
))
target_sampler_output
=
target_sampler_output
[
0
]
k
=
execute_model_req
.
num_lookahead_slots
bs
=
len
(
execute_model_req
.
seq_group_metadata_list
)
target_token_ids
=
target_sampler_output
.
sampled_token_ids
target_probs
=
target_sampler_output
.
sampled_token_probs
target_logprobs
=
target_sampler_output
.
logprobs
prompt_logprobs
=
None
# If all requests have the same number of query tokens, we can avoid
# the for loop to build output for better performance.
if
min
(
all_proposal_lengths
)
==
k
:
# Regular decodes only.
assert
all
(
not
sg
.
is_prompt
for
sg
in
target_seq_group_metadata_list
if
sg
.
is_prompt
)
bs
,
_
=
proposals
.
proposal_token_ids
.
shape
all_tokens
=
target_token_ids
.
reshape
(
bs
,
k
+
1
)
all_probs
=
target_probs
.
reshape
(
bs
,
k
+
1
,
self
.
_vocab_size
)
all_logprobs
=
target_logprobs
.
reshape
(
bs
,
k
+
1
,
self
.
_vocab_size
)
else
:
# We either have decodes with different lens or prefill+decodes.
all_tokens
=
target_token_ids
.
new_full
(
size
=
(
bs
,
k
+
1
),
fill_value
=-
1
)
all_probs
=
target_probs
.
new_zeros
(
*
all_tokens
.
shape
,
self
.
_vocab_size
)
all_logprobs
=
target_logprobs
.
new_full
(
size
=
all_probs
.
shape
,
fill_value
=-
float
(
"inf"
))
target_token_ids
=
target_token_ids
.
flatten
()
# When prompt logprobs is enabled, lens of returned tensors go from
# n_sampled (requests with do_sample=True) to n_prompt+n_prefills.
# We adjust stride accordingly to get the generated tokens and
# their probs, but pass on prompt_logprobs as is, since it may be
# that n_prompts >> K.
has_prompt_log
=
any
((
sg
.
sampling_params
.
prompt_logprobs
and
sg
.
sampling_params
.
prompt_logprobs
>
0
)
for
sg
in
target_seq_group_metadata_list
)
# TODO (NickLucche) we should surface `disable_logprobs` as to not
# break abstraction to get its value.
if
(
not
self
.
_scorer_worker
.
model_runner
.
disable_logprobs
\
and
has_prompt_log
):
prompt_logprobs
=
[
o
.
prompt_logprobs
for
o
in
target_sampler_output
.
outputs
]
# Split loop into prefill|decode for readability.
start_loc
,
i
=
0
,
0
while
i
<
len
(
target_seq_group_metadata_list
)
and
target_seq_group_metadata_list
[
i
].
is_prompt
:
seq_meta
=
target_seq_group_metadata_list
[
i
]
end_loc
=
start_loc
if
has_prompt_log
:
end_loc
+=
seq_meta
.
token_chunk_size
elif
seq_meta
.
do_sample
:
end_loc
+=
1
# Skip chunks with no output tokens.
if
seq_meta
.
do_sample
:
# Get sampled token (last position in chunk) and its prob.
all_tokens
[
i
,
0
]
=
target_token_ids
[
end_loc
-
1
]
all_probs
[
i
,
0
]
=
target_probs
[
end_loc
-
1
]
all_logprobs
[
i
,
0
]
=
target_logprobs
[
end_loc
-
1
]
i
+=
1
start_loc
=
end_loc
# Decodes.
while
i
<
len
(
target_seq_group_metadata_list
):
proposed_len
,
seq_meta
=
all_proposal_lengths
[
i
],
target_seq_group_metadata_list
[
i
]
output_len
=
proposed_len
+
1
end_loc
=
start_loc
+
output_len
all_tokens
[
i
,
:
output_len
]
=
target_token_ids
[
start_loc
:
end_loc
]
all_probs
[
i
,
:
output_len
]
=
target_probs
[
start_loc
:
end_loc
]
all_logprobs
[
i
,
:
output_len
]
=
target_logprobs
[
start_loc
:
end_loc
]
start_loc
=
end_loc
i
+=
1
hidden_states
=
None
if
target_sampler_output
.
hidden_states
is
not
None
:
hidden_states
=
target_sampler_output
.
hidden_states
.
reshape
(
bs
,
(
k
+
1
),
-
1
)
return
SpeculativeScores
(
probs
=
all_probs
,
token_ids
=
all_tokens
,
logprobs
=
all_logprobs
,
hidden_states
=
hidden_states
,
prompt_logprobs
=
prompt_logprobs
)
vllm/spec_decode/multi_step_worker.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
copy
import
weakref
from
typing
import
Dict
,
List
,
Set
,
Tuple
import
torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
(
ExecuteModelRequest
,
HiddenStates
,
SequenceData
,
SequenceGroupMetadata
)
if
current_platform
.
is_cuda_alike
():
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker_base
import
DelegateWorkerBase
class
MultiStepWorker
(
ProposerWorkerBase
,
DelegateWorkerBase
):
"""The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead
by invoking the scheduler less.
The MultiStepWorker does not support cache swap operations, or beam search.
Cache swap operations do not require large modifications. On the other hand,
beam search requires memory allocations during sequence forks and thus
requires more thought for MultiStepWorker support.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
DelegateWorkerBase
.
__init__
(
self
,
*
args
,
**
kwargs
)
# Lazy initialization list.
self
.
_proposer
:
SpeculativeProposer
def
init_device
(
self
)
->
None
:
self
.
worker
.
init_device
()
self
.
_proposer
=
Top1Proposer
(
weakref
.
proxy
(
self
),
# type: ignore[arg-type]
self
.
device
,
self
.
vocab_size
,
max_proposal_len
=
self
.
max_model_len
,
)
def
set_include_gpu_probs_tensor
(
self
)
->
None
:
# Need include_gpu_probs_tensor for MultiStepWorker
self
.
model_runner
.
sampler
.
include_gpu_probs_tensor
=
True
if
hasattr
(
self
.
model_runner
.
model
,
"sampler"
):
(
self
.
model_runner
.
model
.
sampler
.
include_gpu_probs_tensor
)
=
True
def
set_should_modify_greedy_probs_inplace
(
self
)
->
None
:
self
.
model_runner
.
sampler
.
should_modify_greedy_probs_inplace
=
True
if
hasattr
(
self
.
model_runner
.
model
,
"sampler"
):
(
self
.
model_runner
.
model
.
sampler
.
should_modify_greedy_probs_inplace
)
=
True
@
torch
.
inference_mode
()
def
sampler_output
(
self
,
execute_model_req
:
ExecuteModelRequest
,
sample_len
:
int
,
seq_ids_with_bonus_token_in_last_step
:
Set
[
int
],
)
->
Tuple
[
List
[
SamplerOutput
],
bool
]:
"""Run the model forward pass sample_len times. Returns the list of
sampler output, one per model forward pass, along with indicator of
whether torch tensor in sampler output need to be transposed in latter
sampler_output_to_torch logic.
For multi step worker, this indicator shall be True.
"""
self
.
_raise_if_unsupported
(
execute_model_req
)
# Expand the batch for sequences with a bonus token.
# Perform a forward pass on the expanded batch and filter the
# response to retain only the original sequences' responses.
expanded_request
,
indices_of_seq_with_bonus_tokens
=
\
self
.
_expand_execute_model_request
(
execute_model_req
,
seq_ids_with_bonus_token_in_last_step
)
# Run model sample_len times.
model_outputs
:
List
[
SamplerOutput
]
=
[]
if
current_platform
.
is_cuda_alike
()
and
isinstance
(
self
.
model_runner
,
TP1DraftModelRunner
)
and
self
.
model_runner
.
supports_gpu_multi_step
(
expanded_request
):
# Here we run the draft_model_runner with multi-step prepare
# on the GPU directly
expanded_request
.
num_steps
=
sample_len
self
.
model_runner
.
set_indices_of_seq_with_bonus_tokens
(
indices_of_seq_with_bonus_tokens
)
model_outputs
=
self
.
execute_model
(
execute_model_req
=
expanded_request
)
else
:
# Here we run multi-step directly, with every step prepared
# on the CPU.
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
if
expanded_request
.
previous_hidden_states
is
not
None
:
self
.
worker
.
model_runner
.
return_hidden_states
=
True
for
_
in
range
(
sample_len
):
model_output
:
List
[
SamplerOutput
]
=
self
.
worker
.
execute_model
(
execute_model_req
=
expanded_request
)
assert
(
len
(
model_output
)
==
1
),
"composing multistep workers not supported"
model_output
=
model_output
[
0
]
self
.
_maybe_update_previous_hidden_states
(
model_output
,
expanded_request
)
self
.
_append_new_tokens
(
model_output
,
expanded_request
.
seq_group_metadata_list
,
indices_of_seq_with_bonus_tokens
)
model_outputs
.
append
(
model_output
)
# move indices to device to avoid stream sync
indices_of_seq_with_bonus_tokens
=
torch
.
tensor
(
indices_of_seq_with_bonus_tokens
,
device
=
self
.
device
)
filtered_model_outputs
=
self
.
_filter_model_output
(
model_outputs
,
indices_of_seq_with_bonus_tokens
)
return
filtered_model_outputs
,
True
@
staticmethod
def
_maybe_update_previous_hidden_states
(
model_output
:
SamplerOutput
,
expanded_request
:
ExecuteModelRequest
)
->
None
:
"""
Updates the previous hidden states in an expanded request
in-place with the hidden states from the model output.
"""
if
expanded_request
.
previous_hidden_states
is
not
None
:
expanded_request
.
previous_hidden_states
=
HiddenStates
(
model_output
.
hidden_states
,
expanded_request
.
seq_group_metadata_list
)
@
staticmethod
def
_expand_execute_model_request
(
execute_model_req
:
ExecuteModelRequest
,
seq_with_bonus_token_in_last_step
:
set
,
)
->
Tuple
[
ExecuteModelRequest
,
List
[
int
]]:
"""
Expands the execute model request based on sequences with bonus
tokens.
For each sequence with a bonus token, this method creates a new
sequence without the bonus token and adds it to the execute model
request. The original sequence groups are also retained. The indices
of the original sequence groups are returned for further processing.
Args:
execute_model_req (ExecuteModelRequest): The original execute
model request.
seq_with_bonus_token_in_last_step (set): Set of sequence IDs that
contain bonus tokens.
Returns:
Tuple[ExecuteModelRequest, List[int]]: The updated execute model
request with expanded sequences and a list of indices corresponding
to the original sequence groups.
"""
updated_seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
updated_execute_model_req
=
execute_model_req
.
clone
(
updated_seq_group_metadata_list
)
indices_of_original_sequence_groups
=
[]
for
seq_group
in
execute_model_req
.
seq_group_metadata_list
:
seq_group_has_bonus_tokens
=
False
for
seq_id
,
_
in
seq_group
.
seq_data
.
items
():
# Identify sequences with bonus tokens in the sequence group.
if
seq_id
in
seq_with_bonus_token_in_last_step
:
seq_group_has_bonus_tokens
=
True
break
if
seq_group_has_bonus_tokens
:
#Create new sequences without the last bonus token. These new
# sequence have the same sequence id as the original sequence.
# We create a new sequence group and add them there.
updated_seq_group_without_bonus_token
=
\
MultiStepWorker
.
_copy_seq_metadata_excluding_last_token
(
seq_group
,
seq_with_bonus_token_in_last_step
)
updated_seq_group_metadata_list
.
append
(
updated_seq_group_without_bonus_token
)
# Add the original sequence group.
updated_seq_group_metadata_list
.
append
(
MultiStepWorker
.
_shallow_copy_seq_group_metadata
(
seq_group
))
# Record the index of the original sequence group.
indices_of_original_sequence_groups
.
append
(
len
(
updated_seq_group_metadata_list
)
-
1
)
updated_execute_model_req
.
seq_group_metadata_list
=
\
updated_seq_group_metadata_list
if
isinstance
(
updated_execute_model_req
.
previous_hidden_states
,
HiddenStates
):
updated_execute_model_req
.
previous_hidden_states
\
.
expand_with_bonus_tokens
(
seq_with_bonus_token_in_last_step
)
return
updated_execute_model_req
,
indices_of_original_sequence_groups
@
staticmethod
def
_filter_model_output
(
expanded_batch_outputs
:
List
[
SamplerOutput
],
output_indices_to_retain
:
torch
.
Tensor
)
->
List
[
SamplerOutput
]:
"""
Filters the model output to include only the specified sequence
outputs. This method contracts the expanded batch output from the
model to retain the outputs of only those sequences indicated by the
provided indices.
Args:
expanded_batch_output (List[SamplerOutput]): The expanded output
batch from the model.
output_indices_to_retain (torch.Tensor): Indices of the model
outputs to retain.
Returns:
List[SamplerOutput]: A list containing the filtered model
outputs for the specified indices.
"""
return
[
SamplerOutput
(
outputs
=
[
expanded_batch_output
.
outputs
[
i
]
for
i
in
output_indices_to_retain
]
if
len
(
expanded_batch_output
.
outputs
)
>
0
else
[],
sampled_token_probs
=
(
expanded_batch_output
.
sampled_token_probs
[
output_indices_to_retain
]
if
expanded_batch_output
.
sampled_token_probs
is
not
None
else
None
),
logprobs
=
(
expanded_batch_output
.
logprobs
[
output_indices_to_retain
]
if
expanded_batch_output
.
logprobs
is
not
None
else
None
),
sampled_token_ids
=
(
expanded_batch_output
.
sampled_token_ids
[
output_indices_to_retain
]
if
expanded_batch_output
.
sampled_token_ids
is
not
None
else
None
))
for
expanded_batch_output
in
expanded_batch_outputs
]
def
get_spec_proposals
(
self
,
execute_model_req
:
ExecuteModelRequest
,
seq_ids_with_bonus_token_in_last_step
:
set
,
)
->
SpeculativeProposals
:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return
self
.
_proposer
.
get_spec_proposals
(
execute_model_req
,
seq_ids_with_bonus_token_in_last_step
)
@
staticmethod
def
_append_new_tokens
(
model_output
:
List
[
SamplerOutput
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
indices_of_seq_with_bonus_tokens
:
List
[
int
])
->
None
:
"""Given model output from a single run, append the tokens to the
sequences. This is normally done outside of the worker, but it is
required if the worker is to perform multiple forward passes.
"""
count
=
0
for
index
,
(
seq_group_metadata
,
sequence_group_outputs
)
in
enumerate
(
zip
(
seq_group_metadata_list
,
model_output
)):
seq_group_metadata
.
is_prompt
=
False
for
seq_output
in
sequence_group_outputs
.
samples
:
# NOTE: Beam search is not supported, so we can assume that
# parent_seq_id == seq_id.
seq
=
seq_group_metadata
.
seq_data
[
seq_output
.
parent_seq_id
]
token_id
=
seq_output
.
output_token
token_logprob
=
seq_output
.
logprobs
[
token_id
]
# Determine the actual token ID to be generated,
# considering bonus tokens
if
index
!=
indices_of_seq_with_bonus_tokens
[
count
]:
bonus_seq_metadata
=
seq_group_metadata_list
[
indices_of_seq_with_bonus_tokens
[
count
]]
_
,
bonus_token_seq_data
=
next
(
iter
(
bonus_seq_metadata
.
seq_data
.
items
()))
token_id
=
bonus_token_seq_data
.
output_token_ids
[
-
1
]
else
:
count
+=
1
seq
.
append_token_id
(
token_id
,
token_logprob
.
logprob
,
seq_output
.
output_embed
)
seq
.
update_num_computed_tokens
(
1
)
@
staticmethod
def
_shallow_copy_seq_group_metadata
(
seq_group_metadata
:
SequenceGroupMetadata
,
)
->
SequenceGroupMetadata
:
"""Copy input data structures to remove side-effects when input data
structures are shared with other modules.
Helpful when the vLLM scheduler runs in the same process as the worker.
The alternative is deep-copying (or other form of deep copy); this has
performance downsides.
"""
# Shallow-copy the SequenceGroupMetadata. This allows us to
# append tokens and change is_prompt without external side-effects.
# We must shallow-copy seq_group_metadata as is_prompt could change.
new_seq_group_metadata
=
copy
.
copy
(
seq_group_metadata
)
# We must shallow-copy seq_data as we will append token ids
new_seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
for
seq_id
,
old_seq_data
in
seq_group_metadata
.
seq_data
.
items
():
new_seq_data
[
seq_id
]
=
copy
.
copy
(
old_seq_data
)
new_seq_data
[
seq_id
].
output_token_ids
=
\
old_seq_data
.
output_token_ids
[:]
new_seq_group_metadata
.
seq_data
=
new_seq_data
return
new_seq_group_metadata
@
staticmethod
def
_copy_seq_metadata_excluding_last_token
(
seq_group_metadata
:
SequenceGroupMetadata
,
seq_ids_to_copy
:
Set
[
int
],
)
->
SequenceGroupMetadata
:
"""
Creates a shallow copy of the given SequenceGroupMetadata, retaining
only the sequence IDs specified in seq_ids_to_copy. For each of these
sequence IDs, all output_token_ids except the last one are copied.
Sequence IDs not in seq_ids_to_copy are excluded from the copy.
Parameters:
seq_group_metadata (SequenceGroupMetadata): The original sequence
group metadata.
seq_ids_to_copy (Set[int]): The set of sequence IDs to include in the
copy.
Returns:
SequenceGroupMetadata: A shallow copy of the sequence group metadata
with the specified modifications.
"""
# Shallow-copy the SequenceGroupMetadata.
new_seq_group_metadata
=
copy
.
copy
(
seq_group_metadata
)
# Shallow-copy seq_data and modify the output_token_ids.
new_seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
for
seq_id
,
old_seq_data
in
seq_group_metadata
.
seq_data
.
items
():
if
(
seq_id
in
seq_ids_to_copy
):
new_seq_data
[
seq_id
]
=
copy
.
copy
(
old_seq_data
)
# Copy all the output token ids except the last.
# Also reduce num_computed_tokens by 1 since we are not
# including the last output token.
# NOTE: num_computed_tokens is not directly used by the
# speculative decoding workers, as it is only relevant for
# chunked prefill, which is disabled for speculative decoding.
# However, to maintain consistency in num_computed_tokens,
# we update it here.
new_seq_data
[
seq_id
].
output_token_ids
=
\
old_seq_data
.
output_token_ids
[:
-
1
]
new_seq_data
[
seq_id
].
update_num_computed_tokens
(
-
1
)
new_seq_group_metadata
.
seq_data
=
new_seq_data
return
new_seq_group_metadata
def
_assert_enough_kv_space
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
num_steps
:
int
)
->
None
:
"""Assert there are enough physical blocks per sequence to store the
current KV plus additional KV from num_steps tokens.
"""
assert
self
.
model_runner
.
block_size
is
not
None
for
seq_group_metadata
in
seq_group_metadata_list
:
# Only one seq_id is guaranteed because there is no beam search.
seq_id
=
list
(
seq_group_metadata
.
seq_data
.
keys
())[
0
]
seq
=
seq_group_metadata
.
seq_data
[
seq_id
]
# After num_steps, the seq len will be the current seq len
# plus one token per step.
final_seq_len
=
seq
.
get_len
()
+
num_steps
# We will have final_seq_len - 1 KV because vLLM saves KV for a
# token in the iteration after the token was generated.
required_num_kv_slots
=
final_seq_len
-
1
# The allocated number of kv slots is the number of allocated blocks
# times the number of slots of block.
number_physical_blocks
=
len
(
seq_group_metadata
.
block_tables
[
seq_id
])
allocated_kv_slots
=
(
number_physical_blocks
*
self
.
model_runner
.
block_size
)
if
required_num_kv_slots
>
allocated_kv_slots
:
request_id
=
seq_group_metadata
.
request_id
raise
ValueError
(
"The worker attempted to run "
f
"
{
num_steps
}
times but found insufficient KV space for "
f
"
{
request_id
=
}
{
seq_id
=
}
. (
{
allocated_kv_slots
=
}
"
f
"
{
required_num_kv_slots
=
}
)."
)
def
_raise_if_unsupported
(
self
,
execute_model_req
:
ExecuteModelRequest
,
)
->
None
:
"""MultiStepWorker does not yet implement support for cache swap
operations or beam search.
"""
if
any
([
execute_model_req
.
blocks_to_swap_in
,
execute_model_req
.
blocks_to_swap_out
,
execute_model_req
.
blocks_to_copy
]):
raise
NotImplementedError
(
"MultiStepWorker does not support cache operations"
)
if
any
(
len
(
seq_group_metadata
.
seq_data
.
keys
())
!=
1
for
seq_group_metadata
in
execute_model_req
.
seq_group_metadata_list
):
raise
NotImplementedError
(
"MultiStepWorker does not support beam search."
)
def
maybe_load_lm_head_weight
(
self
,
lm_head_weight
:
torch
.
Tensor
,
)
->
None
:
weight_loader
=
getattr
(
self
.
worker
.
model_runner
.
model_runner
.
model
.
lm_head
.
weight
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
self
.
worker
.
model_runner
.
model_runner
.
model
.
lm_head
.
weight
,
lm_head_weight
)
vllm/spec_decode/ngram_worker.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
weakref
from
typing
import
List
,
Optional
,
Set
,
Tuple
import
torch
import
torch.nn
as
nn
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
class
_DummyModel
(
nn
.
Module
):
pass
class
NGramWorker
(
NonLLMProposerWorkerBase
):
"""NGramWorker provides a light drafter without need for model.
Current NGramWorker only implements prompt lookup decoding,
and in future we may also do RAG type drafter and other scenarios
which don't rely on LLM model to give proposals.
"""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
local_rank
:
int
,
device_type
:
str
=
"cuda"
,
**
kwargs
,
):
super
().
__init__
(
vllm_config
)
# Get local_rank/vocab_size from kwargs attribute
self
.
local_rank
=
local_rank
self
.
device_type
=
device_type
# Lazy initialization list.
self
.
_proposer
:
Top1Proposer
def
set_ngram_window_size
(
self
,
ngram_prompt_lookup_min
:
int
,
ngram_prompt_lookup_max
:
int
):
# Search valid candidate window between
# ngram_prompt_lookup_min/ngram_prompt_lookup_max
self
.
ngram_prompt_lookup_max
=
ngram_prompt_lookup_max
self
.
ngram_prompt_lookup_min
=
ngram_prompt_lookup_min
def
init_device
(
self
):
self
.
device
=
torch
.
device
(
f
"
{
self
.
device_type
}
:
{
self
.
local_rank
}
"
)
# Current NGramWorker only supports Top1Proposer
self
.
_proposer
=
Top1Proposer
(
weakref
.
proxy
(
self
),
# type: ignore[arg-type]
device
=
self
.
device
,
vocab_size
=
self
.
vocab_size
,
)
def
load_model
(
self
)
->
None
:
pass
# Dummy
def
get_model
(
self
)
->
nn
.
Module
:
return
_DummyModel
()
def
sampler_output
(
self
,
execute_model_req
:
ExecuteModelRequest
,
sample_len
:
int
,
# Unused parameter. NGramWorker does not use the KV Cache and
# therefore does not need this parameter.
seq_ids_with_bonus_token_in_last_step
:
Set
[
int
],
)
->
Tuple
[
Optional
[
List
[
Optional
[
SamplerOutput
]]],
bool
]:
"""NGram match algo to pick proposal candidate. Returns the list of
sampler output, one per SequenceGroupMetadata.
For ngram worker, we already done needed transposed internal, so the
indicator pass to sampler_output_to_torch shall be False.
"""
self
.
_raise_if_unsupported
(
execute_model_req
)
has_spec_out
=
False
token_id_list
:
List
[
Optional
[
torch
.
Tensor
]]
=
[]
token_prob_list
:
List
[
Optional
[
torch
.
Tensor
]]
=
[]
for
idx
,
seq_group_metadata
in
enumerate
(
execute_model_req
.
seq_group_metadata_list
):
seq_data
=
next
(
iter
(
seq_group_metadata
.
seq_data
.
values
()))
seq_len
=
seq_data
.
get_len
()
# When seq_len is less than 3072 (3K), we use CPU to perform
# the ngram match. Otherwise, we use the device specified in
# the model config (normally GPU). 3072 is a rough threshold
# based on profiling on H100, and it can be adjusted based
# on the actual performance on different hardware.
cur_device
=
"cpu"
if
seq_len
<
3072
else
self
.
device
input_ids
=
torch
.
as_tensor
(
seq_data
.
get_token_ids
(),
dtype
=
torch
.
long
,
device
=
cur_device
)
input_length
=
seq_data
.
get_len
()
for
ngram_size
in
range
(
min
(
self
.
ngram_prompt_lookup_max
,
input_length
-
1
),
self
.
ngram_prompt_lookup_min
-
1
,
-
1
,
):
ngram_tensor
=
input_ids
[
-
ngram_size
:]
if
ngram_size
==
1
:
# Do not match itself and do not use unfold and all
matches
=
(
input_ids
[:
-
1
]
==
ngram_tensor
)
else
:
windows
=
input_ids
.
unfold
(
dimension
=
0
,
size
=
ngram_size
,
step
=
1
)
# Do not match itself
matches
=
(
windows
[:
-
1
]
==
ngram_tensor
).
all
(
dim
=-
1
)
# first_match includes "values" (bool), indicating whether
# the match is found, and "indices", indicating the index
# of the first match.
first_match
=
matches
.
max
(
dim
=-
1
)
if
first_match
.
values
.
item
():
proposal_start_idx
=
first_match
.
indices
.
add_
(
ngram_size
)
spec_indices
=
(
proposal_start_idx
).
repeat
(
sample_len
)
+
torch
.
arange
(
sample_len
,
device
=
cur_device
)
spec_indices
.
clamp_
(
max
=
input_ids
.
shape
[
-
1
]
-
1
)
res
=
input_ids
.
gather
(
dim
=-
1
,
index
=
spec_indices
).
to
(
self
.
device
)
token_id_list
.
append
(
res
)
token_prob_list
.
append
(
torch
.
nn
.
functional
.
one_hot
(
res
,
num_classes
=
self
.
vocab_size
).
to
(
torch
.
float32
))
has_spec_out
=
True
break
else
:
token_id_list
.
append
(
None
)
token_prob_list
.
append
(
None
)
if
not
has_spec_out
:
return
None
,
False
outputs
:
List
[
Optional
[
SamplerOutput
]]
=
[]
for
idx
in
range
(
len
(
execute_model_req
.
seq_group_metadata_list
)):
if
token_id_list
[
idx
]
is
None
:
outputs
.
append
(
None
)
else
:
outputs
.
append
(
SamplerOutput
(
outputs
=
None
,
sampled_token_probs
=
token_prob_list
[
idx
],
logprobs
=
torch
.
zeros
((
sample_len
,
self
.
vocab_size
),
dtype
=
torch
.
float32
,
device
=
self
.
device
),
sampled_token_ids
=
token_id_list
[
idx
],
))
return
outputs
,
False
def
get_spec_proposals
(
self
,
execute_model_req
:
ExecuteModelRequest
,
# Unused parameter. NGramWorker does not use the KV Cache and
# therefore does not need this parameter.
seq_ids_with_bonus_token_in_last_step
:
Set
[
int
],
)
->
SpeculativeProposals
:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return
self
.
_proposer
.
get_spec_proposals
(
execute_model_req
,
seq_ids_with_bonus_token_in_last_step
)
def
_raise_if_unsupported
(
self
,
execute_model_req
:
ExecuteModelRequest
,
)
->
None
:
"""NGramWorker does not yet implement support for cache swap
operations or beam search.
"""
if
any
([
execute_model_req
.
blocks_to_swap_in
,
execute_model_req
.
blocks_to_swap_out
,
execute_model_req
.
blocks_to_copy
]):
raise
NotImplementedError
(
"NGramWorker does not support cache operations"
)
if
any
(
len
(
seq_group_metadata
.
seq_data
.
keys
())
!=
1
for
seq_group_metadata
in
execute_model_req
.
seq_group_metadata_list
):
raise
NotImplementedError
(
"NGramWorker does not support beam search."
)
vllm/spec_decode/proposer_worker_base.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Optional
,
Set
,
Tuple
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.spec_decode.interfaces
import
SpeculativeProposer
from
vllm.worker.worker_base
import
LoRANotSupportedWorkerBase
class
ProposerWorkerBase
(
LoRANotSupportedWorkerBase
,
SpeculativeProposer
):
"""Interface for proposer workers"""
@
abstractmethod
def
sampler_output
(
self
,
execute_model_req
:
ExecuteModelRequest
,
sample_len
:
int
,
# A set containing all sequence IDs that were assigned bonus tokens
# in their last forward pass. This set is used to backfill the KV cache
# with the key-value pairs of the penultimate token in the sequences.
# This parameter is only used by the MultiStepWorker, which relies on
# the KV cache for token generation. It is not used by workers that
# do not utilize the KV cache.
seq_ids_with_bonus_token_in_last_step
:
Set
[
int
]
)
->
Tuple
[
Optional
[
List
[
SamplerOutput
]],
bool
]:
raise
NotImplementedError
def
set_include_gpu_probs_tensor
(
self
)
->
None
:
"""Implementation optional"""
pass
def
set_should_modify_greedy_probs_inplace
(
self
)
->
None
:
"""Implementation optional"""
pass
class
NonLLMProposerWorkerBase
(
ProposerWorkerBase
,
ABC
):
"""Proposer worker which does not use a model with kvcache"""
def
execute_model
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
List
[
SamplerOutput
]:
"""get_spec_proposals is used to get the proposals"""
return
[]
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""This is never called on the proposer, only the target model"""
raise
NotImplementedError
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
pass
def
get_cache_block_size_bytes
(
self
)
->
int
:
return
0
vllm/spec_decode/smaller_tp_proposer_worker.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
List
,
Optional
,
Set
,
Tuple
import
torch
import
torch.nn
as
nn
from
vllm.distributed.parallel_state
import
(
get_tp_group
,
init_model_parallel_group
,
patch_tensor_parallel_group
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
logger
=
init_logger
(
__name__
)
class
_DummyModel
(
nn
.
Module
):
pass
class
SmallerTpProposerWorker
(
ProposerWorkerBase
):
"""Class which allows a speculative draft model to run with smaller tensor
parallel degree than target model.
This reduces the communication overhead of small draft models.
To implement this feature, this class differs behavior based on is_dummy
flag, where dummy means worker that does not participate draft generation.
Participating workers use a smaller tp group by patching vLLM's tensor
parallel group temporarily during forward passes of draft models.
"""
@
classmethod
def
maybe_wrap_worker
(
cls
,
worker
,
draft_tensor_parallel_size
:
int
,
target_tensor_parallel_size
:
int
):
"""Wrap the worker in a SmallerTpProposerWorker if necessary.
"""
if
draft_tensor_parallel_size
==
target_tensor_parallel_size
:
return
worker
# gpu ranks that will generate draft tokens together
draft_ranks
=
list
(
range
(
draft_tensor_parallel_size
))
logger
.
info
(
"Wrapping {%s} in {%s}"
,
type
(
worker
),
cls
)
return
cls
(
worker
,
draft_ranks
)
def
__init__
(
self
,
worker
:
MultiStepWorker
,
draft_ranks
:
List
[
int
]):
"""Create a SmallerTpProposerWorker.
Args:
worker (~vllm.spec_decode.multi_step_worker.MultiStepWorker): an
actual worker wrapped with this class
draft_ranks (List[int]): if this value is given, only the GPU ranks
written in this value participate in draft generation
"""
self
.
_worker
=
worker
self
.
_draft_ranks
=
draft_ranks
# init during init_device
self
.
_is_dummy
=
False
self
.
_tp_group
=
None
def
_patch_tensor_parallel_group
(
self
):
"""Temporarily patch the global tp group state with its own tp group
state.
"""
return
patch_tensor_parallel_group
(
self
.
_tp_group
)
def
init_device
(
self
)
->
None
:
self
.
_is_dummy
=
get_tp_group
().
rank
not
in
self
.
_draft_ranks
# dummy workers do nothing
if
self
.
_is_dummy
:
return
# creates tp process group containing only a subset of gpu ranks
local_rank
=
get_tp_group
().
local_rank
tp_backend
=
torch
.
distributed
.
get_backend
(
get_tp_group
().
device_group
)
self
.
_tp_group
=
init_model_parallel_group
([
self
.
_draft_ranks
],
local_rank
,
tp_backend
)
with
self
.
_patch_tensor_parallel_group
():
self
.
_worker
.
init_device
()
def
set_include_gpu_probs_tensor
(
self
)
->
None
:
if
self
.
_is_dummy
:
return
# Need include_gpu_probs_tensor for multi_step_worker
self
.
_worker
.
set_include_gpu_probs_tensor
()
def
set_should_modify_greedy_probs_inplace
(
self
)
->
None
:
if
self
.
_is_dummy
:
return
self
.
_worker
.
set_should_modify_greedy_probs_inplace
()
def
load_model
(
self
)
->
None
:
if
self
.
_is_dummy
:
return
with
self
.
_patch_tensor_parallel_group
():
self
.
_worker
.
load_model
()
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
if
self
.
_is_dummy
:
# this case is not used now
return
-
1
,
-
1
with
self
.
_patch_tensor_parallel_group
():
return
self
.
_worker
.
determine_num_available_blocks
()
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
if
self
.
_is_dummy
:
return
with
self
.
_patch_tensor_parallel_group
():
self
.
_worker
.
initialize_cache
(
num_gpu_blocks
,
num_cpu_blocks
)
def
sampler_output
(
self
,
execute_model_req
:
ExecuteModelRequest
,
sample_len
:
int
,
seq_ids_with_bonus_token_in_last_step
:
Set
[
int
],
)
->
Tuple
[
List
[
SamplerOutput
],
bool
]:
# Do not check _is_dummy, as it's always called by get_spec_proposals
return
self
.
_worker
.
sampler_output
(
execute_model_req
,
sample_len
,
seq_ids_with_bonus_token_in_last_step
)
def
get_spec_proposals
(
self
,
execute_model_req
:
ExecuteModelRequest
,
seq_ids_with_bonus_token_in_last_step
:
Set
[
int
],
)
->
SpeculativeProposals
:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
if
self
.
_is_dummy
:
return
SpeculativeProposals
(
None
,
None
,
None
)
with
self
.
_patch_tensor_parallel_group
():
return
self
.
_worker
.
get_spec_proposals
(
execute_model_req
,
seq_ids_with_bonus_token_in_last_step
)
def
get_model
(
self
)
->
nn
.
Module
:
if
self
.
_is_dummy
:
return
_DummyModel
()
with
self
.
_patch_tensor_parallel_group
():
return
self
.
_worker
.
get_model
()
def
execute_model
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
List
[
SamplerOutput
]:
if
self
.
_is_dummy
:
return
[]
with
self
.
_patch_tensor_parallel_group
():
return
self
.
_worker
.
execute_model
(
execute_model_req
)
def
get_cache_block_size_bytes
(
self
)
->
int
:
if
self
.
_is_dummy
:
# by returning zero, target worker can use the entire kv cache space
return
0
return
self
.
_worker
.
get_cache_block_size_bytes
()
@
property
def
vocab_size
(
self
)
->
int
:
return
self
.
_worker
.
vocab_size
def
maybe_load_lm_head_weight
(
self
,
lm_head_weight
:
torch
.
Tensor
,
)
->
None
:
if
self
.
_is_dummy
:
return
with
self
.
_patch_tensor_parallel_group
():
weight_loader
=
getattr
(
self
.
_worker
.
worker
.
model_runner
.
model_runner
.
model
.
\
lm_head
.
weight
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
self
.
_worker
.
worker
.
model_runner
.
model_runner
.
model
.
\
lm_head
.
weight
,
lm_head_weight
)
vllm/spec_decode/spec_decode_worker.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
copy
from
collections
import
defaultdict
from
functools
import
cached_property
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
import
torch
import
torch.nn
as
nn
from
vllm.config
import
ParallelConfig
,
SpeculativeConfig
,
VllmConfig
from
vllm.distributed.communication_op
import
(
broadcast_tensor_dict
,
get_tp_group
,
tensor_model_parallel_gather
)
from
vllm.distributed.parallel_state
import
model_parallel_is_initialized
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
SpecDecodeBaseSampler
,
SpecDecodeStochasticBaseSampler
)
from
vllm.model_executor.layers.typical_acceptance_sampler
import
(
TypicalAcceptanceSampler
)
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
CompletionSequenceGroupOutput
,
ExecuteModelRequest
,
HiddenStates
,
SequenceGroupMetadata
,
get_all_seq_ids_and_request_ids
)
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
if
current_platform
.
is_cuda_alike
():
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.medusa_worker
import
MedusaWorker
from
vllm.spec_decode.metrics
import
AsyncMetricsCollector
from
vllm.spec_decode.mlp_speculator_worker
import
MLPSpeculatorWorker
from
vllm.spec_decode.mqa_scorer
import
MQAScorer
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.ngram_worker
import
NGramWorker
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.smaller_tp_proposer_worker
import
SmallerTpProposerWorker
from
vllm.spec_decode.target_model_runner
import
TargetModelRunner
from
vllm.spec_decode.util
import
(
Timer
,
create_logprobs_output
,
create_sequence_group_output
,
get_all_num_logprobs
,
get_sampled_token_logprobs
,
nvtx_range
,
split_batch_by_proposal_len
)
from
vllm.utils
import
resolve_obj_by_qualname
from
vllm.worker.worker_base
import
LoRANotSupportedWorkerBase
,
WorkerBase
logger
=
init_logger
(
__name__
)
def
create_spec_worker
(
*
args
,
**
kwargs
)
->
"SpecDecodeWorker"
:
"""Helper method that is the entrypoint for Executors which use
WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
"""
vllm_config
:
VllmConfig
=
kwargs
.
get
(
"vllm_config"
)
speculative_config
:
SpeculativeConfig
=
vllm_config
.
speculative_config
assert
speculative_config
is
not
None
if
vllm_config
.
parallel_config
.
pipeline_parallel_size
>
1
:
raise
NotImplementedError
(
"Speculative decoding is currently "
"incompatible with pipeline parallelism"
)
draft_worker_kwargs
=
kwargs
.
copy
()
kwargs
[
"model_runner_cls"
]
=
TargetModelRunner
target_worker_config
=
copy
.
deepcopy
(
vllm_config
)
target_worker_config
.
parallel_config
.
worker_cls
=
\
target_worker_config
.
parallel_config
.
sd_worker_cls
cls
=
resolve_obj_by_qualname
(
target_worker_config
.
parallel_config
.
worker_cls
)
target_worker
=
cls
(
*
args
,
**
kwargs
)
# Set the disable_logprobs variable in the TargetModelRunner instance
# as per its value specified in the SpeculativeConfig.
target_worker
.
model_runner
.
disable_logprobs
=
\
speculative_config
.
disable_logprobs
draft_worker_config
=
copy
.
deepcopy
(
vllm_config
)
draft_worker_config
.
model_config
=
speculative_config
.
draft_model_config
draft_worker_config
.
quant_config
=
VllmConfig
.
_get_quantization_config
(
draft_worker_config
.
model_config
,
vllm_config
.
load_config
,
)
speculative_config
.
draft_parallel_config
.
worker_cls
=
\
draft_worker_config
.
parallel_config
.
sd_worker_cls
draft_worker_config
.
parallel_config
=
speculative_config
.
draft_parallel_config
# noqa
# TODO allow draft-model specific load config.
# Override draft-model specific worker args.
draft_worker_kwargs
.
update
(
vllm_config
=
draft_worker_config
,
ngram_prompt_lookup_max
=
speculative_config
.
prompt_lookup_max
,
ngram_prompt_lookup_min
=
speculative_config
.
prompt_lookup_min
,
)
spec_decode_worker
=
SpecDecodeWorker
.
create_worker
(
scorer_worker
=
target_worker
,
draft_worker_kwargs
=
draft_worker_kwargs
,
disable_mqa_scorer
=
speculative_config
.
disable_mqa_scorer
,
disable_by_batch_size
=
speculative_config
.
disable_by_batch_size
,
draft_token_acceptance_method
=
speculative_config
.
acceptance_method
,
typical_acceptance_sampler_posterior_threshold
=
speculative_config
.
posterior_threshold
,
typical_acceptance_sampler_posterior_alpha
=
speculative_config
.
posterior_alpha
,
disable_logprobs
=
speculative_config
.
disable_logprobs
,
disable_log_stats
=
speculative_config
.
disable_log_stats
,
num_speculative_tokens
=
speculative_config
.
num_speculative_tokens
,
)
return
spec_decode_worker
# Reminder: Please update docs/features/compatibility_matrix.md
# If the feature combo become valid
class
SpecDecodeWorker
(
LoRANotSupportedWorkerBase
):
"""Worker which implements speculative decoding.
Speculative decoding reduces decoding per-token latency by using a proposal
method, such as a small draft model, to speculate ahead of a larger LLM. The
probabilities of the speculative tokens are then determined by the larger
LLM, after which some verification routine determines which (if any) of the
speculative tokens are accepted by the larger LLM.
See https://github.com/vllm-project/vllm/pull/2188 and
https://github.com/vllm-project/vllm/pull/3103 for more info.
The current implementation has the following limitations:
* Only draft-model proposal is implemented (contributions for more forms are
welcome!).
* Only top-1 proposal and scoring are implemented. Tree-attention is left as
future work.
* All sequences in a batch must have the same proposal length, or zero. This
can be improved by having per-sequence speculation in the future.
* The scoring forward pass is done without an MQA kernel, which is
suboptimal especially as the batch size, proposal length, and sequence
lengths grow. Contributions to add a MQA scoring are welcome once
correctness tests pass.
More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
"""
@
classmethod
def
create_worker
(
cls
,
scorer_worker
:
WorkerBase
,
draft_worker_kwargs
:
Dict
[
str
,
Any
],
disable_mqa_scorer
:
bool
,
disable_by_batch_size
:
Optional
[
int
],
draft_token_acceptance_method
:
str
,
typical_acceptance_sampler_posterior_threshold
:
float
,
typical_acceptance_sampler_posterior_alpha
:
float
,
disable_logprobs
:
bool
,
disable_log_stats
:
bool
,
num_speculative_tokens
:
int
,
)
->
"SpecDecodeWorker"
:
allow_zero_draft_token_step
=
True
enable_lm_head_weight_load
=
False
num_spec_prefill_steps
=
1
ngram_prompt_lookup_max
=
(
draft_worker_kwargs
.
pop
(
"ngram_prompt_lookup_max"
))
ngram_prompt_lookup_min
=
(
draft_worker_kwargs
.
pop
(
"ngram_prompt_lookup_min"
))
draft_model_config
=
draft_worker_kwargs
[
"vllm_config"
].
model_config
draft_parallel_config
:
ParallelConfig
=
draft_worker_kwargs
[
'vllm_config'
].
parallel_config
if
ngram_prompt_lookup_max
>
0
:
draft_worker_kwargs
[
"device_type"
]
=
scorer_worker
.
device_config
.
device
.
type
proposer_worker
=
NGramWorker
(
**
draft_worker_kwargs
)
proposer_worker
.
set_ngram_window_size
(
ngram_prompt_lookup_min
,
ngram_prompt_lookup_max
)
else
:
draft_tp
=
draft_parallel_config
.
tensor_parallel_size
target_tp
=
scorer_worker
.
parallel_config
.
tensor_parallel_size
if
draft_model_config
.
hf_config
.
model_type
==
"mlp_speculator"
:
proposer_worker
=
MLPSpeculatorWorker
(
**
draft_worker_kwargs
)
elif
draft_model_config
.
hf_config
.
model_type
==
"medusa"
:
proposer_worker
=
MedusaWorker
(
**
draft_worker_kwargs
)
else
:
if
draft_tp
==
1
:
if
current_platform
.
is_cuda_alike
():
draft_worker_kwargs
[
"model_runner_cls"
]
=
TP1DraftModelRunner
else
:
if
draft_model_config
.
hf_config
.
model_type
==
"eagle"
:
raise
NotImplementedError
(
f
"
{
draft_model_config
.
hf_config
.
model_type
}
"
"does not support TP > 1 yet"
)
allow_zero_draft_token_step
=
False
# Load lm_head weight for eagle in init_device
if
draft_model_config
.
hf_config
.
model_type
==
"eagle"
:
enable_lm_head_weight_load
=
True
proposer_worker
=
MultiStepWorker
(
**
draft_worker_kwargs
)
if
draft_model_config
.
hf_config
.
model_type
==
"deepseek_mtp"
:
num_spec_prefill_steps
=
\
draft_model_config
.
hf_config
.
n_predict
proposer_worker
=
SmallerTpProposerWorker
.
maybe_wrap_worker
(
proposer_worker
,
draft_tp
,
target_tp
)
logger
.
info
(
"Configuring SpecDecodeWorker with proposer=%s"
,
type
(
proposer_worker
))
spec_decode_sampler
:
SpecDecodeBaseSampler
=
None
if
draft_token_acceptance_method
==
"rejection_sampler"
:
spec_decode_sampler
=
RejectionSampler
()
elif
draft_token_acceptance_method
==
"typical_acceptance_sampler"
:
spec_decode_sampler
=
TypicalAcceptanceSampler
(
posterior_threshold
=
\
typical_acceptance_sampler_posterior_threshold
,
posterior_alpha
=
typical_acceptance_sampler_posterior_alpha
,
)
logger
.
info
(
"[Speculative Decoding] Configuring"
" SpecDecodeWorker with sampler=%s"
,
type
(
spec_decode_sampler
))
if
not
disable_mqa_scorer
:
if
scorer_worker
.
model_runner
.
attn_backend
.
get_name
(
)
!=
"FLASH_ATTN"
:
disable_mqa_scorer
=
True
logger
.
info
(
"[Speculative Decoding] Disabling MQA scorer as the "
"MQA is only available with flash attn backend."
)
if
draft_model_config
and
\
draft_model_config
.
max_model_len
<
\
scorer_worker
.
model_config
.
max_model_len
:
disable_mqa_scorer
=
True
logger
.
info
(
"[Speculative Decoding] Disabling MQA scorer as the "
"draft model max_model_len is smaller than the target "
"model max_model_len."
)
if
not
scorer_worker
.
model_runner
.
model_config
.
enforce_eager
:
disable_mqa_scorer
=
True
logger
.
info
(
"[Speculative Decoding] Disabling MQA scorer as the "
"target model is not running in eager mode."
)
return
SpecDecodeWorker
(
proposer_worker
,
scorer_worker
,
disable_mqa_scorer
=
disable_mqa_scorer
,
disable_logprobs
=
disable_logprobs
,
disable_log_stats
=
disable_log_stats
,
disable_by_batch_size
=
disable_by_batch_size
,
spec_decode_sampler
=
spec_decode_sampler
,
allow_zero_draft_token_step
=
allow_zero_draft_token_step
,
enable_lm_head_weight_load
=
enable_lm_head_weight_load
,
num_spec_prefill_steps
=
num_spec_prefill_steps
)
def
__init__
(
self
,
proposer_worker
:
ProposerWorkerBase
,
scorer_worker
:
WorkerBase
,
spec_decode_sampler
:
SpecDecodeBaseSampler
,
disable_mqa_scorer
:
bool
=
False
,
disable_logprobs
:
bool
=
False
,
disable_log_stats
:
bool
=
False
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
disable_by_batch_size
:
Optional
[
int
]
=
None
,
allow_zero_draft_token_step
:
Optional
[
bool
]
=
True
,
enable_lm_head_weight_load
:
Optional
[
bool
]
=
False
,
num_spec_prefill_steps
:
int
=
1
,
):
"""
Create a SpecDecodeWorker.
Args:
proposer_worker: A worker that can produce speculative tokens for
sequences.
scorer_worker: A worker that produces probabilities of speculative
tokens according to some base model. Typically a vanilla vLLM
Worker.
spec_decode_sampler: A Torch module used to perform acceptance
sampling of the draft tokens in the verification step of
speculative decoding. Currently we support two different
types of sampler namely RejectionSampler and
TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
instance of RejectionSampler or TypicalAcceptanceSampler.
disable_mqa_scorer: If set to True, disable the MQA scorer and use
the BatchExpansionTop1Scorer instead.
disable_logprobs: If set to True, token log probabilities will
not be output in both the draft worker and the target worker.
If set to False, log probabilities will be output by both.
disable_log_stats: If set to True, disable periodic printing of
speculative stage times.
disable_by_batch_size: If the batch size is larger than this,
disable speculative decoding for new incoming requests.
metrics_collector: Helper class for collecting metrics; can be set
for testing purposes.
allow_zero_draft_token_step: whether to allow a step where the draft
model generates no draft token; should disallow when the tp of
draft model is larger than 1 (TODO: #5814)
enable_lm_head_weight_load: whether to load lm_head weight for
draft models like eagle.
num_spec_prefill_steps: number of speculative prefill steps to run
before the speculative decoding starts. This is only used when
the draft model is a deepseek_mtp model that requires prefill
kv cache separately for each MTP layer.
"""
self
.
proposer_worker
=
proposer_worker
self
.
scorer_worker
=
scorer_worker
scorer_runner
=
getattr
(
self
.
scorer_worker
,
"model_runner"
,
None
)
self
.
generators
=
scorer_runner
.
get_generators
(
)
if
scorer_runner
else
None
self
.
disable_by_batch_size
=
disable_by_batch_size
or
float
(
"inf"
)
self
.
spec_decode_sampler
=
spec_decode_sampler
self
.
_allow_zero_draft_token_step
=
allow_zero_draft_token_step
self
.
_enable_lm_head_weight_load
=
enable_lm_head_weight_load
self
.
_metrics
=
AsyncMetricsCollector
(
self
.
spec_decode_sampler
)
if
metrics_collector
is
None
else
metrics_collector
# Tracks the sequence IDs that received a bonus token ID in
# their last forward pass. Needed only if KV cache is being
# used for token generation such as in the case of MultiStepWorker.
self
.
_seq_with_bonus_token_in_last_step
:
Set
[
int
]
=
set
()
# Tracks the currently active request ids and the sequence IDs
# corresponding to them
self
.
_request_id_seq_id_mapping
:
Dict
[
str
,
Set
[
int
]]
=
defaultdict
(
set
)
# Tracks if the proposer worker uses the KV cache or not.
self
.
probs_dtype
=
self
.
spec_decode_sampler
.
probs_dtype
self
.
token_id_dtype
=
self
.
spec_decode_sampler
.
token_id_dtype
# Lazy initialization.
self
.
scorer
:
SpeculativeScorer
self
.
disable_mqa_scorer
=
disable_mqa_scorer
# Hidden states from target model to pass to proposer
# in the subsequent step.
self
.
previous_hidden_states
:
Optional
[
HiddenStates
]
=
None
self
.
_disable_logprobs
=
disable_logprobs
self
.
_disable_log_stats
=
disable_log_stats
self
.
_num_spec_prefill_steps
=
num_spec_prefill_steps
def
init_device
(
self
)
->
None
:
"""Initialize both scorer and proposer models.
"""
# The scorer worker model is initialized first in case the proposer
# model has a smaller TP degree than the target worker.
self
.
scorer_worker
.
init_device
()
self
.
proposer_worker
.
init_device
()
# NOTE(cade): load_model is not part of the WorkerBase interface.
self
.
scorer_worker
.
load_model
()
self
.
proposer_worker
.
load_model
()
if
self
.
_enable_lm_head_weight_load
:
# NOTE(Shangming): gather lm_head weight when tp enabled
target_lm_head_weight
:
torch
.
Tensor
=
tensor_model_parallel_gather
(
self
.
scorer_worker
.
model_runner
.
model_runner
.
model
.
lm_head
.
\
weight
.
data
,
dim
=
0
,
)
self
.
proposer_worker
.
maybe_load_lm_head_weight
(
target_lm_head_weight
)
self
.
_metrics
.
init_tensors
(
self
.
rank
,
device_type
=
self
.
device
)
if
model_parallel_is_initialized
():
self
.
spec_decode_sampler
.
init_tensors
(
get_tp_group
().
local_rank
,
device_type
=
self
.
device
)
else
:
self
.
spec_decode_sampler
.
init_tensors
(
self
.
rank
,
device_type
=
self
.
device
)
scorer_cls
:
Type
[
SpeculativeScorer
]
if
self
.
disable_mqa_scorer
:
scorer_cls
=
BatchExpansionTop1Scorer
logger
.
info
(
"[Speculative Decoding] Use batch "
"expansion for scoring proposals."
)
else
:
scorer_cls
=
MQAScorer
logger
.
info
(
"[Speculative Decoding] Use MQA scorer for scoring proposals."
)
self
.
scorer
=
scorer_cls
(
scorer_worker
=
self
.
scorer_worker
,
device
=
self
.
device
,
vocab_size
=
self
.
_vocab_size
)
self
.
_configure_model_sampler_for_spec_decode
()
def
load_model
(
self
,
*
args
,
**
kwargs
):
pass
def
_configure_model_sampler_for_spec_decode
(
self
):
"""Configure model sampler to emit GPU tensors. This allows spec decode
to keep data on device without transferring to CPU and serializing,
which significantly reduces overhead of sampling during verification.
NOTE(cade): This breaks abstraction boundaries pretty badly. The better
design is to have the "move to CPU and serialize" sampling decision be
done outside of the model/sampler; this way the "last-mile" worker
object which interfaces with the scheduler can serialize and incur the
performance hit as necessary. This allows us to run the worker several
iterations in a row without incurring the "move to CPU and serialize"
performance penalty.
Since this requires a large change to vLLM, we defer it to later and
temporarily accept this broken abstraction boundary.
NOTE(cade): This will require a special check if the proposer worker
does not have a sampler (e.g. ngram speculation).
"""
(
self
.
scorer_worker
.
model_runner
.
sampler
.
include_gpu_probs_tensor
)
=
True
(
self
.
scorer_worker
.
model_runner
.
sampler
.
should_modify_greedy_probs_inplace
)
=
True
self
.
proposer_worker
.
set_include_gpu_probs_tensor
()
self
.
proposer_worker
.
set_should_modify_greedy_probs_inplace
()
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Determine the number of cache blocks to use.
This is done by profiling the scorer model (which is typically the
larger of the two). Then the total memory which would be used by the
scorer cache is divided evenly between the proposer and scorer model KV,
such that the number of blocks is equal in both KV caches.
"""
num_gpu_blocks
,
num_cpu_blocks
=
(
self
.
scorer_worker
.
determine_num_available_blocks
())
scorer_cache_block_size_bytes
=
(
self
.
scorer_worker
.
get_cache_block_size_bytes
())
proposer_cache_block_size_bytes
=
(
self
.
proposer_worker
.
get_cache_block_size_bytes
())
new_num_gpu_blocks
=
split_num_cache_blocks_evenly
(
scorer_cache_block_size_bytes
,
proposer_cache_block_size_bytes
,
num_gpu_blocks
)
return
new_num_gpu_blocks
,
num_cpu_blocks
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
"""Initialize the cache engine of the scorer and proposer workers.
"""
self
.
scorer_worker
.
initialize_cache
(
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
num_cpu_blocks
)
self
.
proposer_worker
.
initialize_cache
(
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
num_cpu_blocks
)
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
scorer_worker
.
get_model
()
@
torch
.
inference_mode
()
def
execute_model
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
List
[
SamplerOutput
]:
"""Perform speculative decoding on the input batch.
"""
if
self
.
rank
!=
self
.
_driver_rank
:
self
.
_run_non_driver_rank
()
return
[]
if
execute_model_req
is
None
:
# This signals that there's no more requests to process for now.
# All workers are running infinite loop with broadcast_tensor_dict,
# and it stops the loop when the driver broadcasts an empty input.
# Send an empty input to notify all other workers to stop their
# execution loop.
broadcast_tensor_dict
({},
src
=
0
)
return
[]
self
.
_track_finished_requests
(
execute_model_req
)
disable_all_speculation
=
self
.
_should_disable_all_speculation
(
execute_model_req
)
num_lookahead_slots
=
execute_model_req
.
num_lookahead_slots
all_prompt
=
True
atleast_one_prompt
=
False
all_zero_spec_tokens
=
True
for
sgm
in
execute_model_req
.
seq_group_metadata_list
:
all_prompt
=
all_prompt
and
sgm
.
is_prompt
atleast_one_prompt
=
atleast_one_prompt
or
sgm
.
is_prompt
all_zero_spec_tokens
=
all_zero_spec_tokens
and
(
sgm
.
num_speculative_tokens
==
0
)
if
all_prompt
and
execute_model_req
.
seq_group_metadata_list
:
assert
num_lookahead_slots
==
0
,
(
"Prompt only runs should have num_lookahead_slots equal to 0. "
"This should never happen, please file a bug at "
"https://github.com/vllm-project/vllm/issues"
)
# Speculative decoding is disabled in the following cases:
# 1. Prefill phase: Speculative decoding is not
# used during the prefill phase.
# 2. Auto-disable enabled: The running queue size exceeds
# the specified threshold.
# 3. No request: There are no requests in the batch, or
# none of the requests in the batch have spec decoding enabled.
# In any of these cases, the proposer and scorer workers
# are called normally.
# We expect `num_speculative_tokens` to be None for prefills.
no_spec
=
(
num_lookahead_slots
==
0
or
disable_all_speculation
or
all_zero_spec_tokens
)
# Broadcast how many lookahead slots are scheduled for this step, and
# whether all speculation is disabled, to all non-driver workers.
# This is required as if the number of draft model runs changes
# dynamically, the non-driver workers won't know unless we perform a
# communication to inform them.
# no_spec is used to signal non-driver worker about prefill vs decode
# stage. This is needed to ensure that order of execution of proposer
# and scorer is same in both driver and non-driver workers (i.e.,
# scorer -> proposer for prefill and proposer -> scorer in decode). This
# order is needed to support models like EAGLE that take scorer states
# as inputs.
broadcast_dict
=
dict
(
num_lookahead_slots
=
num_lookahead_slots
,
no_spec
=
no_spec
,
disable_all_speculation
=
disable_all_speculation
,
# When both chunked prefill and speculative decoding are enabled
# it is possible that the same batch contains both prefill
# and decodes. If that happens in the scorer we run the batch
# as one single forward pass. However, in the proposer we
# run them as 2 different batches - one for prefill and
# the other for decodes. The variable indicates to the non-driver
# worker that there are prefills as part of the speculative batch
# and hence it needs to run an extra prefill forward pass.
run_spec_proposer_for_prefill
=
atleast_one_prompt
,
)
broadcast_tensor_dict
(
broadcast_dict
,
src
=
self
.
_driver_rank
)
assert
execute_model_req
.
seq_group_metadata_list
is
not
None
,
(
"speculative decoding requires non-None seq_group_metadata_list"
)
self
.
_maybe_disable_speculative_tokens
(
disable_all_speculation
,
execute_model_req
.
seq_group_metadata_list
)
if
no_spec
:
return
self
.
_run_no_spec
(
execute_model_req
,
skip_proposer
=
disable_all_speculation
)
return
self
.
_run_speculative_decoding_step
(
execute_model_req
,
num_lookahead_slots
)
@
torch
.
inference_mode
()
def
start_worker_execution_loop
(
self
)
->
None
:
"""Execute model loop to perform speculative decoding
in parallel worker."""
while
self
.
_run_non_driver_rank
():
pass
def
_should_disable_all_speculation
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
bool
:
# When the batch size is too large, disable speculative decoding
# to stop trading off throughput for latency.
return
(
execute_model_req
.
running_queue_size
>=
self
.
disable_by_batch_size
)
def
_maybe_disable_speculative_tokens
(
self
,
disable_all_speculation
:
bool
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
])
->
None
:
if
not
disable_all_speculation
:
return
for
seq_group_metadata
in
seq_group_metadata_list
:
# Once num_speculative_tokens is set to 0, the spec decode
# of this request will be disabled forever.
# TODO(comaniac): We currently store spec decoding specific
# state in the global data structure, but we should maintain
# this state within spec decode worker.
seq_group_metadata
.
num_speculative_tokens
=
0
def
_serialize_sampler_output_no_logprobs
(
self
,
execute_model_req
:
ExecuteModelRequest
,
sampler_output
:
SamplerOutput
)
->
List
[
SamplerOutput
]:
"""
Creates and returns a `SamplerOutput` with only the token IDs being
serialized to CPU and populated in `CompletionSequenceGroupOutput`.
All other parameters in `CompletionSequenceGroupOutput` related to log
probabilities are skipped.
Args:
execute_model_req (ExecuteModelRequest): The model request that
was executed.
sampler_output (SamplerOutput): The output from the sampler with
only GPU tensors populated.
Returns:
SamplerOutput: A new `SamplerOutput` instance containing a list of
`CompletionSequenceGroupOutput` objects with only token IDs
populated.
"""
seq_output_prompt_logprobs
=
[
seq
.
is_prompt
and
seq
.
sampling_params
.
prompt_logprobs
is
not
None
and
seq
.
sampling_params
.
prompt_logprobs
>
0
for
seq
in
execute_model_req
.
seq_group_metadata_list
]
# ignore slots for prompt tokens that are filled with INVALID_TOKEN_ID
sampled_token_ids_list
=
(
sampler_output
.
sampled_token_ids
[
torch
.
where
(
# subtracting is faster than testing for equality
sampler_output
.
sampled_token_ids
-
VLLM_INVALID_TOKEN_ID
)[
0
]]
\
if
any
(
seq_output_prompt_logprobs
)
else
\
sampler_output
.
sampled_token_ids
).
tolist
()
seq_data_entries
=
[
(
seq_id
,
seq_data
)
for
sg
in
\
execute_model_req
.
seq_group_metadata_list
\
for
seq_id
,
seq_data
in
sg
.
seq_data
.
items
()
]
completion_seq_group_output_list
:
List
[
CompletionSequenceGroupOutput
]
=
[]
output_index
=
0
# Make sure the non-terminal prefill chunks are still aligned with
# their own empty output.
for
idx
,
seq_group_meta
in
enumerate
(
execute_model_req
.
seq_group_metadata_list
):
needs_prompt_logprobs
=
seq_output_prompt_logprobs
[
idx
]
seq_id
,
seq_data
=
seq_data_entries
[
idx
]
if
needs_prompt_logprobs
:
prompt_token_ids
=
seq_data
.
get_prompt_token_ids
()
# Some of these sequences may belong to non-terminal chunks,
# which may still have to report logprobs for prompts.
start
=
1
if
seq_data
.
_num_computed_tokens
==
0
\
else
seq_data
.
_num_computed_tokens
end
=
(
seq_data
.
_num_computed_tokens
+
\
seq_group_meta
.
token_chunk_size
)
prompt_token_ids
=
prompt_token_ids
[
start
:
end
]
prompt_logprobs
=
[
create_logprobs_output
(
token_id
=
p_token_id
,
token_id_logprob_rank
=-
1
,
token_id_logprob
=
0.0
,
topk_token_ids
=
[],
topk_logprobs
=
[],
)
for
p_token_id
in
prompt_token_ids
]
else
:
prompt_logprobs
=
None
# Since we can get chunks here, we dont always have a sampled token
# (only on last chunk) but we still have to provide an output.
if
not
seq_group_meta
.
do_sample
:
completion_seq_group_output_list
.
append
(
CompletionSequenceGroupOutput
(
samples
=
[],
prompt_logprobs
=
prompt_logprobs
))
continue
# Sequence with output.
completion_seq_group_output_list
.
append
(
create_sequence_group_output
(
token_id
=
sampled_token_ids_list
[
output_index
][
0
],
token_id_logprob_rank
=-
1
,
token_id_logprob
=
0.0
,
seq_id
=
seq_id
,
topk_token_ids
=
[],
topk_logprobs
=
[],
prompt_logprobs
=
prompt_logprobs
))
output_index
+=
1
return
[
SamplerOutput
(
outputs
=
completion_seq_group_output_list
)]
@
nvtx_range
(
"spec_decode_worker._run_no_spec"
)
def
_run_no_spec
(
self
,
execute_model_req
:
ExecuteModelRequest
,
skip_proposer
:
bool
)
->
List
[
SamplerOutput
]:
"""Run a single generation step without any speculation. The input is
sent to the proposer and scorer model so that the KV cache is consistent
between the two. When skip_proposer is True, the proposer model is
not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding.
"""
sampler_output
=
self
.
scorer_worker
.
execute_model
(
execute_model_req
)
assert
len
(
sampler_output
)
==
1
sampler_output
=
sampler_output
[
0
]
# Store hidden states from target model execution, BxD.
hidden_states
=
sampler_output
.
hidden_states
if
hidden_states
is
not
None
:
# Only decodes and prefill terminal chunks need a hidden state.
seq_group_meta_with_hidden
=
[
sg
for
sg
in
execute_model_req
.
seq_group_metadata_list
if
sg
.
do_sample
]
if
any
(
seq
.
is_prompt
for
seq
in
seq_group_meta_with_hidden
):
# Drop hidden_states with no prediction (eg non-terminal chunks)
hidden_states
=
hidden_states
[
torch
.
where
(
sampler_output
.
sampled_token_ids
-
VLLM_INVALID_TOKEN_ID
)[
0
]]
if
self
.
previous_hidden_states
is
None
and
len
(
seq_group_meta_with_hidden
):
self
.
previous_hidden_states
=
HiddenStates
(
hidden_states
,
seq_group_meta_with_hidden
)
elif
self
.
previous_hidden_states
and
len
(
seq_group_meta_with_hidden
):
self
.
previous_hidden_states
.
update
(
hidden_states
,
seq_group_meta_with_hidden
)
self
.
previous_hidden_states
.
prune
(
seq_group_meta_with_hidden
)
if
not
skip_proposer
:
# We prepare the prefill hidden states here so that there no
# additional complexity in worker for spec_decode vs non_spec_decode
# flow and execute_model doesn't need additional modifications.
execute_model_req
.
previous_hidden_states
=
\
prepare_prefill_hidden_states
(
sampler_output
.
prefill_hidden_states
)
for
i
in
range
(
self
.
_num_spec_prefill_steps
):
execute_model_req
.
spec_step_idx
=
i
self
.
proposer_worker
.
execute_model
(
execute_model_req
)
sampler_output_to_return
=
(
self
.
_serialize_sampler_output_no_logprobs
(
execute_model_req
=
execute_model_req
,
sampler_output
=
sampler_output
)
if
self
.
_disable_logprobs
else
[
sampler_output
])
# Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers.
sampler_output
.
sampled_token_probs
=
None
sampler_output
.
sampled_token_ids
=
None
sampler_output
.
logprobs
=
None
return
sampler_output_to_return
def
_run_non_driver_rank
(
self
)
->
bool
:
"""Run proposer and verifier model in non-driver workers. This is used
for both speculation cases (num_lookahead_slots>0) and non-speculation
cases (e.g. prefill).
Returns True if there are remaining sequences to process.
"""
assert
self
.
rank
!=
self
.
_driver_rank
data
=
broadcast_tensor_dict
(
src
=
self
.
_driver_rank
)
if
not
data
:
return
False
num_lookahead_slots
=
data
[
"num_lookahead_slots"
]
# In case of prefill, scorer_worker has to be run before proposer so
# that the hidden states can be propagated to proposer when needed.
if
data
[
"no_spec"
]:
self
.
scorer_worker
.
execute_model
()
if
not
data
[
"disable_all_speculation"
]:
# Even if num_lookahead_slots is zero, we want to run the
# proposer model as it may have KV.
#
# We run the proposer once per lookahead slot. In the future we
# should delegate how many times it runs to the proposer.
for
_
in
range
(
max
(
num_lookahead_slots
,
1
)):
self
.
proposer_worker
.
execute_model
()
if
not
data
[
"no_spec"
]:
self
.
scorer_worker
.
execute_model
()
if
data
[
"run_spec_proposer_for_prefill"
]:
self
.
proposer_worker
.
execute_model
()
return
True
@
nvtx_range
(
"spec_decode_worker._run_speculative_decoding_step"
)
def
_run_speculative_decoding_step
(
self
,
execute_model_req
:
ExecuteModelRequest
,
num_lookahead_slots
:
int
)
->
List
[
SamplerOutput
]:
"""Execute a single step of speculative decoding.
This invokes the proposer worker to get k speculative tokens for each
sequence, then scores each speculative token using the scoring worker.
When `enable_chunked_prefill` is set, scorer will batch decodes and
prefills, while proposer will sync its KV-cache by running an extra
forward on prefills.
Returns a list of SamplerOutput, each containing a single token per
sequence.
"""
# With prefill chunking, expect requests to have prompts first
# so that backend gets prefill|decode.
assert
num_lookahead_slots
==
execute_model_req
.
num_lookahead_slots
# Pass last hidden states from target model to proposer
execute_model_req
.
previous_hidden_states
=
self
.
previous_hidden_states
self
.
previous_hidden_states
=
None
with
Timer
()
as
proposal_timer
:
# Generate proposals using draft worker.
proposals
=
self
.
proposer_worker
.
get_spec_proposals
(
execute_model_req
,
self
.
_seq_with_bonus_token_in_last_step
)
if
not
self
.
_allow_zero_draft_token_step
and
proposals
.
no_proposals
:
#TODO: Fix it #5814
raise
RuntimeError
(
"Cannot handle cases where distributed draft "
"workers generate no tokens"
)
execute_model_req
.
previous_hidden_states
=
None
with
Timer
()
as
scoring_timer
:
proposal_scores
=
self
.
scorer
.
score_proposals
(
execute_model_req
,
proposals
,
)
_
,
(
non_spec_seqs
,
non_spec_indices
)
=
split_batch_by_proposal_len
(
execute_model_req
.
seq_group_metadata_list
,
proposals
.
proposal_lens
)
# With prefill chunking enabled, `non_spec_seqs` contains prefills too:
# discard decodes that have already been processed by proposer.
non_spec_indices
=
[
idx
for
idx
in
non_spec_indices
if
execute_model_req
.
seq_group_metadata_list
[
idx
].
is_prompt
]
if
len
(
non_spec_indices
):
all_hidden_states
=
proposal_scores
.
hidden_states
if
all_hidden_states
is
not
None
:
prefill_hidden_states
=
all_hidden_states
[
non_spec_indices
]
execute_model_req
.
previous_hidden_states
=
\
prepare_prefill_hidden_states
(
prefill_hidden_states
)
# Sync proposer KV cache for prefills.
prefill_req
=
execute_model_req
.
clone
(
non_spec_seqs
)
# TODO avoid sampling here?
self
.
proposer_worker
.
execute_model
(
prefill_req
)
with
Timer
()
as
verification_timer
:
accepted_token_ids
,
target_logprobs
=
self
.
_verify_tokens
(
execute_model_req
.
seq_group_metadata_list
,
proposal_scores
,
proposals
,
execute_model_req
.
num_lookahead_slots
)
stage_times
=
(
proposal_timer
.
elapsed_time_ms
/
num_lookahead_slots
,
scoring_timer
.
elapsed_time_ms
,
verification_timer
.
elapsed_time_ms
)
return
self
.
_create_output_sampler_list
(
execute_model_req
.
seq_group_metadata_list
,
accepted_token_ids
,
target_logprobs
=
target_logprobs
,
prompt_logprobs
=
proposal_scores
.
prompt_logprobs
if
not
self
.
_disable_logprobs
else
None
,
k
=
execute_model_req
.
num_lookahead_slots
,
stage_times
=
stage_times
)
@
nvtx_range
(
"spec_decode_worker._verify_tokens"
)
def
_verify_tokens
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
proposal_scores
:
SpeculativeScores
,
proposals
:
SpeculativeProposals
,
max_proposal_len
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Determine which speculative tokens are accepted using the
probabilities of each token according to the proposer and scorer models.
Returns a tuple of Tensors, one for the accepted token ids and one for
the logprobs according to the scoring model.
"""
proposal_lens_list
=
proposals
.
proposal_lens
.
tolist
()
# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
(
_
,
spec_indices
),
(
_
,
non_spec_indices
)
=
split_batch_by_proposal_len
(
seq_group_metadata_list
,
proposal_lens_list
)
original_indices
=
spec_indices
+
non_spec_indices
# Get probabilities of target model, including bonus tokens.
proposal_verifier_probs
=
proposal_scores
.
probs
[
spec_indices
]
# Get non-speculative sampled tokens from target model.
non_spec_token_ids
=
proposal_scores
.
token_ids
[
non_spec_indices
]
# Get bonus tokens from target model.
bonus_token_ids
=
proposal_scores
.
token_ids
[
spec_indices
,
-
1
:]
# Get probabilities according to proposal method.
proposal_probs
=
proposals
.
proposal_probs
[
spec_indices
]
# Get proposed tokens.
proposal_token_ids
=
proposals
.
proposal_token_ids
[
spec_indices
]
# Sampler arguments
sampler_extra_kwargs
:
Dict
[
str
,
Any
]
=
{}
if
self
.
generators
and
isinstance
(
self
.
spec_decode_sampler
,
SpecDecodeStochasticBaseSampler
):
sampler_extra_kwargs
[
"seeded_seqs"
]
=
{
idx
:
self
.
generators
[
sgm
.
request_id
]
for
idx
,
sgm
in
enumerate
(
seq_group_metadata_list
)
if
sgm
.
sampling_params
.
seed
is
not
None
}
accepted_token_ids
=
self
.
spec_decode_sampler
(
target_with_bonus_probs
=
proposal_verifier_probs
,
bonus_token_ids
=
bonus_token_ids
,
draft_probs
=
proposal_probs
,
draft_token_ids
=
proposal_token_ids
,
**
sampler_extra_kwargs
,
)
# Append output tokens from non-speculative sequences to
# the accepted token ids tensor.
non_spec_token_ids
=
non_spec_token_ids
.
expand
(
-
1
,
max_proposal_len
+
1
).
clone
()
non_spec_token_ids
[:,
1
:]
=
-
1
accepted_token_ids
=
torch
.
cat
(
[
accepted_token_ids
,
non_spec_token_ids
])
logprobs
=
proposal_scores
.
logprobs
# Rearrange so that results are in the order of the original seq group
# metadata.
accepted_token_ids
[
original_indices
]
=
accepted_token_ids
.
clone
()
# B x K+1 x D
hidden_states
=
proposal_scores
.
hidden_states
if
hidden_states
is
not
None
:
# Only get terminal hidden states for next step
terminal_metadata
=
[
sg
for
sg
in
seq_group_metadata_list
if
sg
.
do_sample
]
# Contract hidden states based on accepted tokens
hs_size
=
hidden_states
.
shape
[
-
1
]
accepted_index
=
accepted_token_ids
+
1
# Convert -1 to 0
accepted_index
=
accepted_index
.
count_nonzero
(
dim
=
1
).
add_
(
-
1
)
# b
# Drop non-terminal prefill chunks hidden states.
hidden_states
=
hidden_states
[
accepted_index
!=
VLLM_INVALID_TOKEN_ID
]
accepted_index
=
accepted_index
[
accepted_index
!=
VLLM_INVALID_TOKEN_ID
]
assert
len
(
accepted_index
)
==
hidden_states
.
shape
[
0
]
==
len
(
terminal_metadata
)
index
=
accepted_index
[:,
None
,
None
].
expand
(
-
1
,
1
,
hs_size
)
# b x 1 x d
second_last_token_hidden_states
=
hidden_states
[:,
-
2
]
# b x d
hidden_states
=
hidden_states
.
gather
(
1
,
index
).
squeeze
(
1
)
# b x d
# Store hidden states from target model for subsequent decode step
self
.
previous_hidden_states
=
HiddenStates
(
hidden_states
,
terminal_metadata
,
second_last_token_hidden_states
)
return
accepted_token_ids
,
logprobs
def
_create_output_sampler_list
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
accepted_token_ids
:
torch
.
Tensor
,
# shape: [batch_size, k+1]
target_logprobs
:
torch
.
Tensor
,
# shape: [batch_size, k+1, vocab_size]
prompt_logprobs
:
Optional
[
torch
.
Tensor
],
# shape: [nprompt_tokens, vocab_size]
k
:
int
,
stage_times
:
Tuple
[
float
,
float
,
float
],
)
->
List
[
SamplerOutput
]:
"""Given the accepted token ids, create a list of SamplerOutput.
The output is padded with -1 tokens such that each sequence has
the same number of outputs.
"""
batch_size
,
num_steps
=
accepted_token_ids
.
shape
accepted_token_ids_by_step
=
accepted_token_ids
.
transpose
(
0
,
1
)
if
self
.
_disable_logprobs
:
# We are skipping the logprobs. Hence don't serialize the
# logprobs related tensors from the GPU. Instead create
# empty/dummy lists.
(
accepted_token_id_ranks_by_step
,
accepted_token_id_logprobs_by_step
,
topk_logprobs_by_step
,
topk_indices_by_step
)
=
\
self
.
_create_dummy_logprob_lists
(
batch_size
,
num_steps
,
self
.
scorer_worker
.
model_config
.
max_logprobs
)
else
:
# Organize input tensors by step instead of by sequence.
target_logprobs_by_step
=
target_logprobs
.
transpose
(
0
,
1
)
# Serialize all tensors into Python lists.
(
accepted_token_id_ranks_by_step
,
accepted_token_id_logprobs_by_step
,
topk_logprobs_by_step
,
topk_indices_by_step
)
=
\
self
.
_create_logprob_lists_from_tensors
(
target_logprobs_by_step
,
accepted_token_ids_by_step
,
self
.
scorer_worker
.
model_config
.
max_logprobs
)
# Get the sequence ids and num_logprobs (sampling parameter) in the
# batch.
seq_ids
,
request_ids_seq_ids_mapping
=
get_all_seq_ids_and_request_ids
(
seq_group_metadata_list
)
num_logprobs_per_seq
=
get_all_num_logprobs
(
seq_group_metadata_list
)
# Serialize tensor to CPU Python list.
accepted_token_ids_by_step
=
accepted_token_ids_by_step
.
tolist
()
# Construct the output on a per-step, per-sequence basis.
# Non-terminal prefill chunks will end up here as rows with just -1s
# i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]] while
# terminal chunks will only have one generated token at time 0.
sampler_output_list
:
List
[
SamplerOutput
]
=
[]
# Prefills are not multi-step (return at most 1 token), in order to
# avoid padding or repetition to fit decodes, we separate them.
for
i
,
sg
in
enumerate
(
seq_group_metadata_list
):
if
not
sg
.
is_prompt
:
# Requests are ordered as prefills|decodes=>no more prefills.
break
num_logprobs
=
num_logprobs_per_seq
[
i
]
seq_kwargs
=
dict
(
token_id
=-
1
,
token_id_logprob_rank
=
0
,
token_id_logprob
=-
float
(
'inf'
),
topk_token_ids
=
[
-
1
]
*
num_logprobs
,
topk_logprobs
=
[
-
float
(
'inf'
)]
*
num_logprobs
,
seq_id
=
seq_ids
[
i
])
# Terminal chunk, has token.
if
sg
.
do_sample
:
seq_kwargs
.
update
(
dict
(
token_id
=
accepted_token_ids
[
i
][
0
].
item
(),
token_id_logprob_rank
=
accepted_token_id_ranks_by_step
[
0
][
i
],
token_id_logprob
=
accepted_token_id_logprobs_by_step
[
0
]
[
i
],
topk_token_ids
=
topk_indices_by_step
[
0
][
i
]
[:
num_logprobs
],
# output only so step is 0
topk_logprobs
=
topk_logprobs_by_step
[
0
][
i
]
[:
num_logprobs
],
))
needs_plogs
=
(
sg
.
sampling_params
.
prompt_logprobs
and
sg
.
sampling_params
.
prompt_logprobs
>
0
)
plogs
=
None
if
prompt_logprobs
is
not
None
:
# Even non-terminal prompt chunks can have logprobs here.
plogs
=
prompt_logprobs
[
i
]
elif
needs_plogs
:
# Prompt logprobs are requested but `_disable_logprobs` is set.
seq_data
=
next
(
iter
(
sg
.
seq_data
.
values
()))
# Get only the tokens in this chunk!
prompt_token_ids
=
seq_data
.
get_prompt_token_ids
()
prompt_token_ids
=
prompt_token_ids
[
seq_data
.
_num_computed_tokens
:
seq_data
.
_num_computed_tokens
+
sg
.
token_chunk_size
]
is_first_chunk
=
seq_data
.
_num_computed_tokens
==
0
# There's no prob generated for the first token in a sequence.
if
is_first_chunk
:
prompt_token_ids
=
prompt_token_ids
[
1
:]
plogs
=
[
create_logprobs_output
(
token_id
=
p_token_id
,
token_id_logprob_rank
=-
1
,
token_id_logprob
=
0.0
,
topk_token_ids
=
[],
topk_logprobs
=
[],
)
for
p_token_id
in
prompt_token_ids
]
seq_kwargs
.
update
(
dict
(
prompt_logprobs
=
plogs
))
sampler_output_list
.
append
(
SamplerOutput
(
outputs
=
[
create_sequence_group_output
(
**
seq_kwargs
)]))
# type: ignore
# Decodes, create one SamplerOutput per-step (at most K+1).
for
step_index
in
range
(
num_steps
):
if
all
(
token_id
==
-
1
for
sg
,
token_id
in
zip
(
seq_group_metadata_list
,
accepted_token_ids_by_step
[
step_index
])
if
not
sg
.
is_prompt
):
break
step_output_token_ids
:
List
[
CompletionSequenceGroupOutput
]
=
[]
for
sequence_index
in
range
(
batch_size
):
seq_meta
=
seq_group_metadata_list
[
sequence_index
]
# Prompts already processed above.
if
seq_meta
.
is_prompt
:
continue
# Each sequence may have a different num_logprobs; retrieve it.
num_logprobs
=
num_logprobs_per_seq
[
sequence_index
]
step_output_token_ids
.
append
(
create_sequence_group_output
(
token_id
=
accepted_token_ids_by_step
[
step_index
]
[
sequence_index
],
token_id_logprob_rank
=
accepted_token_id_ranks_by_step
[
step_index
][
sequence_index
],
token_id_logprob
=
accepted_token_id_logprobs_by_step
[
step_index
][
sequence_index
],
seq_id
=
seq_ids
[
sequence_index
],
topk_token_ids
=
topk_indices_by_step
[
step_index
]
[
sequence_index
][:
num_logprobs
],
topk_logprobs
=
topk_logprobs_by_step
[
step_index
]
[
sequence_index
][:
num_logprobs
],
step_index
=
step_index
))
sampler_output_list
.
append
(
SamplerOutput
(
outputs
=
step_output_token_ids
))
# Populate the data structures needed to keep track of sequences with
# bonus tokens.
self
.
_track_sequences_with_bonus_tokens
(
seq_ids
,
request_ids_seq_ids_mapping
,
accepted_token_ids_by_step
)
maybe_rejsample_metrics
=
(
self
.
_metrics
.
maybe_collect_rejsample_metrics
(
k
))
if
maybe_rejsample_metrics
is
not
None
:
sampler_output_list
[
0
].
spec_decode_worker_metrics
=
maybe_rejsample_metrics
# Log time spent in each stage periodically.
# This is periodic because the rejection sampler emits metrics
# periodically.
self
.
_maybe_log_stage_times
(
*
stage_times
)
# First `n_prefills` entries will contain prefills SamplerOutput when
# chunked prefill is enabled, the rest is decodes in multi-step format.
return
sampler_output_list
def
_maybe_log_stage_times
(
self
,
average_time_per_proposal_tok_ms
:
float
,
scoring_time_ms
:
float
,
verification_time_ms
:
float
)
->
None
:
"""Log the speculative stage times. If stat logging is disabled, do
nothing.
"""
if
self
.
_disable_log_stats
:
return
logger
.
info
(
"SpecDecodeWorker stage times: "
"average_time_per_proposal_tok_ms=%.02f "
"scoring_time_ms=%.02f verification_time_ms=%.02f"
,
average_time_per_proposal_tok_ms
,
scoring_time_ms
,
verification_time_ms
)
def
_create_dummy_logprob_lists
(
self
,
batch_size
:
int
,
num_steps
:
int
,
num_top_k
:
int
,
)
->
Tuple
[
List
[
List
[
int
]],
List
[
List
[
float
]],
List
[
List
[
List
[
Optional
[
float
]]]],
List
[
List
[
List
[
Optional
[
int
]]]]]:
"""
Creates and returns four dummy lists representing token probabilities
and their ranks.
This method initializes and returns:
- The ranks of the accepted tokens, shaped (num_steps, batch_size)
- The log probabilities of the accepted tokens,
shaped (num_steps, batch_size)
- The log probabilities of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
- The token IDs of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
Args:
batch_size (int): The size of the batch.
num_steps (int): The number of steps in the sequence.
num_top_k (int): The number of top-k token log probabilities to
return.
Returns:
A tuple containing four dummy lists as described above.
"""
accepted_token_id_ranks_by_step
=
[[
-
1
]
*
batch_size
for
_
in
range
(
num_steps
)]
accepted_token_id_logprobs_by_step
=
[[
0.0
]
*
batch_size
for
_
in
range
(
num_steps
)]
topk_logprobs_by_step
:
List
[
List
[
List
[
Optional
[
float
]]]]
=
[[
[
None
]
*
num_top_k
for
_
in
range
(
batch_size
)
]
for
_
in
range
(
num_steps
)]
topk_indices_by_step
:
List
[
List
[
List
[
Optional
[
int
]]]]
=
[[
[
None
]
*
num_top_k
for
_
in
range
(
batch_size
)
]
for
_
in
range
(
num_steps
)]
return
(
accepted_token_id_ranks_by_step
,
accepted_token_id_logprobs_by_step
,
topk_logprobs_by_step
,
topk_indices_by_step
)
def
_create_logprob_lists_from_tensors
(
self
,
target_logprobs_by_step
:
torch
.
Tensor
,
accepted_token_ids_by_step
:
torch
.
Tensor
,
num_top_k
:
int
,
)
->
Tuple
[
List
[
List
[
int
]],
List
[
List
[
float
]],
List
[
List
[
List
[
Optional
[
float
]]]],
List
[
List
[
List
[
Optional
[
int
]]]]]:
"""
Creates and returns four lists representing token probabilities and
their ranks.
This method initializes and returns four lists containing:
- The ranks of the accepted tokens, shaped (num_steps, batch_size)
- The log probabilities of the accepted tokens,
shaped (num_steps, batch_size)
- The log probabilities of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
- The token IDs of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
Args:
target_logprobs_by_step (torch.Tensor): Tensor representing the
log probabilities of the target model,
shaped (num_steps, batch_size, vocab_size)
accepted_token_ids_by_step (torch.Tensor): Tensor representing
the accepted token_ids, shaped (num_steps, batch_size)
num_top_k (int): The number of top-k token log probabilities to
return.
Returns:
A tuple containing the lists as described above.
"""
# Serialize all tensors to CPU Python lists.
# Get the logprobs/rank of the accepted tokens.
(
accepted_token_id_ranks_by_step_tensor
,
accepted_token_id_logprobs_by_step_tensor
)
=
get_sampled_token_logprobs
(
logprob_tensor
=
target_logprobs_by_step
,
sampled_token_ids
=
accepted_token_ids_by_step
,
)
# Get the top-k logprobs (which may or may not include the
# logprob of the accepted token).
(
topk_logprobs_by_step_tensor
,
topk_indices_by_step_tensor
)
=
target_logprobs_by_step
.
topk
(
k
=
num_top_k
,
dim
=-
1
,
)
accepted_token_id_ranks_by_step
=
(
accepted_token_id_ranks_by_step_tensor
.
tolist
())
accepted_token_id_logprobs_by_step
=
(
accepted_token_id_logprobs_by_step_tensor
.
tolist
())
topk_logprobs_by_step
=
topk_logprobs_by_step_tensor
.
tolist
()
topk_indices_by_step
=
topk_indices_by_step_tensor
.
tolist
()
return
(
accepted_token_id_ranks_by_step
,
accepted_token_id_logprobs_by_step
,
topk_logprobs_by_step
,
topk_indices_by_step
)
def
_track_finished_requests
(
self
,
execute_model_req
:
ExecuteModelRequest
):
"""
Removes the finished requests and their associated sequence ids from
internal book keeping data structures.
"""
for
finished_request
in
execute_model_req
.
finished_requests_ids
:
for
seq_id
in
self
.
_request_id_seq_id_mapping
[
finished_request
]:
self
.
_seq_with_bonus_token_in_last_step
.
discard
(
seq_id
)
del
self
.
_request_id_seq_id_mapping
[
finished_request
]
def
_track_sequences_with_bonus_tokens
(
self
,
seq_ids
:
List
[
int
],
request_ids_seq_ids_mapping
:
Dict
[
str
,
Set
[
int
]],
accepted_token_ids_by_step
:
List
[
List
[
int
]]):
"""
Updates the internal data structures which keep track of sequences
which have been assigned bonus tokens in their last forward pass.
"""
for
seq_index
,
seq_id
in
enumerate
(
seq_ids
):
last_token_id
=
accepted_token_ids_by_step
[
-
1
][
seq_index
]
if
last_token_id
==
-
1
:
self
.
_seq_with_bonus_token_in_last_step
.
discard
(
seq_id
)
else
:
self
.
_seq_with_bonus_token_in_last_step
.
add
(
seq_id
)
for
request_id
,
sequences
in
request_ids_seq_ids_mapping
.
items
():
self
.
_request_id_seq_id_mapping
[
request_id
].
update
(
sequences
)
@
cached_property
def
_vocab_size
(
self
)
->
int
:
"""Get the vocab size of the model and make sure it's consistent between
draft and target workers.
"""
vocab_sizes
=
[
worker
.
vocab_size
for
worker
in
[
self
.
proposer_worker
,
self
.
scorer_worker
]
]
assert
all
(
vocab_sizes
[
0
]
==
vocab_size
for
vocab_size
in
vocab_sizes
)
return
vocab_sizes
[
0
]
@
property
def
rank
(
self
):
return
self
.
scorer_worker
.
rank
@
property
def
device
(
self
):
return
self
.
scorer_worker
.
device
@
property
def
_driver_rank
(
self
)
->
int
:
return
0
def
get_cache_block_size_bytes
(
self
):
"""Return the size of a cache block in bytes.
This function is only used to compose workers within a SpecDecodeWorker.
We leave composing a SpecDecodeWorker within a SpecDecodeWorker
undefined for now, although it could be implemented in the future.
See https://arxiv.org/abs/2308.04623.
"""
raise
NotImplementedError
def
start_profile
(
self
):
if
isinstance
(
self
.
scorer_worker
,
WorkerBase
):
self
.
scorer_worker
.
start_profile
()
def
stop_profile
(
self
):
if
isinstance
(
self
.
scorer_worker
,
WorkerBase
):
self
.
scorer_worker
.
stop_profile
()
def
split_num_cache_blocks_evenly
(
scorer_cache_block_size_bytes
:
int
,
proposer_cache_block_size_bytes
:
int
,
total_num_gpu_blocks
:
int
)
->
int
:
"""Given total_num_gpu_blocks, the number of GPU blocks that could be
allocate to the target model, this function calculates how many blocks
should be given to the draft and target model.
Note that usually the block size, in bytes, of each model is different,
as it's a function of number of KV/layer, number of heads, and hidden
dimension size.
Since the target and draft models allocate the same number of blocks, we
simply calculate the number of blocks where if allocated by both models,
the total memory usage from KV cache is no larger than the number of
blocks allocatable by the target model alone.
"""
new_num_gpu_blocks
=
int
(
total_num_gpu_blocks
*
scorer_cache_block_size_bytes
/
(
proposer_cache_block_size_bytes
+
scorer_cache_block_size_bytes
))
return
new_num_gpu_blocks
def
prepare_prefill_hidden_states
(
prefill_hidden_states
:
torch
.
Tensor
)
->
HiddenStates
:
# For prefill step in proposer, we run the model for N-1 tokens
# because Nth token will be processed in the first decode step. For
# N-1 tokens, the input should be 0:N-1 hidden states which should
# be concatanated with 1:N token (since output of scorer has to be
# the input for proposer). Therefore, we shift the hidden states to
# align n-1th hidden state with nth token.
return
HiddenStates
(
prefill_hidden_states
.
roll
(
shifts
=
1
,
dims
=
0
))
if
prefill_hidden_states
is
not
None
else
None
vllm/spec_decode/target_model_runner.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
List
,
Optional
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerWrapperBase
)
class
TargetModelRunner
(
ModelRunnerWrapperBase
):
"""Specialized model runner for speculative decoding target model.
In speculative decoding, the log probabilities selected finally may not
be the same ones as selected by the target model sampling. This means
that the time spent in the log probability calculation of the target model
is time wasted, since we calculate log probabilities after deciding which
tokens are accepted. For this reason disabling log probabilities in the
target model will make decode faster. The model runner sets the
SamplingMetadata parameters according to whether log probabilities are
requested or not.
"""
def
__init__
(
self
,
model_runner
:
ModelRunnerBase
):
# An internal boolean member variable to indicate if token log
# probabilities are needed or not.
super
().
__init__
(
model_runner
)
self
.
disable_logprobs
=
True
def
prepare_model_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
,
)
->
ModelRunnerInputBase
:
model_input
:
ModelRunnerInputBase
=
\
self
.
model_runner
.
prepare_model_input
(
seq_group_metadata_list
,
virtual_engine
,
finished_requests_ids
)
# If token log probabilities is disabled then skip generating sampler
# CPU output. We directly serialize the GPU sampled_token_id tensors
# as needed. If log probabilities is enabled then synchronize all the
# sampling related tensors which includes the logprobs tensors.
model_input
.
sampling_metadata
.
skip_sampler_cpu_output
=
(
self
.
disable_logprobs
)
return
model_input
vllm/spec_decode/top1_proposer.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
List
,
Optional
,
Set
,
Tuple
import
torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SequenceGroupMetadata
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.util
import
sampler_output_to_torch
class
Top1Proposer
(
SpeculativeProposer
):
"""Helper class which separates out sequences which would exceed the max
model length when speculated upon.
This allows combinations of models such as JackFram/llama-68m draft with
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
2048 while Llama2-13b has max_position_embeddings of 4096.
We treat the sequences which exceed the proposal draft model length as
"non-spec sequences". Essentially they skip the draft model and go through
normal decoding in the target model.
Currently, only proposal_lens of 0 and k are supported, where k is a global
batch proposal length. In the future vLLM should support per-sequence
proposal lengths.
"""
def
__init__
(
self
,
worker
:
ProposerWorkerBase
,
device
:
str
,
vocab_size
:
int
,
max_proposal_len
:
Optional
[
int
]
=
None
,
):
self
.
_worker
=
worker
self
.
_device
=
device
self
.
max_proposal_len
=
max_proposal_len
self
.
_vocab_size
=
vocab_size
def
get_spec_proposals
(
self
,
execute_model_req
:
ExecuteModelRequest
,
seq_ids_with_bonus_token_in_last_step
:
Set
[
int
],
)
->
SpeculativeProposals
:
"""Get speculative proposals given the input batch.
Sequences which would exceed the max model length are skipped during
speculation.
"""
proposal_len
=
execute_model_req
.
num_lookahead_slots
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
# Split speculative- and non-speculative- sequences.
(
proposal_lens
,
nonzero_proposal_len_seqs
,
nonzero_proposal_len_indices
,
)
=
self
.
_split_by_proposal_len
(
seq_group_metadata_list
,
proposal_len
)
if
nonzero_proposal_len_seqs
:
# Speculate tokens using the draft worker for the speculative
# sequences.
# If sampler_transposed is true, then maybe_sampler_output's
# token_ids is like [batch] format in proposal_len size list,
# while if it is false, the format would be [proposal_len]
# in batch size list
hidden_states
=
execute_model_req
.
previous_hidden_states
if
hidden_states
is
not
None
:
hidden_states
.
prune
(
nonzero_proposal_len_seqs
)
nonzero_execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
nonzero_proposal_len_seqs
,
num_lookahead_slots
=
proposal_len
,
previous_hidden_states
=
hidden_states
,
)
maybe_sampler_output
,
transposed
=
self
.
_worker
.
sampler_output
(
execute_model_req
=
nonzero_execute_model_req
,
sample_len
=
proposal_len
,
seq_ids_with_bonus_token_in_last_step
=
\
seq_ids_with_bonus_token_in_last_step
,
)
(
proposal_lens
,
maybe_sampler_output
,
nonzero_proposal_len_indices
,
)
=
self
.
_remove_no_proposal_seqs
(
proposal_lens
,
maybe_sampler_output
,
nonzero_proposal_len_indices
,
transposed
)
else
:
# If no sequences can be speculated, set sampler output to None.
maybe_sampler_output
=
None
transposed
=
False
# Combine speculative- and non-speculative sequences into the same
# representation.
proposal_tokens
,
proposal_probs
,
proposal_lens
=
self
.
_merge_outputs
(
batch_size
=
len
(
seq_group_metadata_list
),
proposal_len
=
proposal_len
,
maybe_sampler_output
=
maybe_sampler_output
,
proposal_lens
=
proposal_lens
,
nonzero_proposal_len_indices
=
nonzero_proposal_len_indices
,
sampler_transposed
=
transposed
,
)
proposals
=
SpeculativeProposals
(
proposal_token_ids
=
proposal_tokens
,
proposal_probs
=
proposal_probs
,
proposal_lens
=
proposal_lens
,
no_proposals
=
maybe_sampler_output
is
None
)
return
proposals
def
_split_by_proposal_len
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
proposal_len
:
int
,
)
->
Tuple
[
List
[
int
],
List
[
SequenceGroupMetadata
],
List
[
int
]]:
"""Split sequences by two groups:
1. Sequences with non-zero proposal length.
2. Sequences with zero proposal length (due to disabled speculation
or exceed the maximum model length).
"""
proposal_lens
:
List
[
int
]
=
[]
nonzero_proposal_len_seqs
:
List
[
SequenceGroupMetadata
]
=
[]
nonzero_proposal_len_indices
:
List
[
int
]
=
[]
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
# The speculative decoding for this request has either been disabled
# (e.g. due to high traffic) or this is a prompt request.
if
(
seq_group_metadata
.
is_prompt
or
seq_group_metadata
.
num_speculative_tokens
==
0
):
proposal_lens
.
append
(
0
)
continue
seq_data
=
next
(
iter
(
seq_group_metadata
.
seq_data
.
values
()))
seq_len
=
seq_data
.
get_len
()
# Currently only proposal lens of 0 or the global batch proposal len
# are supported.
# If max_proposal_len is defined, then we shall not exceed this
# quota for nonzero_proposal
new_k
=
0
if
(
self
.
max_proposal_len
is
None
or
seq_len
+
proposal_len
<
self
.
max_proposal_len
):
new_k
=
proposal_len
nonzero_proposal_len_seqs
.
append
(
seq_group_metadata
)
nonzero_proposal_len_indices
.
append
(
i
)
proposal_lens
.
append
(
new_k
)
seq_group_metadata
.
num_speculative_tokens
=
new_k
return
(
proposal_lens
,
nonzero_proposal_len_seqs
,
nonzero_proposal_len_indices
,
)
@
staticmethod
def
_remove_no_proposal_seqs
(
proposal_lens
,
maybe_sampler_output
,
nonzero_proposal_len_indices
,
transposed
):
"""Remove sequences from nonzero_proposal_len_indices and reset
their proposal_len to 0 the draft worker does not provide a proposal
(maybe_sampler_output=None). This can avoid scoring overheads.
"""
# If maybe_sampler_output is None, then the draft worker did not
# provide a proposal for any sequence and thus no action needed.
# Also we do not support transposed maybe_sampler_output for now
# because it seems not straightforward for draft workers outputting
# transposed sampler outputs to handle the case of no proposal.
if
maybe_sampler_output
is
None
or
transposed
:
return
(
proposal_lens
,
maybe_sampler_output
,
nonzero_proposal_len_indices
)
new_proposal_lens
:
List
[
int
]
=
[]
new_nonzero_proposal_len_indices
:
List
[
int
]
=
[]
new_maybe_sampler_output
:
List
[
SamplerOutput
]
=
[]
nonzero_proposal_len_idx_ptr
=
0
seq_idx
=
0
while
seq_idx
<
len
(
proposal_lens
)
and
nonzero_proposal_len_idx_ptr
<
len
(
nonzero_proposal_len_indices
):
if
seq_idx
<
nonzero_proposal_len_indices
[
nonzero_proposal_len_idx_ptr
]:
# Sequence is not in the original nonzero_proposal_len_indices,
# meaning that it has a proposal length of 0 before sending to
# the draft worker.
assert
proposal_lens
[
seq_idx
]
==
0
new_proposal_lens
.
append
(
0
)
else
:
# Sequence is in the original nonzero_proposal_len_indices
if
maybe_sampler_output
[
nonzero_proposal_len_idx_ptr
]
is
None
:
# but does not have a proposal from the draft worker.
new_proposal_lens
.
append
(
0
)
else
:
# and has a proposal from the draft worker. Add it to the
# new nonzero proposal list and keep the sampler output.
new_proposal_lens
.
append
(
proposal_lens
[
seq_idx
])
new_nonzero_proposal_len_indices
.
append
(
seq_idx
)
new_maybe_sampler_output
.
append
(
maybe_sampler_output
[
nonzero_proposal_len_idx_ptr
])
nonzero_proposal_len_idx_ptr
+=
1
seq_idx
+=
1
# The remaining sequences should have proposal length of 0.
new_proposal_lens
.
extend
(
proposal_lens
[
seq_idx
:])
# We assume sampler_output will not be a list of all Nones.
# In this case this function should not be called.
assert
new_maybe_sampler_output
return
(
new_proposal_lens
,
new_maybe_sampler_output
,
new_nonzero_proposal_len_indices
)
def
_merge_outputs
(
self
,
batch_size
:
int
,
proposal_len
:
int
,
maybe_sampler_output
:
Optional
[
List
[
SamplerOutput
]],
proposal_lens
:
List
[
int
],
nonzero_proposal_len_indices
:
List
[
int
],
sampler_transposed
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
if
maybe_sampler_output
is
None
:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens
=
torch
.
tensor
(
-
1
,
dtype
=
torch
.
long
,
device
=
self
.
_device
).
expand
(
batch_size
,
proposal_len
)
proposal_probs
=
torch
.
tensor
(
0
,
dtype
=
torch
.
float32
,
device
=
self
.
_device
).
expand
(
batch_size
,
proposal_len
,
self
.
_vocab_size
)
proposal_lens_tensor
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
self
.
_device
).
expand
(
len
(
proposal_lens
))
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
sampler_output
=
maybe_sampler_output
proposal_tokens
,
proposal_probs
,
*
_
=
sampler_output_to_torch
(
sampler_output
,
sampler_transposed
)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens
=
proposal_tokens
.
new_full
(
size
=
(
batch_size
,
*
proposal_tokens
.
shape
[
1
:]),
fill_value
=-
1
,
)
entire_proposal_tokens
[
nonzero_proposal_len_indices
]
=
proposal_tokens
entire_proposal_probs
=
proposal_probs
.
new_zeros
(
batch_size
,
*
proposal_probs
.
shape
[
1
:],
)
entire_proposal_probs
[
nonzero_proposal_len_indices
]
=
proposal_probs
proposal_tokens
,
proposal_probs
=
(
entire_proposal_tokens
,
entire_proposal_probs
,
)
proposal_lens_tensor
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
long
,
device
=
self
.
_device
)
proposal_lens_tensor
[
nonzero_proposal_len_indices
]
=
proposal_len
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
vllm/spec_decode/util.py
deleted
100644 → 0
View file @
9ffe905a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
time
from
contextlib
import
contextmanager
from
typing
import
Dict
,
List
,
Optional
,
Sequence
,
Tuple
import
torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
PromptLogprobs
,
SequenceGroupMetadata
,
SequenceOutput
)
SeqId
=
int
def
get_all_num_logprobs
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
])
->
List
[
int
]:
"""Given a list of SequenceGroupMetadata, create a list of all num_logprobs.
If the sampling params do not call for any logprobs, return 0 for that
sequence.
"""
all_num_logprobs
:
List
[
int
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
num_logprobs
=
seq_group_metadata
.
sampling_params
.
logprobs
if
num_logprobs
is
None
:
num_logprobs
=
0
all_num_logprobs
.
append
(
num_logprobs
)
return
all_num_logprobs
def
get_sampled_token_logprobs
(
# shape [num_steps, batch_size, vocab_size]
logprob_tensor
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
# shape [num_steps, batch_size]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Get the logprobs for the sampled tokens. Returns the ranks and logprobs.
"""
num_steps
,
batch_size
,
vocab_size
=
logprob_tensor
.
shape
selected_logprobs
=
logprob_tensor
[
torch
.
arange
(
num_steps
).
unsqueeze
(
1
),
torch
.
arange
(
batch_size
),
sampled_token_ids
,
]
expanded_selected_logprobs
=
selected_logprobs
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
vocab_size
)
sampled_token_ids_ranks
=
(
logprob_tensor
>
expanded_selected_logprobs
).
sum
(
-
1
).
add_
(
1
)
return
sampled_token_ids_ranks
,
selected_logprobs
def
create_logprobs_output
(
token_id
:
int
,
token_id_logprob_rank
:
int
,
token_id_logprob
:
float
,
topk_token_ids
:
List
[
Optional
[
int
]],
topk_logprobs
:
List
[
Optional
[
float
]],
)
->
Dict
[
int
,
Logprob
]:
"""Create a Logprob Dict for a token given the sampling results.
Args:
token_id (int): The sampled token for the sequence.
token_id_logprob_rank (int): The logprob rank of the sampled token.
token_id_logprob (float): The logprob value of the sampled token.
topk_token_ids (List[Optional[int]]): The list of top-k token ids.
topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
"""
# vLLM logprobs always include the sampled token. In addition, the user may
# request topk-logprobs (where top-k varies per user up to max_logprobs).
logprobs
:
Dict
[
int
,
Logprob
]
=
{
token_id
:
Logprob
(
logprob
=
token_id_logprob
,
rank
=
token_id_logprob_rank
,
),
}
logprobs
.
update
({
topk_token_id
:
Logprob
(
logprob
=
topk_logprob
if
topk_logprob
is
not
None
else
0.0
,
rank
=
topk_index
+
1
,
)
for
topk_index
,
(
topk_token_id
,
topk_logprob
)
\
in
enumerate
(
zip
(
topk_token_ids
,
topk_logprobs
))
\
if
topk_token_id
is
not
None
})
return
logprobs
def
create_sequence_group_output
(
token_id
:
int
,
token_id_logprob_rank
:
int
,
token_id_logprob
:
float
,
seq_id
:
SeqId
,
topk_token_ids
:
List
[
Optional
[
int
]],
topk_logprobs
:
List
[
Optional
[
float
]],
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
,
step_index
:
Optional
[
int
]
=
0
)
->
CompletionSequenceGroupOutput
:
"""Create a SequenceGroupOutput given the sampling results.
Args:
token_id (int): The sampled token for the sequence.
token_id_logprob_rank (int): The logprob rank of the sampled token.
token_id_logprob (float): The logprob value of the sampled token.
seq_id (int): The sequence id.
topk_token_ids (List[Optional[int]]): The list of top-k token ids.
topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
step_index: (Optional[int]): The index of the speculative token.
"""
logprobs
=
create_logprobs_output
(
token_id
,
token_id_logprob_rank
,
token_id_logprob
,
topk_token_ids
,
topk_logprobs
,
)
return
CompletionSequenceGroupOutput
(
samples
=
[
SequenceOutput
(
parent_seq_id
=
seq_id
,
output_token
=
token_id
,
logprobs
=
logprobs
)
],
prompt_logprobs
=
prompt_logprobs
,
step_index
=
step_index
)
def
split_batch_by_proposal_len
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
proposal_lens
:
List
[
int
],
)
->
Tuple
[
Tuple
[
List
[
SequenceGroupMetadata
],
List
[
int
]],
Tuple
[
List
[
SequenceGroupMetadata
],
List
[
int
]]]:
"""Utility function that splits a batch based on whether the proposal len is
zero or not. We should remove this once vLLM supports per-sequence proposal
lens in a batch.
"""
nonzero_lists
:
Tuple
[
List
[
SequenceGroupMetadata
],
List
[
int
]]
=
([],
[])
zero_lists
:
Tuple
[
List
[
SequenceGroupMetadata
],
List
[
int
]]
=
([],
[])
for
i
,
(
seq_group
,
proposal_len
)
in
enumerate
(
zip
(
seq_group_metadata_list
,
proposal_lens
)):
seq_groups
,
indices
=
nonzero_lists
if
proposal_len
else
zero_lists
seq_groups
.
append
(
seq_group
)
indices
.
append
(
i
)
return
nonzero_lists
,
zero_lists
def
sampler_output_to_torch
(
sampler_output_list
:
Sequence
[
SamplerOutput
],
sampler_transposed
:
bool
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""Utility function which converts a list of SamplerOutput to tensors.
sampler_transposed here is used as the indicator for whether
we need do additional tensor transpose logic here.
Returns:
sampled_token_ids: torch.Tensor
shape: [batch_size, len(sampler_output_list)]
sampled_token_probs: torch.Tensor
shape: [batch_size, len(sampler_output_list), vocab_size]
"""
# shape: [batch_size, num_sampler_output, vocab_size]
sampled_token_probs
=
torch
.
stack
(
[
sampler_output
.
sampled_token_probs
for
sampler_output
in
sampler_output_list
],
dim
=
0
,
)
# shape: [batch_size, num_sampler_output, vocab_size]
sampled_token_logprobs
=
torch
.
stack
(
[
sampler_output
.
logprobs
for
sampler_output
in
sampler_output_list
],
dim
=
0
,
)
# shape: [batch_size, num_sampler_output]
sampled_token_ids
=
torch
.
stack
(
[
sampler_output
.
sampled_token_ids
.
flatten
()
for
sampler_output
in
sampler_output_list
],
dim
=
0
,
)
if
sampler_transposed
:
sampled_token_probs
=
sampled_token_probs
.
transpose
(
0
,
1
)
sampled_token_logprobs
=
sampled_token_logprobs
.
transpose
(
0
,
1
)
sampled_token_ids
=
sampled_token_ids
.
transpose
(
0
,
1
)
if
sampler_output_list
[
0
].
hidden_states
is
not
None
:
# shape: [batch_size, num_sampler_output, hidden_dim]
sampled_hidden_states
=
torch
.
stack
(
[
sampler_output
.
hidden_states
for
sampler_output
in
sampler_output_list
],
dim
=
0
,
)
if
sampler_transposed
:
sampled_hidden_states
=
sampled_hidden_states
.
transpose
(
0
,
1
)
else
:
sampled_hidden_states
=
None
return
(
sampled_token_ids
,
sampled_token_probs
,
sampled_token_logprobs
,
sampled_hidden_states
)
def
maybe_mock_device_tensors
(
sampler_output
:
SamplerOutput
,
batch_size
:
int
,
vocab_size
:
int
,
device
:
str
)
->
None
:
"""Helper method which mocks out the GPU tensors in SamplerOutput with dummy
values. This will be removed in PR 7/9.
https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
"""
values
=
[
sampler_output
.
sampled_token_probs
,
sampler_output
.
sampled_token_ids
]
assert
all
(
v
is
None
for
v
in
values
)
or
not
any
(
v
is
None
for
v
in
values
)
if
not
any
(
v
is
None
for
v
in
values
):
# Do nothing if the tensors are already created (usually in unit tests).
return
# Softmax to ensure valid probs.
sampler_output
.
sampled_token_probs
=
torch
.
nn
.
functional
.
softmax
(
torch
.
rand
(
batch_size
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
device
),
dim
=-
1
)
sampler_output
.
sampled_token_ids
=
torch
.
randint
(
low
=
10
,
high
=
100
,
size
=
(
batch_size
,
),
dtype
=
torch
.
long
,
device
=
device
)
@
contextmanager
def
nvtx_range
(
msg
,
*
args
,
**
kwargs
):
"""
Context manager / decorator that pushes an NVTX range at the beginning
of its scope, and pops it at the end. If extra arguments are given,
they are passed as arguments to msg.format().
If running with cuda graphs, you must enable nsys cuda graph profiling.
Arguments:
msg (string): message to associate with the range
"""
if
current_platform
.
is_cuda_alike
():
torch
.
cuda
.
nvtx
.
range_push
(
msg
.
format
(
*
args
,
**
kwargs
))
try
:
yield
finally
:
torch
.
cuda
.
nvtx
.
range_pop
()
else
:
yield
class
Timer
:
"""Basic timer context manager for measuring CPU time.
"""
def
__enter__
(
self
):
self
.
start_time
=
time
.
time
()
return
self
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
self
.
end_time
=
time
.
time
()
self
.
elapsed_time_s
=
self
.
end_time
-
self
.
start_time
self
.
elapsed_time_ms
=
self
.
elapsed_time_s
*
1000
vllm/transformers_utils/configs/eagle.py
View file @
dd572c0a
...
...
@@ -6,7 +6,6 @@ from typing import Optional, Union
from
transformers
import
AutoConfig
,
PretrainedConfig
import
vllm.envs
as
envs
from
vllm.transformers_utils.configs.deepseek_vl2
import
DeepseekV2Config
...
...
@@ -44,28 +43,25 @@ class EAGLEConfig(PretrainedConfig):
self
.
truncated_vocab_size
=
self
.
model
.
vocab_size
if
\
truncated_vocab_size
is
None
else
truncated_vocab_size
if
not
envs
.
VLLM_USE_V1
:
kwargs
[
"architectures"
]
=
[
"EAGLEModel"
]
# Eagle model name should follow naming convention of
# LlamaForCausalLM -> EagleLlamaForCausalLM
if
method
==
"eagle"
:
assert
self
.
model
is
not
None
,
\
"model should not be None when method is eagle"
kwargs
[
"architectures"
]
=
[
f
"Eagle
{
arch
}
"
if
not
arch
.
startswith
(
"Eagle"
)
\
else
arch
for
arch
in
self
.
model
.
architectures
]
elif
method
==
"eagle3"
:
assert
self
.
model
is
not
None
,
\
"model should not be None when method is eagle3"
kwargs
[
"architectures"
]
=
[
f
"Eagle3
{
arch
}
"
if
not
arch
.
startswith
(
"Eagle3"
)
\
else
arch
for
arch
in
self
.
model
.
architectures
]
else
:
# Eagle model name should follow naming convention of
# LlamaForCausalLM -> EagleLlamaForCausalLM
if
method
==
"eagle"
:
assert
self
.
model
is
not
None
,
\
"model should not be None when method is eagle"
kwargs
[
"architectures"
]
=
[
f
"Eagle
{
arch
}
"
if
not
arch
.
startswith
(
"Eagle"
)
\
else
arch
for
arch
in
self
.
model
.
architectures
]
elif
method
==
"eagle3"
:
assert
self
.
model
is
not
None
,
\
"model should not be None when method is eagle3"
kwargs
[
"architectures"
]
=
[
f
"Eagle3
{
arch
}
"
if
not
arch
.
startswith
(
"Eagle3"
)
\
else
arch
for
arch
in
self
.
model
.
architectures
]
else
:
raise
ValueError
(
f
"Invalid method
{
method
}
.
\
Supported methods are eagle and eagle3."
)
raise
ValueError
(
f
"Invalid method
{
method
}
.
\
Supported methods are eagle and eagle3."
)
super
().
__init__
(
**
kwargs
)
...
...
vllm/worker/worker_base.py
View file @
dd572c0a
...
...
@@ -397,8 +397,6 @@ class LocalOrDistributedWorkerBase(WorkerBase):
model_input
,
worker_input
,
kwargs
=
inputs
num_steps
=
worker_input
.
num_steps
if
execute_model_req
is
not
None
and
execute_model_req
.
spec_step_idx
:
kwargs
[
"spec_step_idx"
]
=
execute_model_req
.
spec_step_idx
self
.
execute_worker
(
worker_input
)
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment