Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6d646d08
Unverified
Commit
6d646d08
authored
Sep 03, 2024
by
Alexander Matveev
Committed by
GitHub
Sep 03, 2024
Browse files
[Core] Optimize Async + Multi-step (#8050)
parent
95a178f8
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
326 additions
and
248 deletions
+326
-248
tests/multi_step/test_correctness_async_llm.py
tests/multi_step/test_correctness_async_llm.py
+2
-2
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+54
-55
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+97
-125
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+36
-26
vllm/sequence.py
vllm/sequence.py
+1
-3
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+3
-1
vllm/worker/multi_step_model_runner.py
vllm/worker/multi_step_model_runner.py
+132
-33
vllm/worker/multi_step_worker.py
vllm/worker/multi_step_worker.py
+1
-3
No files found.
tests/multi_step/test_correctness_async_llm.py
View file @
6d646d08
...
...
@@ -103,13 +103,13 @@ async def test_multi_step(
model
,
server_args
+
distributed_args
,
num_logprobs
,
max_wait_seconds
=
3
*
240
)
max_wait_seconds
=
5
*
240
)
test_completions
=
await
completions_with_server_args
(
prompts
,
model
,
ms_server_args
+
distributed_args
,
num_logprobs
,
max_wait_seconds
=
3
*
240
)
max_wait_seconds
=
5
*
240
)
# Assert multi-step scheduling produces identical tokens
# to single-step scheduling.
...
...
vllm/engine/async_llm_engine.py
View file @
6d646d08
...
...
@@ -280,40 +280,27 @@ class _AsyncLLMEngine(LLMEngine):
scheduler_outputs
=
cached_outputs
.
scheduler_outputs
allow_async_output_proc
=
cached_outputs
.
allow_async_output_proc
# Detect async + multi-step
use_async_and_multi_step
=
(
self
.
scheduler_config
.
is_multi_step
and
allow_async_output_proc
)
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
# Clear outputs for each new scheduler iteration
ctx
.
request_outputs
.
clear
()
# skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
# Clear outputs on scheduler iteration start
ctx
.
request_outputs
.
clear
()
# Schedule iteration
(
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
# Detect async + multi-step
use_async_and_multi_step
=
(
self
.
scheduler_config
.
is_multi_step
and
allow_async_output_proc
)
ctx
.
seq_group_metadata_list
=
seq_group_metadata_list
ctx
.
scheduler_outputs
=
scheduler_outputs
# Maybe switch from async mode to sync mode
if
not
allow_async_output_proc
and
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
# For async + multi-step, init the queue
if
use_async_and_multi_step
:
assert
len
(
ctx
.
output_queue
)
==
0
assert
seq_group_metadata_list
is
not
None
ctx
.
output_queue
.
append
(
(
None
,
seq_group_metadata_list
,
scheduler_outputs
))
self
.
_process_model_outputs
(
ctx
=
ctx
)
if
(
self
.
scheduler_config
.
is_multi_step
and
scheduler_outputs
.
num_lookahead_slots
>
0
):
...
...
@@ -351,26 +338,20 @@ class _AsyncLLMEngine(LLMEngine):
last_sampled_token_ids
=
last_sampled_token_ids
)
if
allow_async_output_proc
:
async_callback
=
self
.
async_callback_multi_step
[
virtual_engine
]
if
use_async_and_multi_step
\
else
self
.
async_callback
[
virtual_engine
]
execute_model_req
.
async_callback
=
async_callback
execute_model_req
.
use_async_and_multi_step
=
\
use_async_and_multi_step
execute_model_req
.
async_callback
=
self
.
async_callbacks
[
virtual_engine
]
# Execute the model.
output
=
await
self
.
model_executor
.
execute_model_async
(
execute_model_req
)
# we need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if
self
.
scheduler_config
.
is_multi_step
:
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
)
else
:
if
not
use_async_and_multi_step
and
len
(
ctx
.
output_queue
)
>
0
:
assert
not
self
.
scheduler_config
.
is_multi_step
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
if
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
output
=
[]
# Finish the current step for all the sequence groups.
...
...
@@ -384,24 +365,22 @@ class _AsyncLLMEngine(LLMEngine):
self
.
cached_scheduler_outputs
[
virtual_engine
]
=
SchedulerOutputState
()
if
use_async_and_multi_step
:
# For async + multi-step, clear the queue
ctx
.
output_queue
.
clear
()
else
:
ctx
.
output_queue
.
append
(
(
output
,
seq_group_metadata_list
,
scheduler_outputs
))
is_async
=
allow_async_output_proc
is_last_step
=
True
ctx
.
output_queue
.
append
(
(
output
,
seq_group_metadata_list
,
scheduler_outputs
,
is_async
,
is_last_step
))
if
output
and
allow_async_output_proc
:
assert
len
(
output
)
==
1
,
"
Multi step decoding does not work with async output processing."
# noqa: E501
self
.
_advance_to_next_step
(
output
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
if
output
and
allow_async_output_proc
:
assert
len
(
output
)
==
1
,
"
Async postprocessor expects only a single output set"
self
.
_advance_to_next_step
(
output
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
if
not
allow_async_output_proc
:
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
False
)
self
.
_process_model_outputs
(
ctx
=
ctx
)
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
output
)
...
...
@@ -411,17 +390,12 @@ class _AsyncLLMEngine(LLMEngine):
else
:
# Multi-step case
if
use_async_and_multi_step
:
return
[]
else
:
ctx
.
request_outputs
=
[]
return
ctx
.
request_outputs
if
not
self
.
has_unfinished_requests
():
# Drain async postprocessor (if exists)
if
len
(
ctx
.
output_queue
)
>
0
:
assert
not
self
.
scheduler_config
.
is_multi_step
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
self
.
_process_model_outputs
(
ctx
=
ctx
)
assert
len
(
ctx
.
output_queue
)
==
0
return
ctx
.
request_outputs
...
...
@@ -640,6 +614,17 @@ class AsyncLLMEngine:
self
.
log_requests
=
log_requests
self
.
engine
=
self
.
_init_engine
(
*
args
,
**
kwargs
)
# This ensures quick processing of request outputs
# so the append to asyncio queues is not delayed,
# especially for multi-step.
#
# TODO: Currently, disabled for engine_use_ray, ask
# Cody/Will/Woosuk about this case.
self
.
use_process_request_outputs_callback
=
not
self
.
engine_use_ray
if
self
.
use_process_request_outputs_callback
:
self
.
engine
.
process_request_outputs_callback
=
\
self
.
process_request_outputs
if
self
.
engine_use_ray
:
print_warning_once
(
"DEPRECATED. `--engine-use-ray` is deprecated and will "
...
...
@@ -883,13 +868,27 @@ class AsyncLLMEngine:
request_outputs
=
await
self
.
engine
.
step_async
(
virtual_engine
)
# Put the outputs into the corresponding streams.
finished
=
True
# If used as a callback, then already invoked inside
# LLMEngine's _process_model_outputs
if
not
self
.
use_process_request_outputs_callback
:
all_finished
=
self
.
process_request_outputs
(
request_outputs
)
else
:
# For callback case, we only need to detect when all
# requests are finished
all_finished
=
all
(
request_output
.
finished
for
request_output
in
request_outputs
)
return
not
all_finished
def
process_request_outputs
(
self
,
request_outputs
)
->
bool
:
# Put the outputs into the corresponding streams.
all_finished
=
True
for
request_output
in
request_outputs
:
self
.
_request_tracker
.
process_request_output
(
request_output
,
verbose
=
self
.
log_requests
)
finished
=
finished
and
request_output
.
finished
all_
finished
=
all_
finished
and
request_output
.
finished
return
not
finished
return
all_
finished
async
def
_engine_abort
(
self
,
request_ids
:
Iterable
[
str
]):
if
self
.
engine_use_ray
:
...
...
vllm/engine/llm_engine.py
View file @
6d646d08
...
...
@@ -93,13 +93,14 @@ class SchedulerOutputState:
@
dataclass
class
SchedulerContext
:
output_queue
:
Deque
[
Tuple
[
Optional
[
List
[
SamplerOutput
]],
List
[
SequenceGroupMetadata
],
SchedulerOutputs
]]
=
field
(
default_factory
=
lambda
:
deque
())
List
[
SequenceGroupMetadata
],
SchedulerOutputs
,
bool
,
bool
]]
=
field
(
default_factory
=
lambda
:
deque
())
request_outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
field
(
default_factory
=
lambda
:
[])
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
class
LLMEngine
:
...
...
@@ -357,6 +358,26 @@ class LLMEngine:
# different process.
self
.
tokenizer
.
ping
()
self
.
cached_scheduler_outputs
=
[
SchedulerOutputState
()
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
self
.
scheduler_contexts
=
[
SchedulerContext
()
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
self
.
async_callbacks
=
[
functools
.
partial
(
self
.
_process_model_outputs
,
ctx
=
self
.
scheduler_contexts
[
v_id
])
for
v_id
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
# Currently used by AsyncLLMEngine to ensure quick append
# of request outputs to asyncio queues
self
.
process_request_outputs_callback
=
None
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
...
...
@@ -364,9 +385,7 @@ class LLMEngine:
Scheduler
(
scheduler_config
,
cache_config
,
lora_config
,
parallel_config
.
pipeline_parallel_size
,
functools
.
partial
(
self
.
_process_model_outputs
,
virtual_engine
=
v_id
,
is_async
=
True
)
self
.
async_callbacks
[
v_id
]
if
model_config
.
use_async_output_proc
else
None
)
for
v_id
in
range
(
parallel_config
.
pipeline_parallel_size
)
]
...
...
@@ -417,30 +436,6 @@ class LLMEngine:
),
))
self
.
cached_scheduler_outputs
=
[
SchedulerOutputState
()
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
self
.
scheduler_contexts
=
[
SchedulerContext
()
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
self
.
async_callback
=
[
functools
.
partial
(
self
.
_process_model_outputs
,
virtual_engine
=
v_id
,
is_async
=
True
)
for
v_id
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
self
.
async_callback_multi_step
=
[
functools
.
partial
(
self
.
_process_model_outputs
,
virtual_engine
=
v_id
,
is_async
=
False
)
for
v_id
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
def
_initialize_kv_caches
(
self
)
->
None
:
"""Initialize the KV cache in the worker(s).
...
...
@@ -1249,11 +1244,7 @@ class LLMEngine:
return
def
_process_model_outputs
(
self
,
virtual_engine
:
int
,
is_async
:
bool
,
sampler_output
:
Optional
[
SamplerOutput
]
=
None
,
is_last_output
:
bool
=
False
)
->
None
:
def
_process_model_outputs
(
self
,
ctx
:
SchedulerContext
)
->
None
:
"""Apply the model output to the sequences in the scheduled seq groups.
virtual_engine: The engine id to operate on
...
...
@@ -1273,24 +1264,12 @@ class LLMEngine:
"""
now
=
time
.
time
()
is_multi_step
=
sampler_output
is
not
None
ctx
:
SchedulerContext
=
self
.
scheduler_contexts
[
virtual_engine
]
if
len
(
ctx
.
output_queue
)
==
0
:
return
None
if
is_multi_step
:
# Async + multi-step case
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
)
=
ctx
.
output_queue
[
0
]
assert
outputs
is
None
outputs
=
[
sampler_output
]
else
:
# Async standard case
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
)
=
ctx
.
output_queue
.
popleft
()
# Get pending async postprocessor
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
,
is_async
,
is_last_step
)
=
ctx
.
output_queue
.
popleft
()
assert
outputs
is
not
None
# Sanity check
...
...
@@ -1306,6 +1285,7 @@ class LLMEngine:
outputs_by_sequence_group
=
outputs
finished_before
:
List
[
int
]
=
[]
finished_now
:
List
[
int
]
=
[]
for
i
,
seq_group_meta
in
enumerate
(
seq_group_metadata_list
):
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
...
...
@@ -1343,26 +1323,44 @@ class LLMEngine:
if
self
.
model_config
.
embedding_mode
:
self
.
_process_sequence_group_outputs
(
seq_group
,
output
)
continue
else
:
self
.
output_processor
.
process_prompt_logprob
(
seq_group
,
output
)
if
seq_group_meta
.
do_sample
:
self
.
output_processor
.
process_outputs
(
seq_group
,
output
,
is_async
)
self
.
output_processor
.
process_prompt_logprob
(
seq_group
,
output
)
if
seq_group_meta
.
do_sample
:
self
.
output_processor
.
process_outputs
(
seq_group
,
output
,
is_async
)
if
seq_group
.
is_finished
():
finished_now
.
append
(
i
)
# For async + multi-step, free finished seqs and create outputs
# only on the final step.
if
is_multi_step
and
not
is_last_output
:
return
# Generate outputs for the requests that finished this iteration
for
i
in
finished_now
:
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
for
scheduler
in
self
.
scheduler
:
scheduler
.
free_finished_seq_groups
()
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
ctx
.
request_outputs
.
append
(
request_output
)
# Create the outputs.
for
i
,
_
in
enumerate
(
seq_group_metadata_list
):
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
# Free currently finished requests
if
finished_now
:
for
scheduler
in
self
.
scheduler
:
scheduler
.
free_finished_seq_groups
()
# For multi-step, do not create outputs each iteration
if
not
is_last_step
:
# Immediately process request outputs here (if callback is given)
if
(
finished_now
and
self
.
process_request_outputs_callback
is
not
None
):
self
.
process_request_outputs_callback
(
ctx
.
request_outputs
)
return
# Create the outputs
# Note: scheduled_seq_groups and seq_group_metadata_list
# must match with the indices
for
i
,
scheduled_seq_group
in
enumerate
(
scheduler_outputs
.
scheduled_seq_groups
):
if
not
is_multi_step
and
i
in
finished_
before
:
if
i
in
finished_before
or
i
in
finished_
now
:
continue
# Avoids double processing
seq_group
=
scheduled_seq_group
.
seq_group
...
...
@@ -1376,11 +1374,15 @@ class LLMEngine:
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
ctx
.
request_outputs
.
append
(
request_output
)
# For async + multi-step, do stats only on the last output.
# Otherwise, do stats if the execution is async
do_stats
=
is_multi_step
or
is_async
# Immediately process request outputs here (if callback is given)
if
(
ctx
.
request_outputs
and
self
.
process_request_outputs_callback
is
not
None
):
self
.
process_request_outputs_callback
(
ctx
.
request_outputs
)
if
do_stats
:
# For async case, we need to record the stats here.
# For non-async case, the stats are done in the
# LLMEngine/AsyncLLMEngine directly
if
is_async
:
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
outputs
,
finished_before
)
...
...
@@ -1485,40 +1487,26 @@ class LLMEngine:
scheduler_outputs
=
cached_outputs
.
scheduler_outputs
allow_async_output_proc
=
cached_outputs
.
allow_async_output_proc
# Detect async + multi-step
use_async_and_multi_step
=
(
self
.
scheduler_config
.
is_multi_step
and
allow_async_output_proc
)
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
# Clear outputs for each new scheduler iteration
ctx
.
request_outputs
.
clear
()
# Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
# Clear outputs on scheduler iteration start
ctx
.
request_outputs
.
clear
()
# Schedule iteration
(
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
# Detect async + multi-step
use_async_and_multi_step
=
(
self
.
scheduler_config
.
is_multi_step
and
allow_async_output_proc
)
ctx
.
seq_group_metadata_list
=
seq_group_metadata_list
ctx
.
scheduler_outputs
=
scheduler_outputs
# Maybe switch from async mode to sync mode
if
not
allow_async_output_proc
and
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
# For async + multi-step, init the queue
if
use_async_and_multi_step
:
assert
len
(
ctx
.
output_queue
)
==
0
assert
seq_group_metadata_list
is
not
None
ctx
.
output_queue
.
append
(
(
None
,
seq_group_metadata_list
,
scheduler_outputs
))
self
.
_process_model_outputs
(
ctx
=
ctx
)
if
(
self
.
scheduler_config
.
is_multi_step
and
scheduler_outputs
.
num_lookahead_slots
>
0
):
...
...
@@ -1555,13 +1543,8 @@ class LLMEngine:
last_sampled_token_ids
=
last_sampled_token_ids
)
if
allow_async_output_proc
:
async_callback
=
self
.
async_callback_multi_step
[
virtual_engine
]
if
use_async_and_multi_step
\
else
self
.
async_callback
[
virtual_engine
]
execute_model_req
.
async_callback
=
async_callback
execute_model_req
.
use_async_and_multi_step
=
\
use_async_and_multi_step
execute_model_req
.
async_callback
=
self
.
async_callbacks
[
virtual_engine
]
output
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
...
...
@@ -1573,10 +1556,8 @@ class LLMEngine:
else
:
# Nothing scheduled => If there is pending async postprocessor,
# then finish it here.
if
not
use_async_and_multi_step
and
len
(
ctx
.
output_queue
)
>
0
:
assert
not
self
.
scheduler_config
.
is_multi_step
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
if
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
# No outputs in this case
output
=
[]
...
...
@@ -1590,28 +1571,24 @@ class LLMEngine:
if
self
.
scheduler_config
.
is_multi_step
:
self
.
cached_scheduler_outputs
[
0
]
=
SchedulerOutputState
()
if
use_async_and_multi_step
:
# For async + multi-step, clear the queue
ctx
.
output_queue
.
clear
()
else
:
# Add results to the output_queue
# (for async or non-async postprocessing)
ctx
.
output_queue
.
append
(
(
output
,
seq_group_metadata_list
,
scheduler_outputs
))
# Add results to the output_queue
is_async
=
allow_async_output_proc
is_last_step
=
True
ctx
.
output_queue
.
append
(
(
output
,
seq_group_metadata_list
,
scheduler_outputs
,
is_async
,
is_last_step
))
if
output
and
allow_async_output_proc
:
assert
len
(
output
)
==
1
,
(
"Multi step decoding does not work "
"with async output processing."
)
if
output
and
allow_async_output_proc
:
assert
len
(
output
)
==
1
,
(
"Async postprocessor expects only a single output set"
)
self
.
_advance_to_next_step
(
output
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
self
.
_advance_to_next_step
(
output
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
# Check if need to run the usual non-async path
if
not
allow_async_output_proc
:
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
False
)
self
.
_process_model_outputs
(
ctx
=
ctx
)
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
output
)
...
...
@@ -1620,17 +1597,12 @@ class LLMEngine:
self
.
do_tracing
(
scheduler_outputs
)
else
:
# Multi-step case
if
use_async_and_multi_step
:
return
[]
else
:
ctx
.
request_outputs
=
[]
return
ctx
.
request_outputs
if
not
self
.
has_unfinished_requests
():
# Drain async postprocessor (if exists)
if
len
(
ctx
.
output_queue
)
>
0
:
assert
not
self
.
scheduler_config
.
is_multi_step
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
self
.
_process_model_outputs
(
ctx
=
ctx
)
assert
len
(
ctx
.
output_queue
)
==
0
# Stop the execute model loop in parallel workers until there are
...
...
vllm/engine/output_processor/multi_step.py
View file @
6d646d08
...
...
@@ -85,9 +85,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
no tokens need to be appended since it is already done
externally (before the next schedule() call)
"""
# TODO: Add support for async if necessary
assert
not
is_async
# Sequences can be in RUNNING or FINISHED_ABORTED state
# once scheduled, as a sequence is moved to FINSIHED_ABORTED
# if a client disconnects from the api server.
...
...
@@ -101,19 +98,41 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
"Beam search not supported in multi-step decoding."
)
seq
=
seqs
[
0
]
# Since there's only one sequence per sequence group, we can take the
# first sample.
samples
=
[
output
.
samples
[
0
]
for
output
in
outputs
]
# -1 means the output token is not valid (eg. due to spec decode
# rejecting tokens).
valid_samples
=
[
sample
for
sample
in
samples
if
sample
.
output_token
!=
-
1
]
assert
valid_samples
self
.
_process_seq_outputs
(
seq
,
valid_samples
,
sequence_group
.
sampling_params
)
if
is_async
:
# Async case: We process tokens one by one. Here, we know the token
# was already appended, so we only need to do the rest of the
# postprocessor: Detokenization + stopping logic
self
.
_process_decode_and_stop
(
seq
,
sequence_group
.
sampling_params
)
else
:
# Standard multi-step case
# Since there's only one sequence per sequence group,
# we can take the first sample.
samples
=
[
output
.
samples
[
0
]
for
output
in
outputs
]
# -1 means the output token is not valid (eg. due to spec decode
# rejecting tokens).
valid_samples
=
[
sample
for
sample
in
samples
if
sample
.
output_token
!=
-
1
]
assert
valid_samples
self
.
_process_seq_outputs
(
seq
,
valid_samples
,
sequence_group
.
sampling_params
)
def
_process_decode_and_stop
(
self
,
seq
:
Sequence
,
sampling_params
:
SamplingParams
)
->
None
:
new_char_count
=
0
if
sampling_params
.
detokenize
:
new_char_count
=
self
.
detokenizer
.
decode_sequence_inplace
(
seq
,
sampling_params
)
# TODO(sang): Support lora.
self
.
stop_checker
.
maybe_stop_sequence
(
seq
,
new_char_count
=
new_char_count
,
sampling_params
=
sampling_params
,
)
def
_process_seq_outputs
(
self
,
seq
:
Sequence
,
valid_samples
:
List
[
SequenceOutput
],
...
...
@@ -151,16 +170,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
logprobs
=
output_logprob
,
)
new_char_count
=
0
if
sampling_params
.
detokenize
:
new_char_count
=
self
.
detokenizer
.
decode_sequence_inplace
(
seq
,
sampling_params
)
self
.
_process_decode_and_stop
(
seq
,
sampling_params
)
# TODO(sang): Support lora.
self
.
stop_checker
.
maybe_stop_sequence
(
seq
,
new_char_count
=
new_char_count
,
sampling_params
=
sampling_params
,
)
if
seq
.
is_finished
():
break
vllm/sequence.py
View file @
6d646d08
...
...
@@ -1225,7 +1225,6 @@ class ExecuteModelRequest(
last_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
# Async callback
async_callback
:
Optional
[
Callable
]
=
None
use_async_and_multi_step
:
bool
=
False
@
property
def
is_first_multi_step
(
self
)
->
bool
:
...
...
@@ -1272,5 +1271,4 @@ class ExecuteModelRequest(
finished_requests_ids
=
self
.
finished_requests_ids
,
last_sampled_token_ids
=
self
.
last_sampled_token_ids
.
clone
()
if
self
.
last_sampled_token_ids
is
not
None
else
None
,
async_callback
=
self
.
async_callback
,
use_async_and_multi_step
=
self
.
use_async_and_multi_step
)
async_callback
=
self
.
async_callback
)
vllm/worker/model_runner.py
View file @
6d646d08
...
...
@@ -21,6 +21,7 @@ from vllm.attention.backends.utils import CommonAttentionState
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
)
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.distributed
import
get_pp_group
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
...
...
@@ -96,7 +97,8 @@ class ModelInputForGPU(ModelRunnerInputBase):
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
virtual_engine
:
int
=
0
async_callback
:
Optional
[
Callable
]
=
None
use_async_and_multi_step
:
bool
=
False
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
...
...
vllm/worker/multi_step_model_runner.py
View file @
6d646d08
...
...
@@ -22,6 +22,7 @@ from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
get_pythonized_sample_results
)
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
IntermediateTensors
,
Logprob
,
SequenceGroupMetadata
,
SequenceOutput
)
from
vllm.utils
import
PyObjectCache
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
ModelInputForGPUWithSamplingMetadata
)
from
vllm.worker.model_runner_base
import
(
...
...
@@ -37,6 +38,29 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
def
seq_output_builder
():
return
SequenceOutput
(
0
,
0
,
{
0
:
Logprob
(
logprob
=
float
(
'inf'
),
rank
=
None
,
decoded_token
=
None
)})
def
completion_seq_group_output_builder
():
return
CompletionSequenceGroupOutput
([],
None
)
# Used by pythonization to reduce python object allocations
class
PythonizationCache
:
def
__init__
(
self
):
self
.
cached_seq_output
=
PyObjectCache
(
seq_output_builder
)
self
.
cached_completion_seq_group_output
=
PyObjectCache
(
completion_seq_group_output_builder
)
def
reset
(
self
):
self
.
cached_seq_output
.
reset
()
self
.
cached_completion_seq_group_output
.
reset
()
@
dataclass
class
ModelOutput
:
"""The output of a single model forward pass.
...
...
@@ -59,6 +83,7 @@ class ModelOutput:
pythonized
:
bool
=
False
# On-device tensor containing the logprobs of each token.
logprobs
:
Optional
[
"torch.Tensor"
]
=
None
pythonization_cache
:
Optional
[
PythonizationCache
]
=
None
def
pythonize
(
self
,
input_metadata
:
"StatefulModelInput"
,
copy_stream
:
torch
.
cuda
.
Stream
,
...
...
@@ -97,7 +122,8 @@ class ModelOutput:
with
torch
.
cuda
.
stream
(
copy_stream
):
_pythonize_sampler_output
(
input_metadata
,
self
.
sampler_output
,
pinned_sampled_token_buffer
,
self
.
sampled_token_ids
,
self
.
logprobs
)
self
.
sampled_token_ids
,
self
.
logprobs
,
self
.
pythonization_cache
)
# Erase the logprobs GPU-side tensor.
# Note that although _pythonize_sampler_output() runs in its
...
...
@@ -209,6 +235,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
self
.
_copy_stream
=
torch
.
cuda
.
Stream
()
self
.
pinned_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
self
.
pythonization_cache
=
PythonizationCache
()
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
tensor_dict
:
Dict
[
str
,
Any
])
->
StatefulModelInput
:
model_input
=
(
StatefulModelInput
.
from_broadcasted_tensor_dict
(
...
...
@@ -237,14 +265,22 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
output_proc_callback
:
Callable
):
# Proceed with pythonization and output_proc in order.
# Stop on the first one that fails to pythonize
output_proc_callback
()
cont
=
True
for
model_output
in
model_input
.
cached_outputs
:
if
not
model_output
.
pythonized
:
model_output
.
maybe_pythonize
(
model_input
,
self
.
_copy_stream
,
self
.
pinned_sampled_token_ids
)
if
model_output
.
pythonized
:
output_proc_callback
(
sampler_output
=
model_output
.
sampler_output
)
ctx
=
output_proc_callback
.
keywords
[
"ctx"
]
is_async
=
False
is_last_step
=
False
ctx
.
output_queue
.
append
(
([
model_output
.
sampler_output
],
ctx
.
seq_group_metadata_list
,
ctx
.
scheduler_outputs
,
is_async
,
is_last_step
))
output_proc_callback
()
else
:
cont
=
False
...
...
@@ -255,21 +291,46 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
output_proc_callback
:
Optional
[
Callable
]):
assert
model_input
.
frozen_model_input
is
not
None
has_async_callback
=
output_proc_callback
is
not
None
outputs
=
[]
for
output_id
in
range
(
len
(
model_input
.
cached_outputs
)):
is_last_output
=
output_id
==
len
(
model_input
.
cached_outputs
)
-
1
output
=
model_input
.
cached_outputs
[
output_id
]
if
not
output
.
pythonized
:
is_last_step
=
output_id
==
len
(
model_input
.
cached_outputs
)
-
1
# For non-async case:
# -- We simply add the outputs
# For async case:
# -- Invoke callback, pythonize, add to callback queue and repeat
# -- For last output, just add to callback queue
if
has_async_callback
:
assert
output_proc_callback
is
not
None
# Invoke callback before pythonize (to overlap with GPU)
output_proc_callback
()
# Pythonize
if
not
output
.
pythonized
:
output
.
pythonize
(
model_input
,
self
.
_copy_stream
,
self
.
pinned_sampled_token_ids
)
# For non last step, add to callback queue to chain
# callbacks=>pythonize pairs (for GPU overlap)
if
not
is_last_step
:
ctx
=
output_proc_callback
.
keywords
[
# type: ignore
"ctx"
]
# type: ignore
is_async
=
False
is_last_step
=
False
ctx
.
output_queue
.
append
(
([
output
.
sampler_output
],
ctx
.
seq_group_metadata_list
,
ctx
.
scheduler_outputs
,
is_async
,
is_last_step
))
else
:
outputs
.
append
(
output
.
sampler_output
)
else
:
output
.
pythonize
(
model_input
,
self
.
_copy_stream
,
self
.
pinned_sampled_token_ids
)
if
model_input
.
frozen_model_input
.
use_async_and_multi_step
:
assert
output_proc_callback
is
not
None
output_proc_callback
(
sampler_output
=
output
.
sampler_output
,
is_last_output
=
is_last_output
)
outputs
.
append
(
output
.
sampler_output
)
outputs
.
append
(
output
.
sampler_output
)
return
outputs
...
...
@@ -330,7 +391,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
model_input
,
model_input
.
cached_outputs
[
-
1
].
sampler_output
)
output_proc_callback
=
None
if
frozen_model_input
.
use_
async_
and_multi_step
:
if
frozen_model_input
.
async_
callback
is
not
None
:
output_proc_callback
=
frozen_model_input
.
async_callback
assert
output_proc_callback
is
not
None
async_callback
=
functools
.
partial
(
...
...
@@ -367,7 +428,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
model_input
.
cached_outputs
.
append
(
ModelOutput
(
output
[
0
],
output_ready_event
,
output
[
0
].
sampled_token_ids
,
False
,
output
[
0
].
logprobs
))
output
[
0
].
logprobs
,
self
.
pythonization_cache
))
# These GPU tensors are not required by multi-step;
# erase them to ensure they are not pythonized or
...
...
@@ -378,7 +439,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# Pythonize the output if CPU is ahead and the previous step is
# ready.
if
not
frozen_model_input
.
use_
async_
and_multi_step
:
if
frozen_model_input
.
async_
callback
is
None
:
for
model_output
in
model_input
.
cached_outputs
:
model_output
.
maybe_pythonize
(
model_input
,
self
.
_copy_stream
,
...
...
@@ -397,6 +458,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
if
model_input
.
is_last_step
:
outputs
=
self
.
_final_process_outputs
(
model_input
,
output_proc_callback
)
self
.
pythonization_cache
.
reset
()
return
outputs
# should be [SamplerOutput]
...
...
@@ -537,6 +599,7 @@ def _pythonize_sampler_output(
pinned_sampled_token_buffer
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
logprobs_tensor
:
Optional
[
torch
.
Tensor
],
cache
:
Optional
[
PythonizationCache
],
)
->
None
:
""" This function is only called when the output tensors are ready.
See :class:`ModelOutput`.
...
...
@@ -597,6 +660,9 @@ def _pythonize_sampler_output(
for
sgdx
,
(
seq_group
,
sample_result
)
in
enumerate
(
zip
(
seq_groups
,
samples_list
)):
if
seq_group
.
sampling_params
.
logits_processors
:
assert
len
(
seq_group
.
sampling_params
.
logits_processors
)
==
0
,
(
"Logits Processors are not supported in multi-step decoding"
)
if
do_pythonize_logprobs
:
assert
prompt_logprobs
is
not
None
...
...
@@ -621,23 +687,56 @@ def _pythonize_sampler_output(
seq_ids
=
seq_group
.
seq_ids
next_token_ids
=
sample_result
parent_ids
=
[
0
]
seq_outputs
:
List
[
SequenceOutput
]
=
[]
if
seq_group
.
sampling_params
.
logits_processors
:
assert
len
(
seq_group
.
sampling_params
.
logits_processors
)
==
0
,
(
"Logits Processors are not supported in multi-step decoding"
)
if
cache
is
not
None
:
completion_seq_group_output
:
CompletionSequenceGroupOutput
=
\
cache
.
cached_completion_seq_group_output
.
get_object
()
completion_seq_group_output
.
samples
.
clear
()
seq_outputs
:
List
[
SequenceOutput
]
=
completion_seq_group_output
.
samples
else
:
seq_outputs
=
[]
for
tdx
,
(
parent_id
,
next_token_id
)
in
enumerate
(
zip
(
parent_ids
,
next_token_ids
)):
seq_outputs
.
append
(
SequenceOutput
(
seq_ids
[
parent_id
],
next_token_id
,
(
group_sample_logprobs
[
tdx
]
if
logprobs_are_requested
else
{
next_token_id
:
Logprob
(
logprob
=
float
(
'inf'
),
rank
=
None
,
decoded_token
=
None
)
})))
output
.
outputs
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
(
group_prompt_logprobs
if
logprobs_are_requested
else
None
)))
if
cache
is
not
None
:
seq_output
:
SequenceOutput
=
cache
.
cached_seq_output
.
get_object
(
)
seq_output
.
parent_seq_id
=
seq_ids
[
parent_id
]
seq_output
.
output_token
=
next_token_id
if
logprobs_are_requested
:
seq_output
.
logprobs
=
group_sample_logprobs
[
tdx
]
else
:
logprobs
=
next
(
iter
(
seq_output
.
logprobs
.
values
()))
seq_output
.
logprobs
.
clear
()
logprobs
.
logprob
=
float
(
'inf'
)
logprobs
.
rank
=
None
logprobs
.
decoded_token
=
None
seq_output
.
logprobs
[
next_token_id
]
=
logprobs
seq_outputs
.
append
(
seq_output
)
else
:
seq_outputs
.
append
(
SequenceOutput
(
seq_ids
[
parent_id
],
next_token_id
,
(
group_sample_logprobs
[
tdx
]
if
logprobs_are_requested
else
{
next_token_id
:
Logprob
(
logprob
=
float
(
'inf'
),
rank
=
None
,
decoded_token
=
None
)
})))
if
cache
is
not
None
:
completion_seq_group_output
.
prompt_logprobs
=
\
group_prompt_logprobs
if
logprobs_are_requested
else
None
output
.
outputs
.
append
(
completion_seq_group_output
)
else
:
output
.
outputs
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
(
group_prompt_logprobs
if
logprobs_are_requested
else
None
)))
assert
len
(
output
.
outputs
)
>
0
vllm/worker/multi_step_worker.py
View file @
6d646d08
...
...
@@ -67,9 +67,7 @@ class MultiStepWorker(Worker):
if
execute_model_req
.
async_callback
:
model_input
.
frozen_model_input
=
dataclasses
.
replace
(
# type: ignore
model_input
.
frozen_model_input
,
async_callback
=
execute_model_req
.
async_callback
,
use_async_and_multi_step
=
execute_model_req
.
use_async_and_multi_step
)
async_callback
=
execute_model_req
.
async_callback
)
else
:
# on subsequent steps we reuse the worker input and model input
multi_step_state
=
self
.
multi_step_states
[
virtual_engine
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment