Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f508e03e
Unverified
Commit
f508e03e
authored
Aug 28, 2024
by
Alexander Matveev
Committed by
GitHub
Aug 28, 2024
Browse files
[Core] Async_output_proc: Add virtual engine support (towards pipeline parallel) (#7911)
parent
51f86bf4
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
122 additions
and
67 deletions
+122
-67
vllm/core/scheduler.py
vllm/core/scheduler.py
+5
-6
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+27
-10
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+79
-42
vllm/sequence.py
vllm/sequence.py
+6
-3
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+3
-3
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+2
-3
No files found.
vllm/core/scheduler.py
View file @
f508e03e
...
@@ -302,7 +302,7 @@ class Scheduler:
...
@@ -302,7 +302,7 @@ class Scheduler:
cache_config
:
CacheConfig
,
cache_config
:
CacheConfig
,
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
pipeline_parallel_size
:
int
=
1
,
pipeline_parallel_size
:
int
=
1
,
output_proc_callback
_fn
:
Optional
[
Callable
]
=
None
,
output_proc_callback
:
Optional
[
Callable
]
=
None
,
)
->
None
:
)
->
None
:
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
...
@@ -376,8 +376,8 @@ class Scheduler:
...
@@ -376,8 +376,8 @@ class Scheduler:
# iterations. I.e. since the output processing is lagged one step,
# iterations. I.e. since the output processing is lagged one step,
# we cannot reuse the cached objects immediately when the schedule()
# we cannot reuse the cached objects immediately when the schedule()
# is called again, but only when schedule() is called the second time.
# is called again, but only when schedule() is called the second time.
self
.
output_proc_callback
_fn
=
output_proc_callback
_fn
self
.
output_proc_callback
=
output_proc_callback
self
.
use_async_output_proc
=
self
.
output_proc_callback
_fn
is
not
None
self
.
use_async_output_proc
=
self
.
output_proc_callback
is
not
None
self
.
num_cache_iters
=
2
if
self
.
use_async_output_proc
else
1
self
.
num_cache_iters
=
2
if
self
.
use_async_output_proc
else
1
self
.
cache_id
=
0
self
.
cache_id
=
0
...
@@ -573,8 +573,8 @@ class Scheduler:
...
@@ -573,8 +573,8 @@ class Scheduler:
seq_group
):
seq_group
):
tmp
=
self
.
running
tmp
=
self
.
running
self
.
running
=
orig_running
self
.
running
=
orig_running
assert
self
.
output_proc_callback
_fn
is
not
None
assert
self
.
output_proc_callback
is
not
None
self
.
output_proc_callback
_fn
(
is_async
=
True
)
self
.
output_proc_callback
(
)
self
.
running
=
tmp
self
.
running
=
tmp
while
not
self
.
_can_append_slots
(
seq_group
):
while
not
self
.
_can_append_slots
(
seq_group
):
...
@@ -1091,7 +1091,6 @@ class Scheduler:
...
@@ -1091,7 +1091,6 @@ class Scheduler:
no_beam_search
=
seq_group
.
sampling_params
is
None
or
(
no_beam_search
=
seq_group
.
sampling_params
is
None
or
(
seq_group
.
sampling_params
.
best_of
==
1
seq_group
.
sampling_params
.
best_of
==
1
and
not
seq_group
.
sampling_params
.
use_beam_search
)
and
not
seq_group
.
sampling_params
.
use_beam_search
)
return
no_beam_search
return
no_beam_search
def
schedule
(
def
schedule
(
...
...
vllm/engine/async_llm_engine.py
View file @
f508e03e
...
@@ -279,10 +279,16 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -279,10 +279,16 @@ class _AsyncLLMEngine(LLMEngine):
scheduler_outputs
=
cached_outputs
.
scheduler_outputs
scheduler_outputs
=
cached_outputs
.
scheduler_outputs
allow_async_output_proc
=
cached_outputs
.
allow_async_output_proc
allow_async_output_proc
=
cached_outputs
.
allow_async_output_proc
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
# skip the scheduler if there are any remaining steps in the seq groups.
# skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# This ensures that the scheduler is only called again when the current
# batch has completed.
# batch has completed.
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
# Clear outputs on scheduler iteration start
ctx
.
request_outputs
.
clear
()
(
seq_group_metadata_list
,
scheduler_outputs
,
(
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
allow_async_output_proc
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
...
@@ -290,8 +296,9 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -290,8 +296,9 @@ class _AsyncLLMEngine(LLMEngine):
# If current scheduler iteration has no async postprocessor,
# If current scheduler iteration has no async postprocessor,
# then we need first to drain the pending async postprocessor
# then we need first to drain the pending async postprocessor
# before moving forward
# before moving forward
if
not
allow_async_output_proc
and
len
(
self
.
output_queue
)
>
0
:
if
not
allow_async_output_proc
and
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
is_async
=
True
)
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
if
(
self
.
scheduler_config
.
is_multi_step
if
(
self
.
scheduler_config
.
is_multi_step
and
scheduler_outputs
.
num_lookahead_slots
>
0
):
and
scheduler_outputs
.
num_lookahead_slots
>
0
):
...
@@ -332,8 +339,8 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -332,8 +339,8 @@ class _AsyncLLMEngine(LLMEngine):
last_sampled_token_ids
=
last_sampled_token_ids
)
last_sampled_token_ids
=
last_sampled_token_ids
)
if
allow_async_output_proc
:
if
allow_async_output_proc
:
execute_model_req
.
output_proc_callback_fn
=
\
execute_model_req
.
async_callback
=
self
.
async_callback
[
self
.
_process_model_outputs
virtual_engine
]
# Execute the model.
# Execute the model.
output
=
await
self
.
model_executor
.
execute_model_async
(
output
=
await
self
.
model_executor
.
execute_model_async
(
...
@@ -343,9 +350,10 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -343,9 +350,10 @@ class _AsyncLLMEngine(LLMEngine):
if
self
.
scheduler_config
.
is_multi_step
:
if
self
.
scheduler_config
.
is_multi_step
:
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
)
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
)
else
:
else
:
if
len
(
self
.
output_queue
)
>
0
:
if
len
(
ctx
.
output_queue
)
>
0
:
assert
not
self
.
scheduler_config
.
is_multi_step
assert
not
self
.
scheduler_config
.
is_multi_step
self
.
_process_model_outputs
(
is_async
=
True
)
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
output
=
[]
output
=
[]
# Finish the current step for all the sequence groups.
# Finish the current step for all the sequence groups.
...
@@ -360,7 +368,7 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -360,7 +368,7 @@ class _AsyncLLMEngine(LLMEngine):
virtual_engine
]
=
SchedulerOutputState
()
virtual_engine
]
=
SchedulerOutputState
()
# Cache results in engine
# Cache results in engine
self
.
output_queue
.
append
(
ctx
.
output_queue
.
append
(
(
output
,
seq_group_metadata_list
,
scheduler_outputs
))
(
output
,
seq_group_metadata_list
,
scheduler_outputs
))
if
output
and
allow_async_output_proc
:
if
output
and
allow_async_output_proc
:
...
@@ -372,7 +380,8 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -372,7 +380,8 @@ class _AsyncLLMEngine(LLMEngine):
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
)
if
not
allow_async_output_proc
:
if
not
allow_async_output_proc
:
self
.
_process_model_outputs
(
is_async
=
False
)
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
False
)
# Log stats.
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
output
)
self
.
do_log_stats
(
scheduler_outputs
,
output
)
...
@@ -381,9 +390,17 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -381,9 +390,17 @@ class _AsyncLLMEngine(LLMEngine):
self
.
do_tracing
(
scheduler_outputs
)
self
.
do_tracing
(
scheduler_outputs
)
else
:
else
:
self
.
request_outputs
=
[]
ctx
.
request_outputs
=
[]
if
not
self
.
has_unfinished_requests
():
# Drain async postprocessor (if exists)
if
len
(
ctx
.
output_queue
)
>
0
:
assert
not
self
.
scheduler_config
.
is_multi_step
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
assert
len
(
ctx
.
output_queue
)
==
0
return
self
.
request_outputs
return
ctx
.
request_outputs
async
def
stop_remote_worker_execution_loop_async
(
self
)
->
None
:
async
def
stop_remote_worker_execution_loop_async
(
self
)
->
None
:
"""Stop the remote worker execution loop."""
"""Stop the remote worker execution loop."""
...
...
vllm/engine/llm_engine.py
View file @
f508e03e
import
functools
import
time
import
time
from
collections
import
deque
from
collections
import
deque
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
field
from
typing
import
(
TYPE_CHECKING
,
Any
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
List
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
)
Mapping
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
...
@@ -88,6 +89,17 @@ class SchedulerOutputState:
...
@@ -88,6 +89,17 @@ class SchedulerOutputState:
last_output
:
Optional
[
SamplerOutput
]
=
None
last_output
:
Optional
[
SamplerOutput
]
=
None
@
dataclass
class
SchedulerContext
:
output_queue
:
Deque
[
Tuple
[
List
[
SamplerOutput
],
List
[
SequenceGroupMetadata
],
SchedulerOutputs
]]
=
field
(
default_factory
=
lambda
:
deque
())
request_outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
field
(
default_factory
=
lambda
:
[])
class
LLMEngine
:
class
LLMEngine
:
"""An LLM engine that receives requests and generates texts.
"""An LLM engine that receives requests and generates texts.
...
@@ -350,9 +362,11 @@ class LLMEngine:
...
@@ -350,9 +362,11 @@ class LLMEngine:
Scheduler
(
Scheduler
(
scheduler_config
,
cache_config
,
lora_config
,
scheduler_config
,
cache_config
,
lora_config
,
parallel_config
.
pipeline_parallel_size
,
parallel_config
.
pipeline_parallel_size
,
self
.
_process_model_outputs
functools
.
partial
(
self
.
_process_model_outputs
,
virtual_engine
=
v_id
,
is_async
=
True
)
if
model_config
.
use_async_output_proc
else
None
)
if
model_config
.
use_async_output_proc
else
None
)
for
_
in
range
(
parallel_config
.
pipeline_parallel_size
)
for
v_id
in
range
(
parallel_config
.
pipeline_parallel_size
)
]
]
# Metric Logging.
# Metric Logging.
...
@@ -406,12 +420,17 @@ class LLMEngine:
...
@@ -406,12 +420,17 @@ class LLMEngine:
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
]
# Async output processing pointers
self
.
scheduler_contexts
=
[
self
.
output_queue
:
Deque
[
Tuple
[
List
[
SamplerOutput
],
SchedulerContext
()
List
[
SequenceGroupMetadata
],
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
SchedulerOutputs
]]
=
deque
()
]
self
.
request_outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
[]
self
.
async_callback
=
[
functools
.
partial
(
self
.
_process_model_outputs
,
virtual_engine
=
v_id
,
is_async
=
True
)
for
v_id
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
def
_initialize_kv_caches
(
self
)
->
None
:
def
_initialize_kv_caches
(
self
)
->
None
:
"""Initialize the KV cache in the worker(s).
"""Initialize the KV cache in the worker(s).
...
@@ -1221,32 +1240,28 @@ class LLMEngine:
...
@@ -1221,32 +1240,28 @@ class LLMEngine:
return
return
def
_process_model_outputs
(
self
,
def
_process_model_outputs
(
self
,
virtual_engine
:
int
,
is_async
:
bool
,
is_async
:
bool
)
->
None
:
clear_outputs
:
bool
=
True
)
->
None
:
"""Apply the model output to the sequences in the scheduled seq groups.
"""Apply the model output to the sequences in the scheduled seq groups.
virtual_engine: The engine id to operate on
is_async: Indicates whether this postprocessor runs in
is_async: Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
no tokens need to be appended since it is already done
externally (before the next schedule() call)
externally (before the next schedule() call)
clear_outputs: Sometimes existing outputs need to be combined
with outputs of this call. This happens for postprocessor
draining at the final stage (like when sequences are finished)
Returns RequestOutputs that can be returned to the client.
Returns RequestOutputs that can be returned to the client.
"""
"""
now
=
time
.
time
()
now
=
time
.
time
()
if
clear_outputs
:
ctx
:
SchedulerContext
=
self
.
scheduler_contexts
[
virtual_engine
]
self
.
request_outputs
.
clear
()
if
len
(
self
.
output_queue
)
==
0
:
if
len
(
ctx
.
output_queue
)
==
0
:
return
None
return
None
(
outputs
,
seq_group_metadata_list
,
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
)
=
self
.
output_queue
.
popleft
()
scheduler_outputs
)
=
ctx
.
output_queue
.
popleft
()
# Sanity check
# Sanity check
assert
len
(
seq_group_metadata_list
)
==
len
(
assert
len
(
seq_group_metadata_list
)
==
len
(
...
@@ -1321,11 +1336,11 @@ class LLMEngine:
...
@@ -1321,11 +1336,11 @@ class LLMEngine:
if
(
seq_group
.
is_finished
()
if
(
seq_group
.
is_finished
()
if
self
.
step_return_finished_only
else
True
):
if
self
.
step_return_finished_only
else
True
):
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
self
.
request_outputs
.
append
(
request_output
)
ctx
.
request_outputs
.
append
(
request_output
)
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
:
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
:
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
self
.
request_outputs
.
append
(
request_output
)
ctx
.
request_outputs
.
append
(
request_output
)
if
is_async
:
if
is_async
:
# Log stats.
# Log stats.
...
@@ -1421,29 +1436,43 @@ class LLMEngine:
...
@@ -1421,29 +1436,43 @@ class LLMEngine:
"Pipeline parallelism is only supported through AsyncLLMEngine "
"Pipeline parallelism is only supported through AsyncLLMEngine "
"as performance will be severely degraded otherwise."
)
"as performance will be severely degraded otherwise."
)
# For llm_engine, there is no pipeline parallel support, so the engine
# used is always 0
virtual_engine
=
0
# These are cached outputs from previous iterations. None if on first
# These are cached outputs from previous iterations. None if on first
# iteration
# iteration
cached_outputs
=
self
.
cached_scheduler_outputs
[
0
]
cached_outputs
=
self
.
cached_scheduler_outputs
[
virtual_engine
]
seq_group_metadata_list
=
cached_outputs
.
seq_group_metadata_list
seq_group_metadata_list
=
cached_outputs
.
seq_group_metadata_list
scheduler_outputs
=
cached_outputs
.
scheduler_outputs
scheduler_outputs
=
cached_outputs
.
scheduler_outputs
allow_async_output_proc
=
cached_outputs
.
allow_async_output_proc
allow_async_output_proc
=
cached_outputs
.
allow_async_output_proc
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
# Skip the scheduler if there are any remaining steps in the seq groups.
# Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# This ensures that the scheduler is only called again when the current
# batch has completed.
# batch has completed.
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
# Clear outputs on scheduler iteration start
ctx
.
request_outputs
.
clear
()
# Schedule iteration
(
seq_group_metadata_list
,
scheduler_outputs
,
(
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
)
=
self
.
scheduler
[
0
].
schedule
()
allow_async_output_proc
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
if
not
allow_async_output_proc
and
len
(
self
.
output_queue
)
>
0
:
# Maybe switch from async mode to sync mode
self
.
_process_model_outputs
(
is_async
=
True
)
if
not
allow_async_output_proc
and
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
if
(
self
.
scheduler_config
.
is_multi_step
if
(
self
.
scheduler_config
.
is_multi_step
and
scheduler_outputs
.
num_lookahead_slots
>
0
):
and
scheduler_outputs
.
num_lookahead_slots
>
0
):
# cache the scheduler outputs for the next iteration if we have
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
# lookahead slots
self
.
_cache_scheduler_outputs_for_multi_step
(
self
.
_cache_scheduler_outputs_for_multi_step
(
0
,
seq_group_metadata_list
,
scheduler_outputs
,
virtual_engine
,
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
)
allow_async_output_proc
)
assert
seq_group_metadata_list
is
not
None
assert
seq_group_metadata_list
is
not
None
...
@@ -1454,14 +1483,14 @@ class LLMEngine:
...
@@ -1454,14 +1483,14 @@ class LLMEngine:
if
not
scheduler_outputs
.
is_empty
():
if
not
scheduler_outputs
.
is_empty
():
finished_requests_ids
=
self
.
scheduler
[
finished_requests_ids
=
self
.
scheduler
[
0
].
get_and_reset_finished_requests_ids
()
virtual_engine
].
get_and_reset_finished_requests_ids
()
# Check if we have a cached last_output from the previous iteration.
# Check if we have a cached last_output from the previous iteration.
# For supporting PP this is probably the best way to pass the
# For supporting PP this is probably the best way to pass the
# sampled_token_ids, as a separate broadcast over all the PP stages
# sampled_token_ids, as a separate broadcast over all the PP stages
# will cause one virtual engine's microbatch to block the pipeline.
# will cause one virtual engine's microbatch to block the pipeline.
last_sampled_token_ids
=
\
last_sampled_token_ids
=
\
self
.
_get_last_sampled_token_ids
(
0
)
self
.
_get_last_sampled_token_ids
(
virtual_engine
)
execute_model_req
=
ExecuteModelRequest
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
...
@@ -1476,20 +1505,24 @@ class LLMEngine:
...
@@ -1476,20 +1505,24 @@ class LLMEngine:
last_sampled_token_ids
=
last_sampled_token_ids
)
last_sampled_token_ids
=
last_sampled_token_ids
)
if
allow_async_output_proc
:
if
allow_async_output_proc
:
execute_model_req
.
output_proc_callback_fn
=
\
execute_model_req
.
async_callback
=
self
.
async_callback
[
self
.
_process_model_outputs
virtual_engine
]
output
=
self
.
model_executor
.
execute_model
(
output
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
execute_model_req
=
execute_model_req
)
#
w
e need to do this here so that last step's sampled_token_ids can
#
W
e need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
# be passed to the next iteration for PP.
if
self
.
scheduler_config
.
is_multi_step
:
if
self
.
scheduler_config
.
is_multi_step
:
self
.
_update_cached_scheduler_output
(
0
,
output
)
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
)
else
:
else
:
if
len
(
self
.
output_queue
)
>
0
:
# Nothing scheduled => If there is pending async postprocessor,
# then finish it here.
if
len
(
ctx
.
output_queue
)
>
0
:
assert
not
self
.
scheduler_config
.
is_multi_step
assert
not
self
.
scheduler_config
.
is_multi_step
self
.
_process_model_outputs
(
is_async
=
True
)
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
# No outputs in this case
output
=
[]
output
=
[]
# Finish the current step for all the sequence groups.
# Finish the current step for all the sequence groups.
...
@@ -1504,7 +1537,7 @@ class LLMEngine:
...
@@ -1504,7 +1537,7 @@ class LLMEngine:
# Add results to the output_queue
# Add results to the output_queue
# (for async or non-async postprocessing)
# (for async or non-async postprocessing)
self
.
output_queue
.
append
(
ctx
.
output_queue
.
append
(
(
output
,
seq_group_metadata_list
,
scheduler_outputs
))
(
output
,
seq_group_metadata_list
,
scheduler_outputs
))
if
output
and
allow_async_output_proc
:
if
output
and
allow_async_output_proc
:
...
@@ -1515,8 +1548,10 @@ class LLMEngine:
...
@@ -1515,8 +1548,10 @@ class LLMEngine:
output
[
0
],
seq_group_metadata_list
,
output
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
)
# Check if need to run the usual non-async path
if
not
allow_async_output_proc
:
if
not
allow_async_output_proc
:
self
.
_process_model_outputs
(
is_async
=
False
)
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
False
)
# Log stats.
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
output
)
self
.
do_log_stats
(
scheduler_outputs
,
output
)
...
@@ -1524,14 +1559,16 @@ class LLMEngine:
...
@@ -1524,14 +1559,16 @@ class LLMEngine:
# Tracing
# Tracing
self
.
do_tracing
(
scheduler_outputs
)
self
.
do_tracing
(
scheduler_outputs
)
else
:
else
:
self
.
request_outputs
=
[]
# Multi-step case
ctx
.
request_outputs
=
[]
if
not
self
.
has_unfinished_requests
():
if
not
self
.
has_unfinished_requests
():
# Drain async postprocessor
# Drain async postprocessor
(if exists)
if
len
(
self
.
output_queue
)
>
0
:
if
len
(
ctx
.
output_queue
)
>
0
:
assert
not
self
.
scheduler_config
.
is_multi_step
assert
not
self
.
scheduler_config
.
is_multi_step
self
.
_process_model_outputs
(
is_async
=
True
,
clear_outputs
=
False
)
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
assert
len
(
self
.
output_queue
)
==
0
is_async
=
True
)
assert
len
(
ctx
.
output_queue
)
==
0
# Stop the execute model loop in parallel workers until there are
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# more requests to process. This avoids waiting indefinitely in
...
@@ -1540,7 +1577,7 @@ class LLMEngine:
...
@@ -1540,7 +1577,7 @@ class LLMEngine:
# queued control plane messages, such as add/remove lora adapters.
# queued control plane messages, such as add/remove lora adapters.
self
.
model_executor
.
stop_remote_worker_execution_loop
()
self
.
model_executor
.
stop_remote_worker_execution_loop
()
return
self
.
request_outputs
return
ctx
.
request_outputs
def
_has_remaining_steps
(
def
_has_remaining_steps
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
...
...
vllm/sequence.py
View file @
f508e03e
...
@@ -811,6 +811,9 @@ class SequenceGroup:
...
@@ -811,6 +811,9 @@ class SequenceGroup:
self
.
is_single_seq
=
len
(
self
.
seqs
)
==
1
self
.
is_single_seq
=
len
(
self
.
seqs
)
==
1
def
is_finished
(
self
)
->
bool
:
def
is_finished
(
self
)
->
bool
:
if
self
.
is_single_seq
:
return
self
.
seqs
[
0
].
is_finished
()
return
all
(
seq
.
is_finished
()
for
seq
in
self
.
seqs
)
return
all
(
seq
.
is_finished
()
for
seq
in
self
.
seqs
)
def
is_prefill
(
self
)
->
bool
:
def
is_prefill
(
self
)
->
bool
:
...
@@ -1290,8 +1293,8 @@ class ExecuteModelRequest(
...
@@ -1290,8 +1293,8 @@ class ExecuteModelRequest(
finished_requests_ids
:
List
[
str
]
=
msgspec
.
field
(
default_factory
=
list
)
finished_requests_ids
:
List
[
str
]
=
msgspec
.
field
(
default_factory
=
list
)
# The last sampled token ids for multi step decoding.
# The last sampled token ids for multi step decoding.
last_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
last_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
# Async
postprocessor
# Async
callback
output_pro
c_callback
_fn
:
Optional
[
Callable
]
=
None
asyn
c_callback
:
Optional
[
Callable
]
=
None
@
property
@
property
def
is_first_multi_step
(
self
)
->
bool
:
def
is_first_multi_step
(
self
)
->
bool
:
...
@@ -1338,4 +1341,4 @@ class ExecuteModelRequest(
...
@@ -1338,4 +1341,4 @@ class ExecuteModelRequest(
finished_requests_ids
=
self
.
finished_requests_ids
,
finished_requests_ids
=
self
.
finished_requests_ids
,
last_sampled_token_ids
=
self
.
last_sampled_token_ids
.
clone
()
last_sampled_token_ids
=
self
.
last_sampled_token_ids
.
clone
()
if
self
.
last_sampled_token_ids
is
not
None
else
None
,
if
self
.
last_sampled_token_ids
is
not
None
else
None
,
output_pro
c_callback
_fn
=
self
.
output_pro
c_callback
_fn
)
asyn
c_callback
=
self
.
asyn
c_callback
)
vllm/worker/model_runner.py
View file @
f508e03e
...
@@ -91,7 +91,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
...
@@ -91,7 +91,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
request_ids_to_seq_ids
:
Optional
[
Dict
[
str
,
List
[
int
]]]
=
None
request_ids_to_seq_ids
:
Optional
[
Dict
[
str
,
List
[
int
]]]
=
None
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
virtual_engine
:
int
=
0
virtual_engine
:
int
=
0
output_pro
c_callback
_fn
:
Optional
[
Callable
]
=
None
asyn
c_callback
:
Optional
[
Callable
]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
tensor_dict
=
{
...
@@ -1457,8 +1457,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1457,8 +1457,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if
not
self
.
is_driver_worker
:
if
not
self
.
is_driver_worker
:
return
[]
return
[]
if
model_input
.
output_pro
c_callback
_fn
is
not
None
:
if
model_input
.
asyn
c_callback
is
not
None
:
model_input
.
output_proc_callback_fn
(
is_async
=
True
)
model_input
.
async_callback
(
)
# Sample the next token.
# Sample the next token.
output
:
SamplerOutput
=
self
.
model
.
sample
(
output
:
SamplerOutput
=
self
.
model
.
sample
(
...
...
vllm/worker/worker_base.py
View file @
f508e03e
...
@@ -263,11 +263,10 @@ class LocalOrDistributedWorkerBase(WorkerBase):
...
@@ -263,11 +263,10 @@ class LocalOrDistributedWorkerBase(WorkerBase):
broadcast_data
.
update
(
kwargs
)
broadcast_data
.
update
(
kwargs
)
broadcast_tensor_dict
(
broadcast_data
,
src
=
0
)
broadcast_tensor_dict
(
broadcast_data
,
src
=
0
)
if
execute_model_req
.
output_pro
c_callback
_fn
:
if
execute_model_req
.
asyn
c_callback
:
model_input
=
dataclasses
.
replace
(
# type: ignore
model_input
=
dataclasses
.
replace
(
# type: ignore
model_input
,
model_input
,
output_proc_callback_fn
=
execute_model_req
.
async_callback
=
execute_model_req
.
async_callback
)
output_proc_callback_fn
)
return
model_input
,
worker_input
,
kwargs
return
model_input
,
worker_input
,
kwargs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment