Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4ef41b84
Unverified
Commit
4ef41b84
authored
Sep 08, 2024
by
Alexander Matveev
Committed by
GitHub
Sep 07, 2024
Browse files
[Bugfix] Fix async postprocessor in case of preemption (#8267)
parent
cfe712bf
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
172 additions
and
114 deletions
+172
-114
vllm/core/scheduler.py
vllm/core/scheduler.py
+47
-40
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+12
-12
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+99
-50
vllm/worker/multi_step_model_runner.py
vllm/worker/multi_step_model_runner.py
+14
-12
No files found.
vllm/core/scheduler.py
View file @
4ef41b84
...
@@ -537,13 +537,6 @@ class Scheduler:
...
@@ -537,13 +537,6 @@ class Scheduler:
preempted
:
List
[
SequenceGroup
]
=
ret
.
preempted
preempted
:
List
[
SequenceGroup
]
=
ret
.
preempted
swapped_out
:
List
[
SequenceGroup
]
=
ret
.
swapped_out
swapped_out
:
List
[
SequenceGroup
]
=
ret
.
swapped_out
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
# Store original running requests for the case of async + preemption
if
self
.
use_async_output_proc
:
orig_running
=
self
.
running
.
copy
()
running_queue
=
self
.
running
running_queue
=
self
.
running
assert
len
(
self
.
_async_stopped
)
==
0
assert
len
(
self
.
_async_stopped
)
==
0
while
running_queue
:
while
running_queue
:
...
@@ -552,6 +545,7 @@ class Scheduler:
...
@@ -552,6 +545,7 @@ class Scheduler:
seq_group
,
SequenceStatus
.
RUNNING
,
enable_chunking
,
budget
)
seq_group
,
SequenceStatus
.
RUNNING
,
enable_chunking
,
budget
)
if
num_running_tokens
==
0
:
if
num_running_tokens
==
0
:
# No budget => Stop
break
break
running_queue
.
popleft
()
running_queue
.
popleft
()
...
@@ -565,18 +559,8 @@ class Scheduler:
...
@@ -565,18 +559,8 @@ class Scheduler:
self
.
_async_stopped
.
append
(
seq_group
)
self
.
_async_stopped
.
append
(
seq_group
)
continue
continue
# With async postprocessor, when preemption kicks in, we need
# NOTE(woosuk): Preemption happens only when there is no available
# first to drain the async postprocessor, so that all async
# slot to keep all the sequence groups in the RUNNING state.
# block_table freeing is applied before the preemption freeing
# is applied.
if
self
.
use_async_output_proc
and
not
self
.
_can_append_slots
(
seq_group
):
tmp
=
self
.
running
self
.
running
=
orig_running
assert
self
.
output_proc_callback
is
not
None
self
.
output_proc_callback
()
self
.
running
=
tmp
while
not
self
.
_can_append_slots
(
seq_group
):
while
not
self
.
_can_append_slots
(
seq_group
):
budget
.
subtract_num_batched_tokens
(
seq_group
.
request_id
,
budget
.
subtract_num_batched_tokens
(
seq_group
.
request_id
,
num_running_tokens
)
num_running_tokens
)
...
@@ -588,24 +572,43 @@ class Scheduler:
...
@@ -588,24 +572,43 @@ class Scheduler:
and
seq_group
.
lora_int_id
in
curr_loras
):
and
seq_group
.
lora_int_id
in
curr_loras
):
curr_loras
.
remove
(
seq_group
.
lora_int_id
)
curr_loras
.
remove
(
seq_group
.
lora_int_id
)
# Determine victim sequence
cont_loop
=
True
if
running_queue
:
if
running_queue
:
# Preempt the lowest-priority sequence group
s
.
# Preempt the lowest-priority sequence group.
victim_seq_group
=
running_queue
.
pop
()
victim_seq_group
=
running_queue
.
pop
()
else
:
# No other sequence group can be preempted.
# Preempt the current sequence group.
# Note: This is also where we stop this loop
# (since there is nothing else to preempt)
victim_seq_group
=
seq_group
cont_loop
=
False
# With async postprocessor, before preempting a sequence
# we need to ensure it has no pending async postprocessor
do_preempt
=
True
if
self
.
use_async_output_proc
:
assert
self
.
output_proc_callback
is
not
None
self
.
output_proc_callback
(
request_id
=
victim_seq_group
.
request_id
)
# It may be that the async pending "victim_seq_group"
# becomes finished, in which case we simply free it.
if
victim_seq_group
.
is_finished
():
self
.
_free_finished_seq_group
(
victim_seq_group
)
do_preempt
=
False
# Do preemption
if
do_preempt
:
preempted_mode
=
self
.
_preempt
(
victim_seq_group
,
preempted_mode
=
self
.
_preempt
(
victim_seq_group
,
blocks_to_swap_out
)
blocks_to_swap_out
)
if
preempted_mode
==
PreemptionMode
.
RECOMPUTE
:
if
preempted_mode
==
PreemptionMode
.
RECOMPUTE
:
preempted
.
append
(
victim_seq_group
)
preempted
.
append
(
victim_seq_group
)
else
:
else
:
swapped_out
.
append
(
victim_seq_group
)
swapped_out
.
append
(
victim_seq_group
)
else
:
# No other sequence groups can be preempted.
if
not
cont_loop
:
# Preempt the current sequence group.
preempted_mode
=
self
.
_preempt
(
seq_group
,
blocks_to_swap_out
)
if
preempted_mode
==
PreemptionMode
.
RECOMPUTE
:
preempted
.
append
(
seq_group
)
else
:
swapped_out
.
append
(
seq_group
)
break
break
else
:
else
:
self
.
_append_slots
(
seq_group
,
blocks_to_copy
)
self
.
_append_slots
(
seq_group
,
blocks_to_copy
)
...
@@ -1264,22 +1267,26 @@ class Scheduler:
...
@@ -1264,22 +1267,26 @@ class Scheduler:
if
seq
.
is_finished
():
if
seq
.
is_finished
():
self
.
free_seq
(
seq
)
self
.
free_seq
(
seq
)
def
free_finished_seq_groups
(
self
)
->
None
:
def
_free_finished_seq_group
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
remaining
:
Deque
[
SequenceGroup
]
=
deque
()
for
seq_group
in
self
.
running
:
if
seq_group
.
is_finished
():
if
seq_group
.
is_finished
():
# Free cross-attention block table, if it exists
# Free cross-attention block table, if it exists
self
.
_free_seq_group_cross_attn_blocks
(
seq_group
)
self
.
_free_seq_group_cross_attn_blocks
(
seq_group
)
# Add the finished requests to the finished requests list.
# Add the finished requests to the finished requests list.
# This list will be used to update the Mamba cache in the
# This list will be used to update the Mamba cache in the
# next step.
# next step.
self
.
_finished_requests_ids
.
append
(
seq_group
.
request_id
)
self
.
_finished_requests_ids
.
append
(
seq_group
.
request_id
)
else
:
remaining
.
append
(
seq_group
)
# Free finished seqs
# Free finished seqs
self
.
_free_finished_seqs
(
seq_group
)
self
.
_free_finished_seqs
(
seq_group
)
def
free_finished_seq_groups
(
self
)
->
None
:
remaining
:
Deque
[
SequenceGroup
]
=
deque
()
for
seq_group
in
self
.
running
:
self
.
_free_finished_seq_group
(
seq_group
)
if
not
seq_group
.
is_finished
():
remaining
.
append
(
seq_group
)
self
.
running
=
remaining
self
.
running
=
remaining
# Handle async stopped sequence groups
# Handle async stopped sequence groups
...
...
vllm/engine/async_llm_engine.py
View file @
4ef41b84
...
@@ -342,17 +342,17 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -342,17 +342,17 @@ class _AsyncLLMEngine(LLMEngine):
virtual_engine
]
virtual_engine
]
# Execute the model.
# Execute the model.
output
=
await
self
.
model_executor
.
execute_model_async
(
output
s
=
await
self
.
model_executor
.
execute_model_async
(
execute_model_req
)
execute_model_req
)
# we need to do this here so that last step's sampled_token_ids can
# we need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
# be passed to the next iteration for PP.
if
self
.
scheduler_config
.
is_multi_step
:
if
self
.
scheduler_config
.
is_multi_step
:
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
)
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
s
)
else
:
else
:
if
len
(
ctx
.
output_queue
)
>
0
:
if
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
self
.
_process_model_outputs
(
ctx
=
ctx
)
output
=
[]
output
s
=
[]
# Finish the current step for all the sequence groups.
# Finish the current step for all the sequence groups.
if
self
.
scheduler_config
.
is_multi_step
:
if
self
.
scheduler_config
.
is_multi_step
:
...
@@ -365,25 +365,25 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -365,25 +365,25 @@ class _AsyncLLMEngine(LLMEngine):
self
.
cached_scheduler_outputs
[
self
.
cached_scheduler_outputs
[
virtual_engine
]
=
SchedulerOutputState
()
virtual_engine
]
=
SchedulerOutputState
()
is_async
=
allow_async_output_proc
ctx
.
append_output
(
outputs
=
outputs
,
is_last_step
=
True
seq_group_metadata_list
=
seq_group_metadata_list
,
ctx
.
output_queue
.
append
(
scheduler_outputs
=
scheduler_outputs
,
(
output
,
seq_group_metadata_list
,
scheduler_outputs
,
is_asyn
c
,
is_async
=
allow_async_output_pro
c
,
is_last_step
)
)
is_last_step
=
True
)
if
output
and
allow_async_output_proc
:
if
output
s
and
allow_async_output_proc
:
assert
len
(
assert
len
(
output
output
s
)
==
1
,
"Async postprocessor expects only a single output set"
)
==
1
,
"Async postprocessor expects only a single output set"
self
.
_advance_to_next_step
(
self
.
_advance_to_next_step
(
output
[
0
],
seq_group_metadata_list
,
output
s
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
)
if
not
allow_async_output_proc
:
if
not
allow_async_output_proc
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
self
.
_process_model_outputs
(
ctx
=
ctx
)
# Log stats.
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
output
)
self
.
do_log_stats
(
scheduler_outputs
,
output
s
)
# Tracing
# Tracing
self
.
do_tracing
(
scheduler_outputs
)
self
.
do_tracing
(
scheduler_outputs
)
...
...
vllm/engine/llm_engine.py
View file @
4ef41b84
...
@@ -2,9 +2,9 @@ import functools
...
@@ -2,9 +2,9 @@ import functools
import
time
import
time
from
collections
import
deque
from
collections
import
deque
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
List
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
)
Mapping
,
NamedTuple
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Tuple
,
Type
,
Union
from
typing
import
Set
,
Tuple
,
Type
,
Union
...
@@ -90,17 +90,36 @@ class SchedulerOutputState:
...
@@ -90,17 +90,36 @@ class SchedulerOutputState:
last_output
:
Optional
[
SamplerOutput
]
=
None
last_output
:
Optional
[
SamplerOutput
]
=
None
@
dataclass
class
OutputData
(
NamedTuple
):
outputs
:
List
[
SamplerOutput
]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
scheduler_outputs
:
SchedulerOutputs
is_async
:
bool
is_last_step
:
bool
skip
:
List
[
int
]
class
SchedulerContext
:
class
SchedulerContext
:
output_queue
:
Deque
[
Tuple
[
Optional
[
List
[
SamplerOutput
]],
List
[
SequenceGroupMetadata
],
SchedulerOutputs
,
def
__init__
(
self
):
bool
,
self
.
output_queue
:
Deque
[
OutputData
]
=
deque
()
bool
]]
=
field
(
default_factory
=
lambda
:
deque
())
self
.
request_outputs
:
List
[
Union
[
RequestOutput
,
request_outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
[]
EmbeddingRequestOutput
]]
=
field
(
self
.
seq_group_metadata_list
:
Optional
[
default_factory
=
lambda
:
[])
List
[
SequenceGroupMetadata
]]
=
None
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
self
.
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
def
append_output
(
self
,
outputs
:
List
[
SamplerOutput
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
scheduler_outputs
:
SchedulerOutputs
,
is_async
:
bool
,
is_last_step
:
bool
):
self
.
output_queue
.
append
(
OutputData
(
outputs
=
outputs
,
seq_group_metadata_list
=
seq_group_metadata_list
,
scheduler_outputs
=
scheduler_outputs
,
is_async
=
is_async
,
is_last_step
=
is_last_step
,
skip
=
[]))
class
LLMEngine
:
class
LLMEngine
:
...
@@ -1246,23 +1265,15 @@ class LLMEngine:
...
@@ -1246,23 +1265,15 @@ class LLMEngine:
return
return
def
_process_model_outputs
(
self
,
ctx
:
SchedulerContext
)
->
None
:
def
_process_model_outputs
(
self
,
"""Apply the model output to the sequences in the scheduled seq groups.
ctx
:
SchedulerContext
,
request_id
:
Optional
[
str
]
=
None
)
->
None
:
"""Apply the model output to the sequences in the scheduled seq groups
and return responses.
virtual_engine: The engine id to operate on
ctx: The virtual engine context to work on
request_id: If provided, then only this request is going to be processed
is_async: Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
sampler_output: Used with multi-step execution to provide
sampler_output of each step
is_last_output: Used with multi-step execution to indicate
the last step (of each multi-step group)
Returns RequestOutputs that can be returned to the client.
"""
"""
now
=
time
.
time
()
now
=
time
.
time
()
...
@@ -1270,9 +1281,14 @@ class LLMEngine:
...
@@ -1270,9 +1281,14 @@ class LLMEngine:
return
None
return
None
# Get pending async postprocessor
# Get pending async postprocessor
if
request_id
:
# When we process only one request, no pop is required
# (since later we will process all of the rest)
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
,
is_async
,
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
,
is_async
,
is_last_step
)
=
ctx
.
output_queue
.
popleft
()
is_last_step
,
skip
)
=
ctx
.
output_queue
[
0
]
assert
outputs
is
not
None
else
:
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
,
is_async
,
is_last_step
,
skip
)
=
ctx
.
output_queue
.
popleft
()
# Sanity check
# Sanity check
assert
len
(
seq_group_metadata_list
)
==
len
(
assert
len
(
seq_group_metadata_list
)
==
len
(
...
@@ -1286,9 +1302,30 @@ class LLMEngine:
...
@@ -1286,9 +1302,30 @@ class LLMEngine:
else
:
else
:
outputs_by_sequence_group
=
outputs
outputs_by_sequence_group
=
outputs
# Determine the requests we need to operate on
if
request_id
:
indices
=
[]
for
i
,
seq_group_meta
in
enumerate
(
seq_group_metadata_list
):
if
seq_group_meta
.
request_id
==
request_id
:
assert
i
not
in
skip
# Cannot be called twice
indices
.
append
(
i
)
break
# If the request_id was not found, then it means that
# this is a new request that has no pending async
# postprocessor
if
not
indices
:
return
else
:
indices
=
range
(
len
(
seq_group_metadata_list
))
# type: ignore
finished_before
:
List
[
int
]
=
[]
finished_before
:
List
[
int
]
=
[]
finished_now
:
List
[
int
]
=
[]
finished_now
:
List
[
int
]
=
[]
for
i
,
seq_group_meta
in
enumerate
(
seq_group_metadata_list
):
for
i
in
indices
:
if
i
in
skip
:
continue
seq_group_meta
=
seq_group_metadata_list
[
i
]
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
...
@@ -1343,6 +1380,18 @@ class LLMEngine:
...
@@ -1343,6 +1380,18 @@ class LLMEngine:
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
ctx
.
request_outputs
.
append
(
request_output
)
ctx
.
request_outputs
.
append
(
request_output
)
# When we process a single request, we skip it for the next time,
# and invoke the request output callback (if there was final output)
if
request_id
:
assert
len
(
indices
)
==
1
skip
.
append
(
indices
[
0
])
if
(
finished_now
and
self
.
process_request_outputs_callback
is
not
None
):
self
.
process_request_outputs_callback
(
ctx
.
request_outputs
)
ctx
.
request_outputs
.
clear
()
return
# Free currently finished requests
# Free currently finished requests
if
finished_now
:
if
finished_now
:
for
scheduler
in
self
.
scheduler
:
for
scheduler
in
self
.
scheduler
:
...
@@ -1354,17 +1403,16 @@ class LLMEngine:
...
@@ -1354,17 +1403,16 @@ class LLMEngine:
if
(
finished_now
if
(
finished_now
and
self
.
process_request_outputs_callback
is
not
None
):
and
self
.
process_request_outputs_callback
is
not
None
):
self
.
process_request_outputs_callback
(
ctx
.
request_outputs
)
self
.
process_request_outputs_callback
(
ctx
.
request_outputs
)
ctx
.
request_outputs
.
clear
()
return
return
# Create the outputs
# Create the outputs
# Note: scheduled_seq_groups and seq_group_metadata_list
for
i
in
indices
:
# must match with the indices
if
i
in
skip
or
i
in
finished_before
or
i
in
finished_now
:
for
i
,
scheduled_seq_group
in
enumerate
(
scheduler_outputs
.
scheduled_seq_groups
):
if
i
in
finished_before
or
i
in
finished_now
:
continue
# Avoids double processing
continue
# Avoids double processing
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
seq_group
.
maybe_set_first_token_time
(
now
)
if
(
seq_group
.
is_finished
()
if
(
seq_group
.
is_finished
()
...
@@ -1380,6 +1428,7 @@ class LLMEngine:
...
@@ -1380,6 +1428,7 @@ class LLMEngine:
if
(
ctx
.
request_outputs
if
(
ctx
.
request_outputs
and
self
.
process_request_outputs_callback
is
not
None
):
and
self
.
process_request_outputs_callback
is
not
None
):
self
.
process_request_outputs_callback
(
ctx
.
request_outputs
)
self
.
process_request_outputs_callback
(
ctx
.
request_outputs
)
ctx
.
request_outputs
.
clear
()
# For async case, we need to record the stats here.
# For async case, we need to record the stats here.
# For non-async case, the stats are done in the
# For non-async case, the stats are done in the
...
@@ -1548,20 +1597,20 @@ class LLMEngine:
...
@@ -1548,20 +1597,20 @@ class LLMEngine:
execute_model_req
.
async_callback
=
self
.
async_callbacks
[
execute_model_req
.
async_callback
=
self
.
async_callbacks
[
virtual_engine
]
virtual_engine
]
output
=
self
.
model_executor
.
execute_model
(
output
s
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
execute_model_req
=
execute_model_req
)
# We need to do this here so that last step's sampled_token_ids can
# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
# be passed to the next iteration for PP.
if
self
.
scheduler_config
.
is_multi_step
:
if
self
.
scheduler_config
.
is_multi_step
:
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
)
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
s
)
else
:
else
:
# Nothing scheduled => If there is pending async postprocessor,
# Nothing scheduled => If there is pending async postprocessor,
# then finish it here.
# then finish it here.
if
len
(
ctx
.
output_queue
)
>
0
:
if
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
self
.
_process_model_outputs
(
ctx
=
ctx
)
# No outputs in this case
# No outputs in this case
output
=
[]
output
s
=
[]
# Finish the current step for all the sequence groups.
# Finish the current step for all the sequence groups.
if
self
.
scheduler_config
.
is_multi_step
:
if
self
.
scheduler_config
.
is_multi_step
:
...
@@ -1574,18 +1623,18 @@ class LLMEngine:
...
@@ -1574,18 +1623,18 @@ class LLMEngine:
self
.
cached_scheduler_outputs
[
0
]
=
SchedulerOutputState
()
self
.
cached_scheduler_outputs
[
0
]
=
SchedulerOutputState
()
# Add results to the output_queue
# Add results to the output_queue
is_async
=
allow_async_output_proc
ctx
.
append_output
(
outputs
=
outputs
,
is_last_step
=
True
seq_group_metadata_list
=
seq_group_metadata_list
,
ctx
.
output_queue
.
append
(
scheduler_outputs
=
scheduler_outputs
,
(
output
,
seq_group_metadata_list
,
scheduler_outputs
,
is_asyn
c
,
is_async
=
allow_async_output_pro
c
,
is_last_step
)
)
is_last_step
=
True
)
if
output
and
allow_async_output_proc
:
if
output
s
and
allow_async_output_proc
:
assert
len
(
output
)
==
1
,
(
assert
len
(
output
s
)
==
1
,
(
"Async postprocessor expects only a single output set"
)
"Async postprocessor expects only a single output set"
)
self
.
_advance_to_next_step
(
self
.
_advance_to_next_step
(
output
[
0
],
seq_group_metadata_list
,
output
s
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
)
# Check if need to run the usual non-async path
# Check if need to run the usual non-async path
...
@@ -1593,7 +1642,7 @@ class LLMEngine:
...
@@ -1593,7 +1642,7 @@ class LLMEngine:
self
.
_process_model_outputs
(
ctx
=
ctx
)
self
.
_process_model_outputs
(
ctx
=
ctx
)
# Log stats.
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
output
)
self
.
do_log_stats
(
scheduler_outputs
,
output
s
)
# Tracing
# Tracing
self
.
do_tracing
(
scheduler_outputs
)
self
.
do_tracing
(
scheduler_outputs
)
...
...
vllm/worker/multi_step_model_runner.py
View file @
4ef41b84
...
@@ -274,12 +274,13 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -274,12 +274,13 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
self
.
pinned_sampled_token_ids
)
self
.
pinned_sampled_token_ids
)
if
model_output
.
pythonized
:
if
model_output
.
pythonized
:
ctx
=
output_proc_callback
.
keywords
[
"ctx"
]
ctx
=
output_proc_callback
.
keywords
[
"ctx"
]
is_async
=
False
ctx
.
append_output
(
is_last_step
=
False
outputs
=
[
model_output
.
sampler_output
],
ctx
.
output_queue
.
append
(
seq_group_metadata_list
=
ctx
.
seq_group_metadata_list
,
([
model_output
.
sampler_output
scheduler_outputs
=
ctx
.
scheduler_outputs
,
],
ctx
.
seq_group_metadata_list
,
is_async
=
False
,
ctx
.
scheduler_outputs
,
is_async
,
is_last_step
))
is_last_step
=
False
)
output_proc_callback
()
output_proc_callback
()
else
:
else
:
cont
=
False
cont
=
False
...
@@ -319,12 +320,13 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -319,12 +320,13 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
if
not
is_last_step
:
if
not
is_last_step
:
ctx
=
output_proc_callback
.
keywords
[
# type: ignore
ctx
=
output_proc_callback
.
keywords
[
# type: ignore
"ctx"
]
# type: ignore
"ctx"
]
# type: ignore
is_async
=
False
ctx
.
append_output
(
is_last_step
=
False
outputs
=
[
output
.
sampler_output
],
ctx
.
output_queue
.
append
(
seq_group_metadata_list
=
ctx
.
([
output
.
sampler_output
seq_group_metadata_list
,
],
ctx
.
seq_group_metadata_list
,
scheduler_outputs
=
ctx
.
scheduler_outputs
,
ctx
.
scheduler_outputs
,
is_async
,
is_last_step
))
is_async
=
False
,
is_last_step
=
False
)
else
:
else
:
outputs
.
append
(
output
.
sampler_output
)
outputs
.
append
(
output
.
sampler_output
)
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment