Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f508e03e
Unverified
Commit
f508e03e
authored
Aug 28, 2024
by
Alexander Matveev
Committed by
GitHub
Aug 28, 2024
Browse files
[Core] Async_output_proc: Add virtual engine support (towards pipeline parallel) (#7911)
parent
51f86bf4
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
122 additions
and
67 deletions
+122
-67
vllm/core/scheduler.py
vllm/core/scheduler.py
+5
-6
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+27
-10
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+79
-42
vllm/sequence.py
vllm/sequence.py
+6
-3
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+3
-3
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+2
-3
No files found.
vllm/core/scheduler.py
View file @
f508e03e
...
...
@@ -302,7 +302,7 @@ class Scheduler:
cache_config
:
CacheConfig
,
lora_config
:
Optional
[
LoRAConfig
],
pipeline_parallel_size
:
int
=
1
,
output_proc_callback
_fn
:
Optional
[
Callable
]
=
None
,
output_proc_callback
:
Optional
[
Callable
]
=
None
,
)
->
None
:
self
.
scheduler_config
=
scheduler_config
self
.
cache_config
=
cache_config
...
...
@@ -376,8 +376,8 @@ class Scheduler:
# iterations. I.e. since the output processing is lagged one step,
# we cannot reuse the cached objects immediately when the schedule()
# is called again, but only when schedule() is called the second time.
self
.
output_proc_callback
_fn
=
output_proc_callback
_fn
self
.
use_async_output_proc
=
self
.
output_proc_callback
_fn
is
not
None
self
.
output_proc_callback
=
output_proc_callback
self
.
use_async_output_proc
=
self
.
output_proc_callback
is
not
None
self
.
num_cache_iters
=
2
if
self
.
use_async_output_proc
else
1
self
.
cache_id
=
0
...
...
@@ -573,8 +573,8 @@ class Scheduler:
seq_group
):
tmp
=
self
.
running
self
.
running
=
orig_running
assert
self
.
output_proc_callback
_fn
is
not
None
self
.
output_proc_callback
_fn
(
is_async
=
True
)
assert
self
.
output_proc_callback
is
not
None
self
.
output_proc_callback
(
)
self
.
running
=
tmp
while
not
self
.
_can_append_slots
(
seq_group
):
...
...
@@ -1091,7 +1091,6 @@ class Scheduler:
no_beam_search
=
seq_group
.
sampling_params
is
None
or
(
seq_group
.
sampling_params
.
best_of
==
1
and
not
seq_group
.
sampling_params
.
use_beam_search
)
return
no_beam_search
def
schedule
(
...
...
vllm/engine/async_llm_engine.py
View file @
f508e03e
...
...
@@ -279,10 +279,16 @@ class _AsyncLLMEngine(LLMEngine):
scheduler_outputs
=
cached_outputs
.
scheduler_outputs
allow_async_output_proc
=
cached_outputs
.
allow_async_output_proc
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
# skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
# Clear outputs on scheduler iteration start
ctx
.
request_outputs
.
clear
()
(
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
...
...
@@ -290,8 +296,9 @@ class _AsyncLLMEngine(LLMEngine):
# If current scheduler iteration has no async postprocessor,
# then we need first to drain the pending async postprocessor
# before moving forward
if
not
allow_async_output_proc
and
len
(
self
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
is_async
=
True
)
if
not
allow_async_output_proc
and
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
if
(
self
.
scheduler_config
.
is_multi_step
and
scheduler_outputs
.
num_lookahead_slots
>
0
):
...
...
@@ -332,8 +339,8 @@ class _AsyncLLMEngine(LLMEngine):
last_sampled_token_ids
=
last_sampled_token_ids
)
if
allow_async_output_proc
:
execute_model_req
.
output_proc_callback_fn
=
\
self
.
_process_model_outputs
execute_model_req
.
async_callback
=
self
.
async_callback
[
virtual_engine
]
# Execute the model.
output
=
await
self
.
model_executor
.
execute_model_async
(
...
...
@@ -343,9 +350,10 @@ class _AsyncLLMEngine(LLMEngine):
if
self
.
scheduler_config
.
is_multi_step
:
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
)
else
:
if
len
(
self
.
output_queue
)
>
0
:
if
len
(
ctx
.
output_queue
)
>
0
:
assert
not
self
.
scheduler_config
.
is_multi_step
self
.
_process_model_outputs
(
is_async
=
True
)
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
output
=
[]
# Finish the current step for all the sequence groups.
...
...
@@ -360,7 +368,7 @@ class _AsyncLLMEngine(LLMEngine):
virtual_engine
]
=
SchedulerOutputState
()
# Cache results in engine
self
.
output_queue
.
append
(
ctx
.
output_queue
.
append
(
(
output
,
seq_group_metadata_list
,
scheduler_outputs
))
if
output
and
allow_async_output_proc
:
...
...
@@ -372,7 +380,8 @@ class _AsyncLLMEngine(LLMEngine):
scheduler_outputs
.
scheduled_seq_groups
)
if
not
allow_async_output_proc
:
self
.
_process_model_outputs
(
is_async
=
False
)
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
False
)
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
output
)
...
...
@@ -381,9 +390,17 @@ class _AsyncLLMEngine(LLMEngine):
self
.
do_tracing
(
scheduler_outputs
)
else
:
self
.
request_outputs
=
[]
ctx
.
request_outputs
=
[]
if
not
self
.
has_unfinished_requests
():
# Drain async postprocessor (if exists)
if
len
(
ctx
.
output_queue
)
>
0
:
assert
not
self
.
scheduler_config
.
is_multi_step
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
assert
len
(
ctx
.
output_queue
)
==
0
return
self
.
request_outputs
return
ctx
.
request_outputs
async
def
stop_remote_worker_execution_loop_async
(
self
)
->
None
:
"""Stop the remote worker execution loop."""
...
...
vllm/engine/llm_engine.py
View file @
f508e03e
import
functools
import
time
from
collections
import
deque
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
field
from
typing
import
(
TYPE_CHECKING
,
Any
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
...
...
@@ -88,6 +89,17 @@ class SchedulerOutputState:
last_output
:
Optional
[
SamplerOutput
]
=
None
@
dataclass
class
SchedulerContext
:
output_queue
:
Deque
[
Tuple
[
List
[
SamplerOutput
],
List
[
SequenceGroupMetadata
],
SchedulerOutputs
]]
=
field
(
default_factory
=
lambda
:
deque
())
request_outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
field
(
default_factory
=
lambda
:
[])
class
LLMEngine
:
"""An LLM engine that receives requests and generates texts.
...
...
@@ -350,9 +362,11 @@ class LLMEngine:
Scheduler
(
scheduler_config
,
cache_config
,
lora_config
,
parallel_config
.
pipeline_parallel_size
,
self
.
_process_model_outputs
functools
.
partial
(
self
.
_process_model_outputs
,
virtual_engine
=
v_id
,
is_async
=
True
)
if
model_config
.
use_async_output_proc
else
None
)
for
_
in
range
(
parallel_config
.
pipeline_parallel_size
)
for
v_id
in
range
(
parallel_config
.
pipeline_parallel_size
)
]
# Metric Logging.
...
...
@@ -406,12 +420,17 @@ class LLMEngine:
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
# Async output processing pointers
self
.
output_queue
:
Deque
[
Tuple
[
List
[
SamplerOutput
],
List
[
SequenceGroupMetadata
],
SchedulerOutputs
]]
=
deque
()
self
.
request_outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
[]
self
.
scheduler_contexts
=
[
SchedulerContext
()
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
self
.
async_callback
=
[
functools
.
partial
(
self
.
_process_model_outputs
,
virtual_engine
=
v_id
,
is_async
=
True
)
for
v_id
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
def
_initialize_kv_caches
(
self
)
->
None
:
"""Initialize the KV cache in the worker(s).
...
...
@@ -1221,32 +1240,28 @@ class LLMEngine:
return
def
_process_model_outputs
(
self
,
is_async
:
bool
,
clear_outputs
:
bool
=
True
)
->
None
:
def
_process_model_outputs
(
self
,
virtual_engine
:
int
,
is_async
:
bool
)
->
None
:
"""Apply the model output to the sequences in the scheduled seq groups.
virtual_engine: The engine id to operate on
is_async: Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
clear_outputs: Sometimes existing outputs need to be combined
with outputs of this call. This happens for postprocessor
draining at the final stage (like when sequences are finished)
Returns RequestOutputs that can be returned to the client.
"""
now
=
time
.
time
()
if
clear_outputs
:
self
.
request_outputs
.
clear
()
ctx
:
SchedulerContext
=
self
.
scheduler_contexts
[
virtual_engine
]
if
len
(
self
.
output_queue
)
==
0
:
if
len
(
ctx
.
output_queue
)
==
0
:
return
None
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
)
=
self
.
output_queue
.
popleft
()
scheduler_outputs
)
=
ctx
.
output_queue
.
popleft
()
# Sanity check
assert
len
(
seq_group_metadata_list
)
==
len
(
...
...
@@ -1321,11 +1336,11 @@ class LLMEngine:
if
(
seq_group
.
is_finished
()
if
self
.
step_return_finished_only
else
True
):
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
self
.
request_outputs
.
append
(
request_output
)
ctx
.
request_outputs
.
append
(
request_output
)
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
:
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
self
.
request_outputs
.
append
(
request_output
)
ctx
.
request_outputs
.
append
(
request_output
)
if
is_async
:
# Log stats.
...
...
@@ -1421,29 +1436,43 @@ class LLMEngine:
"Pipeline parallelism is only supported through AsyncLLMEngine "
"as performance will be severely degraded otherwise."
)
# For llm_engine, there is no pipeline parallel support, so the engine
# used is always 0
virtual_engine
=
0
# These are cached outputs from previous iterations. None if on first
# iteration
cached_outputs
=
self
.
cached_scheduler_outputs
[
0
]
cached_outputs
=
self
.
cached_scheduler_outputs
[
virtual_engine
]
seq_group_metadata_list
=
cached_outputs
.
seq_group_metadata_list
scheduler_outputs
=
cached_outputs
.
scheduler_outputs
allow_async_output_proc
=
cached_outputs
.
allow_async_output_proc
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
# Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
# Clear outputs on scheduler iteration start
ctx
.
request_outputs
.
clear
()
# Schedule iteration
(
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
)
=
self
.
scheduler
[
0
].
schedule
()
allow_async_output_proc
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
if
not
allow_async_output_proc
and
len
(
self
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
is_async
=
True
)
# Maybe switch from async mode to sync mode
if
not
allow_async_output_proc
and
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
if
(
self
.
scheduler_config
.
is_multi_step
and
scheduler_outputs
.
num_lookahead_slots
>
0
):
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
self
.
_cache_scheduler_outputs_for_multi_step
(
0
,
seq_group_metadata_list
,
scheduler_outputs
,
virtual_engine
,
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
)
assert
seq_group_metadata_list
is
not
None
...
...
@@ -1454,14 +1483,14 @@ class LLMEngine:
if
not
scheduler_outputs
.
is_empty
():
finished_requests_ids
=
self
.
scheduler
[
0
].
get_and_reset_finished_requests_ids
()
virtual_engine
].
get_and_reset_finished_requests_ids
()
# Check if we have a cached last_output from the previous iteration.
# For supporting PP this is probably the best way to pass the
# sampled_token_ids, as a separate broadcast over all the PP stages
# will cause one virtual engine's microbatch to block the pipeline.
last_sampled_token_ids
=
\
self
.
_get_last_sampled_token_ids
(
0
)
self
.
_get_last_sampled_token_ids
(
virtual_engine
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
...
...
@@ -1476,20 +1505,24 @@ class LLMEngine:
last_sampled_token_ids
=
last_sampled_token_ids
)
if
allow_async_output_proc
:
execute_model_req
.
output_proc_callback_fn
=
\
self
.
_process_model_outputs
execute_model_req
.
async_callback
=
self
.
async_callback
[
virtual_engine
]
output
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
#
w
e need to do this here so that last step's sampled_token_ids can
#
W
e need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if
self
.
scheduler_config
.
is_multi_step
:
self
.
_update_cached_scheduler_output
(
0
,
output
)
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
)
else
:
if
len
(
self
.
output_queue
)
>
0
:
# Nothing scheduled => If there is pending async postprocessor,
# then finish it here.
if
len
(
ctx
.
output_queue
)
>
0
:
assert
not
self
.
scheduler_config
.
is_multi_step
self
.
_process_model_outputs
(
is_async
=
True
)
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
# No outputs in this case
output
=
[]
# Finish the current step for all the sequence groups.
...
...
@@ -1504,7 +1537,7 @@ class LLMEngine:
# Add results to the output_queue
# (for async or non-async postprocessing)
self
.
output_queue
.
append
(
ctx
.
output_queue
.
append
(
(
output
,
seq_group_metadata_list
,
scheduler_outputs
))
if
output
and
allow_async_output_proc
:
...
...
@@ -1515,8 +1548,10 @@ class LLMEngine:
output
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
# Check if need to run the usual non-async path
if
not
allow_async_output_proc
:
self
.
_process_model_outputs
(
is_async
=
False
)
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
False
)
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
output
)
...
...
@@ -1524,14 +1559,16 @@ class LLMEngine:
# Tracing
self
.
do_tracing
(
scheduler_outputs
)
else
:
self
.
request_outputs
=
[]
# Multi-step case
ctx
.
request_outputs
=
[]
if
not
self
.
has_unfinished_requests
():
# Drain async postprocessor
if
len
(
self
.
output_queue
)
>
0
:
# Drain async postprocessor
(if exists)
if
len
(
ctx
.
output_queue
)
>
0
:
assert
not
self
.
scheduler_config
.
is_multi_step
self
.
_process_model_outputs
(
is_async
=
True
,
clear_outputs
=
False
)
assert
len
(
self
.
output_queue
)
==
0
self
.
_process_model_outputs
(
virtual_engine
=
virtual_engine
,
is_async
=
True
)
assert
len
(
ctx
.
output_queue
)
==
0
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
...
...
@@ -1540,7 +1577,7 @@ class LLMEngine:
# queued control plane messages, such as add/remove lora adapters.
self
.
model_executor
.
stop_remote_worker_execution_loop
()
return
self
.
request_outputs
return
ctx
.
request_outputs
def
_has_remaining_steps
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
...
...
vllm/sequence.py
View file @
f508e03e
...
...
@@ -811,6 +811,9 @@ class SequenceGroup:
self
.
is_single_seq
=
len
(
self
.
seqs
)
==
1
def
is_finished
(
self
)
->
bool
:
if
self
.
is_single_seq
:
return
self
.
seqs
[
0
].
is_finished
()
return
all
(
seq
.
is_finished
()
for
seq
in
self
.
seqs
)
def
is_prefill
(
self
)
->
bool
:
...
...
@@ -1290,8 +1293,8 @@ class ExecuteModelRequest(
finished_requests_ids
:
List
[
str
]
=
msgspec
.
field
(
default_factory
=
list
)
# The last sampled token ids for multi step decoding.
last_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
# Async
postprocessor
output_pro
c_callback
_fn
:
Optional
[
Callable
]
=
None
# Async
callback
asyn
c_callback
:
Optional
[
Callable
]
=
None
@
property
def
is_first_multi_step
(
self
)
->
bool
:
...
...
@@ -1338,4 +1341,4 @@ class ExecuteModelRequest(
finished_requests_ids
=
self
.
finished_requests_ids
,
last_sampled_token_ids
=
self
.
last_sampled_token_ids
.
clone
()
if
self
.
last_sampled_token_ids
is
not
None
else
None
,
output_pro
c_callback
_fn
=
self
.
output_pro
c_callback
_fn
)
asyn
c_callback
=
self
.
asyn
c_callback
)
vllm/worker/model_runner.py
View file @
f508e03e
...
...
@@ -91,7 +91,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
request_ids_to_seq_ids
:
Optional
[
Dict
[
str
,
List
[
int
]]]
=
None
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
virtual_engine
:
int
=
0
output_pro
c_callback
_fn
:
Optional
[
Callable
]
=
None
asyn
c_callback
:
Optional
[
Callable
]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
...
...
@@ -1457,8 +1457,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if
not
self
.
is_driver_worker
:
return
[]
if
model_input
.
output_pro
c_callback
_fn
is
not
None
:
model_input
.
output_proc_callback_fn
(
is_async
=
True
)
if
model_input
.
asyn
c_callback
is
not
None
:
model_input
.
async_callback
(
)
# Sample the next token.
output
:
SamplerOutput
=
self
.
model
.
sample
(
...
...
vllm/worker/worker_base.py
View file @
f508e03e
...
...
@@ -263,11 +263,10 @@ class LocalOrDistributedWorkerBase(WorkerBase):
broadcast_data
.
update
(
kwargs
)
broadcast_tensor_dict
(
broadcast_data
,
src
=
0
)
if
execute_model_req
.
output_pro
c_callback
_fn
:
if
execute_model_req
.
asyn
c_callback
:
model_input
=
dataclasses
.
replace
(
# type: ignore
model_input
,
output_proc_callback_fn
=
execute_model_req
.
output_proc_callback_fn
)
async_callback
=
execute_model_req
.
async_callback
)
return
model_input
,
worker_input
,
kwargs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment