Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4ef41b84
Unverified
Commit
4ef41b84
authored
Sep 08, 2024
by
Alexander Matveev
Committed by
GitHub
Sep 07, 2024
Browse files
[Bugfix] Fix async postprocessor in case of preemption (#8267)
parent
cfe712bf
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
172 additions
and
114 deletions
+172
-114
vllm/core/scheduler.py
vllm/core/scheduler.py
+47
-40
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+12
-12
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+99
-50
vllm/worker/multi_step_model_runner.py
vllm/worker/multi_step_model_runner.py
+14
-12
No files found.
vllm/core/scheduler.py
View file @
4ef41b84
...
...
@@ -537,13 +537,6 @@ class Scheduler:
preempted
:
List
[
SequenceGroup
]
=
ret
.
preempted
swapped_out
:
List
[
SequenceGroup
]
=
ret
.
swapped_out
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
# Store original running requests for the case of async + preemption
if
self
.
use_async_output_proc
:
orig_running
=
self
.
running
.
copy
()
running_queue
=
self
.
running
assert
len
(
self
.
_async_stopped
)
==
0
while
running_queue
:
...
...
@@ -552,6 +545,7 @@ class Scheduler:
seq_group
,
SequenceStatus
.
RUNNING
,
enable_chunking
,
budget
)
if
num_running_tokens
==
0
:
# No budget => Stop
break
running_queue
.
popleft
()
...
...
@@ -565,18 +559,8 @@ class Scheduler:
self
.
_async_stopped
.
append
(
seq_group
)
continue
# With async postprocessor, when preemption kicks in, we need
# first to drain the async postprocessor, so that all async
# block_table freeing is applied before the preemption freeing
# is applied.
if
self
.
use_async_output_proc
and
not
self
.
_can_append_slots
(
seq_group
):
tmp
=
self
.
running
self
.
running
=
orig_running
assert
self
.
output_proc_callback
is
not
None
self
.
output_proc_callback
()
self
.
running
=
tmp
# NOTE(woosuk): Preemption happens only when there is no available
# slot to keep all the sequence groups in the RUNNING state.
while
not
self
.
_can_append_slots
(
seq_group
):
budget
.
subtract_num_batched_tokens
(
seq_group
.
request_id
,
num_running_tokens
)
...
...
@@ -588,24 +572,43 @@ class Scheduler:
and
seq_group
.
lora_int_id
in
curr_loras
):
curr_loras
.
remove
(
seq_group
.
lora_int_id
)
# Determine victim sequence
cont_loop
=
True
if
running_queue
:
# Preempt the lowest-priority sequence group
s
.
# Preempt the lowest-priority sequence group.
victim_seq_group
=
running_queue
.
pop
()
else
:
# No other sequence group can be preempted.
# Preempt the current sequence group.
# Note: This is also where we stop this loop
# (since there is nothing else to preempt)
victim_seq_group
=
seq_group
cont_loop
=
False
# With async postprocessor, before preempting a sequence
# we need to ensure it has no pending async postprocessor
do_preempt
=
True
if
self
.
use_async_output_proc
:
assert
self
.
output_proc_callback
is
not
None
self
.
output_proc_callback
(
request_id
=
victim_seq_group
.
request_id
)
# It may be that the async pending "victim_seq_group"
# becomes finished, in which case we simply free it.
if
victim_seq_group
.
is_finished
():
self
.
_free_finished_seq_group
(
victim_seq_group
)
do_preempt
=
False
# Do preemption
if
do_preempt
:
preempted_mode
=
self
.
_preempt
(
victim_seq_group
,
blocks_to_swap_out
)
if
preempted_mode
==
PreemptionMode
.
RECOMPUTE
:
preempted
.
append
(
victim_seq_group
)
else
:
swapped_out
.
append
(
victim_seq_group
)
else
:
# No other sequence groups can be preempted.
# Preempt the current sequence group.
preempted_mode
=
self
.
_preempt
(
seq_group
,
blocks_to_swap_out
)
if
preempted_mode
==
PreemptionMode
.
RECOMPUTE
:
preempted
.
append
(
seq_group
)
else
:
swapped_out
.
append
(
seq_group
)
if
not
cont_loop
:
break
else
:
self
.
_append_slots
(
seq_group
,
blocks_to_copy
)
...
...
@@ -1264,22 +1267,26 @@ class Scheduler:
if
seq
.
is_finished
():
self
.
free_seq
(
seq
)
def
_free_finished_seq_group
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
if
seq_group
.
is_finished
():
# Free cross-attention block table, if it exists
self
.
_free_seq_group_cross_attn_blocks
(
seq_group
)
# Add the finished requests to the finished requests list.
# This list will be used to update the Mamba cache in the
# next step.
self
.
_finished_requests_ids
.
append
(
seq_group
.
request_id
)
# Free finished seqs
self
.
_free_finished_seqs
(
seq_group
)
def
free_finished_seq_groups
(
self
)
->
None
:
remaining
:
Deque
[
SequenceGroup
]
=
deque
()
for
seq_group
in
self
.
running
:
if
seq_group
.
is_finished
():
# Free cross-attention block table, if it exists
self
.
_free_seq_group_cross_attn_blocks
(
seq_group
)
# Add the finished requests to the finished requests list.
# This list will be used to update the Mamba cache in the
# next step.
self
.
_finished_requests_ids
.
append
(
seq_group
.
request_id
)
else
:
self
.
_free_finished_seq_group
(
seq_group
)
if
not
seq_group
.
is_finished
():
remaining
.
append
(
seq_group
)
# Free finished seqs
self
.
_free_finished_seqs
(
seq_group
)
self
.
running
=
remaining
# Handle async stopped sequence groups
...
...
vllm/engine/async_llm_engine.py
View file @
4ef41b84
...
...
@@ -342,17 +342,17 @@ class _AsyncLLMEngine(LLMEngine):
virtual_engine
]
# Execute the model.
output
=
await
self
.
model_executor
.
execute_model_async
(
output
s
=
await
self
.
model_executor
.
execute_model_async
(
execute_model_req
)
# we need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if
self
.
scheduler_config
.
is_multi_step
:
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
)
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
s
)
else
:
if
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
output
=
[]
output
s
=
[]
# Finish the current step for all the sequence groups.
if
self
.
scheduler_config
.
is_multi_step
:
...
...
@@ -365,25 +365,25 @@ class _AsyncLLMEngine(LLMEngine):
self
.
cached_scheduler_outputs
[
virtual_engine
]
=
SchedulerOutputState
()
is_async
=
allow_async_output_proc
is_last_step
=
True
ctx
.
output_queue
.
append
(
(
output
,
seq_group_metadata_list
,
scheduler_outputs
,
is_asyn
c
,
is_last_step
)
)
ctx
.
append_output
(
outputs
=
outputs
,
seq_group_metadata_list
=
seq_group_metadata_list
,
scheduler_outputs
=
scheduler_outputs
,
is_async
=
allow_async_output_pro
c
,
is_last_step
=
True
)
if
output
and
allow_async_output_proc
:
if
output
s
and
allow_async_output_proc
:
assert
len
(
output
output
s
)
==
1
,
"Async postprocessor expects only a single output set"
self
.
_advance_to_next_step
(
output
[
0
],
seq_group_metadata_list
,
output
s
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
if
not
allow_async_output_proc
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
output
)
self
.
do_log_stats
(
scheduler_outputs
,
output
s
)
# Tracing
self
.
do_tracing
(
scheduler_outputs
)
...
...
vllm/engine/llm_engine.py
View file @
4ef41b84
...
...
@@ -2,9 +2,9 @@ import functools
import
time
from
collections
import
deque
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
)
Mapping
,
NamedTuple
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Tuple
,
Type
,
Union
...
...
@@ -90,17 +90,36 @@ class SchedulerOutputState:
last_output
:
Optional
[
SamplerOutput
]
=
None
@
dataclass
class
OutputData
(
NamedTuple
):
outputs
:
List
[
SamplerOutput
]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
scheduler_outputs
:
SchedulerOutputs
is_async
:
bool
is_last_step
:
bool
skip
:
List
[
int
]
class
SchedulerContext
:
output_queue
:
Deque
[
Tuple
[
Optional
[
List
[
SamplerOutput
]],
List
[
SequenceGroupMetadata
],
SchedulerOutputs
,
bool
,
bool
]]
=
field
(
default_factory
=
lambda
:
deque
())
request_outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
field
(
default_factory
=
lambda
:
[])
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
def
__init__
(
self
):
self
.
output_queue
:
Deque
[
OutputData
]
=
deque
()
self
.
request_outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
[]
self
.
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
self
.
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
def
append_output
(
self
,
outputs
:
List
[
SamplerOutput
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
scheduler_outputs
:
SchedulerOutputs
,
is_async
:
bool
,
is_last_step
:
bool
):
self
.
output_queue
.
append
(
OutputData
(
outputs
=
outputs
,
seq_group_metadata_list
=
seq_group_metadata_list
,
scheduler_outputs
=
scheduler_outputs
,
is_async
=
is_async
,
is_last_step
=
is_last_step
,
skip
=
[]))
class
LLMEngine
:
...
...
@@ -1246,23 +1265,15 @@ class LLMEngine:
return
def
_process_model_outputs
(
self
,
ctx
:
SchedulerContext
)
->
None
:
"""Apply the model output to the sequences in the scheduled seq groups.
def
_process_model_outputs
(
self
,
ctx
:
SchedulerContext
,
request_id
:
Optional
[
str
]
=
None
)
->
None
:
"""Apply the model output to the sequences in the scheduled seq groups
and return responses.
virtual_engine: The engine id to operate on
ctx: The virtual engine context to work on
request_id: If provided, then only this request is going to be processed
is_async: Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
sampler_output: Used with multi-step execution to provide
sampler_output of each step
is_last_output: Used with multi-step execution to indicate
the last step (of each multi-step group)
Returns RequestOutputs that can be returned to the client.
"""
now
=
time
.
time
()
...
...
@@ -1270,9 +1281,14 @@ class LLMEngine:
return
None
# Get pending async postprocessor
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
,
is_async
,
is_last_step
)
=
ctx
.
output_queue
.
popleft
()
assert
outputs
is
not
None
if
request_id
:
# When we process only one request, no pop is required
# (since later we will process all of the rest)
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
,
is_async
,
is_last_step
,
skip
)
=
ctx
.
output_queue
[
0
]
else
:
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
,
is_async
,
is_last_step
,
skip
)
=
ctx
.
output_queue
.
popleft
()
# Sanity check
assert
len
(
seq_group_metadata_list
)
==
len
(
...
...
@@ -1286,9 +1302,30 @@ class LLMEngine:
else
:
outputs_by_sequence_group
=
outputs
# Determine the requests we need to operate on
if
request_id
:
indices
=
[]
for
i
,
seq_group_meta
in
enumerate
(
seq_group_metadata_list
):
if
seq_group_meta
.
request_id
==
request_id
:
assert
i
not
in
skip
# Cannot be called twice
indices
.
append
(
i
)
break
# If the request_id was not found, then it means that
# this is a new request that has no pending async
# postprocessor
if
not
indices
:
return
else
:
indices
=
range
(
len
(
seq_group_metadata_list
))
# type: ignore
finished_before
:
List
[
int
]
=
[]
finished_now
:
List
[
int
]
=
[]
for
i
,
seq_group_meta
in
enumerate
(
seq_group_metadata_list
):
for
i
in
indices
:
if
i
in
skip
:
continue
seq_group_meta
=
seq_group_metadata_list
[
i
]
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
seq_group
=
scheduled_seq_group
.
seq_group
...
...
@@ -1343,6 +1380,18 @@ class LLMEngine:
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
ctx
.
request_outputs
.
append
(
request_output
)
# When we process a single request, we skip it for the next time,
# and invoke the request output callback (if there was final output)
if
request_id
:
assert
len
(
indices
)
==
1
skip
.
append
(
indices
[
0
])
if
(
finished_now
and
self
.
process_request_outputs_callback
is
not
None
):
self
.
process_request_outputs_callback
(
ctx
.
request_outputs
)
ctx
.
request_outputs
.
clear
()
return
# Free currently finished requests
if
finished_now
:
for
scheduler
in
self
.
scheduler
:
...
...
@@ -1354,17 +1403,16 @@ class LLMEngine:
if
(
finished_now
and
self
.
process_request_outputs_callback
is
not
None
):
self
.
process_request_outputs_callback
(
ctx
.
request_outputs
)
ctx
.
request_outputs
.
clear
()
return
# Create the outputs
# Note: scheduled_seq_groups and seq_group_metadata_list
# must match with the indices
for
i
,
scheduled_seq_group
in
enumerate
(
scheduler_outputs
.
scheduled_seq_groups
):
if
i
in
finished_before
or
i
in
finished_now
:
for
i
in
indices
:
if
i
in
skip
or
i
in
finished_before
or
i
in
finished_now
:
continue
# Avoids double processing
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
if
(
seq_group
.
is_finished
()
...
...
@@ -1380,6 +1428,7 @@ class LLMEngine:
if
(
ctx
.
request_outputs
and
self
.
process_request_outputs_callback
is
not
None
):
self
.
process_request_outputs_callback
(
ctx
.
request_outputs
)
ctx
.
request_outputs
.
clear
()
# For async case, we need to record the stats here.
# For non-async case, the stats are done in the
...
...
@@ -1548,20 +1597,20 @@ class LLMEngine:
execute_model_req
.
async_callback
=
self
.
async_callbacks
[
virtual_engine
]
output
=
self
.
model_executor
.
execute_model
(
output
s
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if
self
.
scheduler_config
.
is_multi_step
:
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
)
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
s
)
else
:
# Nothing scheduled => If there is pending async postprocessor,
# then finish it here.
if
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
# No outputs in this case
output
=
[]
output
s
=
[]
# Finish the current step for all the sequence groups.
if
self
.
scheduler_config
.
is_multi_step
:
...
...
@@ -1574,18 +1623,18 @@ class LLMEngine:
self
.
cached_scheduler_outputs
[
0
]
=
SchedulerOutputState
()
# Add results to the output_queue
is_async
=
allow_async_output_proc
is_last_step
=
True
ctx
.
output_queue
.
append
(
(
output
,
seq_group_metadata_list
,
scheduler_outputs
,
is_asyn
c
,
is_last_step
)
)
if
output
and
allow_async_output_proc
:
assert
len
(
output
)
==
1
,
(
ctx
.
append_output
(
outputs
=
outputs
,
seq_group_metadata_list
=
seq_group_metadata_list
,
scheduler_outputs
=
scheduler_outputs
,
is_async
=
allow_async_output_pro
c
,
is_last_step
=
True
)
if
output
s
and
allow_async_output_proc
:
assert
len
(
output
s
)
==
1
,
(
"Async postprocessor expects only a single output set"
)
self
.
_advance_to_next_step
(
output
[
0
],
seq_group_metadata_list
,
output
s
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
# Check if need to run the usual non-async path
...
...
@@ -1593,7 +1642,7 @@ class LLMEngine:
self
.
_process_model_outputs
(
ctx
=
ctx
)
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
output
)
self
.
do_log_stats
(
scheduler_outputs
,
output
s
)
# Tracing
self
.
do_tracing
(
scheduler_outputs
)
...
...
vllm/worker/multi_step_model_runner.py
View file @
4ef41b84
...
...
@@ -274,12 +274,13 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
self
.
pinned_sampled_token_ids
)
if
model_output
.
pythonized
:
ctx
=
output_proc_callback
.
keywords
[
"ctx"
]
is_async
=
False
is_last_step
=
False
ctx
.
output_queue
.
append
(
([
model_output
.
sampler_output
],
ctx
.
seq_group_metadata_list
,
ctx
.
scheduler_outputs
,
is_async
,
is_last_step
))
ctx
.
append_output
(
outputs
=
[
model_output
.
sampler_output
],
seq_group_metadata_list
=
ctx
.
seq_group_metadata_list
,
scheduler_outputs
=
ctx
.
scheduler_outputs
,
is_async
=
False
,
is_last_step
=
False
)
output_proc_callback
()
else
:
cont
=
False
...
...
@@ -319,12 +320,13 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
if
not
is_last_step
:
ctx
=
output_proc_callback
.
keywords
[
# type: ignore
"ctx"
]
# type: ignore
is_async
=
False
is_last_step
=
False
ctx
.
output_queue
.
append
(
([
output
.
sampler_output
],
ctx
.
seq_group_metadata_list
,
ctx
.
scheduler_outputs
,
is_async
,
is_last_step
))
ctx
.
append_output
(
outputs
=
[
output
.
sampler_output
],
seq_group_metadata_list
=
ctx
.
seq_group_metadata_list
,
scheduler_outputs
=
ctx
.
scheduler_outputs
,
is_async
=
False
,
is_last_step
=
False
)
else
:
outputs
.
append
(
output
.
sampler_output
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment