Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ff578cae
Unverified
Commit
ff578cae
authored
Mar 04, 2024
by
Antoni Baum
Committed by
GitHub
Mar 04, 2024
Browse files
Add health check, make async Engine more robust (#3015)
Co-authored-by:
Zhuohan Li
<
zhuohan123@gmail.com
>
parent
22de4523
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
138 additions
and
65 deletions
+138
-65
tests/async_engine/test_async_llm_engine.py
tests/async_engine/test_async_llm_engine.py
+16
-16
tests/async_engine/test_request_tracker.py
tests/async_engine/test_request_tracker.py
+15
-23
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+87
-26
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+20
-0
No files found.
tests/async_engine/test_async_llm_engine.py
View file @
ff578cae
...
@@ -25,12 +25,8 @@ class MockEngine:
...
@@ -25,12 +25,8 @@ class MockEngine:
return
[
RequestOutput
(
return
[
RequestOutput
(
request_id
=
self
.
request_id
)]
if
self
.
request_id
else
[]
request_id
=
self
.
request_id
)]
if
self
.
request_id
else
[]
async
def
encode_request_async
(
async
def
encode_request_async
(
self
,
*
args
,
**
kwargs
):
self
,
pass
*
args
,
**
kwargs
,
):
return
[
1
]
def
generate
(
self
,
request_id
):
def
generate
(
self
,
request_id
):
self
.
request_id
=
request_id
self
.
request_id
=
request_id
...
@@ -43,13 +39,16 @@ class MockEngine:
...
@@ -43,13 +39,16 @@ class MockEngine:
self
.
add_request_calls
+=
1
self
.
add_request_calls
+=
1
async
def
add_request_async
(
self
,
**
kwargs
):
async
def
add_request_async
(
self
,
**
kwargs
):
del
kwargs
# Unused
self
.
add_request_calls
+=
1
self
.
add_request_calls
+=
1
return
def
abort_request
(
self
,
request_id
):
def
abort_request
(
self
,
request_id
):
del
request_id
# Unused
del
request_id
# Unused
self
.
abort_request_calls
+=
1
self
.
abort_request_calls
+=
1
def
has_unfinished_requests
(
self
):
return
self
.
request_id
is
not
None
class
MockAsyncLLMEngine
(
AsyncLLMEngine
):
class
MockAsyncLLMEngine
(
AsyncLLMEngine
):
...
@@ -72,20 +71,21 @@ async def test_new_requests_event():
...
@@ -72,20 +71,21 @@ async def test_new_requests_event():
await
engine
.
add_request
(
"2"
,
""
,
None
)
await
engine
.
add_request
(
"2"
,
""
,
None
)
engine
.
engine
.
generate
(
"2"
)
engine
.
engine
.
generate
(
"2"
)
await
asyncio
.
sleep
(
0
)
await
asyncio
.
sleep
(
0
)
assert
engine
.
engine
.
add_request_calls
==
2
assert
engine
.
engine
.
step_calls
==
2
await
asyncio
.
sleep
(
0
)
await
asyncio
.
sleep
(
0
)
assert
engine
.
engine
.
step_calls
==
3
assert
engine
.
engine
.
add_request_calls
==
2
assert
engine
.
engine
.
step_calls
>=
2
await
asyncio
.
sleep
(
0.001
)
assert
engine
.
engine
.
step_calls
>=
3
engine
.
engine
.
stop_generating
()
engine
.
engine
.
stop_generating
()
await
asyncio
.
sleep
(
0
)
await
asyncio
.
sleep
(
0
.001
)
assert
engine
.
engine
.
step_calls
==
4
old_step_calls
=
engine
.
engine
.
step_calls
await
asyncio
.
sleep
(
0
)
await
asyncio
.
sleep
(
0
.001
)
assert
engine
.
engine
.
step_calls
==
4
assert
engine
.
engine
.
step_calls
==
old_step_calls
await
engine
.
add_request
(
"3"
,
""
,
None
)
await
engine
.
add_request
(
"3"
,
""
,
None
)
await
asyncio
.
sleep
(
0.01
)
await
asyncio
.
sleep
(
0.01
)
assert
engine
.
engine
.
add_request_calls
==
3
assert
engine
.
engine
.
add_request_calls
==
3
assert
engine
.
engine
.
step_calls
==
5
assert
engine
.
engine
.
step_calls
==
old_step_calls
+
1
await
asyncio
.
sleep
(
0.01
)
await
asyncio
.
sleep
(
0.01
)
assert
engine
.
engine
.
add_request_calls
==
3
assert
engine
.
engine
.
add_request_calls
==
3
assert
engine
.
engine
.
step_calls
==
5
assert
engine
.
engine
.
step_calls
==
old_step_calls
+
1
tests/async_engine/test_request_tracker.py
View file @
ff578cae
...
@@ -4,25 +4,14 @@ from vllm.engine.async_llm_engine import RequestTracker
...
@@ -4,25 +4,14 @@ from vllm.engine.async_llm_engine import RequestTracker
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
class
DummyEvent
:
@
pytest
.
mark
.
asyncio
async
def
test_request_tracker
():
def
__init__
(
self
):
self
.
flag
=
False
def
set
(
self
):
self
.
flag
=
True
def
clear
(
self
):
self
.
flag
=
False
def
test_request_tracker
():
tracker
=
RequestTracker
()
tracker
=
RequestTracker
()
tracker
.
new_requests_event
=
DummyEvent
()
stream_1
=
tracker
.
add_request
(
"1"
)
stream_1
=
tracker
.
add_request
(
"1"
)
assert
tracker
.
new_requests_event
.
flag
assert
tracker
.
new_requests_event
.
is_set
()
await
tracker
.
wait_for_new_requests
()
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
assert
not
tracker
.
new_requests_event
.
flag
assert
not
tracker
.
new_requests_event
.
is_set
()
assert
len
(
new
)
==
1
assert
len
(
new
)
==
1
assert
new
[
0
][
"request_id"
]
==
"1"
assert
new
[
0
][
"request_id"
]
==
"1"
assert
not
finished
assert
not
finished
...
@@ -30,9 +19,10 @@ def test_request_tracker():
...
@@ -30,9 +19,10 @@ def test_request_tracker():
stream_2
=
tracker
.
add_request
(
"2"
)
stream_2
=
tracker
.
add_request
(
"2"
)
stream_3
=
tracker
.
add_request
(
"3"
)
stream_3
=
tracker
.
add_request
(
"3"
)
assert
tracker
.
new_requests_event
.
flag
assert
tracker
.
new_requests_event
.
is_set
()
await
tracker
.
wait_for_new_requests
()
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
assert
not
tracker
.
new_requests_event
.
flag
assert
not
tracker
.
new_requests_event
.
is_set
()
assert
len
(
new
)
==
2
assert
len
(
new
)
==
2
assert
new
[
0
][
"request_id"
]
==
"2"
assert
new
[
0
][
"request_id"
]
==
"2"
assert
new
[
1
][
"request_id"
]
==
"3"
assert
new
[
1
][
"request_id"
]
==
"3"
...
@@ -43,7 +33,7 @@ def test_request_tracker():
...
@@ -43,7 +33,7 @@ def test_request_tracker():
# request_ids must be unique
# request_ids must be unique
with
pytest
.
raises
(
KeyError
):
with
pytest
.
raises
(
KeyError
):
tracker
.
add_request
(
"1"
)
tracker
.
add_request
(
"1"
)
assert
not
tracker
.
new_requests_event
.
flag
assert
not
tracker
.
new_requests_event
.
is_set
()
tracker
.
abort_request
(
"1"
)
tracker
.
abort_request
(
"1"
)
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
...
@@ -54,7 +44,8 @@ def test_request_tracker():
...
@@ -54,7 +44,8 @@ def test_request_tracker():
stream_4
=
tracker
.
add_request
(
"4"
)
stream_4
=
tracker
.
add_request
(
"4"
)
tracker
.
abort_request
(
"4"
)
tracker
.
abort_request
(
"4"
)
assert
tracker
.
new_requests_event
.
flag
assert
tracker
.
new_requests_event
.
is_set
()
await
tracker
.
wait_for_new_requests
()
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
assert
len
(
finished
)
==
1
assert
len
(
finished
)
==
1
assert
"4"
in
finished
assert
"4"
in
finished
...
@@ -62,11 +53,12 @@ def test_request_tracker():
...
@@ -62,11 +53,12 @@ def test_request_tracker():
assert
stream_4
.
finished
assert
stream_4
.
finished
stream_5
=
tracker
.
add_request
(
"5"
)
stream_5
=
tracker
.
add_request
(
"5"
)
assert
tracker
.
new_requests_event
.
flag
assert
tracker
.
new_requests_event
.
is_set
()
tracker
.
process_request_output
(
tracker
.
process_request_output
(
RequestOutput
(
"2"
,
"output"
,
[],
[],
[],
bool
(
finished
)))
RequestOutput
(
"2"
,
"output"
,
[],
[],
[],
finished
=
True
))
await
tracker
.
wait_for_new_requests
()
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
assert
not
tracker
.
new_requests_event
.
flag
assert
not
tracker
.
new_requests_event
.
is_set
()
assert
len
(
finished
)
==
1
assert
len
(
finished
)
==
1
assert
"2"
in
finished
assert
"2"
in
finished
assert
len
(
new
)
==
1
assert
len
(
new
)
==
1
...
...
vllm/engine/async_llm_engine.py
View file @
ff578cae
import
asyncio
import
asyncio
import
os
import
time
import
time
from
functools
import
partial
from
functools
import
partial
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
,
AsyncIterator
)
Union
,
AsyncIterator
,
Callable
)
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
...
@@ -14,28 +15,31 @@ from vllm.outputs import RequestOutput
...
@@ -14,28 +15,31 @@ from vllm.outputs import RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
ENGINE_ITERATION_TIMEOUT_S
=
int
(
os
.
environ
.
get
(
"VLLM_ENGINE_ITERATION_TIMEOUT_S"
,
"60"
))
class
AsyncEngineDeadError
(
RuntimeError
):
class
AsyncEngineDeadError
(
RuntimeError
):
pass
pass
def
_raise_exception_on_finish
(
task
:
asyncio
.
Task
,
def
_raise_exception_on_finish
(
request_tracker
:
"RequestTracker"
)
->
None
:
task
:
asyncio
.
Task
,
error_callback
:
Callable
[[
Exception
],
None
])
->
None
:
msg
=
(
"Task finished unexpectedly. This should never happen! "
msg
=
(
"Task finished unexpectedly. This should never happen! "
"Please open an issue on Github."
)
"Please open an issue on Github."
)
exception
=
None
try
:
try
:
try
:
task
.
result
()
task
.
result
()
# NOTE: This will be thrown if task exits normally (which it should not)
except
asyncio
.
CancelledError
:
return
except
Exception
as
exc
:
raise
AsyncEngineDeadError
(
msg
+
" See stack trace above for the actual cause."
)
from
exc
raise
AsyncEngineDeadError
(
msg
)
raise
AsyncEngineDeadError
(
msg
)
except
Exception
as
exc
:
except
Exception
as
e
:
request_tracker
.
propagate_exception
(
exc
)
exception
=
e
raise
exc
logger
.
error
(
"Engine background task failed"
,
exc_info
=
e
)
error_callback
(
exception
)
raise
AsyncEngineDeadError
(
msg
+
" See stack trace above for the actual cause."
)
from
e
class
AsyncStream
:
class
AsyncStream
:
...
@@ -78,13 +82,13 @@ class RequestTracker:
...
@@ -78,13 +82,13 @@ class RequestTracker:
self
.
_finished_requests
:
asyncio
.
Queue
[
str
]
=
asyncio
.
Queue
()
self
.
_finished_requests
:
asyncio
.
Queue
[
str
]
=
asyncio
.
Queue
()
self
.
_new_requests
:
asyncio
.
Queue
[
Tuple
[
AsyncStream
,
self
.
_new_requests
:
asyncio
.
Queue
[
Tuple
[
AsyncStream
,
dict
]]
=
asyncio
.
Queue
()
dict
]]
=
asyncio
.
Queue
()
self
.
new_requests_event
=
None
self
.
new_requests_event
=
asyncio
.
Event
()
def
__contains__
(
self
,
item
):
def
__contains__
(
self
,
item
):
return
item
in
self
.
_request_streams
return
item
in
self
.
_request_streams
def
init_event
(
self
)
:
def
__len__
(
self
)
->
int
:
self
.
new
_request
s_event
=
asyncio
.
Event
(
)
return
len
(
self
.
_request
_streams
)
def
propagate_exception
(
self
,
def
propagate_exception
(
self
,
exc
:
Exception
,
exc
:
Exception
,
...
@@ -93,9 +97,11 @@ class RequestTracker:
...
@@ -93,9 +97,11 @@ class RequestTracker:
(all if request_id is None)."""
(all if request_id is None)."""
if
request_id
is
not
None
:
if
request_id
is
not
None
:
self
.
_request_streams
[
request_id
].
put
(
exc
)
self
.
_request_streams
[
request_id
].
put
(
exc
)
self
.
abort_request
(
request_id
)
else
:
else
:
for
stream
in
self
.
_request_streams
.
value
s
():
for
rid
,
stream
in
self
.
_request_streams
.
item
s
():
stream
.
put
(
exc
)
stream
.
put
(
exc
)
self
.
abort_request
(
rid
)
def
process_request_output
(
self
,
def
process_request_output
(
self
,
request_output
:
RequestOutput
,
request_output
:
RequestOutput
,
...
@@ -172,12 +178,15 @@ class RequestTracker:
...
@@ -172,12 +178,15 @@ class RequestTracker:
self
.
_request_streams
[
stream
.
request_id
]
=
stream
self
.
_request_streams
[
stream
.
request_id
]
=
stream
new_requests
.
append
(
new_request
)
new_requests
.
append
(
new_request
)
self
.
new_requests_event
.
clear
()
return
new_requests
,
finished_requests
return
new_requests
,
finished_requests
async
def
wait_for_new_requests
(
self
):
async
def
wait_for_new_requests
(
self
):
await
self
.
new_requests_event
.
wait
()
if
not
self
.
has_new_requests
():
await
self
.
new_requests_event
.
wait
()
self
.
new_requests_event
.
clear
()
def
has_new_requests
(
self
):
return
not
self
.
_new_requests
.
empty
()
class
_AsyncLLMEngine
(
LLMEngine
):
class
_AsyncLLMEngine
(
LLMEngine
):
...
@@ -285,6 +294,10 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -285,6 +294,10 @@ class _AsyncLLMEngine(LLMEngine):
all_outputs
=
await
asyncio
.
gather
(
*
coros
)
all_outputs
=
await
asyncio
.
gather
(
*
coros
)
return
all_outputs
return
all_outputs
async
def
check_health_async
(
self
):
"""Raises an error if engine is unhealthy."""
self
.
_check_if_any_actor_is_dead
()
class
AsyncLLMEngine
:
class
AsyncLLMEngine
:
"""An asynchronous wrapper for LLMEngine.
"""An asynchronous wrapper for LLMEngine.
...
@@ -335,27 +348,48 @@ class AsyncLLMEngine:
...
@@ -335,27 +348,48 @@ class AsyncLLMEngine:
# collected
# collected
self
.
_background_loop_unshielded
=
None
self
.
_background_loop_unshielded
=
None
self
.
start_engine_loop
=
start_engine_loop
self
.
start_engine_loop
=
start_engine_loop
self
.
_request_tracker
=
RequestTracker
()
self
.
_request_tracker
:
Optional
[
RequestTracker
]
=
None
self
.
_errored_with
:
Optional
[
BaseException
]
=
None
@
property
@
property
def
is_running
(
self
)
->
bool
:
def
is_running
(
self
)
->
bool
:
return
(
self
.
background_loop
is
not
None
return
(
self
.
background_loop
is
not
None
and
not
self
.
background_loop
.
done
())
and
not
self
.
_background_loop_unshielded
.
done
())
@
property
def
is_stopped
(
self
)
->
bool
:
return
self
.
errored
or
(
self
.
background_loop
is
not
None
and
self
.
_background_loop_unshielded
.
done
())
@
property
def
errored
(
self
)
->
bool
:
return
self
.
_errored_with
is
not
None
def
set_errored
(
self
,
exc
:
Exception
)
->
None
:
self
.
_errored_with
=
exc
def
_error_callback
(
self
,
exc
:
Exception
)
->
None
:
self
.
set_errored
(
exc
)
self
.
_request_tracker
.
propagate_exception
(
exc
)
def
get_tokenizer
(
self
):
def
get_tokenizer
(
self
):
return
self
.
engine
.
tokenizer
.
tokenizer
return
self
.
engine
.
tokenizer
.
tokenizer
def
start_background_loop
(
self
)
->
None
:
def
start_background_loop
(
self
)
->
None
:
"""Start the background loop."""
"""Start the background loop."""
if
self
.
errored
:
raise
AsyncEngineDeadError
(
"Background loop has errored already."
)
from
self
.
_errored_with
if
self
.
is_running
:
if
self
.
is_running
:
raise
RuntimeError
(
"Background loop is already running."
)
raise
RuntimeError
(
"Background loop is already running."
)
self
.
_request_tracker
.
init_event
()
# Initialize the RequestTracker here so it uses the right event loop.
self
.
_request_tracker
=
RequestTracker
()
self
.
_background_loop_unshielded
=
asyncio
.
get_event_loop
(
self
.
_background_loop_unshielded
=
asyncio
.
get_event_loop
(
).
create_task
(
self
.
run_engine_loop
())
).
create_task
(
self
.
run_engine_loop
())
self
.
_background_loop_unshielded
.
add_done_callback
(
self
.
_background_loop_unshielded
.
add_done_callback
(
partial
(
_raise_exception_on_finish
,
partial
(
_raise_exception_on_finish
,
request_tr
ack
er
=
self
.
_
request_tr
ack
er
))
error_callb
ack
=
self
.
_
error_callb
ack
))
self
.
background_loop
=
asyncio
.
shield
(
self
.
_background_loop_unshielded
)
self
.
background_loop
=
asyncio
.
shield
(
self
.
_background_loop_unshielded
)
def
_init_engine
(
self
,
*
args
,
def
_init_engine
(
self
,
*
args
,
...
@@ -423,12 +457,23 @@ class AsyncLLMEngine:
...
@@ -423,12 +457,23 @@ class AsyncLLMEngine:
self
.
engine
.
abort_request
(
request_ids
)
self
.
engine
.
abort_request
(
request_ids
)
async
def
run_engine_loop
(
self
):
async
def
run_engine_loop
(
self
):
# Initialize the RequestTracker here so it uses the right event loop.
has_requests_in_progress
=
False
has_requests_in_progress
=
False
while
True
:
while
True
:
if
not
has_requests_in_progress
:
if
not
has_requests_in_progress
:
logger
.
debug
(
"Waiting for new requests..."
)
await
self
.
_request_tracker
.
wait_for_new_requests
()
await
self
.
_request_tracker
.
wait_for_new_requests
()
has_requests_in_progress
=
await
self
.
engine_step
()
logger
.
debug
(
"Got new requests!"
)
# Abort if iteration takes too long due to unrecoverable errors
# (eg. NCCL timeouts).
try
:
has_requests_in_progress
=
await
asyncio
.
wait_for
(
self
.
engine_step
(),
ENGINE_ITERATION_TIMEOUT_S
)
except
asyncio
.
TimeoutError
as
exc
:
logger
.
error
(
"Engine iteration timed out. This should never happen!"
)
self
.
set_errored
(
exc
)
raise
await
asyncio
.
sleep
(
0
)
await
asyncio
.
sleep
(
0
)
async
def
add_request
(
async
def
add_request
(
...
@@ -647,3 +692,19 @@ class AsyncLLMEngine:
...
@@ -647,3 +692,19 @@ class AsyncLLMEngine:
await
self
.
engine
.
do_log_stats
.
remote
()
await
self
.
engine
.
do_log_stats
.
remote
()
else
:
else
:
self
.
engine
.
do_log_stats
()
self
.
engine
.
do_log_stats
()
async
def
check_health
(
self
):
"""Raises an error if engine is unhealthy."""
t
=
time
.
perf_counter
()
logger
.
debug
(
"Starting health check..."
)
if
self
.
is_stopped
:
raise
AsyncEngineDeadError
(
"Background loop is stopped."
)
if
self
.
engine_use_ray
:
try
:
await
self
.
engine
.
check_health
.
remote
()
except
ray
.
exceptions
.
RayActorError
as
e
:
raise
RuntimeError
(
"Engine is dead."
)
from
e
else
:
await
self
.
engine
.
check_health_async
()
logger
.
debug
(
f
"Health check took
{
time
.
perf_counter
()
-
t
}
s"
)
vllm/engine/llm_engine.py
View file @
ff578cae
...
@@ -1119,3 +1119,23 @@ class LLMEngine:
...
@@ -1119,3 +1119,23 @@ class LLMEngine:
for
worker
in
self
.
workers
for
worker
in
self
.
workers
])
])
return
forward_dag
.
experimental_compile
()
return
forward_dag
.
experimental_compile
()
def
check_health
(
self
)
->
None
:
"""Raises an error if engine is unhealthy."""
self
.
_check_if_any_actor_is_dead
()
def
_check_if_any_actor_is_dead
(
self
):
if
not
self
.
parallel_config
.
worker_use_ray
:
return
if
not
self
.
workers
:
return
dead_actors
=
[]
for
actor
in
self
.
workers
:
actor_state
=
ray
.
state
.
actors
(
actor
.
_ray_actor_id
.
hex
())
# pylint: disable=protected-access
if
actor_state
[
"State"
]
==
"DEAD"
:
dead_actors
.
append
(
actor
)
if
dead_actors
:
raise
RuntimeError
(
"At least one Worker is dead. "
f
"Dead Workers:
{
dead_actors
}
. "
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment