Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
01c30741
Commit
01c30741
authored
Apr 27, 2025
by
lizhigong
Browse files
add spec decode zero overhead
parent
b01c8270
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
905 additions
and
30 deletions
+905
-30
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+2
-2
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+2
-2
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+2
-2
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+31
-13
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+2
-2
vllm/zero_overhead/llm_engine.py
vllm/zero_overhead/llm_engine.py
+5
-5
vllm/zero_overhead/model_runner.py
vllm/zero_overhead/model_runner.py
+2
-2
vllm/zero_overhead/sampler.py
vllm/zero_overhead/sampler.py
+0
-0
vllm/zero_overhead/sequence.py
vllm/zero_overhead/sequence.py
+0
-0
vllm/zero_overhead/spec_decode/batch_expansion.py
vllm/zero_overhead/spec_decode/batch_expansion.py
+142
-0
vllm/zero_overhead/spec_decode/muti_step_worker.py
vllm/zero_overhead/spec_decode/muti_step_worker.py
+138
-0
vllm/zero_overhead/spec_decode/spec_decode_worker.py
vllm/zero_overhead/spec_decode/spec_decode_worker.py
+482
-0
vllm/zero_overhead/spec_decode/top1_proproser.py
vllm/zero_overhead/spec_decode/top1_proproser.py
+83
-0
vllm/zero_overhead/stop_check.py
vllm/zero_overhead/stop_check.py
+1
-1
vllm/zero_overhead/tokenizer.py
vllm/zero_overhead/tokenizer.py
+1
-1
vllm/zero_overhead/utils.py
vllm/zero_overhead/utils.py
+12
-0
No files found.
vllm/engine/multiprocessing/engine.py
View file @
01c30741
...
@@ -6,8 +6,8 @@ from contextlib import contextmanager
...
@@ -6,8 +6,8 @@ from contextlib import contextmanager
from
typing
import
Iterator
,
List
,
Optional
,
Union
from
typing
import
Iterator
,
List
,
Optional
,
Union
import
cloudpickle
import
cloudpickle
from
vllm.zero_overhead.
v0.
llm_engine
import
ZeroOverheadEngine
from
vllm.zero_overhead.llm_engine
import
ZeroOverheadEngine
from
vllm.zero_overhead.
v0.
utils
import
is_zero_overhead
from
vllm.zero_overhead.utils
import
is_zero_overhead
import
zmq
import
zmq
from
vllm
import
AsyncEngineArgs
,
SamplingParams
from
vllm
import
AsyncEngineArgs
,
SamplingParams
...
...
vllm/entrypoints/llm.py
View file @
01c30741
...
@@ -43,8 +43,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
...
@@ -43,8 +43,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
(
Counter
,
Device
,
deprecate_args
,
deprecate_kwargs
,
from
vllm.utils
import
(
Counter
,
Device
,
deprecate_args
,
deprecate_kwargs
,
is_list_of
)
is_list_of
)
from
vllm.zero_overhead.
v0.
llm_engine
import
ZeroOverheadEngine
from
vllm.zero_overhead.llm_engine
import
ZeroOverheadEngine
from
vllm.zero_overhead.
v0.
utils
import
is_zero_overhead
from
vllm.zero_overhead.utils
import
is_zero_overhead
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
...
vllm/model_executor/layers/sampler.py
View file @
01c30741
...
@@ -21,7 +21,7 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
...
@@ -21,7 +21,7 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput
,
Logprob
,
CompletionSequenceGroupOutput
,
Logprob
,
PromptLogprobs
,
SampleLogprobs
,
SequenceOutput
)
PromptLogprobs
,
SampleLogprobs
,
SequenceOutput
)
from
vllm.spec_decode.metrics
import
SpecDecodeWorkerMetrics
from
vllm.spec_decode.metrics
import
SpecDecodeWorkerMetrics
from
vllm.zero_overhead.
v0.
utils
import
is_zero_overhead
from
vllm.zero_overhead.utils
import
is_zero_overhead
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
find_spec
(
"flashinfer"
):
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
find_spec
(
"flashinfer"
):
import
flashinfer.sampling
import
flashinfer.sampling
...
@@ -40,7 +40,7 @@ def get_sampler() -> torch.nn.Module:
...
@@ -40,7 +40,7 @@ def get_sampler() -> torch.nn.Module:
from
vllm.v1.sample.sampler
import
Sampler
as
V1Sampler
from
vllm.v1.sample.sampler
import
Sampler
as
V1Sampler
return
V1Sampler
()
return
V1Sampler
()
if
is_zero_overhead
():
if
is_zero_overhead
():
from
vllm.zero_overhead.
v0.
sampler
import
ZeroOverheadSampler
from
vllm.zero_overhead.sampler
import
ZeroOverheadSampler
return
ZeroOverheadSampler
()
return
ZeroOverheadSampler
()
return
Sampler
()
return
Sampler
()
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
01c30741
...
@@ -54,6 +54,7 @@ from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
...
@@ -54,6 +54,7 @@ from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.zero_overhead.utils
import
is_zero_overhead
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -206,8 +207,11 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
...
@@ -206,8 +207,11 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
# Load lm_head weight for eagle in init_device
# Load lm_head weight for eagle in init_device
if
draft_model_config
.
hf_config
.
model_type
==
"eagle"
:
if
draft_model_config
.
hf_config
.
model_type
==
"eagle"
:
enable_lm_head_weight_load
=
True
enable_lm_head_weight_load
=
True
if
is_zero_overhead
():
proposer_worker
=
MultiStepWorker
(
**
draft_worker_kwargs
)
from
vllm.zero_overhead.spec_decode.muti_step_worker
import
ZeroOverheadMultiStepWorker
proposer_worker
=
ZeroOverheadMultiStepWorker
(
**
draft_worker_kwargs
)
else
:
proposer_worker
=
MultiStepWorker
(
**
draft_worker_kwargs
)
if
draft_model_config
.
hf_config
.
model_type
==
"deepseek_mtp"
:
if
draft_model_config
.
hf_config
.
model_type
==
"deepseek_mtp"
:
num_spec_prefill_steps
=
\
num_spec_prefill_steps
=
\
draft_model_config
.
hf_config
.
n_predict
draft_model_config
.
hf_config
.
n_predict
...
@@ -254,17 +258,31 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
...
@@ -254,17 +258,31 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
"[Speculative Decoding] Disabling MQA scorer as the "
"[Speculative Decoding] Disabling MQA scorer as the "
"target model is not running in eager mode."
)
"target model is not running in eager mode."
)
return
SpecDecodeWorker
(
if
is_zero_overhead
():
proposer_worker
,
from
vllm.zero_overhead.spec_decode.spec_decode_worker
import
ZeroOverheadSpecDecodeWorker
scorer_worker
,
return
ZeroOverheadSpecDecodeWorker
(
disable_mqa_scorer
=
disable_mqa_scorer
,
proposer_worker
,
disable_logprobs
=
disable_logprobs
,
scorer_worker
,
disable_log_stats
=
disable_log_stats
,
disable_mqa_scorer
=
disable_mqa_scorer
,
disable_by_batch_size
=
disable_by_batch_size
,
disable_logprobs
=
disable_logprobs
,
spec_decode_sampler
=
spec_decode_sampler
,
disable_log_stats
=
disable_log_stats
,
allow_zero_draft_token_step
=
allow_zero_draft_token_step
,
disable_by_batch_size
=
disable_by_batch_size
,
enable_lm_head_weight_load
=
enable_lm_head_weight_load
,
spec_decode_sampler
=
spec_decode_sampler
,
num_spec_prefill_steps
=
num_spec_prefill_steps
)
allow_zero_draft_token_step
=
allow_zero_draft_token_step
,
enable_lm_head_weight_load
=
enable_lm_head_weight_load
,
num_spec_prefill_steps
=
num_spec_prefill_steps
)
else
:
return
SpecDecodeWorker
(
proposer_worker
,
scorer_worker
,
disable_mqa_scorer
=
disable_mqa_scorer
,
disable_logprobs
=
disable_logprobs
,
disable_log_stats
=
disable_log_stats
,
disable_by_batch_size
=
disable_by_batch_size
,
spec_decode_sampler
=
spec_decode_sampler
,
allow_zero_draft_token_step
=
allow_zero_draft_token_step
,
enable_lm_head_weight_load
=
enable_lm_head_weight_load
,
num_spec_prefill_steps
=
num_spec_prefill_steps
)
def
__init__
(
def
__init__
(
self
,
self
,
...
...
vllm/worker/model_runner.py
View file @
01c30741
...
@@ -60,7 +60,7 @@ from vllm.worker.model_runner_base import (
...
@@ -60,7 +60,7 @@ from vllm.worker.model_runner_base import (
_add_sampling_metadata_broadcastable_dict
,
_add_sampling_metadata_broadcastable_dict
,
_init_attn_metadata_from_tensor_dict
,
_init_attn_metadata_from_tensor_dict
,
_init_sampling_metadata_from_tensor_dict
)
_init_sampling_metadata_from_tensor_dict
)
from
vllm.zero_overhead.
v0.
utils
import
is_zero_overhead
from
vllm.zero_overhead.utils
import
is_zero_overhead
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
...
@@ -1638,7 +1638,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1638,7 +1638,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
ModelInputForGPUWithSamplingMetadata
)
ModelInputForGPUWithSamplingMetadata
)
_builder_cls
:
Type
[
ModelInputForGPUBuilder
]
=
ModelInputForGPUBuilder
_builder_cls
:
Type
[
ModelInputForGPUBuilder
]
=
ModelInputForGPUBuilder
if
is_zero_overhead
():
if
is_zero_overhead
():
from
vllm.zero_overhead.
v0.
model_runner
import
ZeroOverheadModelInputForGpuBuilder
from
vllm.zero_overhead.model_runner
import
ZeroOverheadModelInputForGpuBuilder
_builder_cls
=
ZeroOverheadModelInputForGpuBuilder
_builder_cls
=
ZeroOverheadModelInputForGpuBuilder
def
make_model_input_from_broadcasted_tensor_dict
(
def
make_model_input_from_broadcasted_tensor_dict
(
...
...
vllm/zero_overhead/
v0/
llm_engine.py
→
vllm/zero_overhead/llm_engine.py
View file @
01c30741
...
@@ -33,14 +33,14 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
...
@@ -33,14 +33,14 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.usage.usage_lib
import
UsageContext
,
is_usage_stats_enabled
from
vllm.usage.usage_lib
import
UsageContext
,
is_usage_stats_enabled
from
vllm.utils
import
resolve_obj_by_qualname
,
weak_bind
,
Counter
from
vllm.utils
import
resolve_obj_by_qualname
,
weak_bind
,
Counter
from
vllm.zero_overhead.
v0.
sampler
import
SampleRecorder
,
get_last_sampler
from
vllm.zero_overhead.sampler
import
SampleRecorder
,
get_last_sampler
from
vllm.zero_overhead.
v0.
sequence
import
ZeroOverheadSequence
from
vllm.zero_overhead.sequence
import
ZeroOverheadSequence
from
vllm.zero_overhead.
v0.
stop_check
import
ZeroOverheadStopChecker
from
vllm.zero_overhead.stop_check
import
ZeroOverheadStopChecker
from
vllm.zero_overhead.
v0.
tokenizer
import
ZeroOverheadDetokenizer
from
vllm.zero_overhead.tokenizer
import
ZeroOverheadDetokenizer
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
usage_message
)
from
vllm.profiler.prof
import
profile
from
vllm.profiler.prof
import
profile
from
vllm.zero_overhead.
v0.
utils
import
is_zero_no_thread
from
vllm.zero_overhead.utils
import
is_zero_no_thread
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
...
vllm/zero_overhead/
v0/
model_runner.py
→
vllm/zero_overhead/model_runner.py
View file @
01c30741
...
@@ -10,8 +10,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
...
@@ -10,8 +10,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.utils
import
async_tensor_h2d
,
flatten_2d_lists
from
vllm.utils
import
async_tensor_h2d
,
flatten_2d_lists
from
vllm.worker.model_runner
import
ModelInputForGPU
,
ModelInputForGPUBuilder
from
vllm.worker.model_runner
import
ModelInputForGPU
,
ModelInputForGPUBuilder
from
vllm.zero_overhead.
v0.
sampler
import
get_last_sampler
from
vllm.zero_overhead.sampler
import
get_last_sampler
from
vllm.zero_overhead.
v0.
update_input
import
UpdateInputTokens
from
vllm.zero_overhead.update_input
import
UpdateInputTokens
class
ZeroOverheadModelInputForGpuBuilder
(
ModelInputForGPUBuilder
):
class
ZeroOverheadModelInputForGpuBuilder
(
ModelInputForGPUBuilder
):
...
...
vllm/zero_overhead/
v0/
sampler.py
→
vllm/zero_overhead/sampler.py
View file @
01c30741
File moved
vllm/zero_overhead/
v0/
sequence.py
→
vllm/zero_overhead/sequence.py
View file @
01c30741
File moved
vllm/zero_overhead/spec_decode/batch_expansion.py
0 → 100644
View file @
01c30741
from
array
import
array
import
numpy
as
np
from
itertools
import
chain
,
count
from
typing
import
Iterator
,
List
,
Optional
,
Tuple
import
torch
from
vllm
import
SamplingParams
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
VLLM_TOKEN_ID_ARRAY_TYPE
,
ExecuteModelRequest
,
SequenceData
,
SequenceGroupMetadata
,
get_all_seq_ids
)
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.util
import
nvtx_range
,
split_batch_by_proposal_len
from
vllm.utils
import
async_tensor_h2d
SeqId
=
int
TargetSeqId
=
int
TokenId
=
int
DEFAULT_SIMPLE_SAMPLING_PARAMS
=
SamplingParams
()
class
ZeroOverheadBatchExpansionTop1Scorer
(
BatchExpansionTop1Scorer
):
@
nvtx_range
(
"BatchExpansionTop1Scorer.score_proposals"
)
def
score_proposals
(
self
,
execute_model_req
:
ExecuteModelRequest
,
proposals
:
SpeculativeProposals
,
)
->
SpeculativeScores
:
"""Score the proposed tokens via the scorer model.
This converts each input sequence to a set of k+1 target sequences. The
target sequences have the unique continuations to be scored and a
unique sequence ID that is different from all input sequence ids.
If a speculative sequence length would exceed the max model length, then
no speculation is produced for that sequence.
Args:
execute_model_req: The execution request.
proposals: The speculative proposals to score.
Returns:
SpeculativeScores: The scores of each speculative token, along with
which sequences were ignored during scoring.
"""
proposal_lens_list
=
np
.
zeros
(
proposals
.
proposal_lens
.
shape
,
dtype
=
int
).
tolist
()
#zero_overhead todo fix
proposal_token_ids_list
=
np
.
zeros
(
proposals
.
proposal_token_ids
.
shape
,
dtype
=
int
).
tolist
()
# Filter the list to ignore invalid proposals.
proposal_token_ids_list_without_skips
=
[
proposals
for
proposals
in
proposal_token_ids_list
if
VLLM_INVALID_TOKEN_ID
not
in
proposals
]
(
spec_indices
,
non_spec_indices
,
target_seq_group_metadata_list
,
num_scoring_tokens
)
=
self
.
_expand_batch
(
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
,
proposal_token_ids_list
=
proposal_token_ids_list_without_skips
,
proposal_lens_list
=
proposal_lens_list
,
)
#print('###execute_model_req', execute_model_req)
#print('###target_seq_group_metadata_list', target_seq_group_metadata_list)
target_sampler_output
=
self
.
_scorer_worker
.
execute_model
(
execute_model_req
=
execute_model_req
.
clone
(
seq_group_metadata_list
=
target_seq_group_metadata_list
))
assert
len
(
target_sampler_output
)
==
1
,
"expected single-step output"
target_sampler_output
=
target_sampler_output
[
0
]
#print('###target_sampler_output', target_sampler_output)
if
not
non_spec_indices
:
# All sequence groups in batch have spec decoding enabled
return
self
.
_contract_batch_all_spec
(
target_sampler_output
=
target_sampler_output
,
proposals
=
proposals
,
)
else
:
# Batch has a mix of spec decode enabled and disabled seq groups
return
self
.
_contract_batch
(
execute_model_req
.
seq_group_metadata_list
,
target_sampler_output
=
target_sampler_output
,
proposals
=
proposals
,
num_scoring_tokens
=
num_scoring_tokens
,
non_spec_indices
=
non_spec_indices
,
spec_indices
=
spec_indices
,
k
=
execute_model_req
.
num_lookahead_slots
,
)
def
_contract_non_speculative
(
self
,
scores
:
SpeculativeScores
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
non_spec_indices
:
List
[
int
],
non_spec_outputs
:
SpeculativeScores
,
has_prompt_log
:
bool
)
->
SpeculativeScores
:
"""
Augment input `scores` with non-speculative requests outputs.
This includes decode requests with speculation turned off, as well
as prefill requests when `enable_chunked_prefill` is set.
For the latter, prefills are further separated into terminal and
non-terminal chunks (from which no token is sampled).
"""
if
not
non_spec_indices
:
return
scores
if
has_prompt_log
:
# When prompt_logprobs is enabled, prefills yield output token
# (and respective prob) in the last entry (prompt|out):
# [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..].
# With chunked prefill, non-terminal chunks have -1 on each
# position: they're still picked, but they're discarded later.
seq_meta
=
seq_group_metadata_list
nospec_sizes
=
torch
.
tensor
([
seq_meta
[
i
].
token_chunk_size
if
seq_meta
[
i
].
is_prompt
else
1
for
i
in
non_spec_indices
])
nospec_sampled_token_idxs
=
torch
.
cumsum
(
nospec_sizes
,
0
).
add_
(
-
1
)
else
:
# In this case only sampled tokens are returned, select all.
nospec_sampled_token_idxs
=
list
(
range
(
len
(
non_spec_outputs
.
token_ids
)))
nospec_sampled_token_idxs
=
async_tensor_h2d
(
nospec_sampled_token_idxs
,
torch
.
int32
,
self
.
_device
,
True
)
non_spec_indices
=
async_tensor_h2d
(
non_spec_indices
,
torch
.
int32
,
self
.
_device
,
True
)
scores
.
token_ids
[
non_spec_indices
,
:
1
]
=
\
non_spec_outputs
.
token_ids
[
nospec_sampled_token_idxs
].
unsqueeze
(
1
)
scores
.
probs
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_outputs
.
probs
[
nospec_sampled_token_idxs
].
unsqueeze
(
1
)
scores
.
logprobs
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_outputs
.
logprobs
[
nospec_sampled_token_idxs
].
unsqueeze
(
1
)
if
scores
.
hidden_states
is
not
None
:
assert
non_spec_outputs
.
hidden_states
is
not
None
scores
.
hidden_states
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_outputs
.
hidden_states
[
nospec_sampled_token_idxs
].
unsqueeze
(
1
)
return
scores
\ No newline at end of file
vllm/zero_overhead/spec_decode/muti_step_worker.py
0 → 100644
View file @
01c30741
import
copy
import
weakref
from
typing
import
Dict
,
List
,
Set
,
Tuple
import
torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
(
ExecuteModelRequest
,
HiddenStates
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.utils
import
async_tensor_h2d
from
vllm.zero_overhead.spec_decode.top1_proproser
import
ZeroOverheadTop1Proposer
if
current_platform
.
is_cuda_alike
():
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.worker.worker_base
import
DelegateWorkerBase
class
ZeroOverheadMultiStepWorker
(
MultiStepWorker
):
def
init_device
(
self
)
->
None
:
self
.
worker
.
init_device
()
self
.
_proposer
=
ZeroOverheadTop1Proposer
(
weakref
.
proxy
(
self
),
# type: ignore[arg-type]
self
.
device
,
self
.
vocab_size
,
max_proposal_len
=
self
.
max_model_len
,
)
@
torch
.
inference_mode
()
def
sampler_output
(
self
,
execute_model_req
:
ExecuteModelRequest
,
sample_len
:
int
,
seq_ids_with_bonus_token_in_last_step
:
Set
[
int
],
)
->
Tuple
[
List
[
SamplerOutput
],
bool
]:
"""Run the model forward pass sample_len times. Returns the list of
sampler output, one per model forward pass, along with indicator of
whether torch tensor in sampler output need to be transposed in latter
sampler_output_to_torch logic.
For multi step worker, this indicator shall be True.
"""
print
(
'###execute_model_req'
,
execute_model_req
)
self
.
_raise_if_unsupported
(
execute_model_req
)
# Expand the batch for sequences with a bonus token.
# Perform a forward pass on the expanded batch and filter the
# response to retain only the original sequences' responses.
expanded_request
,
indices_of_seq_with_bonus_tokens
=
\
self
.
_expand_execute_model_request
(
execute_model_req
,
seq_ids_with_bonus_token_in_last_step
)
# Run model sample_len times.
model_outputs
:
List
[
SamplerOutput
]
=
[]
if
current_platform
.
is_cuda_alike
()
and
isinstance
(
self
.
model_runner
,
TP1DraftModelRunner
)
and
self
.
model_runner
.
supports_gpu_multi_step
(
expanded_request
):
# Here we run the draft_model_runner with multi-step prepare
# on the GPU directly
expanded_request
.
num_steps
=
sample_len
self
.
model_runner
.
set_indices_of_seq_with_bonus_tokens
(
indices_of_seq_with_bonus_tokens
)
model_outputs
=
self
.
execute_model
(
execute_model_req
=
expanded_request
)
else
:
# Here we run multi-step directly, with every step prepared
# on the CPU.
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
for
_
in
range
(
sample_len
):
print
(
'###self.worker.execute_model'
,
sample_len
)
print
(
'###expanded_request'
,
expanded_request
)
model_output
:
List
[
SamplerOutput
]
=
self
.
worker
.
execute_model
(
execute_model_req
=
expanded_request
)
assert
(
len
(
model_output
)
==
1
),
"composing multistep workers not supported"
model_output
=
model_output
[
0
]
print
(
'###model_output'
,
model_output
)
self
.
_append_new_tokens
(
model_output
,
expanded_request
.
seq_group_metadata_list
,
indices_of_seq_with_bonus_tokens
)
model_outputs
.
append
(
model_output
)
filtered_model_outputs
=
self
.
_filter_model_output_zero_overhead
(
model_outputs
,
indices_of_seq_with_bonus_tokens
)
return
filtered_model_outputs
,
True
def
_filter_model_output_zero_overhead
(
self
,
expanded_batch_outputs
:
List
[
SamplerOutput
],
output_indices_to_retain
:
List
[
int
])
->
List
[
SamplerOutput
]:
"""
Filters the model output to include only the specified sequence
outputs. This method contracts the expanded batch output from the
model to retain the outputs of only those sequences indicated by the
provided indices.
Args:
expanded_batch_output (List[SamplerOutput]): The expanded output
batch from the model.
output_indices_to_retain (torch.Tensor): Indices of the model
outputs to retain.
Returns:
List[SamplerOutput]: A list containing the filtered model
outputs for the specified indices.
"""
indices_of_seq_with_bonus_tokens
=
async_tensor_h2d
(
output_indices_to_retain
,
torch
.
int32
,
self
.
device
,
True
)
return
[
SamplerOutput
(
outputs
=
[
expanded_batch_output
.
outputs
[
i
]
for
i
in
output_indices_to_retain
]
if
len
(
expanded_batch_output
.
outputs
)
>
0
else
[],
sampled_token_probs
=
(
expanded_batch_output
.
sampled_token_probs
[
indices_of_seq_with_bonus_tokens
]
if
expanded_batch_output
.
sampled_token_probs
is
not
None
else
None
),
logprobs
=
(
expanded_batch_output
.
logprobs
[
indices_of_seq_with_bonus_tokens
]
if
expanded_batch_output
.
logprobs
is
not
None
else
None
),
sampled_token_ids
=
(
expanded_batch_output
.
sampled_token_ids
[
indices_of_seq_with_bonus_tokens
]
if
expanded_batch_output
.
sampled_token_ids
is
not
None
else
None
))
for
expanded_batch_output
in
expanded_batch_outputs
]
\ No newline at end of file
vllm/zero_overhead/spec_decode/spec_decode_worker.py
0 → 100644
View file @
01c30741
import
os
import
copy
from
collections
import
defaultdict
from
functools
import
cached_property
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
import
torch
import
torch.nn
as
nn
from
vllm.config
import
ParallelConfig
,
SpeculativeConfig
,
VllmConfig
from
vllm.distributed.communication_op
import
(
broadcast_tensor_dict
,
get_tp_group
,
tensor_model_parallel_gather
)
from
vllm.distributed.parallel_state
import
model_parallel_is_initialized
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
SpecDecodeBaseSampler
,
SpecDecodeStochasticBaseSampler
)
from
vllm.model_executor.layers.typical_acceptance_sampler
import
(
TypicalAcceptanceSampler
)
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
CompletionSequenceGroupOutput
,
ExecuteModelRequest
,
HiddenStates
,
SequenceGroupMetadata
,
get_all_seq_ids_and_request_ids
,
Logits
)
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTreeStyleScorer
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.spec_decode_worker
import
SpecDecodeWorker
from
vllm.zero_overhead.spec_decode.batch_expansion
import
ZeroOverheadBatchExpansionTop1Scorer
if
current_platform
.
is_cuda_alike
():
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.medusa_worker
import
MedusaWorker
from
vllm.spec_decode.metrics
import
AsyncMetricsCollector
from
vllm.spec_decode.mlp_speculator_worker
import
MLPSpeculatorWorker
from
vllm.spec_decode.mqa_scorer
import
MQAScorer
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.ngram_worker
import
NGramWorker
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.smaller_tp_proposer_worker
import
SmallerTpProposerWorker
from
vllm.spec_decode.target_model_runner
import
TargetModelRunner
from
vllm.spec_decode.util
import
(
Timer
,
create_logprobs_output
,
create_sequence_group_output
,
get_all_num_logprobs
,
get_sampled_token_logprobs
,
nvtx_range
,
split_batch_by_proposal_len
)
from
vllm.utils
import
resolve_obj_by_qualname
from
vllm.worker.worker_base
import
LoRANotSupportedWorkerBase
,
WorkerBase
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
logger
=
init_logger
(
__name__
)
class
ZeroOverheadSpecDecodeWorker
(
SpecDecodeWorker
):
def
init_device
(
self
)
->
None
:
"""Initialize both scorer and proposer models.
"""
# The scorer worker model is initialized first in case the proposer
# model has a smaller TP degree than the target worker.
self
.
scorer_worker
.
init_device
()
self
.
proposer_worker
.
init_device
()
# NOTE(cade): load_model is not part of the WorkerBase interface.
self
.
scorer_worker
.
load_model
()
self
.
proposer_worker
.
load_model
()
if
self
.
_enable_lm_head_weight_load
:
# NOTE(Shangming): gather lm_head weight when tp enabled
target_lm_head_weight
:
torch
.
Tensor
=
tensor_model_parallel_gather
(
self
.
scorer_worker
.
model_runner
.
model_runner
.
model
.
lm_head
.
\
weight
.
data
,
dim
=
0
,
)
self
.
proposer_worker
.
maybe_load_lm_head_weight
(
target_lm_head_weight
)
self
.
_metrics
.
init_tensors
(
self
.
rank
,
device_type
=
self
.
device
)
if
model_parallel_is_initialized
():
self
.
spec_decode_sampler
.
init_tensors
(
get_tp_group
().
local_rank
,
device_type
=
self
.
device
)
else
:
self
.
spec_decode_sampler
.
init_tensors
(
self
.
rank
,
device_type
=
self
.
device
)
scorer_cls
:
Type
[
SpeculativeScorer
]
if
self
.
disable_mqa_scorer
:
scorer_cls
=
ZeroOverheadBatchExpansionTop1Scorer
logger
.
info
(
"[Speculative Decoding] Use batch "
"expansion for scoring proposals."
)
else
:
scorer_cls
=
MQAScorer
logger
.
info
(
"[Speculative Decoding] Use MQA scorer for scoring proposals."
)
if
not
self
.
tree_decoding
:
self
.
scorer
=
scorer_cls
(
scorer_worker
=
self
.
scorer_worker
,
device
=
self
.
device
,
vocab_size
=
self
.
_vocab_size
)
else
:
self
.
scorer
=
BatchExpansionTreeStyleScorer
(
scorer_worker
=
self
.
scorer_worker
,
device
=
self
.
device
,
vocab_size
=
self
.
_vocab_size
)
self
.
_configure_model_sampler_for_spec_decode
()
@
nvtx_range
(
"spec_decode_worker._verify_tokens"
)
def
_verify_tokens
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
proposal_scores
:
SpeculativeScores
,
proposals
:
SpeculativeProposals
,
max_proposal_len
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
List
[
List
[
int
]],
List
[
int
]]:
"""Determine which speculative tokens are accepted using the
probabilities of each token according to the proposer and scorer models.
Returns a tuple of Tensors, one for the accepted token ids and one for
the logprobs according to the scoring model.
"""
proposal_lens_list
=
proposals
.
proposal_lens
# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
(
_
,
spec_indices
),
(
_
,
non_spec_indices
)
=
split_batch_by_proposal_len
(
seq_group_metadata_list
,
proposal_lens_list
)
original_indices
=
spec_indices
+
non_spec_indices
# Get probabilities of target model, including bonus tokens.
if
non_spec_indices
:
proposal_verifier_probs
=
proposal_scores
.
probs
[
spec_indices
]
else
:
proposal_verifier_probs
=
proposal_scores
.
probs
if
self
.
tree_decoding
:
retrieve_indices
=
proposals
.
retrieve_indices
proposal_verifier_probs
=
proposal_verifier_probs
[:,
retrieve_indices
]
# Get non-speculative sampled tokens from target model.
non_spec_token_ids
=
proposal_scores
.
token_ids
[
non_spec_indices
]
# Get bonus tokens from target model.
bonus_token_ids
=
proposal_scores
.
token_ids
[:,
-
1
:]
if
non_spec_indices
:
bonus_token_ids
=
bonus_token_ids
[
spec_indices
,
:]
# Get probabilities according to proposal method.
proposal_probs
=
proposals
.
proposal_probs
if
proposals
.
proposal_probs
is
not
None
else
None
if
proposal_probs
is
not
None
and
non_spec_indices
:
proposal_probs
=
proposal_probs
[
spec_indices
]
# Get proposed tokens.
proposal_token_ids
=
proposals
.
proposal_token_ids
if
non_spec_indices
:
proposal_token_ids
=
proposal_token_ids
[
spec_indices
]
# Get tree buffers.
cart_candidates
=
proposals
.
cart_candidates
if
proposals
.
cart_candidates
is
not
None
else
None
if
cart_candidates
is
not
None
and
non_spec_indices
:
cart_candidates
=
cart_candidates
[
spec_indices
]
# Sampler arguments
sampler_extra_kwargs
:
Dict
[
str
,
Any
]
=
{}
if
self
.
generators
and
isinstance
(
self
.
spec_decode_sampler
,
SpecDecodeStochasticBaseSampler
):
sampler_extra_kwargs
[
"seeded_seqs"
]
=
{
idx
:
self
.
generators
[
sgm
.
request_id
]
for
idx
,
sgm
in
enumerate
(
seq_group_metadata_list
)
if
sgm
.
sampling_params
.
seed
is
not
None
}
if
isinstance
(
self
.
spec_decode_sampler
,
TypicalAcceptanceSampler
):
sampler_extra_kwargs
[
"cart_candidates"
]
=
cart_candidates
sampler_extra_kwargs
[
"best_candidates"
]
=
[]
sampler_extra_kwargs
[
"accept_lengths"
]
=
[]
first_step_flags
=
[]
for
i
,
sgm
in
enumerate
(
seq_group_metadata_list
):
seq
=
next
(
iter
(
sgm
.
seq_data
.
values
()))
first_step_flags
.
append
(
True
if
seq
.
get_first_step_flag
()
else
False
)
sampler_extra_kwargs
[
"first_step_flags"
]
=
first_step_flags
accepted_token_ids
=
self
.
spec_decode_sampler
(
target_with_bonus_probs
=
proposal_verifier_probs
,
bonus_token_ids
=
bonus_token_ids
,
draft_probs
=
proposal_probs
,
draft_token_ids
=
proposal_token_ids
,
**
sampler_extra_kwargs
,
)
# Append output tokens from non-speculative sequences to
# the accepted token ids tensor.
if
not
self
.
tree_decoding
:
non_spec_token_ids
=
non_spec_token_ids
.
expand
(
-
1
,
max_proposal_len
+
1
).
clone
()
else
:
non_spec_token_ids
=
non_spec_token_ids
.
expand
(
-
1
,
max_proposal_len
).
clone
()
non_spec_token_ids
[:,
1
:]
=
-
1
accepted_token_ids
=
torch
.
cat
(
[
accepted_token_ids
,
non_spec_token_ids
])
logprobs
=
proposal_scores
.
logprobs
# Rearrange so that results are in the order of the original seq group
# metadata.
original_indices
=
async_tensor_h2d
(
original_indices
,
torch
.
int32
,
self
.
device
,
True
)
accepted_token_ids
[
original_indices
]
=
accepted_token_ids
.
clone
()
# B x K+1 x D
hidden_states
=
proposal_scores
.
hidden_states
select_indices
=
None
accept_lengths
=
None
select_indices_list
=
[]
if
cart_candidates
is
None
:
if
hidden_states
is
not
None
:
# Only get terminal hidden states for next step
terminal_metadata
=
[
sg
for
sg
in
seq_group_metadata_list
if
sg
.
do_sample
]
# Contract hidden states based on accepted tokens
hs_size
=
hidden_states
.
shape
[
-
1
]
accepted_index
=
accepted_token_ids
+
1
# Convert -1 to 0
accepted_index
=
accepted_index
.
count_nonzero
(
dim
=
1
).
add_
(
-
1
)
# b
# Drop non-terminal prefill chunks hidden states.
hidden_states
=
hidden_states
[
accepted_index
!=
VLLM_INVALID_TOKEN_ID
]
accepted_index
=
accepted_index
[
accepted_index
!=
VLLM_INVALID_TOKEN_ID
]
assert
len
(
accepted_index
)
==
hidden_states
.
shape
[
0
]
==
len
(
terminal_metadata
)
index
=
accepted_index
[:,
None
,
None
].
expand
(
-
1
,
1
,
hs_size
)
# b x 1 x d
second_last_token_hidden_states
=
hidden_states
[:,
-
2
]
# b x d
hidden_states
=
hidden_states
.
gather
(
1
,
index
).
squeeze
(
1
)
# b x d
# Store hidden states from target model for subsequent decode step
self
.
previous_hidden_states
=
HiddenStates
(
hidden_states
,
terminal_metadata
,
second_last_token_hidden_states
)
else
:
retrieve_indices
=
proposals
.
retrieve_indices
batch_size
=
len
(
seq_group_metadata_list
)
best_candidates
=
sampler_extra_kwargs
[
"best_candidates"
]
accept_lengths
=
sampler_extra_kwargs
[
"accept_lengths"
]
# Contract hidden states based on accepted tokens
hs_size
=
hidden_states
.
shape
[
-
1
]
hidden_states
=
hidden_states
.
view
(
batch_size
,
-
1
,
hs_size
)
# Store logits from target model for subsequent proposal
logits
=
proposal_scores
.
logits
logits
=
logits
.
view
(
batch_size
,
-
1
,
logits
.
shape
[
-
1
])
logits
=
logits
[:,
retrieve_indices
]
# [batch_size, retrieve_size, max_depth, vocab_size]
previous_logits_list
=
[]
previous_hidden_state_list
=
[]
retrieve_indices
=
retrieve_indices
.
cpu
()
for
i
in
range
(
batch_size
):
logit
=
logits
[
i
,
best_candidates
[
i
],
accept_lengths
[
i
]].
unsqueeze
(
0
)
previous_logits_list
.
append
(
logit
)
select_indices
=
retrieve_indices
[
best_candidates
[
i
],
:
accept_lengths
[
i
]
+
1
]
hidden_state
=
hidden_states
[
i
,
select_indices
[
-
1
]].
unsqueeze
(
0
)
select_indices_list
.
append
(
select_indices
)
previous_hidden_state_list
.
append
(
hidden_state
)
logits
=
torch
.
cat
(
previous_logits_list
,
dim
=
0
)
self
.
previous_logits
=
Logits
(
logits
,
seq_group_metadata_list
)
hidden_states
=
torch
.
cat
(
previous_hidden_state_list
,
dim
=
0
)
# [batch_size, 1, vocab_size]
self
.
previous_hidden_states
=
HiddenStates
(
hidden_states
,
seq_group_metadata_list
,)
return
accepted_token_ids
,
logprobs
,
select_indices_list
,
accept_lengths
def
_create_output_sampler_list
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
accepted_token_ids
:
torch
.
Tensor
,
# shape: [batch_size, k+1]
target_logprobs
:
torch
.
Tensor
,
# shape: [batch_size, k+1, vocab_size]
prompt_logprobs
:
Optional
[
torch
.
Tensor
],
# shape: [nprompt_tokens, vocab_size]
k
:
int
,
stage_times
:
Tuple
[
float
,
float
,
float
],
)
->
List
[
SamplerOutput
]:
"""Given the accepted token ids, create a list of SamplerOutput.
The output is padded with -1 tokens such that each sequence has
the same number of outputs.
"""
batch_size
,
num_steps
=
accepted_token_ids
.
shape
accepted_token_ids_by_step
=
accepted_token_ids
.
transpose
(
0
,
1
)
if
self
.
_disable_logprobs
:
# We are skipping the logprobs. Hence don't serialize the
# logprobs related tensors from the GPU. Instead create
# empty/dummy lists.
(
accepted_token_id_ranks_by_step
,
accepted_token_id_logprobs_by_step
,
topk_logprobs_by_step
,
topk_indices_by_step
)
=
\
self
.
_create_dummy_logprob_lists
(
batch_size
,
num_steps
,
self
.
scorer_worker
.
model_config
.
max_logprobs
)
else
:
# Organize input tensors by step instead of by sequence.
target_logprobs_by_step
=
target_logprobs
.
transpose
(
0
,
1
)
# Serialize all tensors into Python lists.
(
accepted_token_id_ranks_by_step
,
accepted_token_id_logprobs_by_step
,
topk_logprobs_by_step
,
topk_indices_by_step
)
=
\
self
.
_create_logprob_lists_from_tensors
(
target_logprobs_by_step
,
accepted_token_ids_by_step
,
self
.
scorer_worker
.
model_config
.
max_logprobs
)
# Get the sequence ids and num_logprobs (sampling parameter) in the
# batch.
seq_ids
,
request_ids_seq_ids_mapping
=
get_all_seq_ids_and_request_ids
(
seq_group_metadata_list
)
num_logprobs_per_seq
=
get_all_num_logprobs
(
seq_group_metadata_list
)
# Serialize tensor to CPU Python list.
#print('###accepted_token_ids_by_step', accepted_token_ids_by_step)
# Construct the output on a per-step, per-sequence basis.
# Non-terminal prefill chunks will end up here as rows with just -1s
# i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]] while
# terminal chunks will only have one generated token at time 0.
sampler_output_list
:
List
[
SamplerOutput
]
=
[]
# Prefills are not multi-step (return at most 1 token), in order to
# avoid padding or repetition to fit decodes, we separate them.
for
i
,
sg
in
enumerate
(
seq_group_metadata_list
):
if
not
sg
.
is_prompt
:
# Requests are ordered as prefills|decodes=>no more prefills.
break
num_logprobs
=
num_logprobs_per_seq
[
i
]
seq_kwargs
=
dict
(
token_id
=-
1
,
token_id_logprob_rank
=
0
,
token_id_logprob
=-
float
(
'inf'
),
topk_token_ids
=
[
-
1
]
*
num_logprobs
,
topk_logprobs
=
[
-
float
(
'inf'
)]
*
num_logprobs
,
seq_id
=
seq_ids
[
i
])
# Terminal chunk, has token.
if
sg
.
do_sample
:
seq_kwargs
.
update
(
dict
(
token_id
=
accepted_token_ids
[
i
][
0
].
item
(),
token_id_logprob_rank
=
accepted_token_id_ranks_by_step
[
0
][
i
],
token_id_logprob
=
accepted_token_id_logprobs_by_step
[
0
]
[
i
],
topk_token_ids
=
topk_indices_by_step
[
0
][
i
]
[:
num_logprobs
],
# output only so step is 0
topk_logprobs
=
topk_logprobs_by_step
[
0
][
i
]
[:
num_logprobs
],
))
needs_plogs
=
(
sg
.
sampling_params
.
prompt_logprobs
and
sg
.
sampling_params
.
prompt_logprobs
>
0
)
plogs
=
None
if
prompt_logprobs
is
not
None
:
# Even non-terminal prompt chunks can have logprobs here.
plogs
=
prompt_logprobs
[
i
]
elif
needs_plogs
:
# Prompt logprobs are requested but `_disable_logprobs` is set.
seq_data
=
next
(
iter
(
sg
.
seq_data
.
values
()))
# Get only the tokens in this chunk!
prompt_token_ids
=
seq_data
.
get_prompt_token_ids
()
prompt_token_ids
=
prompt_token_ids
[
seq_data
.
_num_computed_tokens
:
seq_data
.
_num_computed_tokens
+
sg
.
token_chunk_size
]
is_first_chunk
=
seq_data
.
_num_computed_tokens
==
0
# There's no prob generated for the first token in a sequence.
if
is_first_chunk
:
prompt_token_ids
=
prompt_token_ids
[
1
:]
plogs
=
[
create_logprobs_output
(
token_id
=
p_token_id
,
token_id_logprob_rank
=-
1
,
token_id_logprob
=
0.0
,
topk_token_ids
=
[],
topk_logprobs
=
[],
)
for
p_token_id
in
prompt_token_ids
]
seq_kwargs
.
update
(
dict
(
prompt_logprobs
=
plogs
))
sampler_output_list
.
append
(
SamplerOutput
(
outputs
=
[
create_sequence_group_output
(
**
seq_kwargs
)]))
# type: ignore
# Decodes, create one SamplerOutput per-step (at most K+1).
for
step_index
in
range
(
num_steps
):
# if all(token_id == -1 for sg, token_id in zip(
# seq_group_metadata_list,
# accepted_token_ids_by_step[step_index])
# if not sg.is_prompt):
# break
step_output_token_ids
:
List
[
CompletionSequenceGroupOutput
]
=
[]
for
sequence_index
in
range
(
batch_size
):
seq_meta
=
seq_group_metadata_list
[
sequence_index
]
# Prompts already processed above.
if
seq_meta
.
is_prompt
:
continue
# Each sequence may have a different num_logprobs; retrieve it.
num_logprobs
=
num_logprobs_per_seq
[
sequence_index
]
step_output_token_ids
.
append
(
create_sequence_group_output
(
token_id
=
accepted_token_ids_by_step
[
step_index
]
[
sequence_index
],
token_id_logprob_rank
=
accepted_token_id_ranks_by_step
[
step_index
][
sequence_index
],
token_id_logprob
=
accepted_token_id_logprobs_by_step
[
step_index
][
sequence_index
],
seq_id
=
seq_ids
[
sequence_index
],
topk_token_ids
=
topk_indices_by_step
[
step_index
]
[
sequence_index
][:
num_logprobs
],
topk_logprobs
=
topk_logprobs_by_step
[
step_index
]
[
sequence_index
][:
num_logprobs
],
))
sampler_output_list
.
append
(
SamplerOutput
(
outputs
=
step_output_token_ids
))
# Populate the data structures needed to keep track of sequences with
# bonus tokens.
self
.
_track_sequences_with_bonus_tokens
(
seq_ids
,
request_ids_seq_ids_mapping
,
accepted_token_ids_by_step
)
maybe_rejsample_metrics
=
(
self
.
_metrics
.
maybe_collect_rejsample_metrics
(
k
))
if
maybe_rejsample_metrics
is
not
None
and
sampler_output_list
:
sampler_output_list
[
0
].
spec_decode_worker_metrics
=
maybe_rejsample_metrics
# Log time spent in each stage periodically.
# This is periodic because the rejection sampler emits metrics
# periodically.
self
.
_maybe_log_stage_times
(
*
stage_times
)
# First `n_prefills` entries will contain prefills SamplerOutput when
# chunked prefill is enabled, the rest is decodes in multi-step format.
return
sampler_output_list
def
_track_sequences_with_bonus_tokens
(
self
,
seq_ids
:
List
[
int
],
request_ids_seq_ids_mapping
:
Dict
[
str
,
Set
[
int
]],
accepted_token_ids_by_step
:
List
[
List
[
int
]]):
"""
Updates the internal data structures which keep track of sequences
which have been assigned bonus tokens in their last forward pass.
"""
for
seq_index
,
seq_id
in
enumerate
(
seq_ids
):
# last_token_id = accepted_token_ids_by_step[-1][seq_index]
# if last_token_id == -1:
# self._seq_with_bonus_token_in_last_step.discard(seq_id)
# else:
self
.
_seq_with_bonus_token_in_last_step
.
add
(
seq_id
)
for
request_id
,
sequences
in
request_ids_seq_ids_mapping
.
items
():
self
.
_request_id_seq_id_mapping
[
request_id
].
update
(
sequences
)
\ No newline at end of file
vllm/zero_overhead/spec_decode/top1_proproser.py
0 → 100644
View file @
01c30741
import
os
from
typing
import
List
,
Optional
,
Set
,
Tuple
import
torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SequenceGroupMetadata
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.spec_decode.util
import
sampler_output_to_torch
from
vllm.utils
import
async_tensor_h2d
class
ZeroOverheadTop1Proposer
(
Top1Proposer
):
def
_merge_outputs
(
self
,
batch_size
:
int
,
proposal_len
:
int
,
maybe_sampler_output
:
Optional
[
List
[
SamplerOutput
]],
proposal_lens
:
List
[
int
],
nonzero_proposal_len_indices
:
List
[
int
],
sampler_transposed
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
if
maybe_sampler_output
is
None
:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens
=
torch
.
tensor
(
-
1
,
dtype
=
torch
.
long
,
device
=
self
.
_device
).
expand
(
batch_size
,
proposal_len
)
proposal_probs
=
torch
.
tensor
(
0
,
dtype
=
torch
.
float32
,
device
=
self
.
_device
).
expand
(
batch_size
,
proposal_len
,
self
.
_vocab_size
)
proposal_lens_tensor
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
self
.
_device
).
expand
(
len
(
proposal_lens
))
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
sampler_output
=
maybe_sampler_output
proposal_tokens
,
proposal_probs
,
*
_
=
sampler_output_to_torch
(
sampler_output
,
sampler_transposed
)
nonzero_proposal_len_indices
=
async_tensor_h2d
(
nonzero_proposal_len_indices
,
torch
.
int32
,
self
.
_device
,
True
)
proposal_len
=
[
proposal_len
for
i
in
range
(
batch_size
)]
proposal_len
=
async_tensor_h2d
(
proposal_len
,
torch
.
long
,
self
.
_device
,
True
)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens
=
proposal_tokens
.
new_full
(
size
=
(
batch_size
,
*
proposal_tokens
.
shape
[
1
:]),
fill_value
=-
1
,
)
entire_proposal_tokens
[
nonzero_proposal_len_indices
]
=
proposal_tokens
entire_proposal_probs
=
proposal_probs
.
new_zeros
(
batch_size
,
*
proposal_probs
.
shape
[
1
:],
)
entire_proposal_probs
[
nonzero_proposal_len_indices
]
=
proposal_probs
proposal_tokens
,
proposal_probs
=
(
entire_proposal_tokens
,
entire_proposal_probs
,
)
proposal_lens_tensor
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
long
,
device
=
self
.
_device
)
proposal_lens_tensor
[
nonzero_proposal_len_indices
]
=
proposal_len
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
\ No newline at end of file
vllm/zero_overhead/
v0/
stop_check.py
→
vllm/zero_overhead/stop_check.py
View file @
01c30741
...
@@ -5,7 +5,7 @@ from vllm.engine.output_processor.stop_checker import StopChecker
...
@@ -5,7 +5,7 @@ from vllm.engine.output_processor.stop_checker import StopChecker
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
SequenceStatus
from
vllm.sequence
import
SequenceStatus
from
vllm.zero_overhead.
v0.
sequence
import
ZeroOverheadSequence
from
vllm.zero_overhead.sequence
import
ZeroOverheadSequence
class
ZeroOverheadStopChecker
(
StopChecker
):
class
ZeroOverheadStopChecker
(
StopChecker
):
...
...
vllm/zero_overhead/
v0/
tokenizer.py
→
vllm/zero_overhead/tokenizer.py
View file @
01c30741
...
@@ -4,7 +4,7 @@ from vllm.sampling_params import SamplingParams
...
@@ -4,7 +4,7 @@ from vllm.sampling_params import SamplingParams
from
vllm.sequence
import
VLLM_INVALID_TOKEN_ID
from
vllm.sequence
import
VLLM_INVALID_TOKEN_ID
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.detokenizer_utils
import
convert_prompt_ids_to_tokens
,
detokenize_incrementally
from
vllm.transformers_utils.detokenizer_utils
import
convert_prompt_ids_to_tokens
,
detokenize_incrementally
from
vllm.zero_overhead.
v0.
sequence
import
ZeroOverheadSequence
from
vllm.zero_overhead.sequence
import
ZeroOverheadSequence
class
ZeroOverheadDetokenizer
(
Detokenizer
):
class
ZeroOverheadDetokenizer
(
Detokenizer
):
...
...
vllm/zero_overhead/
v0/
utils.py
→
vllm/zero_overhead/utils.py
View file @
01c30741
...
@@ -9,12 +9,4 @@ def is_zero_overhead():
...
@@ -9,12 +9,4 @@ def is_zero_overhead():
return
zero_overhead
return
zero_overhead
def
is_zero_no_thread
():
def
is_zero_no_thread
():
return
zero_no_thread
and
zero_overhead
return
zero_no_thread
and
zero_overhead
\ No newline at end of file
def
UpdateInputTokens
(
input_tokens
,
last_sample
,
indices
):
global
_update_input_tokens_ptr
grid
=
[
input_tokens
.
shape
[
0
],
1
,
1
]
if
_update_input_tokens_ptr
is
None
:
_update_input_tokens_ptr
=
_update_input_tokens
[
grid
](
last_sample
,
input_tokens
,
indices
,
input_tokens
.
shape
[
0
])
else
:
_update_input_tokens_ptr
[
grid
](
last_sample
,
input_tokens
,
indices
,
input_tokens
.
shape
[
0
])
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment