Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
01c30741
Commit
01c30741
authored
Apr 27, 2025
by
lizhigong
Browse files
add spec decode zero overhead
parent
b01c8270
Changes
16
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
905 additions
and
30 deletions
+905
-30
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+2
-2
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+2
-2
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+2
-2
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+31
-13
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+2
-2
vllm/zero_overhead/llm_engine.py
vllm/zero_overhead/llm_engine.py
+5
-5
vllm/zero_overhead/model_runner.py
vllm/zero_overhead/model_runner.py
+2
-2
vllm/zero_overhead/sampler.py
vllm/zero_overhead/sampler.py
+0
-0
vllm/zero_overhead/sequence.py
vllm/zero_overhead/sequence.py
+0
-0
vllm/zero_overhead/spec_decode/batch_expansion.py
vllm/zero_overhead/spec_decode/batch_expansion.py
+142
-0
vllm/zero_overhead/spec_decode/muti_step_worker.py
vllm/zero_overhead/spec_decode/muti_step_worker.py
+138
-0
vllm/zero_overhead/spec_decode/spec_decode_worker.py
vllm/zero_overhead/spec_decode/spec_decode_worker.py
+482
-0
vllm/zero_overhead/spec_decode/top1_proproser.py
vllm/zero_overhead/spec_decode/top1_proproser.py
+83
-0
vllm/zero_overhead/stop_check.py
vllm/zero_overhead/stop_check.py
+1
-1
vllm/zero_overhead/tokenizer.py
vllm/zero_overhead/tokenizer.py
+1
-1
vllm/zero_overhead/utils.py
vllm/zero_overhead/utils.py
+12
-0
No files found.
vllm/engine/multiprocessing/engine.py
View file @
01c30741
...
...
@@ -6,8 +6,8 @@ from contextlib import contextmanager
from
typing
import
Iterator
,
List
,
Optional
,
Union
import
cloudpickle
from
vllm.zero_overhead.
v0.
llm_engine
import
ZeroOverheadEngine
from
vllm.zero_overhead.
v0.
utils
import
is_zero_overhead
from
vllm.zero_overhead.llm_engine
import
ZeroOverheadEngine
from
vllm.zero_overhead.utils
import
is_zero_overhead
import
zmq
from
vllm
import
AsyncEngineArgs
,
SamplingParams
...
...
vllm/entrypoints/llm.py
View file @
01c30741
...
...
@@ -43,8 +43,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
(
Counter
,
Device
,
deprecate_args
,
deprecate_kwargs
,
is_list_of
)
from
vllm.zero_overhead.
v0.
llm_engine
import
ZeroOverheadEngine
from
vllm.zero_overhead.
v0.
utils
import
is_zero_overhead
from
vllm.zero_overhead.llm_engine
import
ZeroOverheadEngine
from
vllm.zero_overhead.utils
import
is_zero_overhead
logger
=
init_logger
(
__name__
)
...
...
vllm/model_executor/layers/sampler.py
View file @
01c30741
...
...
@@ -21,7 +21,7 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput
,
Logprob
,
PromptLogprobs
,
SampleLogprobs
,
SequenceOutput
)
from
vllm.spec_decode.metrics
import
SpecDecodeWorkerMetrics
from
vllm.zero_overhead.
v0.
utils
import
is_zero_overhead
from
vllm.zero_overhead.utils
import
is_zero_overhead
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
find_spec
(
"flashinfer"
):
import
flashinfer.sampling
...
...
@@ -40,7 +40,7 @@ def get_sampler() -> torch.nn.Module:
from
vllm.v1.sample.sampler
import
Sampler
as
V1Sampler
return
V1Sampler
()
if
is_zero_overhead
():
from
vllm.zero_overhead.
v0.
sampler
import
ZeroOverheadSampler
from
vllm.zero_overhead.sampler
import
ZeroOverheadSampler
return
ZeroOverheadSampler
()
return
Sampler
()
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
01c30741
...
...
@@ -54,6 +54,7 @@ from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.zero_overhead.utils
import
is_zero_overhead
logger
=
init_logger
(
__name__
)
...
...
@@ -206,7 +207,10 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
# Load lm_head weight for eagle in init_device
if
draft_model_config
.
hf_config
.
model_type
==
"eagle"
:
enable_lm_head_weight_load
=
True
if
is_zero_overhead
():
from
vllm.zero_overhead.spec_decode.muti_step_worker
import
ZeroOverheadMultiStepWorker
proposer_worker
=
ZeroOverheadMultiStepWorker
(
**
draft_worker_kwargs
)
else
:
proposer_worker
=
MultiStepWorker
(
**
draft_worker_kwargs
)
if
draft_model_config
.
hf_config
.
model_type
==
"deepseek_mtp"
:
num_spec_prefill_steps
=
\
...
...
@@ -254,6 +258,20 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
"[Speculative Decoding] Disabling MQA scorer as the "
"target model is not running in eager mode."
)
if
is_zero_overhead
():
from
vllm.zero_overhead.spec_decode.spec_decode_worker
import
ZeroOverheadSpecDecodeWorker
return
ZeroOverheadSpecDecodeWorker
(
proposer_worker
,
scorer_worker
,
disable_mqa_scorer
=
disable_mqa_scorer
,
disable_logprobs
=
disable_logprobs
,
disable_log_stats
=
disable_log_stats
,
disable_by_batch_size
=
disable_by_batch_size
,
spec_decode_sampler
=
spec_decode_sampler
,
allow_zero_draft_token_step
=
allow_zero_draft_token_step
,
enable_lm_head_weight_load
=
enable_lm_head_weight_load
,
num_spec_prefill_steps
=
num_spec_prefill_steps
)
else
:
return
SpecDecodeWorker
(
proposer_worker
,
scorer_worker
,
...
...
vllm/worker/model_runner.py
View file @
01c30741
...
...
@@ -60,7 +60,7 @@ from vllm.worker.model_runner_base import (
_add_sampling_metadata_broadcastable_dict
,
_init_attn_metadata_from_tensor_dict
,
_init_sampling_metadata_from_tensor_dict
)
from
vllm.zero_overhead.
v0.
utils
import
is_zero_overhead
from
vllm.zero_overhead.utils
import
is_zero_overhead
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
...
...
@@ -1638,7 +1638,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
ModelInputForGPUWithSamplingMetadata
)
_builder_cls
:
Type
[
ModelInputForGPUBuilder
]
=
ModelInputForGPUBuilder
if
is_zero_overhead
():
from
vllm.zero_overhead.
v0.
model_runner
import
ZeroOverheadModelInputForGpuBuilder
from
vllm.zero_overhead.model_runner
import
ZeroOverheadModelInputForGpuBuilder
_builder_cls
=
ZeroOverheadModelInputForGpuBuilder
def
make_model_input_from_broadcasted_tensor_dict
(
...
...
vllm/zero_overhead/
v0/
llm_engine.py
→
vllm/zero_overhead/llm_engine.py
View file @
01c30741
...
...
@@ -33,14 +33,14 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.usage.usage_lib
import
UsageContext
,
is_usage_stats_enabled
from
vllm.utils
import
resolve_obj_by_qualname
,
weak_bind
,
Counter
from
vllm.zero_overhead.
v0.
sampler
import
SampleRecorder
,
get_last_sampler
from
vllm.zero_overhead.
v0.
sequence
import
ZeroOverheadSequence
from
vllm.zero_overhead.
v0.
stop_check
import
ZeroOverheadStopChecker
from
vllm.zero_overhead.
v0.
tokenizer
import
ZeroOverheadDetokenizer
from
vllm.zero_overhead.sampler
import
SampleRecorder
,
get_last_sampler
from
vllm.zero_overhead.sequence
import
ZeroOverheadSequence
from
vllm.zero_overhead.stop_check
import
ZeroOverheadStopChecker
from
vllm.zero_overhead.tokenizer
import
ZeroOverheadDetokenizer
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
from
vllm.profiler.prof
import
profile
from
vllm.zero_overhead.
v0.
utils
import
is_zero_no_thread
from
vllm.zero_overhead.utils
import
is_zero_no_thread
logger
=
init_logger
(
__name__
)
...
...
vllm/zero_overhead/
v0/
model_runner.py
→
vllm/zero_overhead/model_runner.py
View file @
01c30741
...
...
@@ -10,8 +10,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.utils
import
async_tensor_h2d
,
flatten_2d_lists
from
vllm.worker.model_runner
import
ModelInputForGPU
,
ModelInputForGPUBuilder
from
vllm.zero_overhead.
v0.
sampler
import
get_last_sampler
from
vllm.zero_overhead.
v0.
update_input
import
UpdateInputTokens
from
vllm.zero_overhead.sampler
import
get_last_sampler
from
vllm.zero_overhead.update_input
import
UpdateInputTokens
class
ZeroOverheadModelInputForGpuBuilder
(
ModelInputForGPUBuilder
):
...
...
vllm/zero_overhead/
v0/
sampler.py
→
vllm/zero_overhead/sampler.py
View file @
01c30741
File moved
vllm/zero_overhead/
v0/
sequence.py
→
vllm/zero_overhead/sequence.py
View file @
01c30741
File moved
vllm/zero_overhead/spec_decode/batch_expansion.py
0 → 100644
View file @
01c30741
from
array
import
array
import
numpy
as
np
from
itertools
import
chain
,
count
from
typing
import
Iterator
,
List
,
Optional
,
Tuple
import
torch
from
vllm
import
SamplingParams
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
VLLM_TOKEN_ID_ARRAY_TYPE
,
ExecuteModelRequest
,
SequenceData
,
SequenceGroupMetadata
,
get_all_seq_ids
)
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.util
import
nvtx_range
,
split_batch_by_proposal_len
from
vllm.utils
import
async_tensor_h2d
SeqId
=
int
TargetSeqId
=
int
TokenId
=
int
DEFAULT_SIMPLE_SAMPLING_PARAMS
=
SamplingParams
()
class
ZeroOverheadBatchExpansionTop1Scorer
(
BatchExpansionTop1Scorer
):
@
nvtx_range
(
"BatchExpansionTop1Scorer.score_proposals"
)
def
score_proposals
(
self
,
execute_model_req
:
ExecuteModelRequest
,
proposals
:
SpeculativeProposals
,
)
->
SpeculativeScores
:
"""Score the proposed tokens via the scorer model.
This converts each input sequence to a set of k+1 target sequences. The
target sequences have the unique continuations to be scored and a
unique sequence ID that is different from all input sequence ids.
If a speculative sequence length would exceed the max model length, then
no speculation is produced for that sequence.
Args:
execute_model_req: The execution request.
proposals: The speculative proposals to score.
Returns:
SpeculativeScores: The scores of each speculative token, along with
which sequences were ignored during scoring.
"""
proposal_lens_list
=
np
.
zeros
(
proposals
.
proposal_lens
.
shape
,
dtype
=
int
).
tolist
()
#zero_overhead todo fix
proposal_token_ids_list
=
np
.
zeros
(
proposals
.
proposal_token_ids
.
shape
,
dtype
=
int
).
tolist
()
# Filter the list to ignore invalid proposals.
proposal_token_ids_list_without_skips
=
[
proposals
for
proposals
in
proposal_token_ids_list
if
VLLM_INVALID_TOKEN_ID
not
in
proposals
]
(
spec_indices
,
non_spec_indices
,
target_seq_group_metadata_list
,
num_scoring_tokens
)
=
self
.
_expand_batch
(
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
,
proposal_token_ids_list
=
proposal_token_ids_list_without_skips
,
proposal_lens_list
=
proposal_lens_list
,
)
#print('###execute_model_req', execute_model_req)
#print('###target_seq_group_metadata_list', target_seq_group_metadata_list)
target_sampler_output
=
self
.
_scorer_worker
.
execute_model
(
execute_model_req
=
execute_model_req
.
clone
(
seq_group_metadata_list
=
target_seq_group_metadata_list
))
assert
len
(
target_sampler_output
)
==
1
,
"expected single-step output"
target_sampler_output
=
target_sampler_output
[
0
]
#print('###target_sampler_output', target_sampler_output)
if
not
non_spec_indices
:
# All sequence groups in batch have spec decoding enabled
return
self
.
_contract_batch_all_spec
(
target_sampler_output
=
target_sampler_output
,
proposals
=
proposals
,
)
else
:
# Batch has a mix of spec decode enabled and disabled seq groups
return
self
.
_contract_batch
(
execute_model_req
.
seq_group_metadata_list
,
target_sampler_output
=
target_sampler_output
,
proposals
=
proposals
,
num_scoring_tokens
=
num_scoring_tokens
,
non_spec_indices
=
non_spec_indices
,
spec_indices
=
spec_indices
,
k
=
execute_model_req
.
num_lookahead_slots
,
)
def
_contract_non_speculative
(
self
,
scores
:
SpeculativeScores
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
non_spec_indices
:
List
[
int
],
non_spec_outputs
:
SpeculativeScores
,
has_prompt_log
:
bool
)
->
SpeculativeScores
:
"""
Augment input `scores` with non-speculative requests outputs.
This includes decode requests with speculation turned off, as well
as prefill requests when `enable_chunked_prefill` is set.
For the latter, prefills are further separated into terminal and
non-terminal chunks (from which no token is sampled).
"""
if
not
non_spec_indices
:
return
scores
if
has_prompt_log
:
# When prompt_logprobs is enabled, prefills yield output token
# (and respective prob) in the last entry (prompt|out):
# [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..].
# With chunked prefill, non-terminal chunks have -1 on each
# position: they're still picked, but they're discarded later.
seq_meta
=
seq_group_metadata_list
nospec_sizes
=
torch
.
tensor
([
seq_meta
[
i
].
token_chunk_size
if
seq_meta
[
i
].
is_prompt
else
1
for
i
in
non_spec_indices
])
nospec_sampled_token_idxs
=
torch
.
cumsum
(
nospec_sizes
,
0
).
add_
(
-
1
)
else
:
# In this case only sampled tokens are returned, select all.
nospec_sampled_token_idxs
=
list
(
range
(
len
(
non_spec_outputs
.
token_ids
)))
nospec_sampled_token_idxs
=
async_tensor_h2d
(
nospec_sampled_token_idxs
,
torch
.
int32
,
self
.
_device
,
True
)
non_spec_indices
=
async_tensor_h2d
(
non_spec_indices
,
torch
.
int32
,
self
.
_device
,
True
)
scores
.
token_ids
[
non_spec_indices
,
:
1
]
=
\
non_spec_outputs
.
token_ids
[
nospec_sampled_token_idxs
].
unsqueeze
(
1
)
scores
.
probs
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_outputs
.
probs
[
nospec_sampled_token_idxs
].
unsqueeze
(
1
)
scores
.
logprobs
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_outputs
.
logprobs
[
nospec_sampled_token_idxs
].
unsqueeze
(
1
)
if
scores
.
hidden_states
is
not
None
:
assert
non_spec_outputs
.
hidden_states
is
not
None
scores
.
hidden_states
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_outputs
.
hidden_states
[
nospec_sampled_token_idxs
].
unsqueeze
(
1
)
return
scores
\ No newline at end of file
vllm/zero_overhead/spec_decode/muti_step_worker.py
0 → 100644
View file @
01c30741
import
copy
import
weakref
from
typing
import
Dict
,
List
,
Set
,
Tuple
import
torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
(
ExecuteModelRequest
,
HiddenStates
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.utils
import
async_tensor_h2d
from
vllm.zero_overhead.spec_decode.top1_proproser
import
ZeroOverheadTop1Proposer
if
current_platform
.
is_cuda_alike
():
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.worker.worker_base
import
DelegateWorkerBase
class
ZeroOverheadMultiStepWorker
(
MultiStepWorker
):
def
init_device
(
self
)
->
None
:
self
.
worker
.
init_device
()
self
.
_proposer
=
ZeroOverheadTop1Proposer
(
weakref
.
proxy
(
self
),
# type: ignore[arg-type]
self
.
device
,
self
.
vocab_size
,
max_proposal_len
=
self
.
max_model_len
,
)
@
torch
.
inference_mode
()
def
sampler_output
(
self
,
execute_model_req
:
ExecuteModelRequest
,
sample_len
:
int
,
seq_ids_with_bonus_token_in_last_step
:
Set
[
int
],
)
->
Tuple
[
List
[
SamplerOutput
],
bool
]:
"""Run the model forward pass sample_len times. Returns the list of
sampler output, one per model forward pass, along with indicator of
whether torch tensor in sampler output need to be transposed in latter
sampler_output_to_torch logic.
For multi step worker, this indicator shall be True.
"""
print
(
'###execute_model_req'
,
execute_model_req
)
self
.
_raise_if_unsupported
(
execute_model_req
)
# Expand the batch for sequences with a bonus token.
# Perform a forward pass on the expanded batch and filter the
# response to retain only the original sequences' responses.
expanded_request
,
indices_of_seq_with_bonus_tokens
=
\
self
.
_expand_execute_model_request
(
execute_model_req
,
seq_ids_with_bonus_token_in_last_step
)
# Run model sample_len times.
model_outputs
:
List
[
SamplerOutput
]
=
[]
if
current_platform
.
is_cuda_alike
()
and
isinstance
(
self
.
model_runner
,
TP1DraftModelRunner
)
and
self
.
model_runner
.
supports_gpu_multi_step
(
expanded_request
):
# Here we run the draft_model_runner with multi-step prepare
# on the GPU directly
expanded_request
.
num_steps
=
sample_len
self
.
model_runner
.
set_indices_of_seq_with_bonus_tokens
(
indices_of_seq_with_bonus_tokens
)
model_outputs
=
self
.
execute_model
(
execute_model_req
=
expanded_request
)
else
:
# Here we run multi-step directly, with every step prepared
# on the CPU.
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
for
_
in
range
(
sample_len
):
print
(
'###self.worker.execute_model'
,
sample_len
)
print
(
'###expanded_request'
,
expanded_request
)
model_output
:
List
[
SamplerOutput
]
=
self
.
worker
.
execute_model
(
execute_model_req
=
expanded_request
)
assert
(
len
(
model_output
)
==
1
),
"composing multistep workers not supported"
model_output
=
model_output
[
0
]
print
(
'###model_output'
,
model_output
)
self
.
_append_new_tokens
(
model_output
,
expanded_request
.
seq_group_metadata_list
,
indices_of_seq_with_bonus_tokens
)
model_outputs
.
append
(
model_output
)
filtered_model_outputs
=
self
.
_filter_model_output_zero_overhead
(
model_outputs
,
indices_of_seq_with_bonus_tokens
)
return
filtered_model_outputs
,
True
def
_filter_model_output_zero_overhead
(
self
,
expanded_batch_outputs
:
List
[
SamplerOutput
],
output_indices_to_retain
:
List
[
int
])
->
List
[
SamplerOutput
]:
"""
Filters the model output to include only the specified sequence
outputs. This method contracts the expanded batch output from the
model to retain the outputs of only those sequences indicated by the
provided indices.
Args:
expanded_batch_output (List[SamplerOutput]): The expanded output
batch from the model.
output_indices_to_retain (torch.Tensor): Indices of the model
outputs to retain.
Returns:
List[SamplerOutput]: A list containing the filtered model
outputs for the specified indices.
"""
indices_of_seq_with_bonus_tokens
=
async_tensor_h2d
(
output_indices_to_retain
,
torch
.
int32
,
self
.
device
,
True
)
return
[
SamplerOutput
(
outputs
=
[
expanded_batch_output
.
outputs
[
i
]
for
i
in
output_indices_to_retain
]
if
len
(
expanded_batch_output
.
outputs
)
>
0
else
[],
sampled_token_probs
=
(
expanded_batch_output
.
sampled_token_probs
[
indices_of_seq_with_bonus_tokens
]
if
expanded_batch_output
.
sampled_token_probs
is
not
None
else
None
),
logprobs
=
(
expanded_batch_output
.
logprobs
[
indices_of_seq_with_bonus_tokens
]
if
expanded_batch_output
.
logprobs
is
not
None
else
None
),
sampled_token_ids
=
(
expanded_batch_output
.
sampled_token_ids
[
indices_of_seq_with_bonus_tokens
]
if
expanded_batch_output
.
sampled_token_ids
is
not
None
else
None
))
for
expanded_batch_output
in
expanded_batch_outputs
]
\ No newline at end of file
vllm/zero_overhead/spec_decode/spec_decode_worker.py
0 → 100644
View file @
01c30741
This diff is collapsed.
Click to expand it.
vllm/zero_overhead/spec_decode/top1_proproser.py
0 → 100644
View file @
01c30741
import
os
from
typing
import
List
,
Optional
,
Set
,
Tuple
import
torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SequenceGroupMetadata
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.spec_decode.util
import
sampler_output_to_torch
from
vllm.utils
import
async_tensor_h2d
class
ZeroOverheadTop1Proposer
(
Top1Proposer
):
def
_merge_outputs
(
self
,
batch_size
:
int
,
proposal_len
:
int
,
maybe_sampler_output
:
Optional
[
List
[
SamplerOutput
]],
proposal_lens
:
List
[
int
],
nonzero_proposal_len_indices
:
List
[
int
],
sampler_transposed
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
if
maybe_sampler_output
is
None
:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens
=
torch
.
tensor
(
-
1
,
dtype
=
torch
.
long
,
device
=
self
.
_device
).
expand
(
batch_size
,
proposal_len
)
proposal_probs
=
torch
.
tensor
(
0
,
dtype
=
torch
.
float32
,
device
=
self
.
_device
).
expand
(
batch_size
,
proposal_len
,
self
.
_vocab_size
)
proposal_lens_tensor
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
self
.
_device
).
expand
(
len
(
proposal_lens
))
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
sampler_output
=
maybe_sampler_output
proposal_tokens
,
proposal_probs
,
*
_
=
sampler_output_to_torch
(
sampler_output
,
sampler_transposed
)
nonzero_proposal_len_indices
=
async_tensor_h2d
(
nonzero_proposal_len_indices
,
torch
.
int32
,
self
.
_device
,
True
)
proposal_len
=
[
proposal_len
for
i
in
range
(
batch_size
)]
proposal_len
=
async_tensor_h2d
(
proposal_len
,
torch
.
long
,
self
.
_device
,
True
)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens
=
proposal_tokens
.
new_full
(
size
=
(
batch_size
,
*
proposal_tokens
.
shape
[
1
:]),
fill_value
=-
1
,
)
entire_proposal_tokens
[
nonzero_proposal_len_indices
]
=
proposal_tokens
entire_proposal_probs
=
proposal_probs
.
new_zeros
(
batch_size
,
*
proposal_probs
.
shape
[
1
:],
)
entire_proposal_probs
[
nonzero_proposal_len_indices
]
=
proposal_probs
proposal_tokens
,
proposal_probs
=
(
entire_proposal_tokens
,
entire_proposal_probs
,
)
proposal_lens_tensor
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
long
,
device
=
self
.
_device
)
proposal_lens_tensor
[
nonzero_proposal_len_indices
]
=
proposal_len
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
\ No newline at end of file
vllm/zero_overhead/
v0/
stop_check.py
→
vllm/zero_overhead/stop_check.py
View file @
01c30741
...
...
@@ -5,7 +5,7 @@ from vllm.engine.output_processor.stop_checker import StopChecker
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
SequenceStatus
from
vllm.zero_overhead.
v0.
sequence
import
ZeroOverheadSequence
from
vllm.zero_overhead.sequence
import
ZeroOverheadSequence
class
ZeroOverheadStopChecker
(
StopChecker
):
...
...
vllm/zero_overhead/
v0/
tokenizer.py
→
vllm/zero_overhead/tokenizer.py
View file @
01c30741
...
...
@@ -4,7 +4,7 @@ from vllm.sampling_params import SamplingParams
from
vllm.sequence
import
VLLM_INVALID_TOKEN_ID
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.detokenizer_utils
import
convert_prompt_ids_to_tokens
,
detokenize_incrementally
from
vllm.zero_overhead.
v0.
sequence
import
ZeroOverheadSequence
from
vllm.zero_overhead.sequence
import
ZeroOverheadSequence
class
ZeroOverheadDetokenizer
(
Detokenizer
):
...
...
vllm/zero_overhead/
v0/
utils.py
→
vllm/zero_overhead/utils.py
View file @
01c30741
...
...
@@ -10,11 +10,3 @@ def is_zero_overhead():
def
is_zero_no_thread
():
return
zero_no_thread
and
zero_overhead
\ No newline at end of file
def
UpdateInputTokens
(
input_tokens
,
last_sample
,
indices
):
global
_update_input_tokens_ptr
grid
=
[
input_tokens
.
shape
[
0
],
1
,
1
]
if
_update_input_tokens_ptr
is
None
:
_update_input_tokens_ptr
=
_update_input_tokens
[
grid
](
last_sample
,
input_tokens
,
indices
,
input_tokens
.
shape
[
0
])
else
:
_update_input_tokens_ptr
[
grid
](
last_sample
,
input_tokens
,
indices
,
input_tokens
.
shape
[
0
])
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment