Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6d646d08
Unverified
Commit
6d646d08
authored
Sep 03, 2024
by
Alexander Matveev
Committed by
GitHub
Sep 03, 2024
Browse files
[Core] Optimize Async + Multi-step (#8050)
parent
95a178f8
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
326 additions
and
248 deletions
+326
-248
tests/multi_step/test_correctness_async_llm.py
tests/multi_step/test_correctness_async_llm.py
+2
-2
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+54
-55
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+97
-125
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+36
-26
vllm/sequence.py
vllm/sequence.py
+1
-3
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+3
-1
vllm/worker/multi_step_model_runner.py
vllm/worker/multi_step_model_runner.py
+132
-33
vllm/worker/multi_step_worker.py
vllm/worker/multi_step_worker.py
+1
-3
No files found.
tests/multi_step/test_correctness_async_llm.py
View file @
6d646d08
...
@@ -103,13 +103,13 @@ async def test_multi_step(
...
@@ -103,13 +103,13 @@ async def test_multi_step(
model
,
model
,
server_args
+
distributed_args
,
server_args
+
distributed_args
,
num_logprobs
,
num_logprobs
,
max_wait_seconds
=
3
*
240
)
max_wait_seconds
=
5
*
240
)
test_completions
=
await
completions_with_server_args
(
test_completions
=
await
completions_with_server_args
(
prompts
,
prompts
,
model
,
model
,
ms_server_args
+
distributed_args
,
ms_server_args
+
distributed_args
,
num_logprobs
,
num_logprobs
,
max_wait_seconds
=
3
*
240
)
max_wait_seconds
=
5
*
240
)
# Assert multi-step scheduling produces identical tokens
# Assert multi-step scheduling produces identical tokens
# to single-step scheduling.
# to single-step scheduling.
...
...
vllm/engine/async_llm_engine.py
View file @
6d646d08
...
@@ -280,40 +280,27 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -280,40 +280,27 @@ class _AsyncLLMEngine(LLMEngine):
scheduler_outputs
=
cached_outputs
.
scheduler_outputs
scheduler_outputs
=
cached_outputs
.
scheduler_outputs
allow_async_output_proc
=
cached_outputs
.
allow_async_output_proc
allow_async_output_proc
=
cached_outputs
.
allow_async_output_proc
# Detect async + multi-step
use_async_and_multi_step
=
(
self
.
scheduler_config
.
is_multi_step
and
allow_async_output_proc
)
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
# Clear outputs for each new scheduler iteration
ctx
.
request_outputs
.
clear
()
# skip the scheduler if there are any remaining steps in the seq groups.
# skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# This ensures that the scheduler is only called again when the current
# batch has completed.
# batch has completed.
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
# Clear outputs on scheduler iteration start
ctx
.
request_outputs
.
clear
()
# Schedule iteration
# Schedule iteration
(
seq_group_metadata_list
,
scheduler_outputs
,
(
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
allow_async_output_proc
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
# Detect async + multi-step
ctx
.
seq_group_metadata_list
=
seq_group_metadata_list
use_async_and_multi_step
=
(
self
.
scheduler_config
.
is_multi_step
ctx
.
scheduler_outputs
=
scheduler_outputs
and
allow_async_output_proc
)
# Maybe switch from async mode to sync mode
# Maybe switch from async mode to sync mode
if
not
allow_async_output_proc
and
len
(
ctx
.
output_queue
)
>
0
:
if
not
allow_async_output_proc
and
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
self
.
_process_model_outputs
(
ctx
=
ctx
)
is_async
=
True
)
# For async + multi-step, init the queue
if
use_async_and_multi_step
:
assert
len
(
ctx
.
output_queue
)
==
0
assert
seq_group_metadata_list
is
not
None
ctx
.
output_queue
.
append
(
(
None
,
seq_group_metadata_list
,
scheduler_outputs
))
if
(
self
.
scheduler_config
.
is_multi_step
if
(
self
.
scheduler_config
.
is_multi_step
and
scheduler_outputs
.
num_lookahead_slots
>
0
):
and
scheduler_outputs
.
num_lookahead_slots
>
0
):
...
@@ -351,26 +338,20 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -351,26 +338,20 @@ class _AsyncLLMEngine(LLMEngine):
last_sampled_token_ids
=
last_sampled_token_ids
)
last_sampled_token_ids
=
last_sampled_token_ids
)
if
allow_async_output_proc
:
if
allow_async_output_proc
:
async_callback
=
self
.
async_callback_multi_step
[
execute_model_req
.
async_callback
=
self
.
async_callbacks
[
virtual_engine
]
if
use_async_and_multi_step
\
virtual_engine
]
else
self
.
async_callback
[
virtual_engine
]
execute_model_req
.
async_callback
=
async_callback
execute_model_req
.
use_async_and_multi_step
=
\
use_async_and_multi_step
# Execute the model.
# Execute the model.
output
=
await
self
.
model_executor
.
execute_model_async
(
output
=
await
self
.
model_executor
.
execute_model_async
(
execute_model_req
)
execute_model_req
)
# we need to do this here so that last step's sampled_token_ids can
# we need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
# be passed to the next iteration for PP.
if
self
.
scheduler_config
.
is_multi_step
:
if
self
.
scheduler_config
.
is_multi_step
:
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
)
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
)
else
:
else
:
if
not
use_async_and_multi_step
and
len
(
ctx
.
output_queue
)
>
0
:
if
len
(
ctx
.
output_queue
)
>
0
:
assert
not
self
.
scheduler_config
.
is_multi_step
self
.
_process_model_outputs
(
ctx
=
ctx
)
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
output
=
[]
output
=
[]
# Finish the current step for all the sequence groups.
# Finish the current step for all the sequence groups.
...
@@ -384,24 +365,22 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -384,24 +365,22 @@ class _AsyncLLMEngine(LLMEngine):
self
.
cached_scheduler_outputs
[
self
.
cached_scheduler_outputs
[
virtual_engine
]
=
SchedulerOutputState
()
virtual_engine
]
=
SchedulerOutputState
()
if
use_async_and_multi_step
:
is_async
=
allow_async_output_proc
# For async + multi-step, clear the queue
is_last_step
=
True
ctx
.
output_queue
.
clear
()
ctx
.
output_queue
.
append
(
else
:
(
output
,
seq_group_metadata_list
,
scheduler_outputs
,
is_async
,
ctx
.
output_queue
.
append
(
is_last_step
))
(
output
,
seq_group_metadata_list
,
scheduler_outputs
))
if
output
and
allow_async_output_proc
:
if
output
and
allow_async_output_proc
:
assert
len
(
assert
len
(
output
output
)
==
1
,
"
Multi step decoding does not work with async output processing."
# noqa: E501
)
==
1
,
"
Async postprocessor expects only a single output set"
self
.
_advance_to_next_step
(
self
.
_advance_to_next_step
(
output
[
0
],
seq_group_metadata_list
,
output
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
)
if
not
allow_async_output_proc
:
if
not
allow_async_output_proc
:
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
self
.
_process_model_outputs
(
ctx
=
ctx
)
is_async
=
False
)
# Log stats.
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
output
)
self
.
do_log_stats
(
scheduler_outputs
,
output
)
...
@@ -411,17 +390,12 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -411,17 +390,12 @@ class _AsyncLLMEngine(LLMEngine):
else
:
else
:
# Multi-step case
# Multi-step case
if
use_async_and_multi_step
:
return
ctx
.
request_outputs
return
[]
else
:
ctx
.
request_outputs
=
[]
if
not
self
.
has_unfinished_requests
():
if
not
self
.
has_unfinished_requests
():
# Drain async postprocessor (if exists)
# Drain async postprocessor (if exists)
if
len
(
ctx
.
output_queue
)
>
0
:
if
len
(
ctx
.
output_queue
)
>
0
:
assert
not
self
.
scheduler_config
.
is_multi_step
self
.
_process_model_outputs
(
ctx
=
ctx
)
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
assert
len
(
ctx
.
output_queue
)
==
0
assert
len
(
ctx
.
output_queue
)
==
0
return
ctx
.
request_outputs
return
ctx
.
request_outputs
...
@@ -640,6 +614,17 @@ class AsyncLLMEngine:
...
@@ -640,6 +614,17 @@ class AsyncLLMEngine:
self
.
log_requests
=
log_requests
self
.
log_requests
=
log_requests
self
.
engine
=
self
.
_init_engine
(
*
args
,
**
kwargs
)
self
.
engine
=
self
.
_init_engine
(
*
args
,
**
kwargs
)
# This ensures quick processing of request outputs
# so the append to asyncio queues is not delayed,
# especially for multi-step.
#
# TODO: Currently, disabled for engine_use_ray, ask
# Cody/Will/Woosuk about this case.
self
.
use_process_request_outputs_callback
=
not
self
.
engine_use_ray
if
self
.
use_process_request_outputs_callback
:
self
.
engine
.
process_request_outputs_callback
=
\
self
.
process_request_outputs
if
self
.
engine_use_ray
:
if
self
.
engine_use_ray
:
print_warning_once
(
print_warning_once
(
"DEPRECATED. `--engine-use-ray` is deprecated and will "
"DEPRECATED. `--engine-use-ray` is deprecated and will "
...
@@ -883,13 +868,27 @@ class AsyncLLMEngine:
...
@@ -883,13 +868,27 @@ class AsyncLLMEngine:
request_outputs
=
await
self
.
engine
.
step_async
(
virtual_engine
)
request_outputs
=
await
self
.
engine
.
step_async
(
virtual_engine
)
# Put the outputs into the corresponding streams.
# Put the outputs into the corresponding streams.
finished
=
True
# If used as a callback, then already invoked inside
# LLMEngine's _process_model_outputs
if
not
self
.
use_process_request_outputs_callback
:
all_finished
=
self
.
process_request_outputs
(
request_outputs
)
else
:
# For callback case, we only need to detect when all
# requests are finished
all_finished
=
all
(
request_output
.
finished
for
request_output
in
request_outputs
)
return
not
all_finished
def
process_request_outputs
(
self
,
request_outputs
)
->
bool
:
# Put the outputs into the corresponding streams.
all_finished
=
True
for
request_output
in
request_outputs
:
for
request_output
in
request_outputs
:
self
.
_request_tracker
.
process_request_output
(
self
.
_request_tracker
.
process_request_output
(
request_output
,
verbose
=
self
.
log_requests
)
request_output
,
verbose
=
self
.
log_requests
)
finished
=
finished
and
request_output
.
finished
all_
finished
=
all_
finished
and
request_output
.
finished
return
not
finished
return
all_
finished
async
def
_engine_abort
(
self
,
request_ids
:
Iterable
[
str
]):
async
def
_engine_abort
(
self
,
request_ids
:
Iterable
[
str
]):
if
self
.
engine_use_ray
:
if
self
.
engine_use_ray
:
...
...
vllm/engine/llm_engine.py
View file @
6d646d08
...
@@ -93,13 +93,14 @@ class SchedulerOutputState:
...
@@ -93,13 +93,14 @@ class SchedulerOutputState:
@
dataclass
@
dataclass
class
SchedulerContext
:
class
SchedulerContext
:
output_queue
:
Deque
[
Tuple
[
Optional
[
List
[
SamplerOutput
]],
output_queue
:
Deque
[
Tuple
[
Optional
[
List
[
SamplerOutput
]],
List
[
SequenceGroupMetadata
],
List
[
SequenceGroupMetadata
],
SchedulerOutputs
,
SchedulerOutputs
]]
=
field
(
bool
,
default_factory
=
lambda
:
deque
())
bool
]]
=
field
(
default_factory
=
lambda
:
deque
())
request_outputs
:
List
[
Union
[
RequestOutput
,
request_outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
field
(
EmbeddingRequestOutput
]]
=
field
(
default_factory
=
lambda
:
[])
default_factory
=
lambda
:
[])
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
class
LLMEngine
:
class
LLMEngine
:
...
@@ -357,6 +358,26 @@ class LLMEngine:
...
@@ -357,6 +358,26 @@ class LLMEngine:
# different process.
# different process.
self
.
tokenizer
.
ping
()
self
.
tokenizer
.
ping
()
self
.
cached_scheduler_outputs
=
[
SchedulerOutputState
()
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
self
.
scheduler_contexts
=
[
SchedulerContext
()
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
self
.
async_callbacks
=
[
functools
.
partial
(
self
.
_process_model_outputs
,
ctx
=
self
.
scheduler_contexts
[
v_id
])
for
v_id
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
# Currently used by AsyncLLMEngine to ensure quick append
# of request outputs to asyncio queues
self
.
process_request_outputs_callback
=
None
# Create the scheduler.
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
# GPU and CPU blocks, which are profiled in the distributed executor.
...
@@ -364,9 +385,7 @@ class LLMEngine:
...
@@ -364,9 +385,7 @@ class LLMEngine:
Scheduler
(
Scheduler
(
scheduler_config
,
cache_config
,
lora_config
,
scheduler_config
,
cache_config
,
lora_config
,
parallel_config
.
pipeline_parallel_size
,
parallel_config
.
pipeline_parallel_size
,
functools
.
partial
(
self
.
_process_model_outputs
,
self
.
async_callbacks
[
v_id
]
virtual_engine
=
v_id
,
is_async
=
True
)
if
model_config
.
use_async_output_proc
else
None
)
if
model_config
.
use_async_output_proc
else
None
)
for
v_id
in
range
(
parallel_config
.
pipeline_parallel_size
)
for
v_id
in
range
(
parallel_config
.
pipeline_parallel_size
)
]
]
...
@@ -417,30 +436,6 @@ class LLMEngine:
...
@@ -417,30 +436,6 @@ class LLMEngine:
),
),
))
))
self
.
cached_scheduler_outputs
=
[
SchedulerOutputState
()
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
self
.
scheduler_contexts
=
[
SchedulerContext
()
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
self
.
async_callback
=
[
functools
.
partial
(
self
.
_process_model_outputs
,
virtual_engine
=
v_id
,
is_async
=
True
)
for
v_id
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
self
.
async_callback_multi_step
=
[
functools
.
partial
(
self
.
_process_model_outputs
,
virtual_engine
=
v_id
,
is_async
=
False
)
for
v_id
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
def
_initialize_kv_caches
(
self
)
->
None
:
def
_initialize_kv_caches
(
self
)
->
None
:
"""Initialize the KV cache in the worker(s).
"""Initialize the KV cache in the worker(s).
...
@@ -1249,11 +1244,7 @@ class LLMEngine:
...
@@ -1249,11 +1244,7 @@ class LLMEngine:
return
return
def
_process_model_outputs
(
self
,
def
_process_model_outputs
(
self
,
ctx
:
SchedulerContext
)
->
None
:
virtual_engine
:
int
,
is_async
:
bool
,
sampler_output
:
Optional
[
SamplerOutput
]
=
None
,
is_last_output
:
bool
=
False
)
->
None
:
"""Apply the model output to the sequences in the scheduled seq groups.
"""Apply the model output to the sequences in the scheduled seq groups.
virtual_engine: The engine id to operate on
virtual_engine: The engine id to operate on
...
@@ -1273,24 +1264,12 @@ class LLMEngine:
...
@@ -1273,24 +1264,12 @@ class LLMEngine:
"""
"""
now
=
time
.
time
()
now
=
time
.
time
()
is_multi_step
=
sampler_output
is
not
None
ctx
:
SchedulerContext
=
self
.
scheduler_contexts
[
virtual_engine
]
if
len
(
ctx
.
output_queue
)
==
0
:
if
len
(
ctx
.
output_queue
)
==
0
:
return
None
return
None
if
is_multi_step
:
# Get pending async postprocessor
# Async + multi-step case
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
,
is_async
,
(
outputs
,
seq_group_metadata_list
,
is_last_step
)
=
ctx
.
output_queue
.
popleft
()
scheduler_outputs
)
=
ctx
.
output_queue
[
0
]
assert
outputs
is
None
outputs
=
[
sampler_output
]
else
:
# Async standard case
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
)
=
ctx
.
output_queue
.
popleft
()
assert
outputs
is
not
None
assert
outputs
is
not
None
# Sanity check
# Sanity check
...
@@ -1306,6 +1285,7 @@ class LLMEngine:
...
@@ -1306,6 +1285,7 @@ class LLMEngine:
outputs_by_sequence_group
=
outputs
outputs_by_sequence_group
=
outputs
finished_before
:
List
[
int
]
=
[]
finished_before
:
List
[
int
]
=
[]
finished_now
:
List
[
int
]
=
[]
for
i
,
seq_group_meta
in
enumerate
(
seq_group_metadata_list
):
for
i
,
seq_group_meta
in
enumerate
(
seq_group_metadata_list
):
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
...
@@ -1343,26 +1323,44 @@ class LLMEngine:
...
@@ -1343,26 +1323,44 @@ class LLMEngine:
if
self
.
model_config
.
embedding_mode
:
if
self
.
model_config
.
embedding_mode
:
self
.
_process_sequence_group_outputs
(
seq_group
,
output
)
self
.
_process_sequence_group_outputs
(
seq_group
,
output
)
continue
else
:
self
.
output_processor
.
process_prompt_logprob
(
seq_group
,
output
)
if
seq_group_meta
.
do_sample
:
self
.
output_processor
.
process_outputs
(
seq_group
,
output
,
is_async
)
self
.
output_processor
.
process_prompt_logprob
(
seq_group
,
output
)
if
seq_group
.
is_finished
():
if
seq_group_meta
.
do_sample
:
finished_now
.
append
(
i
)
self
.
output_processor
.
process_outputs
(
seq_group
,
output
,
is_async
)
# For async + multi-step, free finished seqs and create outputs
# Generate outputs for the requests that finished this iteration
# only on the final step.
for
i
in
finished_now
:
if
is_multi_step
and
not
is_last_output
:
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
return
for
scheduler
in
self
.
scheduler
:
seq_group
=
scheduled_seq_group
.
seq_group
scheduler
.
free_finished_seq_groups
()
seq_group
.
maybe_set_first_token_time
(
now
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
ctx
.
request_outputs
.
append
(
request_output
)
# Create the outputs.
# Free currently finished requests
for
i
,
_
in
enumerate
(
seq_group_metadata_list
):
if
finished_now
:
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
for
scheduler
in
self
.
scheduler
:
scheduler
.
free_finished_seq_groups
()
# For multi-step, do not create outputs each iteration
if
not
is_last_step
:
# Immediately process request outputs here (if callback is given)
if
(
finished_now
and
self
.
process_request_outputs_callback
is
not
None
):
self
.
process_request_outputs_callback
(
ctx
.
request_outputs
)
return
# Create the outputs
# Note: scheduled_seq_groups and seq_group_metadata_list
# must match with the indices
for
i
,
scheduled_seq_group
in
enumerate
(
scheduler_outputs
.
scheduled_seq_groups
):
if
not
is_multi_step
and
i
in
finished_
before
:
if
i
in
finished_before
or
i
in
finished_
now
:
continue
# Avoids double processing
continue
# Avoids double processing
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
...
@@ -1376,11 +1374,15 @@ class LLMEngine:
...
@@ -1376,11 +1374,15 @@ class LLMEngine:
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
ctx
.
request_outputs
.
append
(
request_output
)
ctx
.
request_outputs
.
append
(
request_output
)
# For async + multi-step, do stats only on the last output.
# Immediately process request outputs here (if callback is given)
# Otherwise, do stats if the execution is async
if
(
ctx
.
request_outputs
do_stats
=
is_multi_step
or
is_async
and
self
.
process_request_outputs_callback
is
not
None
):
self
.
process_request_outputs_callback
(
ctx
.
request_outputs
)
if
do_stats
:
# For async case, we need to record the stats here.
# For non-async case, the stats are done in the
# LLMEngine/AsyncLLMEngine directly
if
is_async
:
# Log stats.
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
outputs
,
finished_before
)
self
.
do_log_stats
(
scheduler_outputs
,
outputs
,
finished_before
)
...
@@ -1485,40 +1487,26 @@ class LLMEngine:
...
@@ -1485,40 +1487,26 @@ class LLMEngine:
scheduler_outputs
=
cached_outputs
.
scheduler_outputs
scheduler_outputs
=
cached_outputs
.
scheduler_outputs
allow_async_output_proc
=
cached_outputs
.
allow_async_output_proc
allow_async_output_proc
=
cached_outputs
.
allow_async_output_proc
# Detect async + multi-step
use_async_and_multi_step
=
(
self
.
scheduler_config
.
is_multi_step
and
allow_async_output_proc
)
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
# Clear outputs for each new scheduler iteration
ctx
.
request_outputs
.
clear
()
# Skip the scheduler if there are any remaining steps in the seq groups.
# Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# This ensures that the scheduler is only called again when the current
# batch has completed.
# batch has completed.
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
# Clear outputs on scheduler iteration start
ctx
.
request_outputs
.
clear
()
# Schedule iteration
# Schedule iteration
(
seq_group_metadata_list
,
scheduler_outputs
,
(
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
allow_async_output_proc
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
# Detect async + multi-step
ctx
.
seq_group_metadata_list
=
seq_group_metadata_list
use_async_and_multi_step
=
(
self
.
scheduler_config
.
is_multi_step
ctx
.
scheduler_outputs
=
scheduler_outputs
and
allow_async_output_proc
)
# Maybe switch from async mode to sync mode
# Maybe switch from async mode to sync mode
if
not
allow_async_output_proc
and
len
(
ctx
.
output_queue
)
>
0
:
if
not
allow_async_output_proc
and
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
self
.
_process_model_outputs
(
ctx
=
ctx
)
is_async
=
True
)
# For async + multi-step, init the queue
if
use_async_and_multi_step
:
assert
len
(
ctx
.
output_queue
)
==
0
assert
seq_group_metadata_list
is
not
None
ctx
.
output_queue
.
append
(
(
None
,
seq_group_metadata_list
,
scheduler_outputs
))
if
(
self
.
scheduler_config
.
is_multi_step
if
(
self
.
scheduler_config
.
is_multi_step
and
scheduler_outputs
.
num_lookahead_slots
>
0
):
and
scheduler_outputs
.
num_lookahead_slots
>
0
):
...
@@ -1555,13 +1543,8 @@ class LLMEngine:
...
@@ -1555,13 +1543,8 @@ class LLMEngine:
last_sampled_token_ids
=
last_sampled_token_ids
)
last_sampled_token_ids
=
last_sampled_token_ids
)
if
allow_async_output_proc
:
if
allow_async_output_proc
:
async_callback
=
self
.
async_callback_multi_step
[
execute_model_req
.
async_callback
=
self
.
async_callbacks
[
virtual_engine
]
if
use_async_and_multi_step
\
virtual_engine
]
else
self
.
async_callback
[
virtual_engine
]
execute_model_req
.
async_callback
=
async_callback
execute_model_req
.
use_async_and_multi_step
=
\
use_async_and_multi_step
output
=
self
.
model_executor
.
execute_model
(
output
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
execute_model_req
=
execute_model_req
)
...
@@ -1573,10 +1556,8 @@ class LLMEngine:
...
@@ -1573,10 +1556,8 @@ class LLMEngine:
else
:
else
:
# Nothing scheduled => If there is pending async postprocessor,
# Nothing scheduled => If there is pending async postprocessor,
# then finish it here.
# then finish it here.
if
not
use_async_and_multi_step
and
len
(
ctx
.
output_queue
)
>
0
:
if
len
(
ctx
.
output_queue
)
>
0
:
assert
not
self
.
scheduler_config
.
is_multi_step
self
.
_process_model_outputs
(
ctx
=
ctx
)
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
# No outputs in this case
# No outputs in this case
output
=
[]
output
=
[]
...
@@ -1590,28 +1571,24 @@ class LLMEngine:
...
@@ -1590,28 +1571,24 @@ class LLMEngine:
if
self
.
scheduler_config
.
is_multi_step
:
if
self
.
scheduler_config
.
is_multi_step
:
self
.
cached_scheduler_outputs
[
0
]
=
SchedulerOutputState
()
self
.
cached_scheduler_outputs
[
0
]
=
SchedulerOutputState
()
if
use_async_and_multi_step
:
# Add results to the output_queue
# For async + multi-step, clear the queue
is_async
=
allow_async_output_proc
ctx
.
output_queue
.
clear
()
is_last_step
=
True
else
:
ctx
.
output_queue
.
append
(
# Add results to the output_queue
(
output
,
seq_group_metadata_list
,
scheduler_outputs
,
is_async
,
# (for async or non-async postprocessing)
is_last_step
))
ctx
.
output_queue
.
append
(
(
output
,
seq_group_metadata_list
,
scheduler_outputs
))
if
output
and
allow_async_output_proc
:
if
output
and
allow_async_output_proc
:
assert
len
(
output
)
==
1
,
(
assert
len
(
output
)
==
1
,
(
"Multi step decoding does not work "
"Async postprocessor expects only a single output set"
)
"with async output processing."
)
self
.
_advance_to_next_step
(
self
.
_advance_to_next_step
(
output
[
0
],
seq_group_metadata_list
,
output
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
)
# Check if need to run the usual non-async path
# Check if need to run the usual non-async path
if
not
allow_async_output_proc
:
if
not
allow_async_output_proc
:
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
self
.
_process_model_outputs
(
ctx
=
ctx
)
is_async
=
False
)
# Log stats.
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
output
)
self
.
do_log_stats
(
scheduler_outputs
,
output
)
...
@@ -1620,17 +1597,12 @@ class LLMEngine:
...
@@ -1620,17 +1597,12 @@ class LLMEngine:
self
.
do_tracing
(
scheduler_outputs
)
self
.
do_tracing
(
scheduler_outputs
)
else
:
else
:
# Multi-step case
# Multi-step case
if
use_async_and_multi_step
:
return
ctx
.
request_outputs
return
[]
else
:
ctx
.
request_outputs
=
[]
if
not
self
.
has_unfinished_requests
():
if
not
self
.
has_unfinished_requests
():
# Drain async postprocessor (if exists)
# Drain async postprocessor (if exists)
if
len
(
ctx
.
output_queue
)
>
0
:
if
len
(
ctx
.
output_queue
)
>
0
:
assert
not
self
.
scheduler_config
.
is_multi_step
self
.
_process_model_outputs
(
ctx
=
ctx
)
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
assert
len
(
ctx
.
output_queue
)
==
0
assert
len
(
ctx
.
output_queue
)
==
0
# Stop the execute model loop in parallel workers until there are
# Stop the execute model loop in parallel workers until there are
...
...
vllm/engine/output_processor/multi_step.py
View file @
6d646d08
...
@@ -85,9 +85,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -85,9 +85,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
no tokens need to be appended since it is already done
no tokens need to be appended since it is already done
externally (before the next schedule() call)
externally (before the next schedule() call)
"""
"""
# TODO: Add support for async if necessary
assert
not
is_async
# Sequences can be in RUNNING or FINISHED_ABORTED state
# Sequences can be in RUNNING or FINISHED_ABORTED state
# once scheduled, as a sequence is moved to FINSIHED_ABORTED
# once scheduled, as a sequence is moved to FINSIHED_ABORTED
# if a client disconnects from the api server.
# if a client disconnects from the api server.
...
@@ -101,19 +98,41 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -101,19 +98,41 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
"Beam search not supported in multi-step decoding."
)
"Beam search not supported in multi-step decoding."
)
seq
=
seqs
[
0
]
seq
=
seqs
[
0
]
# Since there's only one sequence per sequence group, we can take the
if
is_async
:
# first sample.
# Async case: We process tokens one by one. Here, we know the token
samples
=
[
output
.
samples
[
0
]
for
output
in
outputs
]
# was already appended, so we only need to do the rest of the
# postprocessor: Detokenization + stopping logic
# -1 means the output token is not valid (eg. due to spec decode
self
.
_process_decode_and_stop
(
seq
,
sequence_group
.
sampling_params
)
# rejecting tokens).
else
:
valid_samples
=
[
# Standard multi-step case
sample
for
sample
in
samples
if
sample
.
output_token
!=
-
1
]
# Since there's only one sequence per sequence group,
assert
valid_samples
# we can take the first sample.
samples
=
[
output
.
samples
[
0
]
for
output
in
outputs
]
self
.
_process_seq_outputs
(
seq
,
valid_samples
,
sequence_group
.
sampling_params
)
# -1 means the output token is not valid (eg. due to spec decode
# rejecting tokens).
valid_samples
=
[
sample
for
sample
in
samples
if
sample
.
output_token
!=
-
1
]
assert
valid_samples
self
.
_process_seq_outputs
(
seq
,
valid_samples
,
sequence_group
.
sampling_params
)
def
_process_decode_and_stop
(
self
,
seq
:
Sequence
,
sampling_params
:
SamplingParams
)
->
None
:
new_char_count
=
0
if
sampling_params
.
detokenize
:
new_char_count
=
self
.
detokenizer
.
decode_sequence_inplace
(
seq
,
sampling_params
)
# TODO(sang): Support lora.
self
.
stop_checker
.
maybe_stop_sequence
(
seq
,
new_char_count
=
new_char_count
,
sampling_params
=
sampling_params
,
)
def
_process_seq_outputs
(
self
,
seq
:
Sequence
,
def
_process_seq_outputs
(
self
,
seq
:
Sequence
,
valid_samples
:
List
[
SequenceOutput
],
valid_samples
:
List
[
SequenceOutput
],
...
@@ -151,16 +170,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -151,16 +170,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
logprobs
=
output_logprob
,
logprobs
=
output_logprob
,
)
)
new_char_count
=
0
self
.
_process_decode_and_stop
(
seq
,
sampling_params
)
if
sampling_params
.
detokenize
:
new_char_count
=
self
.
detokenizer
.
decode_sequence_inplace
(
seq
,
sampling_params
)
# TODO(sang): Support lora.
self
.
stop_checker
.
maybe_stop_sequence
(
seq
,
new_char_count
=
new_char_count
,
sampling_params
=
sampling_params
,
)
if
seq
.
is_finished
():
if
seq
.
is_finished
():
break
break
vllm/sequence.py
View file @
6d646d08
...
@@ -1225,7 +1225,6 @@ class ExecuteModelRequest(
...
@@ -1225,7 +1225,6 @@ class ExecuteModelRequest(
last_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
last_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
# Async callback
# Async callback
async_callback
:
Optional
[
Callable
]
=
None
async_callback
:
Optional
[
Callable
]
=
None
use_async_and_multi_step
:
bool
=
False
@
property
@
property
def
is_first_multi_step
(
self
)
->
bool
:
def
is_first_multi_step
(
self
)
->
bool
:
...
@@ -1272,5 +1271,4 @@ class ExecuteModelRequest(
...
@@ -1272,5 +1271,4 @@ class ExecuteModelRequest(
finished_requests_ids
=
self
.
finished_requests_ids
,
finished_requests_ids
=
self
.
finished_requests_ids
,
last_sampled_token_ids
=
self
.
last_sampled_token_ids
.
clone
()
last_sampled_token_ids
=
self
.
last_sampled_token_ids
.
clone
()
if
self
.
last_sampled_token_ids
is
not
None
else
None
,
if
self
.
last_sampled_token_ids
is
not
None
else
None
,
async_callback
=
self
.
async_callback
,
async_callback
=
self
.
async_callback
)
use_async_and_multi_step
=
self
.
use_async_and_multi_step
)
vllm/worker/model_runner.py
View file @
6d646d08
...
@@ -21,6 +21,7 @@ from vllm.attention.backends.utils import CommonAttentionState
...
@@ -21,6 +21,7 @@ from vllm.attention.backends.utils import CommonAttentionState
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
)
PromptAdapterConfig
,
SchedulerConfig
)
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.distributed
import
get_pp_group
from
vllm.distributed
import
get_pp_group
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
...
@@ -96,7 +97,8 @@ class ModelInputForGPU(ModelRunnerInputBase):
...
@@ -96,7 +97,8 @@ class ModelInputForGPU(ModelRunnerInputBase):
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
virtual_engine
:
int
=
0
virtual_engine
:
int
=
0
async_callback
:
Optional
[
Callable
]
=
None
async_callback
:
Optional
[
Callable
]
=
None
use_async_and_multi_step
:
bool
=
False
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
tensor_dict
=
{
...
...
vllm/worker/multi_step_model_runner.py
View file @
6d646d08
...
@@ -22,6 +22,7 @@ from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
...
@@ -22,6 +22,7 @@ from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
get_pythonized_sample_results
)
get_pythonized_sample_results
)
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
IntermediateTensors
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
IntermediateTensors
,
Logprob
,
SequenceGroupMetadata
,
SequenceOutput
)
Logprob
,
SequenceGroupMetadata
,
SequenceOutput
)
from
vllm.utils
import
PyObjectCache
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
ModelInputForGPUWithSamplingMetadata
)
ModelInputForGPUWithSamplingMetadata
)
from
vllm.worker.model_runner_base
import
(
from
vllm.worker.model_runner_base
import
(
...
@@ -37,6 +38,29 @@ if TYPE_CHECKING:
...
@@ -37,6 +38,29 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
def
seq_output_builder
():
return
SequenceOutput
(
0
,
0
,
{
0
:
Logprob
(
logprob
=
float
(
'inf'
),
rank
=
None
,
decoded_token
=
None
)})
def
completion_seq_group_output_builder
():
return
CompletionSequenceGroupOutput
([],
None
)
# Used by pythonization to reduce python object allocations
class
PythonizationCache
:
def
__init__
(
self
):
self
.
cached_seq_output
=
PyObjectCache
(
seq_output_builder
)
self
.
cached_completion_seq_group_output
=
PyObjectCache
(
completion_seq_group_output_builder
)
def
reset
(
self
):
self
.
cached_seq_output
.
reset
()
self
.
cached_completion_seq_group_output
.
reset
()
@
dataclass
@
dataclass
class
ModelOutput
:
class
ModelOutput
:
"""The output of a single model forward pass.
"""The output of a single model forward pass.
...
@@ -59,6 +83,7 @@ class ModelOutput:
...
@@ -59,6 +83,7 @@ class ModelOutput:
pythonized
:
bool
=
False
pythonized
:
bool
=
False
# On-device tensor containing the logprobs of each token.
# On-device tensor containing the logprobs of each token.
logprobs
:
Optional
[
"torch.Tensor"
]
=
None
logprobs
:
Optional
[
"torch.Tensor"
]
=
None
pythonization_cache
:
Optional
[
PythonizationCache
]
=
None
def
pythonize
(
self
,
input_metadata
:
"StatefulModelInput"
,
def
pythonize
(
self
,
input_metadata
:
"StatefulModelInput"
,
copy_stream
:
torch
.
cuda
.
Stream
,
copy_stream
:
torch
.
cuda
.
Stream
,
...
@@ -97,7 +122,8 @@ class ModelOutput:
...
@@ -97,7 +122,8 @@ class ModelOutput:
with
torch
.
cuda
.
stream
(
copy_stream
):
with
torch
.
cuda
.
stream
(
copy_stream
):
_pythonize_sampler_output
(
input_metadata
,
self
.
sampler_output
,
_pythonize_sampler_output
(
input_metadata
,
self
.
sampler_output
,
pinned_sampled_token_buffer
,
pinned_sampled_token_buffer
,
self
.
sampled_token_ids
,
self
.
logprobs
)
self
.
sampled_token_ids
,
self
.
logprobs
,
self
.
pythonization_cache
)
# Erase the logprobs GPU-side tensor.
# Erase the logprobs GPU-side tensor.
# Note that although _pythonize_sampler_output() runs in its
# Note that although _pythonize_sampler_output() runs in its
...
@@ -209,6 +235,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -209,6 +235,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
self
.
_copy_stream
=
torch
.
cuda
.
Stream
()
self
.
_copy_stream
=
torch
.
cuda
.
Stream
()
self
.
pinned_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
self
.
pinned_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
self
.
pythonization_cache
=
PythonizationCache
()
def
make_model_input_from_broadcasted_tensor_dict
(
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
tensor_dict
:
Dict
[
str
,
Any
])
->
StatefulModelInput
:
self
,
tensor_dict
:
Dict
[
str
,
Any
])
->
StatefulModelInput
:
model_input
=
(
StatefulModelInput
.
from_broadcasted_tensor_dict
(
model_input
=
(
StatefulModelInput
.
from_broadcasted_tensor_dict
(
...
@@ -237,14 +265,22 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -237,14 +265,22 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
output_proc_callback
:
Callable
):
output_proc_callback
:
Callable
):
# Proceed with pythonization and output_proc in order.
# Proceed with pythonization and output_proc in order.
# Stop on the first one that fails to pythonize
# Stop on the first one that fails to pythonize
output_proc_callback
()
cont
=
True
cont
=
True
for
model_output
in
model_input
.
cached_outputs
:
for
model_output
in
model_input
.
cached_outputs
:
if
not
model_output
.
pythonized
:
if
not
model_output
.
pythonized
:
model_output
.
maybe_pythonize
(
model_input
,
self
.
_copy_stream
,
model_output
.
maybe_pythonize
(
model_input
,
self
.
_copy_stream
,
self
.
pinned_sampled_token_ids
)
self
.
pinned_sampled_token_ids
)
if
model_output
.
pythonized
:
if
model_output
.
pythonized
:
output_proc_callback
(
ctx
=
output_proc_callback
.
keywords
[
"ctx"
]
sampler_output
=
model_output
.
sampler_output
)
is_async
=
False
is_last_step
=
False
ctx
.
output_queue
.
append
(
([
model_output
.
sampler_output
],
ctx
.
seq_group_metadata_list
,
ctx
.
scheduler_outputs
,
is_async
,
is_last_step
))
output_proc_callback
()
else
:
else
:
cont
=
False
cont
=
False
...
@@ -255,21 +291,46 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -255,21 +291,46 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
output_proc_callback
:
Optional
[
Callable
]):
output_proc_callback
:
Optional
[
Callable
]):
assert
model_input
.
frozen_model_input
is
not
None
assert
model_input
.
frozen_model_input
is
not
None
has_async_callback
=
output_proc_callback
is
not
None
outputs
=
[]
outputs
=
[]
for
output_id
in
range
(
len
(
model_input
.
cached_outputs
)):
for
output_id
in
range
(
len
(
model_input
.
cached_outputs
)):
is_last_output
=
output_id
==
len
(
model_input
.
cached_outputs
)
-
1
output
=
model_input
.
cached_outputs
[
output_id
]
output
=
model_input
.
cached_outputs
[
output_id
]
if
not
output
.
pythonized
:
is_last_step
=
output_id
==
len
(
model_input
.
cached_outputs
)
-
1
# For non-async case:
# -- We simply add the outputs
# For async case:
# -- Invoke callback, pythonize, add to callback queue and repeat
# -- For last output, just add to callback queue
if
has_async_callback
:
assert
output_proc_callback
is
not
None
# Invoke callback before pythonize (to overlap with GPU)
output_proc_callback
()
# Pythonize
if
not
output
.
pythonized
:
output
.
pythonize
(
model_input
,
self
.
_copy_stream
,
self
.
pinned_sampled_token_ids
)
# For non last step, add to callback queue to chain
# callbacks=>pythonize pairs (for GPU overlap)
if
not
is_last_step
:
ctx
=
output_proc_callback
.
keywords
[
# type: ignore
"ctx"
]
# type: ignore
is_async
=
False
is_last_step
=
False
ctx
.
output_queue
.
append
(
([
output
.
sampler_output
],
ctx
.
seq_group_metadata_list
,
ctx
.
scheduler_outputs
,
is_async
,
is_last_step
))
else
:
outputs
.
append
(
output
.
sampler_output
)
else
:
output
.
pythonize
(
model_input
,
self
.
_copy_stream
,
output
.
pythonize
(
model_input
,
self
.
_copy_stream
,
self
.
pinned_sampled_token_ids
)
self
.
pinned_sampled_token_ids
)
outputs
.
append
(
output
.
sampler_output
)
if
model_input
.
frozen_model_input
.
use_async_and_multi_step
:
assert
output_proc_callback
is
not
None
output_proc_callback
(
sampler_output
=
output
.
sampler_output
,
is_last_output
=
is_last_output
)
outputs
.
append
(
output
.
sampler_output
)
return
outputs
return
outputs
...
@@ -330,7 +391,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -330,7 +391,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
model_input
,
model_input
.
cached_outputs
[
-
1
].
sampler_output
)
model_input
,
model_input
.
cached_outputs
[
-
1
].
sampler_output
)
output_proc_callback
=
None
output_proc_callback
=
None
if
frozen_model_input
.
use_
async_
and_multi_step
:
if
frozen_model_input
.
async_
callback
is
not
None
:
output_proc_callback
=
frozen_model_input
.
async_callback
output_proc_callback
=
frozen_model_input
.
async_callback
assert
output_proc_callback
is
not
None
assert
output_proc_callback
is
not
None
async_callback
=
functools
.
partial
(
async_callback
=
functools
.
partial
(
...
@@ -367,7 +428,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -367,7 +428,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
model_input
.
cached_outputs
.
append
(
model_input
.
cached_outputs
.
append
(
ModelOutput
(
output
[
0
],
output_ready_event
,
ModelOutput
(
output
[
0
],
output_ready_event
,
output
[
0
].
sampled_token_ids
,
False
,
output
[
0
].
sampled_token_ids
,
False
,
output
[
0
].
logprobs
))
output
[
0
].
logprobs
,
self
.
pythonization_cache
))
# These GPU tensors are not required by multi-step;
# These GPU tensors are not required by multi-step;
# erase them to ensure they are not pythonized or
# erase them to ensure they are not pythonized or
...
@@ -378,7 +439,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -378,7 +439,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# Pythonize the output if CPU is ahead and the previous step is
# Pythonize the output if CPU is ahead and the previous step is
# ready.
# ready.
if
not
frozen_model_input
.
use_
async_
and_multi_step
:
if
frozen_model_input
.
async_
callback
is
None
:
for
model_output
in
model_input
.
cached_outputs
:
for
model_output
in
model_input
.
cached_outputs
:
model_output
.
maybe_pythonize
(
model_input
,
model_output
.
maybe_pythonize
(
model_input
,
self
.
_copy_stream
,
self
.
_copy_stream
,
...
@@ -397,6 +458,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -397,6 +458,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
if
model_input
.
is_last_step
:
if
model_input
.
is_last_step
:
outputs
=
self
.
_final_process_outputs
(
model_input
,
outputs
=
self
.
_final_process_outputs
(
model_input
,
output_proc_callback
)
output_proc_callback
)
self
.
pythonization_cache
.
reset
()
return
outputs
return
outputs
# should be [SamplerOutput]
# should be [SamplerOutput]
...
@@ -537,6 +599,7 @@ def _pythonize_sampler_output(
...
@@ -537,6 +599,7 @@ def _pythonize_sampler_output(
pinned_sampled_token_buffer
:
torch
.
Tensor
,
pinned_sampled_token_buffer
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
logprobs_tensor
:
Optional
[
torch
.
Tensor
],
logprobs_tensor
:
Optional
[
torch
.
Tensor
],
cache
:
Optional
[
PythonizationCache
],
)
->
None
:
)
->
None
:
""" This function is only called when the output tensors are ready.
""" This function is only called when the output tensors are ready.
See :class:`ModelOutput`.
See :class:`ModelOutput`.
...
@@ -597,6 +660,9 @@ def _pythonize_sampler_output(
...
@@ -597,6 +660,9 @@ def _pythonize_sampler_output(
for
sgdx
,
(
seq_group
,
for
sgdx
,
(
seq_group
,
sample_result
)
in
enumerate
(
zip
(
seq_groups
,
samples_list
)):
sample_result
)
in
enumerate
(
zip
(
seq_groups
,
samples_list
)):
if
seq_group
.
sampling_params
.
logits_processors
:
assert
len
(
seq_group
.
sampling_params
.
logits_processors
)
==
0
,
(
"Logits Processors are not supported in multi-step decoding"
)
if
do_pythonize_logprobs
:
if
do_pythonize_logprobs
:
assert
prompt_logprobs
is
not
None
assert
prompt_logprobs
is
not
None
...
@@ -621,23 +687,56 @@ def _pythonize_sampler_output(
...
@@ -621,23 +687,56 @@ def _pythonize_sampler_output(
seq_ids
=
seq_group
.
seq_ids
seq_ids
=
seq_group
.
seq_ids
next_token_ids
=
sample_result
next_token_ids
=
sample_result
parent_ids
=
[
0
]
parent_ids
=
[
0
]
seq_outputs
:
List
[
SequenceOutput
]
=
[]
if
seq_group
.
sampling_params
.
logits_processors
:
if
cache
is
not
None
:
assert
len
(
seq_group
.
sampling_params
.
logits_processors
)
==
0
,
(
completion_seq_group_output
:
CompletionSequenceGroupOutput
=
\
"Logits Processors are not supported in multi-step decoding"
)
cache
.
cached_completion_seq_group_output
.
get_object
()
completion_seq_group_output
.
samples
.
clear
()
seq_outputs
:
List
[
SequenceOutput
]
=
completion_seq_group_output
.
samples
else
:
seq_outputs
=
[]
for
tdx
,
(
parent_id
,
for
tdx
,
(
parent_id
,
next_token_id
)
in
enumerate
(
zip
(
parent_ids
,
next_token_ids
)):
next_token_id
)
in
enumerate
(
zip
(
parent_ids
,
next_token_ids
)):
seq_outputs
.
append
(
if
cache
is
not
None
:
SequenceOutput
(
seq_ids
[
parent_id
],
next_token_id
,
seq_output
:
SequenceOutput
=
cache
.
cached_seq_output
.
get_object
(
(
group_sample_logprobs
[
tdx
]
)
if
logprobs_are_requested
else
{
seq_output
.
parent_seq_id
=
seq_ids
[
parent_id
]
next_token_id
:
seq_output
.
output_token
=
next_token_id
Logprob
(
logprob
=
float
(
'inf'
),
rank
=
None
,
if
logprobs_are_requested
:
decoded_token
=
None
)
seq_output
.
logprobs
=
group_sample_logprobs
[
tdx
]
})))
else
:
output
.
outputs
.
append
(
logprobs
=
next
(
iter
(
seq_output
.
logprobs
.
values
()))
CompletionSequenceGroupOutput
(
seq_output
.
logprobs
.
clear
()
seq_outputs
,
(
group_prompt_logprobs
if
logprobs_are_requested
else
None
)))
logprobs
.
logprob
=
float
(
'inf'
)
logprobs
.
rank
=
None
logprobs
.
decoded_token
=
None
seq_output
.
logprobs
[
next_token_id
]
=
logprobs
seq_outputs
.
append
(
seq_output
)
else
:
seq_outputs
.
append
(
SequenceOutput
(
seq_ids
[
parent_id
],
next_token_id
,
(
group_sample_logprobs
[
tdx
]
if
logprobs_are_requested
else
{
next_token_id
:
Logprob
(
logprob
=
float
(
'inf'
),
rank
=
None
,
decoded_token
=
None
)
})))
if
cache
is
not
None
:
completion_seq_group_output
.
prompt_logprobs
=
\
group_prompt_logprobs
if
logprobs_are_requested
else
None
output
.
outputs
.
append
(
completion_seq_group_output
)
else
:
output
.
outputs
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
(
group_prompt_logprobs
if
logprobs_are_requested
else
None
)))
assert
len
(
output
.
outputs
)
>
0
assert
len
(
output
.
outputs
)
>
0
vllm/worker/multi_step_worker.py
View file @
6d646d08
...
@@ -67,9 +67,7 @@ class MultiStepWorker(Worker):
...
@@ -67,9 +67,7 @@ class MultiStepWorker(Worker):
if
execute_model_req
.
async_callback
:
if
execute_model_req
.
async_callback
:
model_input
.
frozen_model_input
=
dataclasses
.
replace
(
# type: ignore
model_input
.
frozen_model_input
=
dataclasses
.
replace
(
# type: ignore
model_input
.
frozen_model_input
,
model_input
.
frozen_model_input
,
async_callback
=
execute_model_req
.
async_callback
,
async_callback
=
execute_model_req
.
async_callback
)
use_async_and_multi_step
=
execute_model_req
.
use_async_and_multi_step
)
else
:
else
:
# on subsequent steps we reuse the worker input and model input
# on subsequent steps we reuse the worker input and model input
multi_step_state
=
self
.
multi_step_states
[
virtual_engine
]
multi_step_state
=
self
.
multi_step_states
[
virtual_engine
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment