Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
ce741ba3
Unverified
Commit
ce741ba3
authored
Sep 03, 2023
by
Antoni Baum
Committed by
GitHub
Sep 03, 2023
Browse files
Refactor AsyncLLMEngine (#880)
parent
bf87484e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
267 additions
and
145 deletions
+267
-145
vllm/core/scheduler.py
vllm/core/scheduler.py
+9
-4
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+212
-109
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+46
-32
No files found.
vllm/core/scheduler.py
View file @
ce741ba3
import
enum
import
time
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
vllm.config
import
CacheConfig
,
SchedulerConfig
from
vllm.core.block_manager
import
BlockSpaceManager
...
...
@@ -87,17 +87,22 @@ class Scheduler:
# Add sequence groups to the waiting queue.
self
.
waiting
.
append
(
seq_group
)
def
abort_seq_group
(
self
,
request_id
:
str
)
->
None
:
def
abort_seq_group
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]])
->
None
:
if
isinstance
(
request_id
,
str
):
request_id
=
(
request_id
,
)
request_ids
=
set
(
request_id
)
for
state_queue
in
[
self
.
waiting
,
self
.
running
,
self
.
swapped
]:
for
seq_group
in
state_queue
:
if
seq_group
.
request_id
==
request_id
:
if
seq_group
.
request_id
in
request_id
s
:
# Remove the sequence group from the state queue.
state_queue
.
remove
(
seq_group
)
for
seq
in
seq_group
.
seqs
:
if
seq
.
is_finished
():
continue
self
.
free_seq
(
seq
,
SequenceStatus
.
FINISHED_ABORTED
)
return
request_ids
.
remove
(
seq_group
.
request_id
)
if
not
request_ids
:
return
def
has_unfinished_seqs
(
self
)
->
bool
:
return
self
.
waiting
or
self
.
running
or
self
.
swapped
...
...
vllm/engine/async_llm_engine.py
View file @
ce741ba3
import
asyncio
import
time
from
typing
import
Dict
,
List
,
Optional
from
functools
import
partial
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Type
,
Union
from
vllm.config
import
ModelConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
...
...
@@ -12,7 +13,105 @@ from vllm.sampling_params import SamplingParams
logger
=
init_logger
(
__name__
)
TIMEOUT_TO_PREVENT_DEADLOCK
=
1
# seconds
class
AsyncStream
:
"""A stream of RequestOutputs for a request that can be
iterated over asynchronously."""
def
__init__
(
self
,
request_id
:
str
)
->
None
:
self
.
request_id
=
request_id
self
.
_queue
=
asyncio
.
Queue
()
self
.
_finished
=
False
def
put
(
self
,
item
:
RequestOutput
)
->
None
:
if
self
.
_finished
:
return
self
.
_queue
.
put_nowait
(
item
)
def
finish
(
self
)
->
None
:
self
.
_queue
.
put_nowait
(
StopIteration
)
self
.
_finished
=
True
@
property
def
finished
(
self
)
->
bool
:
return
self
.
_finished
def
__aiter__
(
self
):
return
self
async
def
__anext__
(
self
)
->
RequestOutput
:
result
=
await
self
.
_queue
.
get
()
if
result
is
StopIteration
:
raise
StopAsyncIteration
return
result
def
_raise_exception_on_finish
(
task
:
asyncio
.
Task
)
->
None
:
try
:
task
.
result
()
except
Exception
as
e
:
raise
RuntimeError
(
"Task finished unexpectedly."
)
from
e
raise
RuntimeError
(
"Task finished unexpectedly."
)
class
_AsyncLLMEngine
(
LLMEngine
):
"""Extension of LLMEngine to add async methods."""
async
def
step_async
(
self
)
->
List
[
RequestOutput
]:
"""Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible.
This function performs one decoding iteration of the engine. It first
schedules the sequences to be executed in the next iteration and the
token blocks to be swapped in/out/copy. Then, it executes the model
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
(
seq_group_metadata_list
,
scheduler_outputs
,
early_return
)
=
self
.
_schedule
()
if
early_return
is
not
None
:
return
early_return
# Execute the model.
output
=
await
self
.
_run_workers_async
(
"execute_model"
,
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
)
return
self
.
_process_worker_outputs
(
output
,
scheduler_outputs
)
async
def
_run_workers_async
(
self
,
method
:
str
,
*
args
,
get_all_outputs
:
bool
=
False
,
**
kwargs
,
)
->
Any
:
"""Runs the given method on all workers."""
all_outputs
=
[]
for
worker
in
self
.
workers
:
if
self
.
parallel_config
.
worker_use_ray
:
executor
=
partial
(
worker
.
execute_method
.
remote
,
method
)
else
:
executor
=
getattr
(
worker
,
method
)
output
=
executor
(
*
args
,
**
kwargs
)
all_outputs
.
append
(
output
)
if
self
.
parallel_config
.
worker_use_ray
:
all_outputs
=
await
asyncio
.
gather
(
*
all_outputs
)
if
get_all_outputs
:
return
all_outputs
# Make sure all workers have the same results.
output
=
all_outputs
[
0
]
for
other_output
in
all_outputs
[
1
:]:
assert
output
==
other_output
return
output
class
AsyncLLMEngine
:
...
...
@@ -37,49 +136,111 @@ class AsyncLLMEngine:
*args, *kwargs: Arguments for LLMEngine.
"""
_engine_class
:
Type
[
_AsyncLLMEngine
]
=
_AsyncLLMEngine
def
__init__
(
self
,
worker_use_ray
:
bool
,
engine_use_ray
:
bool
,
*
args
,
log_requests
:
bool
=
True
,
start_engine_loop
:
bool
=
False
,
**
kwargs
)
->
None
:
self
.
worker_use_ray
=
worker_use_ray
self
.
engine_use_ray
=
engine_use_ray
self
.
log_requests
=
log_requests
self
.
engine
=
self
.
_init_engine
(
*
args
,
**
kwargs
)
# Request id -> stream.
self
.
request_streams
:
Dict
[
str
,
AsyncStream
]
=
{}
self
.
finished_requests
:
Set
[
str
]
=
set
()
self
.
background_loop
=
None
if
start_engine_loop
:
self
.
_start_background_loop
()
def
_start_background_loop
(
self
)
->
None
:
"""Start the background loop."""
if
self
.
background_loop
is
not
None
:
raise
RuntimeError
(
"Background loop is already running."
)
self
.
background_loop
=
asyncio
.
get_event_loop
().
create_task
(
self
.
run_engine_loop
())
self
.
background_loop
.
add_done_callback
(
_raise_exception_on_finish
)
def
_init_engine
(
self
,
*
args
,
**
kwargs
)
->
Union
[
_AsyncLLMEngine
,
"ray.ObjectRef"
]:
if
not
self
.
engine_use_ray
:
engine_class
=
LLMEngine
engine_class
=
self
.
_engine_class
elif
self
.
worker_use_ray
:
engine_class
=
ray
.
remote
(
num_cpus
=
0
)(
LLMEngine
).
remote
engine_class
=
ray
.
remote
(
num_cpus
=
0
)(
self
.
_engine_class
).
remote
else
:
engine_class
=
ray
.
remote
(
num_gpus
=
1
)(
LLMEngine
).
remote
self
.
engine
=
engine_class
(
*
args
,
**
kwargs
)
# Request id -> request output.
self
.
request_outputs
:
Dict
[
str
,
RequestOutput
]
=
{}
# Request id -> event to notify that there is new output.
self
.
request_events
:
Dict
[
str
,
asyncio
.
Event
]
=
{}
self
.
is_engine_running
=
False
self
.
kicking_request_id
:
Optional
[
str
]
=
None
async
def
engine_step
(
self
,
kicking_request_id
:
Optional
[
str
]
=
None
):
engine_class
=
ray
.
remote
(
num_gpus
=
1
)(
self
.
_engine_class
).
remote
return
engine_class
(
*
args
,
**
kwargs
)
async
def
engine_step
(
self
):
"""Kick the engine to process the waiting requests."""
self
.
is_engine_running
=
True
self
.
kicking_request_id
=
kicking_request_id
if
self
.
engine_use_ray
:
request_outputs
=
await
self
.
engine
.
step
.
remote
()
else
:
# Yield to the event loop to allow other coroutines to run
# while is_engine_running is True. This let the engine to add new
# requests into the queue.
await
asyncio
.
sleep
(
0
)
request_outputs
=
self
.
engine
.
step
()
self
.
is_engine_running
=
False
self
.
kicking_request_id
=
None
request_outputs
=
await
self
.
engine
.
step_async
()
#
Notify the waiting coroutines that there are new outputs
rea
dy
.
#
Put the outputs into the corresponding st
rea
ms
.
for
request_output
in
request_outputs
:
request_id
=
request_output
.
request_id
self
.
request_outputs
[
request_id
]
=
request_output
self
.
request_events
[
request_id
].
set
()
self
.
request_streams
[
request_id
].
put
(
request_output
)
if
request_output
.
finished
:
if
self
.
log_requests
:
logger
.
info
(
f
"Finished request
{
request_id
}
."
)
self
.
request_streams
[
request_id
].
finish
()
self
.
finished_requests
.
add
(
request_id
)
await
self
.
_engine_abort
(
self
.
finished_requests
)
for
request_id
in
self
.
finished_requests
:
del
self
.
request_streams
[
request_id
]
self
.
finished_requests
.
clear
()
async
def
_engine_abort
(
self
,
request_ids
:
Iterable
[
str
]):
if
self
.
engine_use_ray
:
await
self
.
engine
.
abort_request
.
remote
(
request_ids
)
else
:
self
.
engine
.
abort_request
(
request_ids
)
async
def
run_engine_loop
(
self
):
while
True
:
await
self
.
engine_step
()
await
asyncio
.
sleep
(
0
)
async
def
add_request
(
self
,
request_id
:
str
,
prompt
:
Optional
[
str
],
sampling_params
:
SamplingParams
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
)
->
AsyncStream
:
if
self
.
log_requests
:
logger
.
info
(
f
"Received request
{
request_id
}
: "
f
"prompt:
{
prompt
!
r
}
, "
f
"sampling params:
{
sampling_params
}
, "
f
"prompt token ids:
{
prompt_token_ids
}
."
)
stream
=
AsyncStream
(
request_id
)
self
.
request_streams
[
request_id
]
=
stream
# Add the request into the vLLM engine's waiting queue.
if
self
.
engine_use_ray
:
await
self
.
engine
.
add_request
.
remote
(
request_id
,
prompt
,
sampling_params
,
prompt_token_ids
=
prompt_token_ids
,
arrival_time
=
arrival_time
)
else
:
self
.
engine
.
add_request
(
request_id
,
prompt
,
sampling_params
,
prompt_token_ids
=
prompt_token_ids
,
arrival_time
=
arrival_time
)
return
stream
async
def
generate
(
self
,
...
...
@@ -108,78 +269,32 @@ class AsyncLLMEngine:
# Preprocess the request.
arrival_time
=
time
.
time
()
# Create an event to notify us that there is new output from the
# vLLM engine.
request_event
=
asyncio
.
Event
()
self
.
request_events
[
request_id
]
=
request_event
try
:
stream
=
await
self
.
add_request
(
request_id
,
prompt
,
sampling_params
,
prompt_token_ids
=
prompt_token_ids
,
arrival_time
=
arrival_time
)
if
self
.
log_requests
:
logger
.
info
(
f
"Received request
{
request_id
}
: "
f
"prompt:
{
prompt
!
r
}
, "
f
"sampling params:
{
sampling_params
}
, "
f
"prompt token ids:
{
prompt_token_ids
}
."
)
async
for
request_output
in
stream
:
yield
request_output
except
Exception
as
e
:
# If there is an exception, abort the request.
self
.
_abort
(
request_id
)
raise
e
# Add the request into the vLLM engine's waiting queue.
if
self
.
engine_use_ray
:
await
self
.
engine
.
add_request
.
remote
(
request_id
,
prompt
,
sampling_params
,
prompt_token_ids
=
prompt_token_ids
,
arrival_time
=
arrival_time
)
else
:
self
.
engine
.
add_request
(
request_id
,
prompt
,
sampling_params
,
prompt_token_ids
=
prompt_token_ids
,
arrival_time
=
arrival_time
)
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
"""Abort a request.
# The vLLM engine does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking
# the engine to process the requests.
while
True
:
if
request_id
not
in
self
.
request_events
:
# The request has been aborted.
return
# Kick the engine if the engine is not running.
if
not
self
.
is_engine_running
:
try
:
await
self
.
engine_step
(
request_id
)
except
RuntimeError
as
e
:
await
self
.
abort
(
request_id
)
raise
e
# Wait for new output. The group_event will be set in engine_step
# when there is new output available for the sequence group.
# Added a timeout to prevent deadlock.
try
:
await
asyncio
.
wait_for
(
request_event
.
wait
(),
timeout
=
TIMEOUT_TO_PREVENT_DEADLOCK
)
except
asyncio
.
TimeoutError
:
continue
# Reset the event to wait for the next output.
request_event
.
clear
()
# Decode and return new outputs.
request_output
=
self
.
request_outputs
[
request_id
]
yield
request_output
# Once finished, release the resources of the sequence group.
if
request_output
.
finished
:
if
self
.
log_requests
:
logger
.
info
(
f
"Finished request
{
request_id
}
."
)
Abort a submitted request. If the request is finished or not found,
this method will be a no-op.
del
self
.
request_outputs
[
request_id
]
del
self
.
request_events
[
request_id
]
# Kick the engine if the engine is not running. This is to
# prevent that there are still requests in engine's waiting
# queue to be executed.
if
not
self
.
is_engine_running
:
await
self
.
engine_step
()
break
Args:
request_id: The unique id of the request.
"""
return
self
.
_abort
(
request_id
)
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
def
_
abort
(
self
,
request_id
:
str
)
->
None
:
"""Abort a request.
Abort a submitted request. If the request is finished or not found,
...
...
@@ -188,28 +303,16 @@ class AsyncLLMEngine:
Args:
request_id: The unique id of the request.
"""
if
request_id
not
in
self
.
request_events
:
if
request_id
not
in
self
.
request_streams
or
self
.
request_streams
[
request_id
].
finished
:
# The request has already finished or been aborted.
return
if
self
.
log_requests
:
logger
.
info
(
f
"Aborted request
{
request_id
}
."
)
if
self
.
engine_use_ray
:
await
self
.
engine
.
abort_request
.
remote
(
request_id
)
else
:
self
.
engine
.
abort_request
(
request_id
)
if
request_id
in
self
.
request_events
:
del
self
.
request_events
[
request_id
]
if
request_id
in
self
.
request_outputs
:
del
self
.
request_outputs
[
request_id
]
# To prevent deadlock when a request is aborted while the engine is
# running.
if
self
.
kicking_request_id
==
request_id
:
self
.
is_engine_running
=
False
self
.
kicking_request_id
=
None
self
.
request_streams
[
request_id
].
finish
()
self
.
finished_requests
.
add
(
request_id
)
async
def
get_model_config
(
self
)
->
ModelConfig
:
"""Get the model configuration of the vLLM engine."""
...
...
vllm/engine/llm_engine.py
View file @
ce741ba3
import
time
import
copy
import
time
from
functools
import
partial
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
,
Any
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
from
vllm.core.scheduler
import
Scheduler
from
vllm.core.scheduler
import
Scheduler
,
SchedulerOutputs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.ray_utils
import
initialize_cluster
,
ray
,
RayWorker
from
vllm.engine.ray_utils
import
RayWorker
,
initialize_cluster
,
ray
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
from
vllm.sequence
import
(
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceStatus
)
from
vllm.transformers_utils.tokenizer
import
(
detokenize_incrementally
,
get_tokenizer
)
from
vllm.utils
import
Counter
...
...
@@ -135,7 +136,8 @@ class LLMEngine:
get_all_outputs
=
True
,
)
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
):
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
,
**
ray_remote_kwargs
):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from
vllm.worker.worker
import
Worker
# pylint: disable=import-outside-toplevel
...
...
@@ -150,6 +152,7 @@ class LLMEngine:
scheduling_strategy
=
PlacementGroupSchedulingStrategy
(
placement_group
=
placement_group
,
placement_group_capture_child_tasks
=
True
),
**
ray_remote_kwargs
,
)(
RayWorker
).
remote
()
self
.
workers
.
append
(
worker
)
...
...
@@ -268,11 +271,11 @@ class LLMEngine:
# Add the sequence group to the scheduler.
self
.
scheduler
.
add_seq_group
(
seq_group
)
def
abort_request
(
self
,
request_id
:
str
)
->
None
:
"""Aborts a request with the given ID.
def
abort_request
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]]
)
->
None
:
"""Aborts a request
(s)
with the given ID.
Args:
request_id: The ID of the request to abort.
request_id: The ID
(s)
of the request to abort.
"""
self
.
scheduler
.
abort_seq_group
(
request_id
)
...
...
@@ -288,35 +291,21 @@ class LLMEngine:
"""Returns True if there are unfinished requests."""
return
self
.
scheduler
.
has_unfinished_seqs
()
def
step
(
self
)
->
List
[
RequestOutput
]:
"""Performs one decoding iteration and returns newly generated results.
This function performs one decoding iteration of the engine. It first
schedules the sequences to be executed in the next iteration and the
token blocks to be swapped in/out/copy. Then, it executes the model
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
def
_schedule
(
self
)
->
Tuple
[
List
[
SequenceGroupMetadata
],
SchedulerOutputs
,
Optional
[
List
[
RequestOutput
]]]:
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
.
schedule
()
if
scheduler_outputs
.
is_empty
():
if
not
scheduler_outputs
.
ignored_seq_groups
:
# Nothing to do.
return
[]
# If there are ignored seq groups, we need to return them as the
# request outputs.
return
[
return
seq_group_metadata_list
,
scheduler_outputs
,
[
RequestOutput
.
from_seq_group
(
seq_group
)
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
]
return
seq_group_metadata_list
,
scheduler_outputs
,
None
# Execute the model.
output
=
self
.
_run_workers
(
"execute_model"
,
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
)
def
_process_worker_outputs
(
self
,
output
,
scheduler_outputs
:
SchedulerOutputs
)
->
List
[
RequestOutput
]:
# Update the scheduler with the model outputs.
seq_groups
=
self
.
scheduler
.
update
(
output
)
...
...
@@ -339,6 +328,31 @@ class LLMEngine:
scheduler_outputs
.
num_batched_tokens
)
return
request_outputs
def
step
(
self
)
->
List
[
RequestOutput
]:
"""Performs one decoding iteration and returns newly generated results.
This function performs one decoding iteration of the engine. It first
schedules the sequences to be executed in the next iteration and the
token blocks to be swapped in/out/copy. Then, it executes the model
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
(
seq_group_metadata_list
,
scheduler_outputs
,
early_return
)
=
self
.
_schedule
()
if
early_return
is
not
None
:
return
early_return
# Execute the model.
output
=
self
.
_run_workers
(
"execute_model"
,
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
)
return
self
.
_process_worker_outputs
(
output
,
scheduler_outputs
)
def
_log_system_stats
(
self
,
prompt_run
:
bool
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment