Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3f60f224
Unverified
Commit
3f60f224
authored
Aug 29, 2024
by
Alexander Matveev
Committed by
GitHub
Aug 29, 2024
Browse files
[Core] Combine async postprocessor and multi-step (#7921)
parent
f205c098
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
215 additions
and
65 deletions
+215
-65
tests/multi_step/test_correctness_async_llm.py
tests/multi_step/test_correctness_async_llm.py
+6
-4
vllm/core/scheduler.py
vllm/core/scheduler.py
+1
-4
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+44
-21
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+88
-26
vllm/sequence.py
vllm/sequence.py
+3
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+1
-0
vllm/worker/multi_step_model_runner.py
vllm/worker/multi_step_model_runner.py
+64
-9
vllm/worker/multi_step_worker.py
vllm/worker/multi_step_worker.py
+8
-0
No files found.
tests/multi_step/test_correctness_async_llm.py
View file @
3f60f224
...
...
@@ -47,10 +47,12 @@ async def completions_with_server_args(prompts: List[str], model_name: str,
@
pytest
.
mark
.
parametrize
(
"eager_mode"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"num_scheduler_steps"
,
NUM_SCHEDULER_STEPS
)
@
pytest
.
mark
.
parametrize
(
"num_prompts"
,
NUM_PROMPTS
)
@
pytest
.
mark
.
parametrize
(
"is_async"
,
[
False
,
True
])
@
pytest
.
mark
.
asyncio
async
def
test_multi_step
(
example_prompts
,
model
:
str
,
tp_size
:
int
,
pp_size
:
int
,
eager_mode
:
int
,
num_scheduler_steps
:
int
,
num_prompts
:
int
):
num_scheduler_steps
:
int
,
num_prompts
:
int
,
is_async
:
bool
):
prompts
=
example_prompts
if
len
(
prompts
)
<
num_prompts
:
...
...
@@ -62,9 +64,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int,
ms_server_args
=
DEFAULT_SERVER_ARGS
+
\
[
"--num-scheduler-steps"
,
f
"
{
num_scheduler_steps
}
"
]
# Disable output proc callback as its not supported
# with multi-step right now
ms_server_args
+=
[
"--disable-async-output-proc"
]
if
not
is_async
:
ms_server_args
+=
[
"--disable-async-output-proc"
]
if
eager_mode
:
ms_server_args
.
append
(
"--enforce-eager"
)
...
...
vllm/core/scheduler.py
View file @
3f60f224
...
...
@@ -1107,10 +1107,7 @@ class Scheduler:
if
not
self
.
cache_config
.
enable_prefix_caching
:
common_computed_block_nums
=
[]
# TODO: Combine multi-step and async postprocessor
allow_async_output_proc
:
bool
=
(
self
.
use_async_output_proc
and
not
self
.
scheduler_config
.
is_multi_step
)
allow_async_output_proc
:
bool
=
self
.
use_async_output_proc
# Create input data structures.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
...
...
vllm/engine/async_llm_engine.py
View file @
3f60f224
...
...
@@ -279,6 +279,10 @@ class _AsyncLLMEngine(LLMEngine):
scheduler_outputs
=
cached_outputs
.
scheduler_outputs
allow_async_output_proc
=
cached_outputs
.
allow_async_output_proc
# Detect async + multi-step
use_async_and_multi_step
=
(
self
.
scheduler_config
.
is_multi_step
and
allow_async_output_proc
)
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
# skip the scheduler if there are any remaining steps in the seq groups.
...
...
@@ -289,17 +293,27 @@ class _AsyncLLMEngine(LLMEngine):
# Clear outputs on scheduler iteration start
ctx
.
request_outputs
.
clear
()
# Schedule iteration
(
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
# If current scheduler iteration has no async postprocessor,
# then we need first to drain the pending async postprocessor
# before moving forward
# Detect async + multi-step
use_async_and_multi_step
=
(
self
.
scheduler_config
.
is_multi_step
and
allow_async_output_proc
)
# Maybe switch from async mode to sync mode
if
not
allow_async_output_proc
and
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
# For async + multi-step, init the queue
if
use_async_and_multi_step
:
assert
len
(
ctx
.
output_queue
)
==
0
assert
seq_group_metadata_list
is
not
None
ctx
.
output_queue
.
append
(
(
None
,
seq_group_metadata_list
,
scheduler_outputs
))
if
(
self
.
scheduler_config
.
is_multi_step
and
scheduler_outputs
.
num_lookahead_slots
>
0
):
# cache the scheduler outputs for the next iteration if we have
...
...
@@ -311,9 +325,6 @@ class _AsyncLLMEngine(LLMEngine):
assert
seq_group_metadata_list
is
not
None
assert
scheduler_outputs
is
not
None
assert
not
(
self
.
scheduler_config
.
is_multi_step
and
\
allow_async_output_proc
)
if
not
scheduler_outputs
.
is_empty
():
finished_requests_ids
=
self
.
scheduler
[
virtual_engine
].
get_and_reset_finished_requests_ids
()
...
...
@@ -339,8 +350,13 @@ class _AsyncLLMEngine(LLMEngine):
last_sampled_token_ids
=
last_sampled_token_ids
)
if
allow_async_output_proc
:
execute_model_req
.
async_callback
=
self
.
async_callback
[
virtual_engine
]
async_callback
=
self
.
async_callback_multi_step
[
virtual_engine
]
if
use_async_and_multi_step
\
else
self
.
async_callback
[
virtual_engine
]
execute_model_req
.
async_callback
=
async_callback
execute_model_req
.
use_async_and_multi_step
=
\
use_async_and_multi_step
# Execute the model.
output
=
await
self
.
model_executor
.
execute_model_async
(
...
...
@@ -350,7 +366,7 @@ class _AsyncLLMEngine(LLMEngine):
if
self
.
scheduler_config
.
is_multi_step
:
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
)
else
:
if
len
(
ctx
.
output_queue
)
>
0
:
if
not
use_async_and_multi_step
and
len
(
ctx
.
output_queue
)
>
0
:
assert
not
self
.
scheduler_config
.
is_multi_step
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
...
...
@@ -362,22 +378,25 @@ class _AsyncLLMEngine(LLMEngine):
seq_group
.
finish_step
()
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
#
c
lear the cache if we have finished all the steps
#
C
lear the cache if we have finished all the steps
if
self
.
scheduler_config
.
is_multi_step
:
self
.
cached_scheduler_outputs
[
virtual_engine
]
=
SchedulerOutputState
()
# Cache results in engine
ctx
.
output_queue
.
append
(
(
output
,
seq_group_metadata_list
,
scheduler_outputs
))
if
use_async_and_multi_step
:
# For async + multi-step, clear the queue
ctx
.
output_queue
.
clear
()
else
:
ctx
.
output_queue
.
append
(
(
output
,
seq_group_metadata_list
,
scheduler_outputs
))
if
output
and
allow_async_output_proc
:
assert
len
(
output
)
==
1
,
"Multi step decoding does not work with async output processing."
# noqa: E501
self
.
_advance_to_next_step
(
output
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
if
output
and
allow_async_output_proc
:
assert
len
(
output
)
==
1
,
"Multi step decoding does not work with async output processing."
# noqa: E501
self
.
_advance_to_next_step
(
output
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
if
not
allow_async_output_proc
:
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
...
...
@@ -390,7 +409,11 @@ class _AsyncLLMEngine(LLMEngine):
self
.
do_tracing
(
scheduler_outputs
)
else
:
ctx
.
request_outputs
=
[]
# Multi-step case
if
use_async_and_multi_step
:
return
[]
else
:
ctx
.
request_outputs
=
[]
if
not
self
.
has_unfinished_requests
():
# Drain async postprocessor (if exists)
...
...
vllm/engine/llm_engine.py
View file @
3f60f224
...
...
@@ -91,7 +91,8 @@ class SchedulerOutputState:
@
dataclass
class
SchedulerContext
:
output_queue
:
Deque
[
Tuple
[
List
[
SamplerOutput
],
List
[
SequenceGroupMetadata
],
output_queue
:
Deque
[
Tuple
[
Optional
[
List
[
SamplerOutput
]],
List
[
SequenceGroupMetadata
],
SchedulerOutputs
]]
=
field
(
default_factory
=
lambda
:
deque
())
...
...
@@ -432,6 +433,13 @@ class LLMEngine:
for
v_id
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
self
.
async_callback_multi_step
=
[
functools
.
partial
(
self
.
_process_model_outputs
,
virtual_engine
=
v_id
,
is_async
=
False
)
for
v_id
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
def
_initialize_kv_caches
(
self
)
->
None
:
"""Initialize the KV cache in the worker(s).
...
...
@@ -1240,28 +1248,49 @@ class LLMEngine:
return
def
_process_model_outputs
(
self
,
virtual_engine
:
int
,
is_async
:
bool
)
->
None
:
def
_process_model_outputs
(
self
,
virtual_engine
:
int
,
is_async
:
bool
,
sampler_output
:
Optional
[
SamplerOutput
]
=
None
,
is_last_output
:
bool
=
False
)
->
None
:
"""Apply the model output to the sequences in the scheduled seq groups.
virtual_engine: The engine id to operate on
is_async: Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
sampler_output: Used with multi-step execution to provide
sampler_output of each step
is_last_output: Used with multi-step execution to indicate
the last step (of each multi-step group)
Returns RequestOutputs that can be returned to the client.
"""
now
=
time
.
time
()
is_multi_step
=
sampler_output
is
not
None
ctx
:
SchedulerContext
=
self
.
scheduler_contexts
[
virtual_engine
]
if
len
(
ctx
.
output_queue
)
==
0
:
return
None
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
)
=
ctx
.
output_queue
.
popleft
()
if
is_multi_step
:
# Async + multi-step case
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
)
=
ctx
.
output_queue
[
0
]
assert
outputs
is
None
outputs
=
[
sampler_output
]
else
:
# Async standard case
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
)
=
ctx
.
output_queue
.
popleft
()
assert
outputs
is
not
None
# Sanity check
assert
len
(
seq_group_metadata_list
)
==
len
(
...
...
@@ -1320,7 +1349,11 @@ class LLMEngine:
self
.
output_processor
.
process_outputs
(
seq_group
,
output
,
is_async
)
# Free the finished sequence groups.
# For async + multi-step, free finished seqs and create outputs
# only on the final step.
if
is_multi_step
and
not
is_last_output
:
return
for
scheduler
in
self
.
scheduler
:
scheduler
.
free_finished_seq_groups
()
...
...
@@ -1328,7 +1361,7 @@ class LLMEngine:
for
i
,
_
in
enumerate
(
seq_group_metadata_list
):
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
if
i
in
finished_before
:
if
not
is_multi_step
and
i
in
finished_before
:
continue
# Avoids double processing
seq_group
=
scheduled_seq_group
.
seq_group
...
...
@@ -1342,7 +1375,11 @@ class LLMEngine:
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
ctx
.
request_outputs
.
append
(
request_output
)
if
is_async
:
# For async + multi-step, do stats only on the last output.
# Otherwise, do stats if the execution is async
do_stats
=
is_multi_step
or
is_async
if
do_stats
:
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
outputs
,
finished_before
)
...
...
@@ -1437,7 +1474,7 @@ class LLMEngine:
"as performance will be severely degraded otherwise."
)
# For llm_engine, there is no pipeline parallel support, so the engine
# used is always 0
# used is always 0
.
virtual_engine
=
0
# These are cached outputs from previous iterations. None if on first
...
...
@@ -1447,6 +1484,10 @@ class LLMEngine:
scheduler_outputs
=
cached_outputs
.
scheduler_outputs
allow_async_output_proc
=
cached_outputs
.
allow_async_output_proc
# Detect async + multi-step
use_async_and_multi_step
=
(
self
.
scheduler_config
.
is_multi_step
and
allow_async_output_proc
)
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
# Skip the scheduler if there are any remaining steps in the seq groups.
...
...
@@ -1462,11 +1503,22 @@ class LLMEngine:
allow_async_output_proc
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
# Detect async + multi-step
use_async_and_multi_step
=
(
self
.
scheduler_config
.
is_multi_step
and
allow_async_output_proc
)
# Maybe switch from async mode to sync mode
if
not
allow_async_output_proc
and
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
# For async + multi-step, init the queue
if
use_async_and_multi_step
:
assert
len
(
ctx
.
output_queue
)
==
0
assert
seq_group_metadata_list
is
not
None
ctx
.
output_queue
.
append
(
(
None
,
seq_group_metadata_list
,
scheduler_outputs
))
if
(
self
.
scheduler_config
.
is_multi_step
and
scheduler_outputs
.
num_lookahead_slots
>
0
):
# cache the scheduler outputs for the next iteration if we have
...
...
@@ -1478,9 +1530,6 @@ class LLMEngine:
assert
seq_group_metadata_list
is
not
None
assert
scheduler_outputs
is
not
None
assert
not
(
self
.
scheduler_config
.
is_multi_step
and
\
allow_async_output_proc
)
if
not
scheduler_outputs
.
is_empty
():
finished_requests_ids
=
self
.
scheduler
[
virtual_engine
].
get_and_reset_finished_requests_ids
()
...
...
@@ -1505,8 +1554,13 @@ class LLMEngine:
last_sampled_token_ids
=
last_sampled_token_ids
)
if
allow_async_output_proc
:
execute_model_req
.
async_callback
=
self
.
async_callback
[
virtual_engine
]
async_callback
=
self
.
async_callback_multi_step
[
virtual_engine
]
if
use_async_and_multi_step
\
else
self
.
async_callback
[
virtual_engine
]
execute_model_req
.
async_callback
=
async_callback
execute_model_req
.
use_async_and_multi_step
=
\
use_async_and_multi_step
output
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
...
...
@@ -1518,7 +1572,7 @@ class LLMEngine:
else
:
# Nothing scheduled => If there is pending async postprocessor,
# then finish it here.
if
len
(
ctx
.
output_queue
)
>
0
:
if
not
use_async_and_multi_step
and
len
(
ctx
.
output_queue
)
>
0
:
assert
not
self
.
scheduler_config
.
is_multi_step
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
...
...
@@ -1535,18 +1589,23 @@ class LLMEngine:
if
self
.
scheduler_config
.
is_multi_step
:
self
.
cached_scheduler_outputs
[
0
]
=
SchedulerOutputState
()
# Add results to the output_queue
# (for async or non-async postprocessing)
ctx
.
output_queue
.
append
(
(
output
,
seq_group_metadata_list
,
scheduler_outputs
))
if
use_async_and_multi_step
:
# For async + multi-step, clear the queue
ctx
.
output_queue
.
clear
()
else
:
# Add results to the output_queue
# (for async or non-async postprocessing)
ctx
.
output_queue
.
append
(
(
output
,
seq_group_metadata_list
,
scheduler_outputs
))
if
output
and
allow_async_output_proc
:
assert
len
(
output
)
==
1
,
(
"Multi step decoding does not work "
"with async output processing."
)
if
output
and
allow_async_output_proc
:
assert
len
(
output
)
==
1
,
(
"Multi step decoding does not work "
"with async output processing."
)
self
.
_advance_to_next_step
(
output
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
self
.
_advance_to_next_step
(
output
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
# Check if need to run the usual non-async path
if
not
allow_async_output_proc
:
...
...
@@ -1560,7 +1619,10 @@ class LLMEngine:
self
.
do_tracing
(
scheduler_outputs
)
else
:
# Multi-step case
ctx
.
request_outputs
=
[]
if
use_async_and_multi_step
:
return
[]
else
:
ctx
.
request_outputs
=
[]
if
not
self
.
has_unfinished_requests
():
# Drain async postprocessor (if exists)
...
...
vllm/sequence.py
View file @
3f60f224
...
...
@@ -1295,6 +1295,7 @@ class ExecuteModelRequest(
last_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
# Async callback
async_callback
:
Optional
[
Callable
]
=
None
use_async_and_multi_step
:
bool
=
False
@
property
def
is_first_multi_step
(
self
)
->
bool
:
...
...
@@ -1341,4 +1342,5 @@ class ExecuteModelRequest(
finished_requests_ids
=
self
.
finished_requests_ids
,
last_sampled_token_ids
=
self
.
last_sampled_token_ids
.
clone
()
if
self
.
last_sampled_token_ids
is
not
None
else
None
,
async_callback
=
self
.
async_callback
)
async_callback
=
self
.
async_callback
,
use_async_and_multi_step
=
self
.
use_async_and_multi_step
)
vllm/worker/model_runner.py
View file @
3f60f224
...
...
@@ -92,6 +92,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
virtual_engine
:
int
=
0
async_callback
:
Optional
[
Callable
]
=
None
use_async_and_multi_step
:
bool
=
False
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
...
...
vllm/worker/multi_step_model_runner.py
View file @
3f60f224
import
dataclasses
import
functools
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
try
:
from
vllm.attention.backends.flash_attn
import
FlashAttentionMetadata
...
...
@@ -215,6 +217,46 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
)
return
model_input
def
_async_process_outputs
(
self
,
model_input
:
StatefulModelInput
,
output_proc_callback
:
Callable
):
# Proceed with pythonization and output_proc in order.
# Stop on the first one that fails to pythonize
cont
=
True
for
model_output
in
model_input
.
cached_outputs
:
if
not
model_output
.
pythonized
:
model_output
.
maybe_pythonize
(
model_input
,
self
.
_copy_stream
,
self
.
pinned_sampled_token_ids
)
if
model_output
.
pythonized
:
output_proc_callback
(
sampler_output
=
model_output
.
sampler_output
)
else
:
cont
=
False
if
not
cont
:
break
def
_final_process_outputs
(
self
,
model_input
:
StatefulModelInput
,
output_proc_callback
:
Optional
[
Callable
]):
assert
model_input
.
frozen_model_input
is
not
None
outputs
=
[]
for
output_id
in
range
(
len
(
model_input
.
cached_outputs
)):
is_last_output
=
output_id
==
len
(
model_input
.
cached_outputs
)
-
1
output
=
model_input
.
cached_outputs
[
output_id
]
if
not
output
.
pythonized
:
output
.
pythonize
(
model_input
,
self
.
_copy_stream
,
self
.
pinned_sampled_token_ids
)
if
model_input
.
frozen_model_input
.
use_async_and_multi_step
:
assert
output_proc_callback
is
not
None
output_proc_callback
(
sampler_output
=
output
.
sampler_output
,
is_last_output
=
is_last_output
)
outputs
.
append
(
output
.
sampler_output
)
return
outputs
@
torch
.
inference_mode
()
def
execute_model
(
self
,
...
...
@@ -271,6 +313,20 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
model_input
=
self
.
_advance_step
(
model_input
,
model_input
.
cached_outputs
[
-
1
].
sampler_output
)
output_proc_callback
=
None
if
frozen_model_input
.
use_async_and_multi_step
:
output_proc_callback
=
frozen_model_input
.
async_callback
assert
output_proc_callback
is
not
None
async_callback
=
functools
.
partial
(
self
.
_async_process_outputs
,
model_input
=
model_input
,
output_proc_callback
=
output_proc_callback
)
frozen_model_input
=
dataclasses
.
replace
(
# type: ignore
model_input
.
frozen_model_input
,
async_callback
=
async_callback
)
assert
frozen_model_input
is
not
None
# Execute the model
output
=
self
.
_base_model_runner
.
execute_model
(
frozen_model_input
,
kv_caches
,
...
...
@@ -301,9 +357,11 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
output
[
0
].
logprobs
=
None
# Pythonize the output if CPU is ahead and the previous step is
# ready.
for
model_output
in
model_input
.
cached_outputs
:
model_output
.
maybe_pythonize
(
model_input
,
self
.
_copy_stream
,
self
.
pinned_sampled_token_ids
)
if
not
frozen_model_input
.
use_async_and_multi_step
:
for
model_output
in
model_input
.
cached_outputs
:
model_output
.
maybe_pythonize
(
model_input
,
self
.
_copy_stream
,
self
.
pinned_sampled_token_ids
)
model_input
.
current_step
+=
1
...
...
@@ -316,11 +374,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# Pythonize the output and block if needed since it is the last step
if
model_input
.
is_last_step
:
outputs
=
[]
for
output
in
model_input
.
cached_outputs
:
output
.
pythonize
(
model_input
,
self
.
_copy_stream
,
self
.
pinned_sampled_token_ids
)
outputs
.
append
(
output
.
sampler_output
)
outputs
=
self
.
_final_process_outputs
(
model_input
,
output_proc_callback
)
return
outputs
# should be [SamplerOutput]
...
...
vllm/worker/multi_step_worker.py
View file @
3f60f224
import
dataclasses
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
...
...
@@ -61,6 +62,13 @@ class MultiStepWorker(Worker):
execute_model_req
.
seq_group_metadata_list
,
execute_model_req
.
virtual_engine
,
execute_model_req
.
finished_requests_ids
))
if
execute_model_req
.
async_callback
:
model_input
.
frozen_model_input
=
dataclasses
.
replace
(
# type: ignore
model_input
.
frozen_model_input
,
async_callback
=
execute_model_req
.
async_callback
,
use_async_and_multi_step
=
execute_model_req
.
use_async_and_multi_step
)
else
:
# on subsequent steps we reuse the worker input and model input
multi_step_state
=
self
.
multi_step_states
[
virtual_engine
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment