Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0ecda6d1
Commit
0ecda6d1
authored
May 08, 2025
by
lizhigong
Browse files
debug spec decode zero overhead
parent
01c30741
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
367 additions
and
58 deletions
+367
-58
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+4
-2
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+2
-0
vllm/spec_decode/util.py
vllm/spec_decode/util.py
+1
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+8
-1
vllm/zero_overhead/llm_engine.py
vllm/zero_overhead/llm_engine.py
+52
-11
vllm/zero_overhead/model_runner.py
vllm/zero_overhead/model_runner.py
+131
-15
vllm/zero_overhead/sampler.py
vllm/zero_overhead/sampler.py
+1
-0
vllm/zero_overhead/sequence.py
vllm/zero_overhead/sequence.py
+5
-0
vllm/zero_overhead/spec_decode/batch_expansion.py
vllm/zero_overhead/spec_decode/batch_expansion.py
+4
-5
vllm/zero_overhead/spec_decode/muti_step_worker.py
vllm/zero_overhead/spec_decode/muti_step_worker.py
+5
-6
vllm/zero_overhead/spec_decode/spec_decode_worker.py
vllm/zero_overhead/spec_decode/spec_decode_worker.py
+91
-7
vllm/zero_overhead/spec_decode/top1_proproser.py
vllm/zero_overhead/spec_decode/top1_proproser.py
+10
-9
vllm/zero_overhead/utils.py
vllm/zero_overhead/utils.py
+53
-1
No files found.
vllm/model_executor/layers/sampler.py
View file @
0ecda6d1
...
...
@@ -699,7 +699,7 @@ def _sample_with_torch(
if
sampling_type
==
SamplingType
.
GREEDY
:
greedy_samples
=
torch
.
argmax
(
logprobs
[
long_sample_indices
],
dim
=-
1
)
sampled_token_ids_
=
greedy_samples
.
unsqueeze
(
-
1
)
if
sampled_token_ids_tensor
is
not
None
:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor
[
...
...
@@ -736,7 +736,8 @@ def _sample_with_torch(
probs
[
long_sample_indices
],
max_n_in_batch
,
seq_groups
=
seq_groups_arg
)
sampled_token_ids_
=
\
multinomial_samples
[
sampling_type
].
to
(
torch
.
long
)
if
sampled_token_ids_tensor
is
not
None
:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor
[
long_sample_indices
]
=
\
...
...
@@ -745,6 +746,7 @@ def _sample_with_torch(
else
:
raise
ValueError
(
f
"Unsupported sampling type:
{
sampling_type
}
"
)
print
(
'###sampled_token_ids'
,
sampled_token_ids_
)
# Encapsulate arguments for computing Pythonized sampler
# results, whether deferred or otherwise.
maybe_deferred_args
=
SampleResultArgsType
(
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
0ecda6d1
...
...
@@ -910,6 +910,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
accepted_token_ids
,
target_logprobs
,
select_indices_list
,
accept_lengths
=
self
.
_verify_tokens
(
execute_model_req
.
seq_group_metadata_list
,
proposal_scores
,
proposals
,
execute_model_req
.
num_lookahead_slots
)
print
(
'###accepted_token_ids'
,
accepted_token_ids
)
# move kv_caches of selected tokens to right positions
if
self
.
tree_decoding
:
...
...
@@ -1340,6 +1341,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
self
.
_maybe_log_stage_times
(
*
stage_times
)
# First `n_prefills` entries will contain prefills SamplerOutput when
# chunked prefill is enabled, the rest is decodes in multi-step format.
print
(
'###sampler_output_list'
,
sampler_output_list
)
return
sampler_output_list
def
_maybe_log_stage_times
(
self
,
average_time_per_proposal_tok_ms
:
float
,
...
...
vllm/spec_decode/util.py
View file @
0ecda6d1
...
...
@@ -11,6 +11,7 @@ from vllm.platforms import current_platform
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
PromptLogprobs
,
SequenceGroupMetadata
,
SequenceOutput
)
from
vllm.zero_overhead.utils
import
is_zero_overhead
SeqId
=
int
...
...
@@ -139,7 +140,6 @@ def split_batch_by_proposal_len(
zero or not. We should remove this once vLLM supports per-sequence proposal
lens in a batch.
"""
nonzero_lists
:
Tuple
[
List
[
SequenceGroupMetadata
],
List
[
int
]]
=
([],
[])
zero_lists
:
Tuple
[
List
[
SequenceGroupMetadata
],
List
[
int
]]
=
([],
[])
for
i
,
(
seq_group
,
proposal_len
)
in
enumerate
(
...
...
vllm/worker/model_runner.py
View file @
0ecda6d1
...
...
@@ -902,6 +902,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Tokens and positions.
if
cuda_graph_pad_size
:
input_tokens
.
extend
(
itertools
.
repeat
(
0
,
cuda_graph_pad_size
))
print
(
'###input_tokens'
,
input_tokens
)
assert
self
.
runner
.
device
is
not
None
input_tokens_tensor
=
async_tensor_h2d
(
input_tokens
,
torch
.
long
,
self
.
runner
.
device
,
...
...
@@ -916,12 +917,14 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for
idx
in
range
(
3
):
mrope_input_positions
[
idx
].
extend
(
itertools
.
repeat
(
0
,
cuda_graph_pad_size
))
print
(
'###mrope_input_positions'
,
mrope_input_positions
)
input_positions_tensor
=
async_tensor_h2d
(
mrope_input_positions
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
else
:
input_positions
.
extend
(
itertools
.
repeat
(
0
,
cuda_graph_pad_size
))
print
(
'###input_positions'
,
input_positions
)
input_positions_tensor
=
async_tensor_h2d
(
input_positions
,
torch
.
long
,
self
.
runner
.
device
,
...
...
@@ -929,6 +932,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Sequence and query lengths.
if
cuda_graph_pad_size
:
seq_lens
.
extend
(
itertools
.
repeat
(
1
,
cuda_graph_pad_size
))
print
(
'###seq_lens'
,
seq_lens
)
# Attention metadata.
attn_metadata
=
self
.
attn_metadata_builder
.
build
(
...
...
@@ -987,7 +991,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
]
multi_modal_kwargs
=
MultiModalKwargs
.
batch
(
multi_modal_kwargs_list
)
ret
urn
self
.
model_input_cls
(
ret
=
self
.
model_input_cls
(
input_tokens
=
input_tokens_tensor
,
input_positions
=
input_positions_tensor
,
token_types
=
token_types_tensor
,
...
...
@@ -1001,6 +1005,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
finished_requests_ids
=
self
.
finished_requests_ids
,
prompt_adapter_mapping
=
prompt_adapter_mapping
,
prompt_adapter_requests
=
prompt_adapter_requests
)
print
(
'###model_input'
,
ret
)
return
ret
class
GPUModelRunnerBase
(
ModelRunnerBase
[
TModelInputForGPU
]):
...
...
vllm/zero_overhead/llm_engine.py
View file @
0ecda6d1
...
...
@@ -40,7 +40,7 @@ from vllm.zero_overhead.tokenizer import ZeroOverheadDetokenizer
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
from
vllm.profiler.prof
import
profile
from
vllm.zero_overhead.utils
import
is_zero_no_thread
from
vllm.zero_overhead.utils
import
SpecStepKind
,
get_accepted_token_ids
,
get_spec_step
,
is_zero_no_thread
,
set_spec_step
logger
=
init_logger
(
__name__
)
...
...
@@ -301,7 +301,10 @@ class ZeroOverheadEngine(LLMEngine):
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
if
self
.
last_record
is
not
None
:
last_sampler
=
self
.
last_record
[
1
]
self
.
async_d2h
=
last_sampler
.
sampled_token_ids_tensor
.
to
(
'cpu'
,
non_blocking
=
True
)
if
spec_step
==
SpecStepKind
.
KIND_DEFAULT
:
self
.
async_d2h
=
last_sampler
.
sampled_token_ids_tensor
.
to
(
'cpu'
,
non_blocking
=
True
)
elif
spec_step
==
SpecStepKind
.
SCORE_DECODE
:
self
.
async_d2h
=
last_sampler
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
async_event
.
record
()
self
.
q_recorder
.
put
(
self
.
last_record
)
else
:
...
...
@@ -332,13 +335,18 @@ class ZeroOverheadEngine(LLMEngine):
outputs
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
if
len
(
outputs
)
==
1
:
for
output
in
outputs
:
self
.
_advance_to_next_step
(
output
s
[
0
]
,
seq_group_metadata_list
,
output
,
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
=
[
item
for
item
in
scheduler_outputs
.
scheduled_seq_groups
]
#deep copy
last_sampler
=
get_last_sampler
()
self
.
last_record
=
[
outputs
,
last_sampler
,
seq_group_metadata_list
,
scheduler_outputs
]
last_sampler
=
None
spec_step
=
get_spec_step
()
if
spec_step
==
SpecStepKind
.
KIND_DEFAULT
:
last_sampler
=
get_last_sampler
()
elif
spec_step
==
SpecStepKind
.
SCORE_DECODE
:
last_sampler
,
_
=
get_accepted_token_ids
()
self
.
last_record
=
[
outputs
,
last_sampler
,
seq_group_metadata_list
,
scheduler_outputs
,
spec_step
]
except
Exception
as
e
:
print
(
f
"thread_zero_overhead error :
{
e
}
"
)
...
...
@@ -357,13 +365,19 @@ class ZeroOverheadEngine(LLMEngine):
virtual_engine
=
0
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
ctx
.
request_outputs
.
clear
()
outputs
,
last_sampler
,
seq_group_metadata_list
,
scheduler_outputs
=
recode_output
outputs
,
last_sampler
,
seq_group_metadata_list
,
scheduler_outputs
,
spec_step
=
recode_output
ctx
.
seq_group_metadata_list
=
seq_group_metadata_list
ctx
.
scheduler_outputs
=
scheduler_outputs
self
.
async_event
.
synchronize
()
self
.
_fix_last_step
(
outputs
,
last_sampler
,
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
if
spec_step
==
SpecStepKind
.
KIND_DEFAULT
:
self
.
async_event
.
synchronize
()
self
.
_fix_last_step
(
outputs
,
last_sampler
,
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
elif
spec_step
==
SpecStepKind
.
SCORE_DECODE
:
self
.
async_event
.
synchronize
()
self
.
_fix_spec_decode_steps
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
...
...
@@ -430,6 +444,33 @@ class ZeroOverheadEngine(LLMEngine):
sample
.
output_token
=
token_id
seq
.
fix_last_token_id
(
sample
.
output_token
)
break
def
_fix_spec_decode_steps
(
self
,
output
:
List
[
SamplerOutput
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
scheduled_seq_groups
:
List
[
ScheduledSequenceGroup
]):
sample_out_list
=
self
.
async_d2h
.
tolist
()
group_idx
=
0
for
seq_group_metadata
,
accept_token_ids
,
scheduled_seq_group
in
\
zip
(
seq_group_metadata_list
,
sample_out_list
,
scheduled_seq_groups
):
seq_group
=
scheduled_seq_group
.
seq_group
if
seq_group
.
is_finished
():
group_idx
+=
1
continue
if
seq_group_metadata
.
do_sample
:
assert
len
(
seq_group
.
seqs
)
==
1
seq
:
ZeroOverheadSequence
=
seq_group
.
seqs
[
0
]
remove_count
=
0
for
token_id
in
accept_token_ids
:
if
token_id
==
-
1
:
remove_count
+=
1
else
:
seq
.
fix_last_token_id
(
token_id
)
seq
.
remove_last_place_holder
(
remove_count
)
group_idx
+=
1
def
no_thread_step
(
self
):
virtual_engine
=
0
...
...
vllm/zero_overhead/model_runner.py
View file @
0ecda6d1
...
...
@@ -11,7 +11,65 @@ from vllm.sequence import SequenceGroupMetadata
from
vllm.utils
import
async_tensor_h2d
,
flatten_2d_lists
from
vllm.worker.model_runner
import
ModelInputForGPU
,
ModelInputForGPUBuilder
from
vllm.zero_overhead.sampler
import
get_last_sampler
from
vllm.zero_overhead.update_input
import
UpdateInputTokens
from
vllm.zero_overhead.utils
import
SpecStepKind
,
get_accepted_token_ids
,
get_proposal_token_ids
,
get_spec_last_step
,
get_spec_step
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
_update_input_tokens
(
accepted_req_ids
,
accepted_req_ids_len
,
accepted_token_ids
,
accepted_token_len
,
chidren_req_ids
,
chidren_req_ids_len
,
input_tokens
,
input_tokens_len
,
input_positions
,
seq_lens
,
seq_lens_meta
,
seq_lens_tensor
,
slot_mapping
,
seq_start_loc
,
context_lens_tensor
,
):
chidren_req_ids_
=
tl
.
load
(
chidren_req_ids
+
tl
.
arange
(
0
,
chidren_req_ids_len
))
accepted_req_ids_
=
tl
.
load
(
accepted_req_ids
+
tl
.
arange
(
0
,
chidren_req_ids_len
))
for
seq_id_idx
in
range
(
chidren_req_ids_len
/
2
):
seq_id
=
chidren_req_ids_
[
2
*
seq_id_idx
]
for
i
in
range
(
accepted_req_ids_len
):
if
seq_id
==
accepted_req_ids_
[
i
]:
accepted_token_ids_
=
tl
.
load
(
accepted_token_ids
+
tl
.
arange
(
i
*
accepted_token_len
,
tl
.
arange
(
0
,
accepted_token_len
)))
accepted_token_counter
=
0
for
j
in
range
(
accepted_token_len
):
if
accepted_token_ids_
[
j
]
==
-
1
:
break
accepted_token_counter
+=
1
if
accepted_token_counter
==
accepted_token_len
:
tl
.
store
(
input_tokens
+
seq_id_idx
*
2
+
tl
.
arange
(
0
,
2
),
accepted_token_ids_
[
-
2
:])
else
:
tl
.
store
(
input_tokens
+
seq_id_idx
*
2
,
0
)
tl
.
store
(
input_tokens
+
seq_id_idx
*
2
+
1
,
accepted_token_ids_
[
accepted_token_counter
-
1
])
input_pos
=
tl
.
load
(
input_positions
+
seq_id_idx
*
2
+
tl
.
arange
(
0
,
2
))
input_pos
[
0
]
=
0
input_pos
[
1
]
=
input_pos
[
1
]
-
(
accepted_req_ids_len
-
accepted_token_counter
)
tl
.
store
(
input_positions
+
seq_id_idx
*
2
+
tl
.
arange
(
0
,
2
),
input_pos
)
tl
.
store
(
context_lens_tensor
+
seq_id_idx
*
2
+
tl
.
arange
(
0
,
2
),
input_pos
)
input_pos
[
0
]
=
-
1
tl
.
store
(
slot_mapping
+
seq_id_idx
*
2
+
tl
.
arange
(
0
,
2
),
input_pos
)
input_pos
[
0
]
=
1
input_pos
[
1
]
=
input_pos
[
1
]
+
1
tl
.
store
(
seq_lens
+
seq_id_idx
*
2
+
tl
.
arange
(
0
,
2
),
input_pos
)
tl
.
store
(
seq_lens_meta
+
seq_id_idx
*
2
+
tl
.
arange
(
0
,
2
),
input_pos
)
tl
.
store
(
seq_lens_tensor
+
seq_id_idx
*
2
+
tl
.
arange
(
0
,
2
),
input_pos
)
seq_lens_
=
tl
.
load
(
seq_lens
+
tl
.
arange
(
0
,
input_tokens_len
))
seq_start_loc_
=
tl
.
zero_like
(
seq_start_loc
)
for
i
in
range
(
input_tokens_len
):
seq_start_loc_
[
i
+
1
]
=
seq_start_loc_
[
i
]
+
seq_lens_
[
i
]
tl
.
store
(
seq_start_loc
+
tl
.
arange
(
0
,
input_tokens_len
+
1
),
seq_start_loc_
)
class
ZeroOverheadModelInputForGpuBuilder
(
ModelInputForGPUBuilder
):
...
...
@@ -34,22 +92,80 @@ class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
def
build
(
self
)
->
ModelInputForGPU
:
model_input
=
super
().
build
()
print
(
'###model_input'
,
model_input
)
last_sampler
=
get_last_sampler
()
spec_step
=
get_spec_step
()
last_step
=
get_spec_last_step
()
if
last_sampler
is
not
None
:
if
spec_step
==
SpecStepKind
.
KIND_DEFAULT
:
update_indices
=
[]
select_indices
=
[]
for
i
,
seq_id
in
enumerate
(
self
.
req_ids
):
for
j
,
seq_id_
in
enumerate
(
last_sampler
.
seq_ids
):
if
seq_id
==
seq_id_
:
select_indices
.
append
(
j
)
update_indices
.
append
(
i
)
break
if
len
(
select_indices
)
>
0
:
select_indices
=
async_tensor_h2d
(
select_indices
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
update_indices
=
async_tensor_h2d
(
update_indices
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
model_input
.
input_tokens
[
update_indices
]
=
last_sampler
.
sampled_token_ids_tensor
[
select_indices
,
0
]
if
spec_step
==
SpecStepKind
.
OTHER_PROPOSAL
:
if
last_step
==
SpecStepKind
.
OTHER_PROPOSAL
:
# copy last sampled token ids to input tokens directly.
update_indices
=
[
i
for
i
in
range
(
len
(
self
.
req_ids
))]
update_indices
=
async_tensor_h2d
(
update_indices
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
model_input
.
input_tokens
[
update_indices
]
=
last_sampler
.
sampled_token_ids_tensor
[
update_indices
,
0
]
if
last_step
==
SpecStepKind
.
FIRST_PROPOSAL
:
# TODO: ajust input tokens number to 1 per request.
update_indices
=
[
i
for
i
in
range
(
len
(
self
.
req_ids
))]
update_indices
=
async_tensor_h2d
(
update_indices
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
model_input
.
input_tokens
[
update_indices
]
=
last_sampler
.
sampled_token_ids_tensor
[
update_indices
,
0
]
if
spec_step
==
SpecStepKind
.
SCORE_DECODE
:
proposal_token_ids
=
get_proposal_token_ids
()
shape
=
proposal_token_ids
.
shape
batch_size
=
shape
[
0
]
proposal_len
=
shape
[
1
]
update_indices
=
[]
select_indices
=
[]
for
i
,
seq_id
in
enumerate
(
self
.
req_ids
):
for
j
,
seq_id_
in
enumerate
(
last_sampler
.
seq_ids
):
if
seq_id
==
seq_id_
:
select_indices
.
append
(
j
)
update_indices
.
append
(
i
)
break
select_indices
=
async_tensor_h2d
(
select_indices
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
for
i
in
range
(
batch_size
):
for
j
in
range
(
proposal_len
):
update_indices
.
append
(
i
*
(
proposal_len
+
1
)
+
j
+
1
)
update_indices
=
async_tensor_h2d
(
update_indices
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
if
len
(
select_indices
)
>
0
:
model_input
.
input_tokens
[
update_indices
]
=
last_sampler
.
sampled_token_ids_tensor
[
select_indices
,
0
]
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
model_input
.
input_tokens
[
update_indices
]
=
proposal_token_ids
.
view
(
-
1
)
if
spec_step
==
SpecStepKind
.
FIRST_PROPOSAL
:
if
last_step
==
SpecStepKind
.
PREFILL
:
# TODO: when last step is prefill, just update the input ids for last seqence_id onely.
pass
if
last_step
==
SpecStepKind
.
SCORE_DECODE
:
# TODO: when last step is score decode, fix input ids、seq_lens、input_positions use accepte token ids
accept_token_ids
,
accept_seq_ids
=
get_accepted_token_ids
()
chidren_req_ids
=
async_tensor_h2d
(
self
.
req_ids
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
grid
=
[
1
,
1
,
1
]
_update_input_tokens
[
grid
](
accept_seq_ids
,
accept_seq_ids
.
shape
[
0
],
accept_token_ids
,
accept_token_ids
.
shape
[
1
],
chidren_req_ids
,
chidren_req_ids
.
shape
[
0
],
model_input
.
input_tokens
,
model_input
.
input_tokens
.
shape
[
0
],
model_input
.
input_positions
,
model_input
.
seq_lens
,
model_input
.
attn_metadata
.
seq_lens_tensor
,
model_input
.
attn_metadata
.
seq_lens
,
model_input
.
attn_metadata
.
slot_mapping
,
model_input
.
attn_metadata
.
seq_start_loc
,
model_input
.
attn_metadata
.
context_lens_tensor
,
)
print
(
'###zero_model_input'
,
model_input
)
return
model_input
vllm/zero_overhead/sampler.py
View file @
0ecda6d1
...
...
@@ -359,6 +359,7 @@ def _sample_with_torch(
sampled_token_ids_tensor
[
long_sample_indices
]
=
\
multinomial_samples
[
sampling_type
].
to
(
torch
.
long
)
print
(
'###sampled_token_ids'
,
last_sampler
.
sampled_token_ids_tensor
)
# Encapsulate arguments for computing Pythonized sampler
# results, whether deferred or otherwise.
maybe_deferred_args
=
SampleResultArgsType
(
...
...
vllm/zero_overhead/sequence.py
View file @
0ecda6d1
...
...
@@ -19,6 +19,11 @@ class ZeroOverheadSequence(Sequence):
self
.
data
.
_cached_all_token_ids
[
effect_offset
]
=
token_id
self
.
effective_output_len
+=
1
def
remove_last_place_holder
(
self
,
count
):
self
.
data
.
_output_token_ids
=
self
.
data
.
_output_token_ids
[:
-
1
*
count
]
self
.
data
.
_new_appended_tokens
=
self
.
data
.
_new_appended_tokens
[:
-
1
*
count
]
self
.
data
.
_cached_all_token_ids
=
self
.
data
.
_cached_all_token_ids
[:
-
1
*
count
]
self
.
data
.
_num_computed_tokens
-=
count
def
zero_overhead_get_output_token_ids
(
self
)
->
tuple
[
int
,
...]:
return
self
.
data
.
output_token_ids
[:
self
.
effective_output_len
]
...
...
vllm/zero_overhead/spec_decode/batch_expansion.py
View file @
0ecda6d1
...
...
@@ -15,6 +15,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.util
import
nvtx_range
,
split_batch_by_proposal_len
from
vllm.utils
import
async_tensor_h2d
from
vllm.zero_overhead.utils
import
get_proposal_lens_list
,
record_proposal_token_ids
SeqId
=
int
TargetSeqId
=
int
...
...
@@ -48,8 +49,9 @@ class ZeroOverheadBatchExpansionTop1Scorer(BatchExpansionTop1Scorer):
which sequences were ignored during scoring.
"""
proposal_lens_list
=
np
.
zeros
(
proposals
.
proposal_lens
.
shape
,
dtype
=
int
).
tolist
()
#zero_overhead todo fix
proposal_token_ids_list
=
np
.
zeros
(
proposals
.
proposal_token_ids
.
shape
,
dtype
=
int
).
tolist
()
proposal_lens_list
=
get_proposal_lens_list
()
record_proposal_token_ids
(
proposals
.
proposal_token_ids
)
proposal_token_ids_list
=
np
.
zeros
(
proposals
.
proposal_token_ids
.
shape
,
dtype
=
int
).
tolist
()
# place holder tokens
# Filter the list to ignore invalid proposals.
proposal_token_ids_list_without_skips
=
[
...
...
@@ -64,14 +66,11 @@ class ZeroOverheadBatchExpansionTop1Scorer(BatchExpansionTop1Scorer):
proposal_lens_list
=
proposal_lens_list
,
)
#print('###execute_model_req', execute_model_req)
#print('###target_seq_group_metadata_list', target_seq_group_metadata_list)
target_sampler_output
=
self
.
_scorer_worker
.
execute_model
(
execute_model_req
=
execute_model_req
.
clone
(
seq_group_metadata_list
=
target_seq_group_metadata_list
))
assert
len
(
target_sampler_output
)
==
1
,
"expected single-step output"
target_sampler_output
=
target_sampler_output
[
0
]
#print('###target_sampler_output', target_sampler_output)
if
not
non_spec_indices
:
# All sequence groups in batch have spec decoding enabled
return
self
.
_contract_batch_all_spec
(
...
...
vllm/zero_overhead/spec_decode/muti_step_worker.py
View file @
0ecda6d1
...
...
@@ -11,6 +11,7 @@ from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.utils
import
async_tensor_h2d
from
vllm.zero_overhead.spec_decode.top1_proproser
import
ZeroOverheadTop1Proposer
from
vllm.zero_overhead.utils
import
SpecStepKind
,
set_spec_step
if
current_platform
.
is_cuda_alike
():
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
...
...
@@ -45,7 +46,6 @@ class ZeroOverheadMultiStepWorker(MultiStepWorker):
For multi step worker, this indicator shall be True.
"""
print
(
'###execute_model_req'
,
execute_model_req
)
self
.
_raise_if_unsupported
(
execute_model_req
)
# Expand the batch for sequences with a bonus token.
# Perform a forward pass on the expanded batch and filter the
...
...
@@ -53,7 +53,6 @@ class ZeroOverheadMultiStepWorker(MultiStepWorker):
expanded_request
,
indices_of_seq_with_bonus_tokens
=
\
self
.
_expand_execute_model_request
(
execute_model_req
,
seq_ids_with_bonus_token_in_last_step
)
# Run model sample_len times.
model_outputs
:
List
[
SamplerOutput
]
=
[]
if
current_platform
.
is_cuda_alike
()
and
isinstance
(
...
...
@@ -72,20 +71,20 @@ class ZeroOverheadMultiStepWorker(MultiStepWorker):
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
set_spec_step
(
SpecStepKind
.
FIRST_PROPOSAL
)
for
_
in
range
(
sample_len
):
print
(
'###self.worker.execute_model'
,
sample_len
)
print
(
'###expanded_request'
,
expanded_request
)
model_output
:
List
[
SamplerOutput
]
=
self
.
worker
.
execute_model
(
execute_model_req
=
expanded_request
)
assert
(
len
(
model_output
)
==
1
),
"composing multistep workers not supported"
model_output
=
model_output
[
0
]
print
(
'###model_output'
,
model_output
)
set_spec_step
(
SpecStepKind
.
OTHER_PROPOSAL
)
self
.
_append_new_tokens
(
model_output
,
expanded_request
.
seq_group_metadata_list
,
indices_of_seq_with_bonus_tokens
)
model_outputs
.
append
(
model_output
)
set_spec_step
(
SpecStepKind
.
SCORE_DECODE
)
filtered_model_outputs
=
self
.
_filter_model_output_zero_overhead
(
model_outputs
,
indices_of_seq_with_bonus_tokens
)
...
...
vllm/zero_overhead/spec_decode/spec_decode_worker.py
View file @
0ecda6d1
...
...
@@ -27,8 +27,9 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
get_all_seq_ids_and_request_ids
,
Logits
)
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTreeStyleScorer
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.spec_decode_worker
import
SpecDecodeWorker
from
vllm.spec_decode.spec_decode_worker
import
SpecDecodeWorker
,
prepare_prefill_hidden_states
from
vllm.zero_overhead.spec_decode.batch_expansion
import
ZeroOverheadBatchExpansionTop1Scorer
from
vllm.zero_overhead.utils
import
SpecStepKind
,
record_accepted_token_ids
,
set_spec_step
if
current_platform
.
is_cuda_alike
():
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
...
...
@@ -49,7 +50,7 @@ from vllm.spec_decode.util import (Timer, create_logprobs_output,
get_all_num_logprobs
,
get_sampled_token_logprobs
,
nvtx_range
,
split_batch_by_proposal_len
)
from
vllm.utils
import
resolve_obj_by_qualname
from
vllm.utils
import
async_tensor_h2d
,
resolve_obj_by_qualname
from
vllm.worker.worker_base
import
LoRANotSupportedWorkerBase
,
WorkerBase
from
vllm.worker.cache_engine
import
CacheEngine
...
...
@@ -113,6 +114,90 @@ class ZeroOverheadSpecDecodeWorker(SpecDecodeWorker):
self
.
_configure_model_sampler_for_spec_decode
()
@
nvtx_range
(
"spec_decode_worker._run_no_spec"
)
def
_run_no_spec
(
self
,
execute_model_req
:
ExecuteModelRequest
,
skip_proposer
:
bool
)
->
List
[
SamplerOutput
]:
"""Run a single generation step without any speculation. The input is
sent to the proposer and scorer model so that the KV cache is consistent
between the two. When skip_proposer is True, the proposer model is
not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding.
"""
if
self
.
tree_decoding
and
self
.
kvcache_slot_to_be_moved
is
not
None
:
execute_model_req
.
kvcache_slot_to_be_moved
=
self
.
kvcache_slot_to_be_moved
self
.
kvcache_slot_to_be_moved
=
None
set_spec_step
(
SpecStepKind
.
PREFILL
)
sampler_output
=
self
.
scorer_worker
.
execute_model
(
execute_model_req
)
assert
len
(
sampler_output
)
==
1
sampler_output
=
sampler_output
[
0
]
# Store hidden states from target model execution, BxD.
hidden_states
=
sampler_output
.
hidden_states
if
hidden_states
is
not
None
:
# Only decodes and prefill terminal chunks need a hidden state.
seq_group_meta_with_hidden
=
[
sg
for
sg
in
execute_model_req
.
seq_group_metadata_list
if
sg
.
do_sample
]
if
any
(
seq
.
is_prompt
for
seq
in
seq_group_meta_with_hidden
):
# Drop hidden_states with no prediction (eg non-terminal chunks)
hidden_states
=
hidden_states
[
torch
.
where
(
sampler_output
.
sampled_token_ids
-
VLLM_INVALID_TOKEN_ID
)[
0
]]
# if not skip_proposer:
# if self.previous_hidden_states is None and len(
# seq_group_meta_with_hidden):
# self.previous_hidden_states = HiddenStates(
# hidden_states, seq_group_meta_with_hidden)
# elif self.previous_hidden_states and len(
# seq_group_meta_with_hidden):
# self.previous_hidden_states.update(hidden_states,
# seq_group_meta_with_hidden)
if
self
.
previous_hidden_states
is
None
and
len
(
seq_group_meta_with_hidden
):
self
.
previous_hidden_states
=
HiddenStates
(
hidden_states
,
seq_group_meta_with_hidden
)
elif
self
.
previous_hidden_states
and
len
(
seq_group_meta_with_hidden
):
self
.
previous_hidden_states
.
update
(
hidden_states
,
seq_group_meta_with_hidden
)
# Store logits from target model execution.
if
self
.
tree_decoding
:
logits
=
sampler_output
.
logits
if
logits
is
not
None
:
if
self
.
previous_logits
is
None
:
self
.
previous_logits
=
Logits
(
logits
,
execute_model_req
.
seq_group_metadata_list
)
else
:
self
.
previous_logits
.
update
(
logits
,
execute_model_req
.
seq_group_metadata_list
)
if
not
skip_proposer
:
# We prepare the prefill hidden states here so that there no
# additional complexity in worker for spec_decode vs non_spec_decode
# flow and execute_model doesn't need additional modifications.
execute_model_req
.
previous_hidden_states
=
\
prepare_prefill_hidden_states
(
sampler_output
.
prefill_hidden_states
)
for
i
in
range
(
self
.
_num_spec_prefill_steps
):
execute_model_req
.
spec_step_idx
=
i
self
.
proposer_worker
.
execute_model
(
execute_model_req
)
sampler_output_to_return
=
(
self
.
_serialize_sampler_output_no_logprobs
(
execute_model_req
=
execute_model_req
,
sampler_output
=
sampler_output
)
if
self
.
_disable_logprobs
else
[
sampler_output
])
# Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers.
sampler_output
.
sampled_token_probs
=
None
sampler_output
.
sampled_token_ids
=
None
sampler_output
.
logprobs
=
None
return
sampler_output_to_return
@
nvtx_range
(
"spec_decode_worker._verify_tokens"
)
def
_verify_tokens
(
self
,
...
...
@@ -338,7 +423,8 @@ class ZeroOverheadSpecDecodeWorker(SpecDecodeWorker):
num_logprobs_per_seq
=
get_all_num_logprobs
(
seq_group_metadata_list
)
# Serialize tensor to CPU Python list.
#print('###accepted_token_ids_by_step', accepted_token_ids_by_step)
#accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
record_accepted_token_ids
(
accepted_token_ids
,
seq_ids
)
# Construct the output on a per-step, per-sequence basis.
# Non-terminal prefill chunks will end up here as rows with just -1s
...
...
@@ -428,8 +514,7 @@ class ZeroOverheadSpecDecodeWorker(SpecDecodeWorker):
num_logprobs
=
num_logprobs_per_seq
[
sequence_index
]
step_output_token_ids
.
append
(
create_sequence_group_output
(
token_id
=
accepted_token_ids_by_step
[
step_index
]
[
sequence_index
],
token_id
=
0
,
token_id_logprob_rank
=
accepted_token_id_ranks_by_step
[
step_index
][
sequence_index
],
token_id_logprob
=
accepted_token_id_logprobs_by_step
[
...
...
@@ -460,9 +545,8 @@ class ZeroOverheadSpecDecodeWorker(SpecDecodeWorker):
self
.
_maybe_log_stage_times
(
*
stage_times
)
# First `n_prefills` entries will contain prefills SamplerOutput when
# chunked prefill is enabled, the rest is decodes in multi-step format.
print
(
'###sampler_output_list'
,
sampler_output_list
)
return
sampler_output_list
def
_track_sequences_with_bonus_tokens
(
self
,
seq_ids
:
List
[
int
],
...
...
vllm/zero_overhead/spec_decode/top1_proproser.py
View file @
0ecda6d1
...
...
@@ -11,6 +11,7 @@ from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.spec_decode.util
import
sampler_output_to_torch
from
vllm.utils
import
async_tensor_h2d
from
vllm.zero_overhead.utils
import
record_proposal_lens_list
class
ZeroOverheadTop1Proposer
(
Top1Proposer
):
...
...
@@ -48,13 +49,14 @@ class ZeroOverheadTop1Proposer(Top1Proposer):
proposal_tokens
,
proposal_probs
,
*
_
=
sampler_output_to_torch
(
sampler_output
,
sampler_transposed
)
proposal_lens_list
=
[
0
for
i
in
range
(
batch_size
)]
for
indices
in
nonzero_proposal_len_indices
:
proposal_lens_list
[
indices
]
=
proposal_len
record_proposal_lens_list
(
proposal_lens_list
)
nonzero_proposal_len_indices
=
async_tensor_h2d
(
nonzero_proposal_len_indices
,
torch
.
int32
,
self
.
_device
,
True
)
proposal_len
=
[
proposal_len
for
i
in
range
(
batch_size
)]
proposal_len
=
async_tensor_h2d
(
proposal_len
,
torch
.
long
,
self
.
_device
,
True
)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
...
...
@@ -74,10 +76,9 @@ class ZeroOverheadTop1Proposer(Top1Proposer):
entire_proposal_tokens
,
entire_proposal_probs
,
)
proposal_lens_tensor
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
long
,
device
=
self
.
_device
)
proposal_lens_tensor
[
nonzero_proposal_len_indices
]
=
proposal_len
proposal_lens_tensor
=
async_tensor_h2d
(
proposal_lens_list
,
torch
.
long
,
self
.
_device
,
True
)
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
\ No newline at end of file
vllm/zero_overhead/utils.py
View file @
0ecda6d1
from
enum
import
Enum
import
os
zero_overhead
=
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
...
...
@@ -9,4 +10,55 @@ def is_zero_overhead():
return
zero_overhead
def
is_zero_no_thread
():
return
zero_no_thread
and
zero_overhead
\ No newline at end of file
return
zero_no_thread
and
zero_overhead
class
SpecStepKind
(
Enum
):
KIND_DEFAULT
=
0
PREFILL
=
1
FIRST_PROPOSAL
=
2
OTHER_PROPOSAL
=
3
SCORE_DECODE
=
4
class
ZeroOverheadSpecContext
():
def
__init__
(
self
):
self
.
step_kind
=
SpecStepKind
.
KIND_DEFAULT
self
.
last_step
=
SpecStepKind
.
KIND_DEFAULT
self
.
proposal_lens_list
=
None
self
.
proposal_token_ids
=
None
self
.
accepted_token_ids
=
None
self
.
accepted_seq_ids
=
None
spec_context
=
ZeroOverheadSpecContext
()
def
set_spec_step
(
_step
):
global
spec_context
spec_context
.
last_step
=
spec_context
.
step_kind
spec_context
.
step_kind
=
_step
def
get_spec_step
():
return
spec_context
.
step_kind
def
get_spec_last_step
():
return
spec_context
.
last_step
def
record_proposal_lens_list
(
list
):
global
spec_context
spec_context
.
proposal_lens_list
=
list
def
get_proposal_lens_list
():
return
spec_context
.
proposal_lens_list
def
record_proposal_token_ids
(
tensor
):
global
spec_context
spec_context
.
proposal_token_ids
=
tensor
def
get_proposal_token_ids
():
return
spec_context
.
proposal_token_ids
def
record_accepted_token_ids
(
tensor
,
seq_ids
):
global
spec_context
spec_context
.
accepted_token_ids
=
tensor
spec_context
.
accepted_seq_ids
=
seq_ids
def
get_accepted_token_ids
():
return
spec_context
.
accepted_token_ids
,
spec_context
.
accepted_seq_ids
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment