Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3f60f224
Unverified
Commit
3f60f224
authored
Aug 29, 2024
by
Alexander Matveev
Committed by
GitHub
Aug 29, 2024
Browse files
[Core] Combine async postprocessor and multi-step (#7921)
parent
f205c098
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
215 additions
and
65 deletions
+215
-65
tests/multi_step/test_correctness_async_llm.py
tests/multi_step/test_correctness_async_llm.py
+6
-4
vllm/core/scheduler.py
vllm/core/scheduler.py
+1
-4
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+44
-21
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+88
-26
vllm/sequence.py
vllm/sequence.py
+3
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+1
-0
vllm/worker/multi_step_model_runner.py
vllm/worker/multi_step_model_runner.py
+64
-9
vllm/worker/multi_step_worker.py
vllm/worker/multi_step_worker.py
+8
-0
No files found.
tests/multi_step/test_correctness_async_llm.py
View file @
3f60f224
...
@@ -47,10 +47,12 @@ async def completions_with_server_args(prompts: List[str], model_name: str,
...
@@ -47,10 +47,12 @@ async def completions_with_server_args(prompts: List[str], model_name: str,
@
pytest
.
mark
.
parametrize
(
"eager_mode"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"eager_mode"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"num_scheduler_steps"
,
NUM_SCHEDULER_STEPS
)
@
pytest
.
mark
.
parametrize
(
"num_scheduler_steps"
,
NUM_SCHEDULER_STEPS
)
@
pytest
.
mark
.
parametrize
(
"num_prompts"
,
NUM_PROMPTS
)
@
pytest
.
mark
.
parametrize
(
"num_prompts"
,
NUM_PROMPTS
)
@
pytest
.
mark
.
parametrize
(
"is_async"
,
[
False
,
True
])
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_multi_step
(
example_prompts
,
model
:
str
,
tp_size
:
int
,
async
def
test_multi_step
(
example_prompts
,
model
:
str
,
tp_size
:
int
,
pp_size
:
int
,
eager_mode
:
int
,
pp_size
:
int
,
eager_mode
:
int
,
num_scheduler_steps
:
int
,
num_prompts
:
int
):
num_scheduler_steps
:
int
,
num_prompts
:
int
,
is_async
:
bool
):
prompts
=
example_prompts
prompts
=
example_prompts
if
len
(
prompts
)
<
num_prompts
:
if
len
(
prompts
)
<
num_prompts
:
...
@@ -62,9 +64,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int,
...
@@ -62,9 +64,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int,
ms_server_args
=
DEFAULT_SERVER_ARGS
+
\
ms_server_args
=
DEFAULT_SERVER_ARGS
+
\
[
"--num-scheduler-steps"
,
f
"
{
num_scheduler_steps
}
"
]
[
"--num-scheduler-steps"
,
f
"
{
num_scheduler_steps
}
"
]
# Disable output proc callback as its not supported
if
not
is_async
:
# with multi-step right now
ms_server_args
+=
[
"--disable-async-output-proc"
]
ms_server_args
+=
[
"--disable-async-output-proc"
]
if
eager_mode
:
if
eager_mode
:
ms_server_args
.
append
(
"--enforce-eager"
)
ms_server_args
.
append
(
"--enforce-eager"
)
...
...
vllm/core/scheduler.py
View file @
3f60f224
...
@@ -1107,10 +1107,7 @@ class Scheduler:
...
@@ -1107,10 +1107,7 @@ class Scheduler:
if
not
self
.
cache_config
.
enable_prefix_caching
:
if
not
self
.
cache_config
.
enable_prefix_caching
:
common_computed_block_nums
=
[]
common_computed_block_nums
=
[]
# TODO: Combine multi-step and async postprocessor
allow_async_output_proc
:
bool
=
self
.
use_async_output_proc
allow_async_output_proc
:
bool
=
(
self
.
use_async_output_proc
and
not
self
.
scheduler_config
.
is_multi_step
)
# Create input data structures.
# Create input data structures.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
...
...
vllm/engine/async_llm_engine.py
View file @
3f60f224
...
@@ -279,6 +279,10 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -279,6 +279,10 @@ class _AsyncLLMEngine(LLMEngine):
scheduler_outputs
=
cached_outputs
.
scheduler_outputs
scheduler_outputs
=
cached_outputs
.
scheduler_outputs
allow_async_output_proc
=
cached_outputs
.
allow_async_output_proc
allow_async_output_proc
=
cached_outputs
.
allow_async_output_proc
# Detect async + multi-step
use_async_and_multi_step
=
(
self
.
scheduler_config
.
is_multi_step
and
allow_async_output_proc
)
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
# skip the scheduler if there are any remaining steps in the seq groups.
# skip the scheduler if there are any remaining steps in the seq groups.
...
@@ -289,17 +293,27 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -289,17 +293,27 @@ class _AsyncLLMEngine(LLMEngine):
# Clear outputs on scheduler iteration start
# Clear outputs on scheduler iteration start
ctx
.
request_outputs
.
clear
()
ctx
.
request_outputs
.
clear
()
# Schedule iteration
(
seq_group_metadata_list
,
scheduler_outputs
,
(
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
allow_async_output_proc
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
# If current scheduler iteration has no async postprocessor,
# Detect async + multi-step
# then we need first to drain the pending async postprocessor
use_async_and_multi_step
=
(
self
.
scheduler_config
.
is_multi_step
# before moving forward
and
allow_async_output_proc
)
# Maybe switch from async mode to sync mode
if
not
allow_async_output_proc
and
len
(
ctx
.
output_queue
)
>
0
:
if
not
allow_async_output_proc
and
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
is_async
=
True
)
# For async + multi-step, init the queue
if
use_async_and_multi_step
:
assert
len
(
ctx
.
output_queue
)
==
0
assert
seq_group_metadata_list
is
not
None
ctx
.
output_queue
.
append
(
(
None
,
seq_group_metadata_list
,
scheduler_outputs
))
if
(
self
.
scheduler_config
.
is_multi_step
if
(
self
.
scheduler_config
.
is_multi_step
and
scheduler_outputs
.
num_lookahead_slots
>
0
):
and
scheduler_outputs
.
num_lookahead_slots
>
0
):
# cache the scheduler outputs for the next iteration if we have
# cache the scheduler outputs for the next iteration if we have
...
@@ -311,9 +325,6 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -311,9 +325,6 @@ class _AsyncLLMEngine(LLMEngine):
assert
seq_group_metadata_list
is
not
None
assert
seq_group_metadata_list
is
not
None
assert
scheduler_outputs
is
not
None
assert
scheduler_outputs
is
not
None
assert
not
(
self
.
scheduler_config
.
is_multi_step
and
\
allow_async_output_proc
)
if
not
scheduler_outputs
.
is_empty
():
if
not
scheduler_outputs
.
is_empty
():
finished_requests_ids
=
self
.
scheduler
[
finished_requests_ids
=
self
.
scheduler
[
virtual_engine
].
get_and_reset_finished_requests_ids
()
virtual_engine
].
get_and_reset_finished_requests_ids
()
...
@@ -339,8 +350,13 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -339,8 +350,13 @@ class _AsyncLLMEngine(LLMEngine):
last_sampled_token_ids
=
last_sampled_token_ids
)
last_sampled_token_ids
=
last_sampled_token_ids
)
if
allow_async_output_proc
:
if
allow_async_output_proc
:
execute_model_req
.
async_callback
=
self
.
async_callback
[
async_callback
=
self
.
async_callback_multi_step
[
virtual_engine
]
virtual_engine
]
if
use_async_and_multi_step
\
else
self
.
async_callback
[
virtual_engine
]
execute_model_req
.
async_callback
=
async_callback
execute_model_req
.
use_async_and_multi_step
=
\
use_async_and_multi_step
# Execute the model.
# Execute the model.
output
=
await
self
.
model_executor
.
execute_model_async
(
output
=
await
self
.
model_executor
.
execute_model_async
(
...
@@ -350,7 +366,7 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -350,7 +366,7 @@ class _AsyncLLMEngine(LLMEngine):
if
self
.
scheduler_config
.
is_multi_step
:
if
self
.
scheduler_config
.
is_multi_step
:
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
)
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
)
else
:
else
:
if
len
(
ctx
.
output_queue
)
>
0
:
if
not
use_async_and_multi_step
and
len
(
ctx
.
output_queue
)
>
0
:
assert
not
self
.
scheduler_config
.
is_multi_step
assert
not
self
.
scheduler_config
.
is_multi_step
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
is_async
=
True
)
...
@@ -362,12 +378,15 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -362,12 +378,15 @@ class _AsyncLLMEngine(LLMEngine):
seq_group
.
finish_step
()
seq_group
.
finish_step
()
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
#
c
lear the cache if we have finished all the steps
#
C
lear the cache if we have finished all the steps
if
self
.
scheduler_config
.
is_multi_step
:
if
self
.
scheduler_config
.
is_multi_step
:
self
.
cached_scheduler_outputs
[
self
.
cached_scheduler_outputs
[
virtual_engine
]
=
SchedulerOutputState
()
virtual_engine
]
=
SchedulerOutputState
()
# Cache results in engine
if
use_async_and_multi_step
:
# For async + multi-step, clear the queue
ctx
.
output_queue
.
clear
()
else
:
ctx
.
output_queue
.
append
(
ctx
.
output_queue
.
append
(
(
output
,
seq_group_metadata_list
,
scheduler_outputs
))
(
output
,
seq_group_metadata_list
,
scheduler_outputs
))
...
@@ -389,6 +408,10 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -389,6 +408,10 @@ class _AsyncLLMEngine(LLMEngine):
# Tracing
# Tracing
self
.
do_tracing
(
scheduler_outputs
)
self
.
do_tracing
(
scheduler_outputs
)
else
:
# Multi-step case
if
use_async_and_multi_step
:
return
[]
else
:
else
:
ctx
.
request_outputs
=
[]
ctx
.
request_outputs
=
[]
...
...
vllm/engine/llm_engine.py
View file @
3f60f224
...
@@ -91,7 +91,8 @@ class SchedulerOutputState:
...
@@ -91,7 +91,8 @@ class SchedulerOutputState:
@
dataclass
@
dataclass
class
SchedulerContext
:
class
SchedulerContext
:
output_queue
:
Deque
[
Tuple
[
List
[
SamplerOutput
],
List
[
SequenceGroupMetadata
],
output_queue
:
Deque
[
Tuple
[
Optional
[
List
[
SamplerOutput
]],
List
[
SequenceGroupMetadata
],
SchedulerOutputs
]]
=
field
(
SchedulerOutputs
]]
=
field
(
default_factory
=
lambda
:
deque
())
default_factory
=
lambda
:
deque
())
...
@@ -432,6 +433,13 @@ class LLMEngine:
...
@@ -432,6 +433,13 @@ class LLMEngine:
for
v_id
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
for
v_id
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
]
self
.
async_callback_multi_step
=
[
functools
.
partial
(
self
.
_process_model_outputs
,
virtual_engine
=
v_id
,
is_async
=
False
)
for
v_id
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
def
_initialize_kv_caches
(
self
)
->
None
:
def
_initialize_kv_caches
(
self
)
->
None
:
"""Initialize the KV cache in the worker(s).
"""Initialize the KV cache in the worker(s).
...
@@ -1240,29 +1248,50 @@ class LLMEngine:
...
@@ -1240,29 +1248,50 @@ class LLMEngine:
return
return
def
_process_model_outputs
(
self
,
virtual_engine
:
int
,
def
_process_model_outputs
(
self
,
is_async
:
bool
)
->
None
:
virtual_engine
:
int
,
is_async
:
bool
,
sampler_output
:
Optional
[
SamplerOutput
]
=
None
,
is_last_output
:
bool
=
False
)
->
None
:
"""Apply the model output to the sequences in the scheduled seq groups.
"""Apply the model output to the sequences in the scheduled seq groups.
virtual_engine: The engine id to operate on
virtual_engine: The engine id to operate on
is_async: Indicates whether this postprocessor runs in
is_async: Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
no tokens need to be appended since it is already done
externally (before the next schedule() call)
externally (before the next schedule() call)
sampler_output: Used with multi-step execution to provide
sampler_output of each step
is_last_output: Used with multi-step execution to indicate
the last step (of each multi-step group)
Returns RequestOutputs that can be returned to the client.
Returns RequestOutputs that can be returned to the client.
"""
"""
now
=
time
.
time
()
now
=
time
.
time
()
is_multi_step
=
sampler_output
is
not
None
ctx
:
SchedulerContext
=
self
.
scheduler_contexts
[
virtual_engine
]
ctx
:
SchedulerContext
=
self
.
scheduler_contexts
[
virtual_engine
]
if
len
(
ctx
.
output_queue
)
==
0
:
if
len
(
ctx
.
output_queue
)
==
0
:
return
None
return
None
if
is_multi_step
:
# Async + multi-step case
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
)
=
ctx
.
output_queue
[
0
]
assert
outputs
is
None
outputs
=
[
sampler_output
]
else
:
# Async standard case
(
outputs
,
seq_group_metadata_list
,
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
)
=
ctx
.
output_queue
.
popleft
()
scheduler_outputs
)
=
ctx
.
output_queue
.
popleft
()
assert
outputs
is
not
None
# Sanity check
# Sanity check
assert
len
(
seq_group_metadata_list
)
==
len
(
assert
len
(
seq_group_metadata_list
)
==
len
(
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
)
...
@@ -1320,7 +1349,11 @@ class LLMEngine:
...
@@ -1320,7 +1349,11 @@ class LLMEngine:
self
.
output_processor
.
process_outputs
(
seq_group
,
output
,
self
.
output_processor
.
process_outputs
(
seq_group
,
output
,
is_async
)
is_async
)
# Free the finished sequence groups.
# For async + multi-step, free finished seqs and create outputs
# only on the final step.
if
is_multi_step
and
not
is_last_output
:
return
for
scheduler
in
self
.
scheduler
:
for
scheduler
in
self
.
scheduler
:
scheduler
.
free_finished_seq_groups
()
scheduler
.
free_finished_seq_groups
()
...
@@ -1328,7 +1361,7 @@ class LLMEngine:
...
@@ -1328,7 +1361,7 @@ class LLMEngine:
for
i
,
_
in
enumerate
(
seq_group_metadata_list
):
for
i
,
_
in
enumerate
(
seq_group_metadata_list
):
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
if
i
in
finished_before
:
if
not
is_multi_step
and
i
in
finished_before
:
continue
# Avoids double processing
continue
# Avoids double processing
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
...
@@ -1342,7 +1375,11 @@ class LLMEngine:
...
@@ -1342,7 +1375,11 @@ class LLMEngine:
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
ctx
.
request_outputs
.
append
(
request_output
)
ctx
.
request_outputs
.
append
(
request_output
)
if
is_async
:
# For async + multi-step, do stats only on the last output.
# Otherwise, do stats if the execution is async
do_stats
=
is_multi_step
or
is_async
if
do_stats
:
# Log stats.
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
outputs
,
finished_before
)
self
.
do_log_stats
(
scheduler_outputs
,
outputs
,
finished_before
)
...
@@ -1437,7 +1474,7 @@ class LLMEngine:
...
@@ -1437,7 +1474,7 @@ class LLMEngine:
"as performance will be severely degraded otherwise."
)
"as performance will be severely degraded otherwise."
)
# For llm_engine, there is no pipeline parallel support, so the engine
# For llm_engine, there is no pipeline parallel support, so the engine
# used is always 0
# used is always 0
.
virtual_engine
=
0
virtual_engine
=
0
# These are cached outputs from previous iterations. None if on first
# These are cached outputs from previous iterations. None if on first
...
@@ -1447,6 +1484,10 @@ class LLMEngine:
...
@@ -1447,6 +1484,10 @@ class LLMEngine:
scheduler_outputs
=
cached_outputs
.
scheduler_outputs
scheduler_outputs
=
cached_outputs
.
scheduler_outputs
allow_async_output_proc
=
cached_outputs
.
allow_async_output_proc
allow_async_output_proc
=
cached_outputs
.
allow_async_output_proc
# Detect async + multi-step
use_async_and_multi_step
=
(
self
.
scheduler_config
.
is_multi_step
and
allow_async_output_proc
)
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
# Skip the scheduler if there are any remaining steps in the seq groups.
# Skip the scheduler if there are any remaining steps in the seq groups.
...
@@ -1462,11 +1503,22 @@ class LLMEngine:
...
@@ -1462,11 +1503,22 @@ class LLMEngine:
allow_async_output_proc
allow_async_output_proc
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
# Detect async + multi-step
use_async_and_multi_step
=
(
self
.
scheduler_config
.
is_multi_step
and
allow_async_output_proc
)
# Maybe switch from async mode to sync mode
# Maybe switch from async mode to sync mode
if
not
allow_async_output_proc
and
len
(
ctx
.
output_queue
)
>
0
:
if
not
allow_async_output_proc
and
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
is_async
=
True
)
# For async + multi-step, init the queue
if
use_async_and_multi_step
:
assert
len
(
ctx
.
output_queue
)
==
0
assert
seq_group_metadata_list
is
not
None
ctx
.
output_queue
.
append
(
(
None
,
seq_group_metadata_list
,
scheduler_outputs
))
if
(
self
.
scheduler_config
.
is_multi_step
if
(
self
.
scheduler_config
.
is_multi_step
and
scheduler_outputs
.
num_lookahead_slots
>
0
):
and
scheduler_outputs
.
num_lookahead_slots
>
0
):
# cache the scheduler outputs for the next iteration if we have
# cache the scheduler outputs for the next iteration if we have
...
@@ -1478,9 +1530,6 @@ class LLMEngine:
...
@@ -1478,9 +1530,6 @@ class LLMEngine:
assert
seq_group_metadata_list
is
not
None
assert
seq_group_metadata_list
is
not
None
assert
scheduler_outputs
is
not
None
assert
scheduler_outputs
is
not
None
assert
not
(
self
.
scheduler_config
.
is_multi_step
and
\
allow_async_output_proc
)
if
not
scheduler_outputs
.
is_empty
():
if
not
scheduler_outputs
.
is_empty
():
finished_requests_ids
=
self
.
scheduler
[
finished_requests_ids
=
self
.
scheduler
[
virtual_engine
].
get_and_reset_finished_requests_ids
()
virtual_engine
].
get_and_reset_finished_requests_ids
()
...
@@ -1505,8 +1554,13 @@ class LLMEngine:
...
@@ -1505,8 +1554,13 @@ class LLMEngine:
last_sampled_token_ids
=
last_sampled_token_ids
)
last_sampled_token_ids
=
last_sampled_token_ids
)
if
allow_async_output_proc
:
if
allow_async_output_proc
:
execute_model_req
.
async_callback
=
self
.
async_callback
[
async_callback
=
self
.
async_callback_multi_step
[
virtual_engine
]
virtual_engine
]
if
use_async_and_multi_step
\
else
self
.
async_callback
[
virtual_engine
]
execute_model_req
.
async_callback
=
async_callback
execute_model_req
.
use_async_and_multi_step
=
\
use_async_and_multi_step
output
=
self
.
model_executor
.
execute_model
(
output
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
execute_model_req
=
execute_model_req
)
...
@@ -1518,7 +1572,7 @@ class LLMEngine:
...
@@ -1518,7 +1572,7 @@ class LLMEngine:
else
:
else
:
# Nothing scheduled => If there is pending async postprocessor,
# Nothing scheduled => If there is pending async postprocessor,
# then finish it here.
# then finish it here.
if
len
(
ctx
.
output_queue
)
>
0
:
if
not
use_async_and_multi_step
and
len
(
ctx
.
output_queue
)
>
0
:
assert
not
self
.
scheduler_config
.
is_multi_step
assert
not
self
.
scheduler_config
.
is_multi_step
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
is_async
=
True
)
...
@@ -1535,13 +1589,18 @@ class LLMEngine:
...
@@ -1535,13 +1589,18 @@ class LLMEngine:
if
self
.
scheduler_config
.
is_multi_step
:
if
self
.
scheduler_config
.
is_multi_step
:
self
.
cached_scheduler_outputs
[
0
]
=
SchedulerOutputState
()
self
.
cached_scheduler_outputs
[
0
]
=
SchedulerOutputState
()
if
use_async_and_multi_step
:
# For async + multi-step, clear the queue
ctx
.
output_queue
.
clear
()
else
:
# Add results to the output_queue
# Add results to the output_queue
# (for async or non-async postprocessing)
# (for async or non-async postprocessing)
ctx
.
output_queue
.
append
(
ctx
.
output_queue
.
append
(
(
output
,
seq_group_metadata_list
,
scheduler_outputs
))
(
output
,
seq_group_metadata_list
,
scheduler_outputs
))
if
output
and
allow_async_output_proc
:
if
output
and
allow_async_output_proc
:
assert
len
(
output
)
==
1
,
(
"Multi step decoding does not work "
assert
len
(
output
)
==
1
,
(
"Multi step decoding does not work "
"with async output processing."
)
"with async output processing."
)
self
.
_advance_to_next_step
(
self
.
_advance_to_next_step
(
...
@@ -1560,6 +1619,9 @@ class LLMEngine:
...
@@ -1560,6 +1619,9 @@ class LLMEngine:
self
.
do_tracing
(
scheduler_outputs
)
self
.
do_tracing
(
scheduler_outputs
)
else
:
else
:
# Multi-step case
# Multi-step case
if
use_async_and_multi_step
:
return
[]
else
:
ctx
.
request_outputs
=
[]
ctx
.
request_outputs
=
[]
if
not
self
.
has_unfinished_requests
():
if
not
self
.
has_unfinished_requests
():
...
...
vllm/sequence.py
View file @
3f60f224
...
@@ -1295,6 +1295,7 @@ class ExecuteModelRequest(
...
@@ -1295,6 +1295,7 @@ class ExecuteModelRequest(
last_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
last_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
# Async callback
# Async callback
async_callback
:
Optional
[
Callable
]
=
None
async_callback
:
Optional
[
Callable
]
=
None
use_async_and_multi_step
:
bool
=
False
@
property
@
property
def
is_first_multi_step
(
self
)
->
bool
:
def
is_first_multi_step
(
self
)
->
bool
:
...
@@ -1341,4 +1342,5 @@ class ExecuteModelRequest(
...
@@ -1341,4 +1342,5 @@ class ExecuteModelRequest(
finished_requests_ids
=
self
.
finished_requests_ids
,
finished_requests_ids
=
self
.
finished_requests_ids
,
last_sampled_token_ids
=
self
.
last_sampled_token_ids
.
clone
()
last_sampled_token_ids
=
self
.
last_sampled_token_ids
.
clone
()
if
self
.
last_sampled_token_ids
is
not
None
else
None
,
if
self
.
last_sampled_token_ids
is
not
None
else
None
,
async_callback
=
self
.
async_callback
)
async_callback
=
self
.
async_callback
,
use_async_and_multi_step
=
self
.
use_async_and_multi_step
)
vllm/worker/model_runner.py
View file @
3f60f224
...
@@ -92,6 +92,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
...
@@ -92,6 +92,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
virtual_engine
:
int
=
0
virtual_engine
:
int
=
0
async_callback
:
Optional
[
Callable
]
=
None
async_callback
:
Optional
[
Callable
]
=
None
use_async_and_multi_step
:
bool
=
False
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
tensor_dict
=
{
...
...
vllm/worker/multi_step_model_runner.py
View file @
3f60f224
import
dataclasses
import
functools
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
try
:
try
:
from
vllm.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.attention.backends.flash_attn
import
FlashAttentionMetadata
...
@@ -215,6 +217,46 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -215,6 +217,46 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
)
)
return
model_input
return
model_input
def
_async_process_outputs
(
self
,
model_input
:
StatefulModelInput
,
output_proc_callback
:
Callable
):
# Proceed with pythonization and output_proc in order.
# Stop on the first one that fails to pythonize
cont
=
True
for
model_output
in
model_input
.
cached_outputs
:
if
not
model_output
.
pythonized
:
model_output
.
maybe_pythonize
(
model_input
,
self
.
_copy_stream
,
self
.
pinned_sampled_token_ids
)
if
model_output
.
pythonized
:
output_proc_callback
(
sampler_output
=
model_output
.
sampler_output
)
else
:
cont
=
False
if
not
cont
:
break
def
_final_process_outputs
(
self
,
model_input
:
StatefulModelInput
,
output_proc_callback
:
Optional
[
Callable
]):
assert
model_input
.
frozen_model_input
is
not
None
outputs
=
[]
for
output_id
in
range
(
len
(
model_input
.
cached_outputs
)):
is_last_output
=
output_id
==
len
(
model_input
.
cached_outputs
)
-
1
output
=
model_input
.
cached_outputs
[
output_id
]
if
not
output
.
pythonized
:
output
.
pythonize
(
model_input
,
self
.
_copy_stream
,
self
.
pinned_sampled_token_ids
)
if
model_input
.
frozen_model_input
.
use_async_and_multi_step
:
assert
output_proc_callback
is
not
None
output_proc_callback
(
sampler_output
=
output
.
sampler_output
,
is_last_output
=
is_last_output
)
outputs
.
append
(
output
.
sampler_output
)
return
outputs
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_model
(
def
execute_model
(
self
,
self
,
...
@@ -271,6 +313,20 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -271,6 +313,20 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
model_input
=
self
.
_advance_step
(
model_input
=
self
.
_advance_step
(
model_input
,
model_input
.
cached_outputs
[
-
1
].
sampler_output
)
model_input
,
model_input
.
cached_outputs
[
-
1
].
sampler_output
)
output_proc_callback
=
None
if
frozen_model_input
.
use_async_and_multi_step
:
output_proc_callback
=
frozen_model_input
.
async_callback
assert
output_proc_callback
is
not
None
async_callback
=
functools
.
partial
(
self
.
_async_process_outputs
,
model_input
=
model_input
,
output_proc_callback
=
output_proc_callback
)
frozen_model_input
=
dataclasses
.
replace
(
# type: ignore
model_input
.
frozen_model_input
,
async_callback
=
async_callback
)
assert
frozen_model_input
is
not
None
# Execute the model
# Execute the model
output
=
self
.
_base_model_runner
.
execute_model
(
frozen_model_input
,
output
=
self
.
_base_model_runner
.
execute_model
(
frozen_model_input
,
kv_caches
,
kv_caches
,
...
@@ -301,8 +357,10 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -301,8 +357,10 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
output
[
0
].
logprobs
=
None
output
[
0
].
logprobs
=
None
# Pythonize the output if CPU is ahead and the previous step is
# Pythonize the output if CPU is ahead and the previous step is
# ready.
# ready.
if
not
frozen_model_input
.
use_async_and_multi_step
:
for
model_output
in
model_input
.
cached_outputs
:
for
model_output
in
model_input
.
cached_outputs
:
model_output
.
maybe_pythonize
(
model_input
,
self
.
_copy_stream
,
model_output
.
maybe_pythonize
(
model_input
,
self
.
_copy_stream
,
self
.
pinned_sampled_token_ids
)
self
.
pinned_sampled_token_ids
)
model_input
.
current_step
+=
1
model_input
.
current_step
+=
1
...
@@ -316,11 +374,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -316,11 +374,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# Pythonize the output and block if needed since it is the last step
# Pythonize the output and block if needed since it is the last step
if
model_input
.
is_last_step
:
if
model_input
.
is_last_step
:
outputs
=
[]
outputs
=
self
.
_final_process_outputs
(
model_input
,
for
output
in
model_input
.
cached_outputs
:
output_proc_callback
)
output
.
pythonize
(
model_input
,
self
.
_copy_stream
,
self
.
pinned_sampled_token_ids
)
outputs
.
append
(
output
.
sampler_output
)
return
outputs
return
outputs
# should be [SamplerOutput]
# should be [SamplerOutput]
...
...
vllm/worker/multi_step_worker.py
View file @
3f60f224
import
dataclasses
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
...
@@ -61,6 +62,13 @@ class MultiStepWorker(Worker):
...
@@ -61,6 +62,13 @@ class MultiStepWorker(Worker):
execute_model_req
.
seq_group_metadata_list
,
execute_model_req
.
seq_group_metadata_list
,
execute_model_req
.
virtual_engine
,
execute_model_req
.
virtual_engine
,
execute_model_req
.
finished_requests_ids
))
execute_model_req
.
finished_requests_ids
))
if
execute_model_req
.
async_callback
:
model_input
.
frozen_model_input
=
dataclasses
.
replace
(
# type: ignore
model_input
.
frozen_model_input
,
async_callback
=
execute_model_req
.
async_callback
,
use_async_and_multi_step
=
execute_model_req
.
use_async_and_multi_step
)
else
:
else
:
# on subsequent steps we reuse the worker input and model input
# on subsequent steps we reuse the worker input and model input
multi_step_state
=
self
.
multi_step_states
[
virtual_engine
]
multi_step_state
=
self
.
multi_step_states
[
virtual_engine
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment