Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
ff36139f
"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "79df50388df09d9615e3c067695a453bb0a694c0"
Unverified
Commit
ff36139f
authored
Sep 17, 2023
by
Antoni Baum
Committed by
GitHub
Sep 17, 2023
Browse files
Remove AsyncLLMEngine busy loop, shield background task (#1059)
parent
e3e79e9e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
154 additions
and
18 deletions
+154
-18
requirements-dev.txt
requirements-dev.txt
+1
-0
tests/async_engine/test_async_llm_engine.py
tests/async_engine/test_async_llm_engine.py
+80
-0
tests/async_engine/test_request_tracker.py
tests/async_engine/test_request_tracker.py
+21
-0
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+52
-18
No files found.
requirements-dev.txt
View file @
ff36139f
...
@@ -11,3 +11,4 @@ types-setuptools
...
@@ -11,3 +11,4 @@ types-setuptools
# testing
# testing
pytest
pytest
pytest-forked
pytest-forked
pytest-asyncio
tests/async_engine/test_async_llm_engine.py
0 → 100644
View file @
ff36139f
import
asyncio
from
dataclasses
import
dataclass
import
pytest
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
@
dataclass
class
RequestOutput
:
request_id
:
int
finished
:
bool
=
False
class
MockEngine
:
def
__init__
(
self
):
self
.
step_calls
=
0
self
.
add_request_calls
=
0
self
.
abort_request_calls
=
0
self
.
request_id
=
None
async
def
step_async
(
self
):
self
.
step_calls
+=
1
return
[
RequestOutput
(
request_id
=
self
.
request_id
)]
if
self
.
request_id
else
[]
def
generate
(
self
,
request_id
):
self
.
request_id
=
request_id
def
stop_generating
(
self
):
self
.
request_id
=
None
def
add_request
(
self
,
**
kwargs
):
self
.
add_request_calls
+=
1
return
def
abort_request
(
self
,
request_id
):
self
.
abort_request_calls
+=
1
return
class
MockAsyncLLMEngine
(
AsyncLLMEngine
):
def
_init_engine
(
self
,
*
args
,
**
kwargs
):
return
MockEngine
()
@
pytest
.
mark
.
asyncio
async
def
test_new_requests_event
():
engine
=
MockAsyncLLMEngine
(
worker_use_ray
=
False
,
engine_use_ray
=
False
)
engine
.
start_background_loop
()
await
asyncio
.
sleep
(
0.01
)
assert
engine
.
engine
.
step_calls
==
0
await
engine
.
add_request
(
"1"
,
""
,
None
)
await
asyncio
.
sleep
(
0.01
)
assert
engine
.
engine
.
add_request_calls
==
1
assert
engine
.
engine
.
step_calls
==
1
await
engine
.
add_request
(
"2"
,
""
,
None
)
engine
.
engine
.
generate
(
"2"
)
await
asyncio
.
sleep
(
0
)
assert
engine
.
engine
.
add_request_calls
==
2
assert
engine
.
engine
.
step_calls
==
2
await
asyncio
.
sleep
(
0
)
assert
engine
.
engine
.
step_calls
==
3
engine
.
engine
.
stop_generating
()
await
asyncio
.
sleep
(
0
)
assert
engine
.
engine
.
step_calls
==
4
await
asyncio
.
sleep
(
0
)
assert
engine
.
engine
.
step_calls
==
4
await
engine
.
add_request
(
"3"
,
""
,
None
)
await
asyncio
.
sleep
(
0.01
)
assert
engine
.
engine
.
add_request_calls
==
3
assert
engine
.
engine
.
step_calls
==
5
await
asyncio
.
sleep
(
0.01
)
assert
engine
.
engine
.
add_request_calls
==
3
assert
engine
.
engine
.
step_calls
==
5
tests/async_engine/test_request_tracker.py
View file @
ff36139f
...
@@ -4,10 +4,25 @@ from vllm.engine.async_llm_engine import RequestTracker
...
@@ -4,10 +4,25 @@ from vllm.engine.async_llm_engine import RequestTracker
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
class
DummyEvent
:
def
__init__
(
self
):
self
.
_flag
=
False
def
set
(
self
):
self
.
_flag
=
True
def
clear
(
self
):
self
.
_flag
=
False
def
test_request_tracker
():
def
test_request_tracker
():
tracker
=
RequestTracker
()
tracker
=
RequestTracker
()
tracker
.
new_requests_event
=
DummyEvent
()
stream_1
=
tracker
.
add_request
(
"1"
)
stream_1
=
tracker
.
add_request
(
"1"
)
assert
tracker
.
new_requests_event
.
_flag
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
assert
not
tracker
.
new_requests_event
.
_flag
assert
len
(
new
)
==
1
assert
len
(
new
)
==
1
assert
new
[
0
][
"request_id"
]
==
"1"
assert
new
[
0
][
"request_id"
]
==
"1"
assert
not
finished
assert
not
finished
...
@@ -15,7 +30,9 @@ def test_request_tracker():
...
@@ -15,7 +30,9 @@ def test_request_tracker():
stream_2
=
tracker
.
add_request
(
"2"
)
stream_2
=
tracker
.
add_request
(
"2"
)
stream_3
=
tracker
.
add_request
(
"3"
)
stream_3
=
tracker
.
add_request
(
"3"
)
assert
tracker
.
new_requests_event
.
_flag
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
assert
not
tracker
.
new_requests_event
.
_flag
assert
len
(
new
)
==
2
assert
len
(
new
)
==
2
assert
new
[
0
][
"request_id"
]
==
"2"
assert
new
[
0
][
"request_id"
]
==
"2"
assert
new
[
1
][
"request_id"
]
==
"3"
assert
new
[
1
][
"request_id"
]
==
"3"
...
@@ -26,6 +43,7 @@ def test_request_tracker():
...
@@ -26,6 +43,7 @@ def test_request_tracker():
# request_ids must be unique
# request_ids must be unique
with
pytest
.
raises
(
KeyError
):
with
pytest
.
raises
(
KeyError
):
tracker
.
add_request
(
"1"
)
tracker
.
add_request
(
"1"
)
assert
not
tracker
.
new_requests_event
.
_flag
tracker
.
abort_request
(
"1"
)
tracker
.
abort_request
(
"1"
)
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
...
@@ -36,6 +54,7 @@ def test_request_tracker():
...
@@ -36,6 +54,7 @@ def test_request_tracker():
stream_4
=
tracker
.
add_request
(
"4"
)
stream_4
=
tracker
.
add_request
(
"4"
)
tracker
.
abort_request
(
"4"
)
tracker
.
abort_request
(
"4"
)
assert
tracker
.
new_requests_event
.
_flag
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
assert
len
(
finished
)
==
1
assert
len
(
finished
)
==
1
assert
"4"
in
finished
assert
"4"
in
finished
...
@@ -43,9 +62,11 @@ def test_request_tracker():
...
@@ -43,9 +62,11 @@ def test_request_tracker():
assert
stream_4
.
finished
assert
stream_4
.
finished
stream_5
=
tracker
.
add_request
(
"5"
)
stream_5
=
tracker
.
add_request
(
"5"
)
assert
tracker
.
new_requests_event
.
_flag
tracker
.
process_request_output
(
tracker
.
process_request_output
(
RequestOutput
(
"2"
,
"output"
,
[],
[],
finished
=
True
))
RequestOutput
(
"2"
,
"output"
,
[],
[],
finished
=
True
))
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
assert
not
tracker
.
new_requests_event
.
_flag
assert
len
(
finished
)
==
1
assert
len
(
finished
)
==
1
assert
"2"
in
finished
assert
"2"
in
finished
assert
len
(
new
)
==
1
assert
len
(
new
)
==
1
...
...
vllm/engine/async_llm_engine.py
View file @
ff36139f
import
asyncio
import
asyncio
import
time
import
time
from
functools
import
partial
from
functools
import
partial
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
...
@@ -78,14 +79,24 @@ class RequestTracker:
...
@@ -78,14 +79,24 @@ class RequestTracker:
self
.
_finished_requests
:
asyncio
.
Queue
[
str
]
=
asyncio
.
Queue
()
self
.
_finished_requests
:
asyncio
.
Queue
[
str
]
=
asyncio
.
Queue
()
self
.
_new_requests
:
asyncio
.
Queue
[
Tuple
[
AsyncStream
,
self
.
_new_requests
:
asyncio
.
Queue
[
Tuple
[
AsyncStream
,
dict
]]
=
asyncio
.
Queue
()
dict
]]
=
asyncio
.
Queue
()
self
.
new_requests_event
=
None
def
__contains__
(
self
,
item
):
def
__contains__
(
self
,
item
):
return
item
in
self
.
_request_streams
return
item
in
self
.
_request_streams
def
propagate_exception
(
self
,
exc
:
Exception
)
->
None
:
def
init_event
(
self
):
"""Propagate an exception to all request streams."""
self
.
new_requests_event
=
asyncio
.
Event
()
for
stream
in
self
.
_request_streams
.
values
():
stream
.
put
(
exc
)
def
propagate_exception
(
self
,
exc
:
Exception
,
request_id
:
Optional
[
str
]
=
None
)
->
None
:
"""Propagate an exception to request streams
(all if request_id is None)."""
if
request_id
is
not
None
:
self
.
_request_streams
[
request_id
].
put
(
exc
)
else
:
for
stream
in
self
.
_request_streams
.
values
():
stream
.
put
(
exc
)
def
process_request_output
(
self
,
def
process_request_output
(
self
,
request_output
:
RequestOutput
,
request_output
:
RequestOutput
,
...
@@ -112,6 +123,9 @@ class RequestTracker:
...
@@ -112,6 +123,9 @@ class RequestTracker:
"request_id"
:
request_id
,
"request_id"
:
request_id
,
**
engine_add_request_kwargs
**
engine_add_request_kwargs
}))
}))
self
.
new_requests_event
.
set
()
return
stream
return
stream
def
abort_request
(
self
,
request_id
:
str
,
*
,
verbose
:
bool
=
False
)
->
None
:
def
abort_request
(
self
,
request_id
:
str
,
*
,
verbose
:
bool
=
False
)
->
None
:
...
@@ -148,8 +162,13 @@ class RequestTracker:
...
@@ -148,8 +162,13 @@ class RequestTracker:
self
.
_request_streams
[
stream
.
request_id
]
=
stream
self
.
_request_streams
[
stream
.
request_id
]
=
stream
new_requests
.
append
(
new_request
)
new_requests
.
append
(
new_request
)
self
.
new_requests_event
.
clear
()
return
new_requests
,
finished_requests
return
new_requests
,
finished_requests
async
def
wait_for_new_requests
(
self
):
await
self
.
new_requests_event
.
wait
()
class
_AsyncLLMEngine
(
LLMEngine
):
class
_AsyncLLMEngine
(
LLMEngine
):
"""Extension of LLMEngine to add async methods."""
"""Extension of LLMEngine to add async methods."""
...
@@ -251,9 +270,13 @@ class AsyncLLMEngine:
...
@@ -251,9 +270,13 @@ class AsyncLLMEngine:
self
.
max_log_len
=
max_log_len
self
.
max_log_len
=
max_log_len
self
.
engine
=
self
.
_init_engine
(
*
args
,
**
kwargs
)
self
.
engine
=
self
.
_init_engine
(
*
args
,
**
kwargs
)
self
.
request_tracker
:
RequestTracker
=
RequestTracker
()
self
.
background_loop
=
None
self
.
background_loop
=
None
# We need to keep a reference to unshielded
# task as well to prevent it from being garbage
# collected
self
.
_background_loop_unshielded
=
None
self
.
start_engine_loop
=
start_engine_loop
self
.
start_engine_loop
=
start_engine_loop
self
.
_request_tracker
=
RequestTracker
()
@
property
@
property
def
is_running
(
self
)
->
bool
:
def
is_running
(
self
)
->
bool
:
...
@@ -264,11 +287,14 @@ class AsyncLLMEngine:
...
@@ -264,11 +287,14 @@ class AsyncLLMEngine:
"""Start the background loop."""
"""Start the background loop."""
if
self
.
is_running
:
if
self
.
is_running
:
raise
RuntimeError
(
"Background loop is already running."
)
raise
RuntimeError
(
"Background loop is already running."
)
self
.
background_loop
=
asyncio
.
get_event_loop
().
create_task
(
self
.
_request_tracker
.
init_event
()
self
.
run_engine_loop
())
self
.
background_loop
.
add_done_callback
(
self
.
_background_loop_unshielded
=
asyncio
.
get_event_loop
(
).
create_task
(
self
.
run_engine_loop
())
self
.
_background_loop_unshielded
.
add_done_callback
(
partial
(
_raise_exception_on_finish
,
partial
(
_raise_exception_on_finish
,
request_tracker
=
self
.
request_tracker
))
request_tracker
=
self
.
_request_tracker
))
self
.
background_loop
=
asyncio
.
shield
(
self
.
_background_loop_unshielded
)
def
_init_engine
(
self
,
*
args
,
def
_init_engine
(
self
,
*
args
,
**
kwargs
)
->
Union
[
_AsyncLLMEngine
,
"ray.ObjectRef"
]:
**
kwargs
)
->
Union
[
_AsyncLLMEngine
,
"ray.ObjectRef"
]:
...
@@ -280,11 +306,13 @@ class AsyncLLMEngine:
...
@@ -280,11 +306,13 @@ class AsyncLLMEngine:
engine_class
=
ray
.
remote
(
num_gpus
=
1
)(
self
.
_engine_class
).
remote
engine_class
=
ray
.
remote
(
num_gpus
=
1
)(
self
.
_engine_class
).
remote
return
engine_class
(
*
args
,
**
kwargs
)
return
engine_class
(
*
args
,
**
kwargs
)
async
def
engine_step
(
self
):
async
def
engine_step
(
self
)
->
bool
:
"""Kick the engine to process the waiting requests."""
"""Kick the engine to process the waiting requests.
Returns True if there are in-progress requests."""
new_requests
,
finished_requests
=
(
new_requests
,
finished_requests
=
(
self
.
request_tracker
.
get_new_and_finished_requests
())
self
.
_
request_tracker
.
get_new_and_finished_requests
())
for
new_request
in
new_requests
:
for
new_request
in
new_requests
:
# Add the request into the vLLM engine's waiting queue.
# Add the request into the vLLM engine's waiting queue.
...
@@ -304,9 +332,11 @@ class AsyncLLMEngine:
...
@@ -304,9 +332,11 @@ class AsyncLLMEngine:
# Put the outputs into the corresponding streams.
# Put the outputs into the corresponding streams.
for
request_output
in
request_outputs
:
for
request_output
in
request_outputs
:
self
.
request_tracker
.
process_request_output
(
self
.
_
request_tracker
.
process_request_output
(
request_output
,
verbose
=
self
.
log_requests
)
request_output
,
verbose
=
self
.
log_requests
)
return
len
(
request_outputs
)
>
0
async
def
_engine_abort
(
self
,
request_ids
:
Iterable
[
str
]):
async
def
_engine_abort
(
self
,
request_ids
:
Iterable
[
str
]):
if
self
.
engine_use_ray
:
if
self
.
engine_use_ray
:
await
self
.
engine
.
abort_request
.
remote
(
request_ids
)
await
self
.
engine
.
abort_request
.
remote
(
request_ids
)
...
@@ -314,8 +344,12 @@ class AsyncLLMEngine:
...
@@ -314,8 +344,12 @@ class AsyncLLMEngine:
self
.
engine
.
abort_request
(
request_ids
)
self
.
engine
.
abort_request
(
request_ids
)
async
def
run_engine_loop
(
self
):
async
def
run_engine_loop
(
self
):
# Initialize the RequestTracker here so it uses the right event loop.
has_requests_in_progress
=
False
while
True
:
while
True
:
await
self
.
engine_step
()
if
not
has_requests_in_progress
:
await
self
.
_request_tracker
.
wait_for_new_requests
()
has_requests_in_progress
=
await
self
.
engine_step
()
await
asyncio
.
sleep
(
0
)
await
asyncio
.
sleep
(
0
)
async
def
add_request
(
async
def
add_request
(
...
@@ -350,7 +384,7 @@ class AsyncLLMEngine:
...
@@ -350,7 +384,7 @@ class AsyncLLMEngine:
"error that caused the background loop to stop "
"error that caused the background loop to stop "
"(AsyncEngineDeadError)."
)
"(AsyncEngineDeadError)."
)
stream
=
self
.
request_tracker
.
add_request
(
stream
=
self
.
_
request_tracker
.
add_request
(
request_id
,
request_id
,
prompt
=
prompt
,
prompt
=
prompt
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
...
@@ -428,8 +462,8 @@ class AsyncLLMEngine:
...
@@ -428,8 +462,8 @@ class AsyncLLMEngine:
Args:
Args:
request_id: The unique id of the request.
request_id: The unique id of the request.
"""
"""
self
.
request_tracker
.
abort_request
(
request_id
,
self
.
_
request_tracker
.
abort_request
(
request_id
,
verbose
=
self
.
log_requests
)
verbose
=
self
.
log_requests
)
async
def
get_model_config
(
self
)
->
ModelConfig
:
async
def
get_model_config
(
self
)
->
ModelConfig
:
"""Get the model configuration of the vLLM engine."""
"""Get the model configuration of the vLLM engine."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment