Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
ff36139f
Unverified
Commit
ff36139f
authored
Sep 17, 2023
by
Antoni Baum
Committed by
GitHub
Sep 17, 2023
Browse files
Remove AsyncLLMEngine busy loop, shield background task (#1059)
parent
e3e79e9e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
154 additions
and
18 deletions
+154
-18
requirements-dev.txt
requirements-dev.txt
+1
-0
tests/async_engine/test_async_llm_engine.py
tests/async_engine/test_async_llm_engine.py
+80
-0
tests/async_engine/test_request_tracker.py
tests/async_engine/test_request_tracker.py
+21
-0
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+52
-18
No files found.
requirements-dev.txt
View file @
ff36139f
...
@@ -11,3 +11,4 @@ types-setuptools
...
@@ -11,3 +11,4 @@ types-setuptools
# testing
# testing
pytest
pytest
pytest-forked
pytest-forked
pytest-asyncio
tests/async_engine/test_async_llm_engine.py
0 → 100644
View file @
ff36139f
import
asyncio
from
dataclasses
import
dataclass
import
pytest
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
@
dataclass
class
RequestOutput
:
request_id
:
int
finished
:
bool
=
False
class
MockEngine
:
def
__init__
(
self
):
self
.
step_calls
=
0
self
.
add_request_calls
=
0
self
.
abort_request_calls
=
0
self
.
request_id
=
None
async
def
step_async
(
self
):
self
.
step_calls
+=
1
return
[
RequestOutput
(
request_id
=
self
.
request_id
)]
if
self
.
request_id
else
[]
def
generate
(
self
,
request_id
):
self
.
request_id
=
request_id
def
stop_generating
(
self
):
self
.
request_id
=
None
def
add_request
(
self
,
**
kwargs
):
self
.
add_request_calls
+=
1
return
def
abort_request
(
self
,
request_id
):
self
.
abort_request_calls
+=
1
return
class
MockAsyncLLMEngine
(
AsyncLLMEngine
):
def
_init_engine
(
self
,
*
args
,
**
kwargs
):
return
MockEngine
()
@
pytest
.
mark
.
asyncio
async
def
test_new_requests_event
():
engine
=
MockAsyncLLMEngine
(
worker_use_ray
=
False
,
engine_use_ray
=
False
)
engine
.
start_background_loop
()
await
asyncio
.
sleep
(
0.01
)
assert
engine
.
engine
.
step_calls
==
0
await
engine
.
add_request
(
"1"
,
""
,
None
)
await
asyncio
.
sleep
(
0.01
)
assert
engine
.
engine
.
add_request_calls
==
1
assert
engine
.
engine
.
step_calls
==
1
await
engine
.
add_request
(
"2"
,
""
,
None
)
engine
.
engine
.
generate
(
"2"
)
await
asyncio
.
sleep
(
0
)
assert
engine
.
engine
.
add_request_calls
==
2
assert
engine
.
engine
.
step_calls
==
2
await
asyncio
.
sleep
(
0
)
assert
engine
.
engine
.
step_calls
==
3
engine
.
engine
.
stop_generating
()
await
asyncio
.
sleep
(
0
)
assert
engine
.
engine
.
step_calls
==
4
await
asyncio
.
sleep
(
0
)
assert
engine
.
engine
.
step_calls
==
4
await
engine
.
add_request
(
"3"
,
""
,
None
)
await
asyncio
.
sleep
(
0.01
)
assert
engine
.
engine
.
add_request_calls
==
3
assert
engine
.
engine
.
step_calls
==
5
await
asyncio
.
sleep
(
0.01
)
assert
engine
.
engine
.
add_request_calls
==
3
assert
engine
.
engine
.
step_calls
==
5
tests/async_engine/test_request_tracker.py
View file @
ff36139f
...
@@ -4,10 +4,25 @@ from vllm.engine.async_llm_engine import RequestTracker
...
@@ -4,10 +4,25 @@ from vllm.engine.async_llm_engine import RequestTracker
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
class
DummyEvent
:
def
__init__
(
self
):
self
.
_flag
=
False
def
set
(
self
):
self
.
_flag
=
True
def
clear
(
self
):
self
.
_flag
=
False
def
test_request_tracker
():
def
test_request_tracker
():
tracker
=
RequestTracker
()
tracker
=
RequestTracker
()
tracker
.
new_requests_event
=
DummyEvent
()
stream_1
=
tracker
.
add_request
(
"1"
)
stream_1
=
tracker
.
add_request
(
"1"
)
assert
tracker
.
new_requests_event
.
_flag
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
assert
not
tracker
.
new_requests_event
.
_flag
assert
len
(
new
)
==
1
assert
len
(
new
)
==
1
assert
new
[
0
][
"request_id"
]
==
"1"
assert
new
[
0
][
"request_id"
]
==
"1"
assert
not
finished
assert
not
finished
...
@@ -15,7 +30,9 @@ def test_request_tracker():
...
@@ -15,7 +30,9 @@ def test_request_tracker():
stream_2
=
tracker
.
add_request
(
"2"
)
stream_2
=
tracker
.
add_request
(
"2"
)
stream_3
=
tracker
.
add_request
(
"3"
)
stream_3
=
tracker
.
add_request
(
"3"
)
assert
tracker
.
new_requests_event
.
_flag
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
assert
not
tracker
.
new_requests_event
.
_flag
assert
len
(
new
)
==
2
assert
len
(
new
)
==
2
assert
new
[
0
][
"request_id"
]
==
"2"
assert
new
[
0
][
"request_id"
]
==
"2"
assert
new
[
1
][
"request_id"
]
==
"3"
assert
new
[
1
][
"request_id"
]
==
"3"
...
@@ -26,6 +43,7 @@ def test_request_tracker():
...
@@ -26,6 +43,7 @@ def test_request_tracker():
# request_ids must be unique
# request_ids must be unique
with
pytest
.
raises
(
KeyError
):
with
pytest
.
raises
(
KeyError
):
tracker
.
add_request
(
"1"
)
tracker
.
add_request
(
"1"
)
assert
not
tracker
.
new_requests_event
.
_flag
tracker
.
abort_request
(
"1"
)
tracker
.
abort_request
(
"1"
)
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
...
@@ -36,6 +54,7 @@ def test_request_tracker():
...
@@ -36,6 +54,7 @@ def test_request_tracker():
stream_4
=
tracker
.
add_request
(
"4"
)
stream_4
=
tracker
.
add_request
(
"4"
)
tracker
.
abort_request
(
"4"
)
tracker
.
abort_request
(
"4"
)
assert
tracker
.
new_requests_event
.
_flag
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
assert
len
(
finished
)
==
1
assert
len
(
finished
)
==
1
assert
"4"
in
finished
assert
"4"
in
finished
...
@@ -43,9 +62,11 @@ def test_request_tracker():
...
@@ -43,9 +62,11 @@ def test_request_tracker():
assert
stream_4
.
finished
assert
stream_4
.
finished
stream_5
=
tracker
.
add_request
(
"5"
)
stream_5
=
tracker
.
add_request
(
"5"
)
assert
tracker
.
new_requests_event
.
_flag
tracker
.
process_request_output
(
tracker
.
process_request_output
(
RequestOutput
(
"2"
,
"output"
,
[],
[],
finished
=
True
))
RequestOutput
(
"2"
,
"output"
,
[],
[],
finished
=
True
))
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
assert
not
tracker
.
new_requests_event
.
_flag
assert
len
(
finished
)
==
1
assert
len
(
finished
)
==
1
assert
"2"
in
finished
assert
"2"
in
finished
assert
len
(
new
)
==
1
assert
len
(
new
)
==
1
...
...
vllm/engine/async_llm_engine.py
View file @
ff36139f
import
asyncio
import
asyncio
import
time
import
time
from
functools
import
partial
from
functools
import
partial
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
...
@@ -78,14 +79,24 @@ class RequestTracker:
...
@@ -78,14 +79,24 @@ class RequestTracker:
self
.
_finished_requests
:
asyncio
.
Queue
[
str
]
=
asyncio
.
Queue
()
self
.
_finished_requests
:
asyncio
.
Queue
[
str
]
=
asyncio
.
Queue
()
self
.
_new_requests
:
asyncio
.
Queue
[
Tuple
[
AsyncStream
,
self
.
_new_requests
:
asyncio
.
Queue
[
Tuple
[
AsyncStream
,
dict
]]
=
asyncio
.
Queue
()
dict
]]
=
asyncio
.
Queue
()
self
.
new_requests_event
=
None
def
__contains__
(
self
,
item
):
def
__contains__
(
self
,
item
):
return
item
in
self
.
_request_streams
return
item
in
self
.
_request_streams
def
propagate_exception
(
self
,
exc
:
Exception
)
->
None
:
def
init_event
(
self
):
"""Propagate an exception to all request streams."""
self
.
new_requests_event
=
asyncio
.
Event
()
for
stream
in
self
.
_request_streams
.
values
():
stream
.
put
(
exc
)
def
propagate_exception
(
self
,
exc
:
Exception
,
request_id
:
Optional
[
str
]
=
None
)
->
None
:
"""Propagate an exception to request streams
(all if request_id is None)."""
if
request_id
is
not
None
:
self
.
_request_streams
[
request_id
].
put
(
exc
)
else
:
for
stream
in
self
.
_request_streams
.
values
():
stream
.
put
(
exc
)
def
process_request_output
(
self
,
def
process_request_output
(
self
,
request_output
:
RequestOutput
,
request_output
:
RequestOutput
,
...
@@ -112,6 +123,9 @@ class RequestTracker:
...
@@ -112,6 +123,9 @@ class RequestTracker:
"request_id"
:
request_id
,
"request_id"
:
request_id
,
**
engine_add_request_kwargs
**
engine_add_request_kwargs
}))
}))
self
.
new_requests_event
.
set
()
return
stream
return
stream
def
abort_request
(
self
,
request_id
:
str
,
*
,
verbose
:
bool
=
False
)
->
None
:
def
abort_request
(
self
,
request_id
:
str
,
*
,
verbose
:
bool
=
False
)
->
None
:
...
@@ -148,8 +162,13 @@ class RequestTracker:
...
@@ -148,8 +162,13 @@ class RequestTracker:
self
.
_request_streams
[
stream
.
request_id
]
=
stream
self
.
_request_streams
[
stream
.
request_id
]
=
stream
new_requests
.
append
(
new_request
)
new_requests
.
append
(
new_request
)
self
.
new_requests_event
.
clear
()
return
new_requests
,
finished_requests
return
new_requests
,
finished_requests
async
def
wait_for_new_requests
(
self
):
await
self
.
new_requests_event
.
wait
()
class
_AsyncLLMEngine
(
LLMEngine
):
class
_AsyncLLMEngine
(
LLMEngine
):
"""Extension of LLMEngine to add async methods."""
"""Extension of LLMEngine to add async methods."""
...
@@ -251,9 +270,13 @@ class AsyncLLMEngine:
...
@@ -251,9 +270,13 @@ class AsyncLLMEngine:
self
.
max_log_len
=
max_log_len
self
.
max_log_len
=
max_log_len
self
.
engine
=
self
.
_init_engine
(
*
args
,
**
kwargs
)
self
.
engine
=
self
.
_init_engine
(
*
args
,
**
kwargs
)
self
.
request_tracker
:
RequestTracker
=
RequestTracker
()
self
.
background_loop
=
None
self
.
background_loop
=
None
# We need to keep a reference to unshielded
# task as well to prevent it from being garbage
# collected
self
.
_background_loop_unshielded
=
None
self
.
start_engine_loop
=
start_engine_loop
self
.
start_engine_loop
=
start_engine_loop
self
.
_request_tracker
=
RequestTracker
()
@
property
@
property
def
is_running
(
self
)
->
bool
:
def
is_running
(
self
)
->
bool
:
...
@@ -264,11 +287,14 @@ class AsyncLLMEngine:
...
@@ -264,11 +287,14 @@ class AsyncLLMEngine:
"""Start the background loop."""
"""Start the background loop."""
if
self
.
is_running
:
if
self
.
is_running
:
raise
RuntimeError
(
"Background loop is already running."
)
raise
RuntimeError
(
"Background loop is already running."
)
self
.
background_loop
=
asyncio
.
get_event_loop
().
create_task
(
self
.
_request_tracker
.
init_event
()
self
.
run_engine_loop
())
self
.
background_loop
.
add_done_callback
(
self
.
_background_loop_unshielded
=
asyncio
.
get_event_loop
(
).
create_task
(
self
.
run_engine_loop
())
self
.
_background_loop_unshielded
.
add_done_callback
(
partial
(
_raise_exception_on_finish
,
partial
(
_raise_exception_on_finish
,
request_tracker
=
self
.
request_tracker
))
request_tracker
=
self
.
_request_tracker
))
self
.
background_loop
=
asyncio
.
shield
(
self
.
_background_loop_unshielded
)
def
_init_engine
(
self
,
*
args
,
def
_init_engine
(
self
,
*
args
,
**
kwargs
)
->
Union
[
_AsyncLLMEngine
,
"ray.ObjectRef"
]:
**
kwargs
)
->
Union
[
_AsyncLLMEngine
,
"ray.ObjectRef"
]:
...
@@ -280,11 +306,13 @@ class AsyncLLMEngine:
...
@@ -280,11 +306,13 @@ class AsyncLLMEngine:
engine_class
=
ray
.
remote
(
num_gpus
=
1
)(
self
.
_engine_class
).
remote
engine_class
=
ray
.
remote
(
num_gpus
=
1
)(
self
.
_engine_class
).
remote
return
engine_class
(
*
args
,
**
kwargs
)
return
engine_class
(
*
args
,
**
kwargs
)
async
def
engine_step
(
self
):
async
def
engine_step
(
self
)
->
bool
:
"""Kick the engine to process the waiting requests."""
"""Kick the engine to process the waiting requests.
Returns True if there are in-progress requests."""
new_requests
,
finished_requests
=
(
new_requests
,
finished_requests
=
(
self
.
request_tracker
.
get_new_and_finished_requests
())
self
.
_
request_tracker
.
get_new_and_finished_requests
())
for
new_request
in
new_requests
:
for
new_request
in
new_requests
:
# Add the request into the vLLM engine's waiting queue.
# Add the request into the vLLM engine's waiting queue.
...
@@ -304,9 +332,11 @@ class AsyncLLMEngine:
...
@@ -304,9 +332,11 @@ class AsyncLLMEngine:
# Put the outputs into the corresponding streams.
# Put the outputs into the corresponding streams.
for
request_output
in
request_outputs
:
for
request_output
in
request_outputs
:
self
.
request_tracker
.
process_request_output
(
self
.
_
request_tracker
.
process_request_output
(
request_output
,
verbose
=
self
.
log_requests
)
request_output
,
verbose
=
self
.
log_requests
)
return
len
(
request_outputs
)
>
0
async
def
_engine_abort
(
self
,
request_ids
:
Iterable
[
str
]):
async
def
_engine_abort
(
self
,
request_ids
:
Iterable
[
str
]):
if
self
.
engine_use_ray
:
if
self
.
engine_use_ray
:
await
self
.
engine
.
abort_request
.
remote
(
request_ids
)
await
self
.
engine
.
abort_request
.
remote
(
request_ids
)
...
@@ -314,8 +344,12 @@ class AsyncLLMEngine:
...
@@ -314,8 +344,12 @@ class AsyncLLMEngine:
self
.
engine
.
abort_request
(
request_ids
)
self
.
engine
.
abort_request
(
request_ids
)
async
def
run_engine_loop
(
self
):
async
def
run_engine_loop
(
self
):
# Initialize the RequestTracker here so it uses the right event loop.
has_requests_in_progress
=
False
while
True
:
while
True
:
await
self
.
engine_step
()
if
not
has_requests_in_progress
:
await
self
.
_request_tracker
.
wait_for_new_requests
()
has_requests_in_progress
=
await
self
.
engine_step
()
await
asyncio
.
sleep
(
0
)
await
asyncio
.
sleep
(
0
)
async
def
add_request
(
async
def
add_request
(
...
@@ -350,7 +384,7 @@ class AsyncLLMEngine:
...
@@ -350,7 +384,7 @@ class AsyncLLMEngine:
"error that caused the background loop to stop "
"error that caused the background loop to stop "
"(AsyncEngineDeadError)."
)
"(AsyncEngineDeadError)."
)
stream
=
self
.
request_tracker
.
add_request
(
stream
=
self
.
_
request_tracker
.
add_request
(
request_id
,
request_id
,
prompt
=
prompt
,
prompt
=
prompt
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
...
@@ -428,8 +462,8 @@ class AsyncLLMEngine:
...
@@ -428,8 +462,8 @@ class AsyncLLMEngine:
Args:
Args:
request_id: The unique id of the request.
request_id: The unique id of the request.
"""
"""
self
.
request_tracker
.
abort_request
(
request_id
,
self
.
_
request_tracker
.
abort_request
(
request_id
,
verbose
=
self
.
log_requests
)
verbose
=
self
.
log_requests
)
async
def
get_model_config
(
self
)
->
ModelConfig
:
async
def
get_model_config
(
self
)
->
ModelConfig
:
"""Get the model configuration of the vLLM engine."""
"""Get the model configuration of the vLLM engine."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment