Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
ce741ba3
Unverified
Commit
ce741ba3
authored
Sep 03, 2023
by
Antoni Baum
Committed by
GitHub
Sep 03, 2023
Browse files
Refactor AsyncLLMEngine (#880)
parent
bf87484e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
267 additions
and
145 deletions
+267
-145
vllm/core/scheduler.py
vllm/core/scheduler.py
+9
-4
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+212
-109
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+46
-32
No files found.
vllm/core/scheduler.py
View file @
ce741ba3
import
enum
import
enum
import
time
import
time
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
vllm.config
import
CacheConfig
,
SchedulerConfig
from
vllm.config
import
CacheConfig
,
SchedulerConfig
from
vllm.core.block_manager
import
BlockSpaceManager
from
vllm.core.block_manager
import
BlockSpaceManager
...
@@ -87,17 +87,22 @@ class Scheduler:
...
@@ -87,17 +87,22 @@ class Scheduler:
# Add sequence groups to the waiting queue.
# Add sequence groups to the waiting queue.
self
.
waiting
.
append
(
seq_group
)
self
.
waiting
.
append
(
seq_group
)
def
abort_seq_group
(
self
,
request_id
:
str
)
->
None
:
def
abort_seq_group
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]])
->
None
:
if
isinstance
(
request_id
,
str
):
request_id
=
(
request_id
,
)
request_ids
=
set
(
request_id
)
for
state_queue
in
[
self
.
waiting
,
self
.
running
,
self
.
swapped
]:
for
state_queue
in
[
self
.
waiting
,
self
.
running
,
self
.
swapped
]:
for
seq_group
in
state_queue
:
for
seq_group
in
state_queue
:
if
seq_group
.
request_id
==
request_id
:
if
seq_group
.
request_id
in
request_id
s
:
# Remove the sequence group from the state queue.
# Remove the sequence group from the state queue.
state_queue
.
remove
(
seq_group
)
state_queue
.
remove
(
seq_group
)
for
seq
in
seq_group
.
seqs
:
for
seq
in
seq_group
.
seqs
:
if
seq
.
is_finished
():
if
seq
.
is_finished
():
continue
continue
self
.
free_seq
(
seq
,
SequenceStatus
.
FINISHED_ABORTED
)
self
.
free_seq
(
seq
,
SequenceStatus
.
FINISHED_ABORTED
)
return
request_ids
.
remove
(
seq_group
.
request_id
)
if
not
request_ids
:
return
def
has_unfinished_seqs
(
self
)
->
bool
:
def
has_unfinished_seqs
(
self
)
->
bool
:
return
self
.
waiting
or
self
.
running
or
self
.
swapped
return
self
.
waiting
or
self
.
running
or
self
.
swapped
...
...
vllm/engine/async_llm_engine.py
View file @
ce741ba3
import
asyncio
import
asyncio
import
time
import
time
from
typing
import
Dict
,
List
,
Optional
from
functools
import
partial
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Type
,
Union
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
...
@@ -12,7 +13,105 @@ from vllm.sampling_params import SamplingParams
...
@@ -12,7 +13,105 @@ from vllm.sampling_params import SamplingParams
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
TIMEOUT_TO_PREVENT_DEADLOCK
=
1
# seconds
class
AsyncStream
:
"""A stream of RequestOutputs for a request that can be
iterated over asynchronously."""
def
__init__
(
self
,
request_id
:
str
)
->
None
:
self
.
request_id
=
request_id
self
.
_queue
=
asyncio
.
Queue
()
self
.
_finished
=
False
def
put
(
self
,
item
:
RequestOutput
)
->
None
:
if
self
.
_finished
:
return
self
.
_queue
.
put_nowait
(
item
)
def
finish
(
self
)
->
None
:
self
.
_queue
.
put_nowait
(
StopIteration
)
self
.
_finished
=
True
@
property
def
finished
(
self
)
->
bool
:
return
self
.
_finished
def
__aiter__
(
self
):
return
self
async
def
__anext__
(
self
)
->
RequestOutput
:
result
=
await
self
.
_queue
.
get
()
if
result
is
StopIteration
:
raise
StopAsyncIteration
return
result
def
_raise_exception_on_finish
(
task
:
asyncio
.
Task
)
->
None
:
try
:
task
.
result
()
except
Exception
as
e
:
raise
RuntimeError
(
"Task finished unexpectedly."
)
from
e
raise
RuntimeError
(
"Task finished unexpectedly."
)
class
_AsyncLLMEngine
(
LLMEngine
):
"""Extension of LLMEngine to add async methods."""
async
def
step_async
(
self
)
->
List
[
RequestOutput
]:
"""Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible.
This function performs one decoding iteration of the engine. It first
schedules the sequences to be executed in the next iteration and the
token blocks to be swapped in/out/copy. Then, it executes the model
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
(
seq_group_metadata_list
,
scheduler_outputs
,
early_return
)
=
self
.
_schedule
()
if
early_return
is
not
None
:
return
early_return
# Execute the model.
output
=
await
self
.
_run_workers_async
(
"execute_model"
,
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
)
return
self
.
_process_worker_outputs
(
output
,
scheduler_outputs
)
async
def
_run_workers_async
(
self
,
method
:
str
,
*
args
,
get_all_outputs
:
bool
=
False
,
**
kwargs
,
)
->
Any
:
"""Runs the given method on all workers."""
all_outputs
=
[]
for
worker
in
self
.
workers
:
if
self
.
parallel_config
.
worker_use_ray
:
executor
=
partial
(
worker
.
execute_method
.
remote
,
method
)
else
:
executor
=
getattr
(
worker
,
method
)
output
=
executor
(
*
args
,
**
kwargs
)
all_outputs
.
append
(
output
)
if
self
.
parallel_config
.
worker_use_ray
:
all_outputs
=
await
asyncio
.
gather
(
*
all_outputs
)
if
get_all_outputs
:
return
all_outputs
# Make sure all workers have the same results.
output
=
all_outputs
[
0
]
for
other_output
in
all_outputs
[
1
:]:
assert
output
==
other_output
return
output
class
AsyncLLMEngine
:
class
AsyncLLMEngine
:
...
@@ -37,49 +136,111 @@ class AsyncLLMEngine:
...
@@ -37,49 +136,111 @@ class AsyncLLMEngine:
*args, *kwargs: Arguments for LLMEngine.
*args, *kwargs: Arguments for LLMEngine.
"""
"""
_engine_class
:
Type
[
_AsyncLLMEngine
]
=
_AsyncLLMEngine
def
__init__
(
self
,
def
__init__
(
self
,
worker_use_ray
:
bool
,
worker_use_ray
:
bool
,
engine_use_ray
:
bool
,
engine_use_ray
:
bool
,
*
args
,
*
args
,
log_requests
:
bool
=
True
,
log_requests
:
bool
=
True
,
start_engine_loop
:
bool
=
False
,
**
kwargs
)
->
None
:
**
kwargs
)
->
None
:
self
.
worker_use_ray
=
worker_use_ray
self
.
worker_use_ray
=
worker_use_ray
self
.
engine_use_ray
=
engine_use_ray
self
.
engine_use_ray
=
engine_use_ray
self
.
log_requests
=
log_requests
self
.
log_requests
=
log_requests
self
.
engine
=
self
.
_init_engine
(
*
args
,
**
kwargs
)
# Request id -> stream.
self
.
request_streams
:
Dict
[
str
,
AsyncStream
]
=
{}
self
.
finished_requests
:
Set
[
str
]
=
set
()
self
.
background_loop
=
None
if
start_engine_loop
:
self
.
_start_background_loop
()
def
_start_background_loop
(
self
)
->
None
:
"""Start the background loop."""
if
self
.
background_loop
is
not
None
:
raise
RuntimeError
(
"Background loop is already running."
)
self
.
background_loop
=
asyncio
.
get_event_loop
().
create_task
(
self
.
run_engine_loop
())
self
.
background_loop
.
add_done_callback
(
_raise_exception_on_finish
)
def
_init_engine
(
self
,
*
args
,
**
kwargs
)
->
Union
[
_AsyncLLMEngine
,
"ray.ObjectRef"
]:
if
not
self
.
engine_use_ray
:
if
not
self
.
engine_use_ray
:
engine_class
=
LLMEngine
engine_class
=
self
.
_engine_class
elif
self
.
worker_use_ray
:
elif
self
.
worker_use_ray
:
engine_class
=
ray
.
remote
(
num_cpus
=
0
)(
LLMEngine
).
remote
engine_class
=
ray
.
remote
(
num_cpus
=
0
)(
self
.
_engine_class
).
remote
else
:
else
:
engine_class
=
ray
.
remote
(
num_gpus
=
1
)(
LLMEngine
).
remote
engine_class
=
ray
.
remote
(
num_gpus
=
1
)(
self
.
_engine_class
).
remote
self
.
engine
=
engine_class
(
*
args
,
**
kwargs
)
return
engine_class
(
*
args
,
**
kwargs
)
# Request id -> request output.
self
.
request_outputs
:
Dict
[
str
,
RequestOutput
]
=
{}
async
def
engine_step
(
self
):
# Request id -> event to notify that there is new output.
self
.
request_events
:
Dict
[
str
,
asyncio
.
Event
]
=
{}
self
.
is_engine_running
=
False
self
.
kicking_request_id
:
Optional
[
str
]
=
None
async
def
engine_step
(
self
,
kicking_request_id
:
Optional
[
str
]
=
None
):
"""Kick the engine to process the waiting requests."""
"""Kick the engine to process the waiting requests."""
self
.
is_engine_running
=
True
self
.
kicking_request_id
=
kicking_request_id
if
self
.
engine_use_ray
:
if
self
.
engine_use_ray
:
request_outputs
=
await
self
.
engine
.
step
.
remote
()
request_outputs
=
await
self
.
engine
.
step
.
remote
()
else
:
else
:
# Yield to the event loop to allow other coroutines to run
request_outputs
=
await
self
.
engine
.
step_async
()
# while is_engine_running is True. This let the engine to add new
# requests into the queue.
await
asyncio
.
sleep
(
0
)
request_outputs
=
self
.
engine
.
step
()
self
.
is_engine_running
=
False
self
.
kicking_request_id
=
None
#
Notify the waiting coroutines that there are new outputs
rea
dy
.
#
Put the outputs into the corresponding st
rea
ms
.
for
request_output
in
request_outputs
:
for
request_output
in
request_outputs
:
request_id
=
request_output
.
request_id
request_id
=
request_output
.
request_id
self
.
request_outputs
[
request_id
]
=
request_output
self
.
request_streams
[
request_id
].
put
(
request_output
)
self
.
request_events
[
request_id
].
set
()
if
request_output
.
finished
:
if
self
.
log_requests
:
logger
.
info
(
f
"Finished request
{
request_id
}
."
)
self
.
request_streams
[
request_id
].
finish
()
self
.
finished_requests
.
add
(
request_id
)
await
self
.
_engine_abort
(
self
.
finished_requests
)
for
request_id
in
self
.
finished_requests
:
del
self
.
request_streams
[
request_id
]
self
.
finished_requests
.
clear
()
async
def
_engine_abort
(
self
,
request_ids
:
Iterable
[
str
]):
if
self
.
engine_use_ray
:
await
self
.
engine
.
abort_request
.
remote
(
request_ids
)
else
:
self
.
engine
.
abort_request
(
request_ids
)
async
def
run_engine_loop
(
self
):
while
True
:
await
self
.
engine_step
()
await
asyncio
.
sleep
(
0
)
async
def
add_request
(
self
,
request_id
:
str
,
prompt
:
Optional
[
str
],
sampling_params
:
SamplingParams
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
)
->
AsyncStream
:
if
self
.
log_requests
:
logger
.
info
(
f
"Received request
{
request_id
}
: "
f
"prompt:
{
prompt
!
r
}
, "
f
"sampling params:
{
sampling_params
}
, "
f
"prompt token ids:
{
prompt_token_ids
}
."
)
stream
=
AsyncStream
(
request_id
)
self
.
request_streams
[
request_id
]
=
stream
# Add the request into the vLLM engine's waiting queue.
if
self
.
engine_use_ray
:
await
self
.
engine
.
add_request
.
remote
(
request_id
,
prompt
,
sampling_params
,
prompt_token_ids
=
prompt_token_ids
,
arrival_time
=
arrival_time
)
else
:
self
.
engine
.
add_request
(
request_id
,
prompt
,
sampling_params
,
prompt_token_ids
=
prompt_token_ids
,
arrival_time
=
arrival_time
)
return
stream
async
def
generate
(
async
def
generate
(
self
,
self
,
...
@@ -108,78 +269,32 @@ class AsyncLLMEngine:
...
@@ -108,78 +269,32 @@ class AsyncLLMEngine:
# Preprocess the request.
# Preprocess the request.
arrival_time
=
time
.
time
()
arrival_time
=
time
.
time
()
# Create an event to notify us that there is new output from the
try
:
# vLLM engine.
stream
=
await
self
.
add_request
(
request_id
,
request_event
=
asyncio
.
Event
()
prompt
,
self
.
request_events
[
request_id
]
=
request_event
sampling_params
,
prompt_token_ids
=
prompt_token_ids
,
arrival_time
=
arrival_time
)
if
self
.
log_requests
:
async
for
request_output
in
stream
:
logger
.
info
(
f
"Received request
{
request_id
}
: "
yield
request_output
f
"prompt:
{
prompt
!
r
}
, "
except
Exception
as
e
:
f
"sampling params:
{
sampling_params
}
, "
# If there is an exception, abort the request.
f
"prompt token ids:
{
prompt_token_ids
}
."
)
self
.
_abort
(
request_id
)
raise
e
# Add the request into the vLLM engine's waiting queue.
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
if
self
.
engine_use_ray
:
"""Abort a request.
await
self
.
engine
.
add_request
.
remote
(
request_id
,
prompt
,
sampling_params
,
prompt_token_ids
=
prompt_token_ids
,
arrival_time
=
arrival_time
)
else
:
self
.
engine
.
add_request
(
request_id
,
prompt
,
sampling_params
,
prompt_token_ids
=
prompt_token_ids
,
arrival_time
=
arrival_time
)
# The vLLM engine does not have a background loop that keeps
Abort a submitted request. If the request is finished or not found,
# processing incoming requests. Therefore, we need to keep kicking
this method will be a no-op.
# the engine to process the requests.
while
True
:
if
request_id
not
in
self
.
request_events
:
# The request has been aborted.
return
# Kick the engine if the engine is not running.
if
not
self
.
is_engine_running
:
try
:
await
self
.
engine_step
(
request_id
)
except
RuntimeError
as
e
:
await
self
.
abort
(
request_id
)
raise
e
# Wait for new output. The group_event will be set in engine_step
# when there is new output available for the sequence group.
# Added a timeout to prevent deadlock.
try
:
await
asyncio
.
wait_for
(
request_event
.
wait
(),
timeout
=
TIMEOUT_TO_PREVENT_DEADLOCK
)
except
asyncio
.
TimeoutError
:
continue
# Reset the event to wait for the next output.
request_event
.
clear
()
# Decode and return new outputs.
request_output
=
self
.
request_outputs
[
request_id
]
yield
request_output
# Once finished, release the resources of the sequence group.
if
request_output
.
finished
:
if
self
.
log_requests
:
logger
.
info
(
f
"Finished request
{
request_id
}
."
)
del
self
.
request_outputs
[
request_id
]
Args:
del
self
.
request_events
[
request_id
]
request_id: The unique id of the request.
# Kick the engine if the engine is not running. This is to
"""
# prevent that there are still requests in engine's waiting
return
self
.
_abort
(
request_id
)
# queue to be executed.
if
not
self
.
is_engine_running
:
await
self
.
engine_step
()
break
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
def
_
abort
(
self
,
request_id
:
str
)
->
None
:
"""Abort a request.
"""Abort a request.
Abort a submitted request. If the request is finished or not found,
Abort a submitted request. If the request is finished or not found,
...
@@ -188,28 +303,16 @@ class AsyncLLMEngine:
...
@@ -188,28 +303,16 @@ class AsyncLLMEngine:
Args:
Args:
request_id: The unique id of the request.
request_id: The unique id of the request.
"""
"""
if
request_id
not
in
self
.
request_events
:
if
request_id
not
in
self
.
request_streams
or
self
.
request_streams
[
request_id
].
finished
:
# The request has already finished or been aborted.
# The request has already finished or been aborted.
return
return
if
self
.
log_requests
:
if
self
.
log_requests
:
logger
.
info
(
f
"Aborted request
{
request_id
}
."
)
logger
.
info
(
f
"Aborted request
{
request_id
}
."
)
if
self
.
engine_use_ray
:
self
.
request_streams
[
request_id
].
finish
()
await
self
.
engine
.
abort_request
.
remote
(
request_id
)
self
.
finished_requests
.
add
(
request_id
)
else
:
self
.
engine
.
abort_request
(
request_id
)
if
request_id
in
self
.
request_events
:
del
self
.
request_events
[
request_id
]
if
request_id
in
self
.
request_outputs
:
del
self
.
request_outputs
[
request_id
]
# To prevent deadlock when a request is aborted while the engine is
# running.
if
self
.
kicking_request_id
==
request_id
:
self
.
is_engine_running
=
False
self
.
kicking_request_id
=
None
async
def
get_model_config
(
self
)
->
ModelConfig
:
async
def
get_model_config
(
self
)
->
ModelConfig
:
"""Get the model configuration of the vLLM engine."""
"""Get the model configuration of the vLLM engine."""
...
...
vllm/engine/llm_engine.py
View file @
ce741ba3
import
time
import
copy
import
copy
import
time
from
functools
import
partial
from
functools
import
partial
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
,
Any
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
SchedulerConfig
)
from
vllm.core.scheduler
import
Scheduler
from
vllm.core.scheduler
import
Scheduler
,
SchedulerOutputs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.ray_utils
import
initialize_cluster
,
ray
,
RayWorker
from
vllm.engine.ray_utils
import
RayWorker
,
initialize_cluster
,
ray
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
from
vllm.sequence
import
(
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceStatus
)
from
vllm.transformers_utils.tokenizer
import
(
detokenize_incrementally
,
from
vllm.transformers_utils.tokenizer
import
(
detokenize_incrementally
,
get_tokenizer
)
get_tokenizer
)
from
vllm.utils
import
Counter
from
vllm.utils
import
Counter
...
@@ -135,7 +136,8 @@ class LLMEngine:
...
@@ -135,7 +136,8 @@ class LLMEngine:
get_all_outputs
=
True
,
get_all_outputs
=
True
,
)
)
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
):
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
,
**
ray_remote_kwargs
):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
# before CUDA_VISIBLE_DEVICES is set in the Worker
from
vllm.worker.worker
import
Worker
# pylint: disable=import-outside-toplevel
from
vllm.worker.worker
import
Worker
# pylint: disable=import-outside-toplevel
...
@@ -150,6 +152,7 @@ class LLMEngine:
...
@@ -150,6 +152,7 @@ class LLMEngine:
scheduling_strategy
=
PlacementGroupSchedulingStrategy
(
scheduling_strategy
=
PlacementGroupSchedulingStrategy
(
placement_group
=
placement_group
,
placement_group
=
placement_group
,
placement_group_capture_child_tasks
=
True
),
placement_group_capture_child_tasks
=
True
),
**
ray_remote_kwargs
,
)(
RayWorker
).
remote
()
)(
RayWorker
).
remote
()
self
.
workers
.
append
(
worker
)
self
.
workers
.
append
(
worker
)
...
@@ -268,11 +271,11 @@ class LLMEngine:
...
@@ -268,11 +271,11 @@ class LLMEngine:
# Add the sequence group to the scheduler.
# Add the sequence group to the scheduler.
self
.
scheduler
.
add_seq_group
(
seq_group
)
self
.
scheduler
.
add_seq_group
(
seq_group
)
def
abort_request
(
self
,
request_id
:
str
)
->
None
:
def
abort_request
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]]
)
->
None
:
"""Aborts a request with the given ID.
"""Aborts a request
(s)
with the given ID.
Args:
Args:
request_id: The ID of the request to abort.
request_id: The ID
(s)
of the request to abort.
"""
"""
self
.
scheduler
.
abort_seq_group
(
request_id
)
self
.
scheduler
.
abort_seq_group
(
request_id
)
...
@@ -288,35 +291,21 @@ class LLMEngine:
...
@@ -288,35 +291,21 @@ class LLMEngine:
"""Returns True if there are unfinished requests."""
"""Returns True if there are unfinished requests."""
return
self
.
scheduler
.
has_unfinished_seqs
()
return
self
.
scheduler
.
has_unfinished_seqs
()
def
step
(
self
)
->
List
[
RequestOutput
]:
def
_schedule
(
"""Performs one decoding iteration and returns newly generated results.
self
)
->
Tuple
[
List
[
SequenceGroupMetadata
],
SchedulerOutputs
,
This function performs one decoding iteration of the engine. It first
Optional
[
List
[
RequestOutput
]]]:
schedules the sequences to be executed in the next iteration and the
token blocks to be swapped in/out/copy. Then, it executes the model
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
.
schedule
()
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
.
schedule
()
if
scheduler_outputs
.
is_empty
():
if
scheduler_outputs
.
is_empty
():
if
not
scheduler_outputs
.
ignored_seq_groups
:
return
seq_group_metadata_list
,
scheduler_outputs
,
[
# Nothing to do.
return
[]
# If there are ignored seq groups, we need to return them as the
# request outputs.
return
[
RequestOutput
.
from_seq_group
(
seq_group
)
RequestOutput
.
from_seq_group
(
seq_group
)
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
]
]
return
seq_group_metadata_list
,
scheduler_outputs
,
None
# Execute the model.
def
_process_worker_outputs
(
output
=
self
.
_run_workers
(
self
,
output
,
"execute_model"
,
scheduler_outputs
:
SchedulerOutputs
)
->
List
[
RequestOutput
]:
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
)
# Update the scheduler with the model outputs.
# Update the scheduler with the model outputs.
seq_groups
=
self
.
scheduler
.
update
(
output
)
seq_groups
=
self
.
scheduler
.
update
(
output
)
...
@@ -339,6 +328,31 @@ class LLMEngine:
...
@@ -339,6 +328,31 @@ class LLMEngine:
scheduler_outputs
.
num_batched_tokens
)
scheduler_outputs
.
num_batched_tokens
)
return
request_outputs
return
request_outputs
def
step
(
self
)
->
List
[
RequestOutput
]:
"""Performs one decoding iteration and returns newly generated results.
This function performs one decoding iteration of the engine. It first
schedules the sequences to be executed in the next iteration and the
token blocks to be swapped in/out/copy. Then, it executes the model
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
(
seq_group_metadata_list
,
scheduler_outputs
,
early_return
)
=
self
.
_schedule
()
if
early_return
is
not
None
:
return
early_return
# Execute the model.
output
=
self
.
_run_workers
(
"execute_model"
,
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
)
return
self
.
_process_worker_outputs
(
output
,
scheduler_outputs
)
def
_log_system_stats
(
def
_log_system_stats
(
self
,
self
,
prompt_run
:
bool
,
prompt_run
:
bool
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment