Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6e0c9d6b
Unverified
Commit
6e0c9d6b
authored
Sep 24, 2024
by
Joe Runde
Committed by
GitHub
Sep 24, 2024
Browse files
[Bugfix] Use heartbeats instead of health checks (#8583)
parent
6da1ab6b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
87 additions
and
63 deletions
+87
-63
tests/mq_llm_engine/test_error_handling.py
tests/mq_llm_engine/test_error_handling.py
+4
-11
vllm/engine/multiprocessing/__init__.py
vllm/engine/multiprocessing/__init__.py
+1
-6
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+22
-29
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+60
-17
No files found.
tests/mq_llm_engine/test_error_handling.py
View file @
6e0c9d6b
...
@@ -153,27 +153,20 @@ async def test_failed_abort(tmp_socket):
...
@@ -153,27 +153,20 @@ async def test_failed_abort(tmp_socket):
await
client
.
check_health
()
await
client
.
check_health
()
# Trigger an abort on the client side.
# Trigger an abort on the client side.
async
def
bad_abort_after_2s
():
# This request ID does not exist, and will cause the engine to error
await
asyncio
.
sleep
(
2.0
)
await
client
.
abort
(
request_id
=
"foo"
)
await
client
.
abort
(
request_id
=
"foo"
)
# Trigger an abort in 2s from now.
# Future generation requests will now fail
abort_task
=
asyncio
.
create_task
(
bad_abort_after_2s
())
# Exception in abort() will happen during this generation.
# This will kill the engine and should return ENGINE_DEAD_ERROR
# with reference to the original KeyError("foo")
# with reference to the original KeyError("foo")
with
pytest
.
raises
(
MQEngineDeadError
)
as
execinfo
:
with
pytest
.
raises
(
MQEngineDeadError
)
as
execinfo
:
async
for
_
in
client
.
generate
(
async
for
_
in
client
.
generate
(
inputs
=
"Hello my name is"
,
inputs
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(
max_tokens
=
200
0
),
sampling_params
=
SamplingParams
(
max_tokens
=
1
0
),
request_id
=
uuid
.
uuid4
()):
request_id
=
uuid
.
uuid4
()):
pass
pass
assert
"KeyError"
in
repr
(
execinfo
.
value
)
assert
"KeyError"
in
repr
(
execinfo
.
value
)
assert
client
.
errored
assert
client
.
errored
await
abort_task
# This should raise the original error.
# This should raise the original error.
with
pytest
.
raises
(
RAISED_ERROR
):
with
pytest
.
raises
(
RAISED_ERROR
):
await
client
.
check_health
()
await
client
.
check_health
()
...
...
vllm/engine/multiprocessing/__init__.py
View file @
6e0c9d6b
...
@@ -43,10 +43,6 @@ class RPCAbortRequest:
...
@@ -43,10 +43,6 @@ class RPCAbortRequest:
request_id
:
str
request_id
:
str
class
RPCHealthRequest
:
pass
class
RPCStartupRequest
(
Enum
):
class
RPCStartupRequest
(
Enum
):
IS_SERVER_READY
=
1
IS_SERVER_READY
=
1
...
@@ -56,8 +52,7 @@ class RPCStartupResponse:
...
@@ -56,8 +52,7 @@ class RPCStartupResponse:
tracing_enabled
:
bool
tracing_enabled
:
bool
RPC_REQUEST_T
=
Union
[
RPCProcessRequest
,
RPCAbortRequest
,
RPCHealthRequest
,
RPC_REQUEST_T
=
Union
[
RPCProcessRequest
,
RPCAbortRequest
,
RPCStartupRequest
]
RPCStartupRequest
]
REQUEST_OUTPUTS_T
=
Union
[
List
[
RequestOutput
],
RPCError
]
REQUEST_OUTPUTS_T
=
Union
[
List
[
RequestOutput
],
RPCError
]
...
...
vllm/engine/multiprocessing/client.py
View file @
6e0c9d6b
...
@@ -20,9 +20,8 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
...
@@ -20,9 +20,8 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT
,
IPC_INPUT_EXT
,
IPC_HEALTH_EXT
,
IPC_INPUT_EXT
,
IPC_OUTPUT_EXT
,
RPC_REQUEST_T
,
IPC_OUTPUT_EXT
,
RPC_REQUEST_T
,
VLLM_RPC_SUCCESS_STR
,
RPCAbortRequest
,
VLLM_RPC_SUCCESS_STR
,
RPCAbortRequest
,
RPCError
,
RPCHealthRequest
,
RPCError
,
RPCProcessRequest
,
RPCProcessRequest
,
RPCStartupRequest
,
RPCStartupRequest
,
RPCStartupResponse
)
RPCStartupResponse
)
# yapf: enable
# yapf: enable
from
vllm.envs
import
VLLM_RPC_TIMEOUT
from
vllm.envs
import
VLLM_RPC_TIMEOUT
from
vllm.inputs
import
PromptInputs
from
vllm.inputs
import
PromptInputs
...
@@ -95,9 +94,9 @@ class MQLLMEngineClient:
...
@@ -95,9 +94,9 @@ class MQLLMEngineClient:
self
.
output_socket
:
Socket
=
self
.
context
.
socket
(
zmq
.
constants
.
PULL
)
self
.
output_socket
:
Socket
=
self
.
context
.
socket
(
zmq
.
constants
.
PULL
)
self
.
output_socket
.
connect
(
f
"
{
ipc_path
}{
IPC_OUTPUT_EXT
}
"
)
self
.
output_socket
.
connect
(
f
"
{
ipc_path
}{
IPC_OUTPUT_EXT
}
"
)
# IPC path for ack
of check_health reques
ts.
# IPC path for ack
ing heartbea
ts.
self
.
hea
lth
_socket
:
Socket
=
self
.
context
.
socket
(
zmq
.
constants
.
PULL
)
self
.
hea
rtbeat
_socket
:
Socket
=
self
.
context
.
socket
(
zmq
.
constants
.
PULL
)
self
.
hea
lth
_socket
.
connect
(
f
"
{
ipc_path
}{
IPC_HEALTH_EXT
}
"
)
self
.
hea
rtbeat
_socket
.
connect
(
f
"
{
ipc_path
}{
IPC_HEALTH_EXT
}
"
)
# IPC path for the data socket.
# IPC path for the data socket.
self
.
data_ipc_path
=
f
"
{
ipc_path
}{
IPC_DATA_EXT
}
"
self
.
data_ipc_path
=
f
"
{
ipc_path
}{
IPC_DATA_EXT
}
"
...
@@ -124,34 +123,28 @@ class MQLLMEngineClient:
...
@@ -124,34 +123,28 @@ class MQLLMEngineClient:
finally
:
finally
:
socket
.
close
(
linger
=
0
)
socket
.
close
(
linger
=
0
)
async
def
run_check_health_loop
(
self
,
timeout
:
int
):
async
def
run_heartbeat_loop
(
self
,
timeout
:
int
):
"""Background loop that continually probes the RPCServer for health.
"""Background loop that continually listens to the RPCServer for
heartbeats.
The loop sends CHECK_HEALTH requests to the INPUT_SOCKET, which
the MQLLMEngine server is blocking on.
The Server replies on the HEALTH_SOCKET (rather than on the
OUTPUT_SOCKET such that the messages are not intermingled with
output streaming).
"""
"""
try
:
try
:
while
True
:
while
True
:
if
await
self
.
health_socket
.
poll
(
timeout
=
timeout
)
==
0
:
if
await
self
.
heartbeat_socket
.
poll
(
timeout
=
timeout
)
==
0
:
# Wakeup every N seconds and do a health probe.
# No heartbeat was received. Set error and exit the loop
await
self
.
_send_one_way_rpc_request
(
self
.
_set_errored
(
RPCHealthRequest
(),
self
.
input_socket
)
TimeoutError
(
"No heartbeat received "
"from MQLLMEngine"
))
# Wait for ack from the health socket.
logger
.
debug
(
"Shutting down MQLLMEngineClient check "
await
self
.
_await_ack
(
error_message
=
"Health check failed."
,
"health loop due to timeout"
)
socket
=
self
.
health_socket
)
break
else
:
else
:
#
Server sent a health status message unprompted.
#
Heartbeat received- check the message
await
self
.
_check_success
(
await
self
.
_check_success
(
error_message
=
"Hea
lth check
failed."
,
error_message
=
"Hea
rtbeat
failed."
,
socket
=
self
.
hea
lth
_socket
)
socket
=
self
.
hea
rtbeat
_socket
)
logger
.
debug
(
"Hea
lth probe
successful."
)
logger
.
debug
(
"Hea
rtbeat
successful."
)
except
asyncio
.
CancelledError
:
except
asyncio
.
CancelledError
:
logger
.
debug
(
"Shutting down MQLLMEngineClient check health loop."
)
logger
.
debug
(
"Shutting down MQLLMEngineClient check health loop."
)
...
@@ -234,7 +227,7 @@ class MQLLMEngineClient:
...
@@ -234,7 +227,7 @@ class MQLLMEngineClient:
# Start health_loop.
# Start health_loop.
self
.
health_loop
=
asyncio
.
create_task
(
self
.
health_loop
=
asyncio
.
create_task
(
self
.
run_
c
he
ck_health
_loop
(
timeout
=
VLLM_RPC_TIMEOUT
))
self
.
run_he
artbeat
_loop
(
timeout
=
VLLM_RPC_TIMEOUT
))
def
close
(
self
):
def
close
(
self
):
"""Destroy the ZeroMQ Context."""
"""Destroy the ZeroMQ Context."""
...
...
vllm/engine/multiprocessing/engine.py
View file @
6e0c9d6b
import
pickle
import
pickle
import
signal
import
signal
import
threading
import
time
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Iterator
,
List
,
Optional
,
Union
from
typing
import
Iterator
,
List
,
Optional
,
Union
...
@@ -15,10 +17,10 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
...
@@ -15,10 +17,10 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT
,
IPC_INPUT_EXT
,
IPC_HEALTH_EXT
,
IPC_INPUT_EXT
,
IPC_OUTPUT_EXT
,
REQUEST_OUTPUTS_T
,
IPC_OUTPUT_EXT
,
REQUEST_OUTPUTS_T
,
VLLM_RPC_SUCCESS_STR
,
RPCAbortRequest
,
VLLM_RPC_SUCCESS_STR
,
RPCAbortRequest
,
RPCError
,
RPCHealthRequest
,
RPCError
,
RPCProcessRequest
,
RPCProcessRequest
,
RPCStartupRequest
,
RPCStartupRequest
,
RPCStartupResponse
)
RPCStartupResponse
)
# yapf: enable
# yapf: enable
from
vllm.envs
import
VLLM_RPC_TIMEOUT
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
...
@@ -91,9 +93,9 @@ class MQLLMEngine:
...
@@ -91,9 +93,9 @@ class MQLLMEngine:
self
.
output_socket
=
self
.
ctx
.
socket
(
zmq
.
constants
.
PUSH
)
self
.
output_socket
=
self
.
ctx
.
socket
(
zmq
.
constants
.
PUSH
)
self
.
output_socket
.
bind
(
f
"
{
ipc_path
}{
IPC_OUTPUT_EXT
}
"
)
self
.
output_socket
.
bind
(
f
"
{
ipc_path
}{
IPC_OUTPUT_EXT
}
"
)
# Send hea
lth st
at
u
s back to client.
# Send hea
rtbe
ats back to client.
self
.
hea
lth
_socket
=
self
.
ctx
.
socket
(
zmq
.
constants
.
PUSH
)
self
.
hea
rtbeat
_socket
=
self
.
ctx
.
socket
(
zmq
.
constants
.
PUSH
)
self
.
hea
lth
_socket
.
bind
(
f
"
{
ipc_path
}{
IPC_HEALTH_EXT
}
"
)
self
.
hea
rtbeat
_socket
.
bind
(
f
"
{
ipc_path
}{
IPC_HEALTH_EXT
}
"
)
# IPC path for the data socket.
# IPC path for the data socket.
self
.
data_ipc_path
=
f
"
{
ipc_path
}{
IPC_DATA_EXT
}
"
self
.
data_ipc_path
=
f
"
{
ipc_path
}{
IPC_DATA_EXT
}
"
...
@@ -101,6 +103,20 @@ class MQLLMEngine:
...
@@ -101,6 +103,20 @@ class MQLLMEngine:
# Error state.
# Error state.
self
.
_errored_with
:
Optional
[
BaseException
]
=
None
self
.
_errored_with
:
Optional
[
BaseException
]
=
None
# Heartbeat thread
self
.
heartbeat_thread
=
threading
.
Thread
(
target
=
self
.
_heartbeat_loop
,
daemon
=
True
)
self
.
_heartbeat_stop_event
=
threading
.
Event
()
# The heartbeat needs to be faster than what the client will wait for
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self
.
heartbeat_interval_seconds
=
VLLM_RPC_TIMEOUT
/
5000.0
self
.
_last_alive_time
=
time
.
time
()
# The heartbeats can tolerate a long period of the engine chugging
# away at a generation request.
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self
.
last_alive_threshold
=
VLLM_RPC_TIMEOUT
*
3.0
/
1000.0
@
property
@
property
def
dead_error
(
self
)
->
BaseException
:
def
dead_error
(
self
)
->
BaseException
:
if
self
.
_errored_with
is
not
None
:
if
self
.
_errored_with
is
not
None
:
...
@@ -131,6 +147,8 @@ class MQLLMEngine:
...
@@ -131,6 +147,8 @@ class MQLLMEngine:
try
:
try
:
logger
.
debug
(
"Starting Startup Loop."
)
logger
.
debug
(
"Starting Startup Loop."
)
self
.
run_startup_loop
()
self
.
run_startup_loop
()
logger
.
debug
(
"Starting heartbeat thread"
)
self
.
heartbeat_thread
.
start
()
logger
.
debug
(
"Starting Engine Loop."
)
logger
.
debug
(
"Starting Engine Loop."
)
self
.
run_engine_loop
()
self
.
run_engine_loop
()
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -144,6 +162,7 @@ class MQLLMEngine:
...
@@ -144,6 +162,7 @@ class MQLLMEngine:
def
cleanup
(
self
):
def
cleanup
(
self
):
"""Cleanup zeromq state on shutdown."""
"""Cleanup zeromq state on shutdown."""
# Closes all sockets and destroys context.
# Closes all sockets and destroys context.
self
.
_heartbeat_stop_event
.
set
()
self
.
ctx
.
destroy
(
linger
=
0
)
self
.
ctx
.
destroy
(
linger
=
0
)
del
self
.
engine
del
self
.
engine
...
@@ -182,9 +201,11 @@ class MQLLMEngine:
...
@@ -182,9 +201,11 @@ class MQLLMEngine:
"""Core busy loop of the LLMEngine."""
"""Core busy loop of the LLMEngine."""
while
True
:
while
True
:
self
.
_alive
()
if
not
self
.
engine
.
has_unfinished_requests
():
if
not
self
.
engine
.
has_unfinished_requests
():
# Poll until there is work to do.
# Poll until there is work to do.
while
self
.
input_socket
.
poll
(
timeout
=
POLLING_TIMEOUT_MS
)
==
0
:
while
self
.
input_socket
.
poll
(
timeout
=
POLLING_TIMEOUT_MS
)
==
0
:
self
.
_alive
()
self
.
engine
.
do_log_stats
()
self
.
engine
.
do_log_stats
()
logger
.
debug
(
"Waiting for new requests in engine loop."
)
logger
.
debug
(
"Waiting for new requests in engine loop."
)
...
@@ -200,7 +221,6 @@ class MQLLMEngine:
...
@@ -200,7 +221,6 @@ class MQLLMEngine:
def
engine_step
(
self
)
->
List
[
RequestOutput
]:
def
engine_step
(
self
)
->
List
[
RequestOutput
]:
"""Engine step wrapper with error handling."""
"""Engine step wrapper with error handling."""
try
:
try
:
return
self
.
engine
.
step
()
return
self
.
engine
.
step
()
except
SystemExit
:
except
SystemExit
:
...
@@ -229,10 +249,9 @@ class MQLLMEngine:
...
@@ -229,10 +249,9 @@ class MQLLMEngine:
self
.
_handle_process_request
(
request
)
self
.
_handle_process_request
(
request
)
elif
isinstance
(
request
,
RPCAbortRequest
):
elif
isinstance
(
request
,
RPCAbortRequest
):
self
.
_handle_abort_request
(
request
)
self
.
_handle_abort_request
(
request
)
elif
isinstance
(
request
,
RPCHealthRequest
):
self
.
_handle_health_request
()
else
:
else
:
raise
ValueError
(
"Unknown RPCRequest Type: {request}"
)
raise
ValueError
(
"Unknown RPCRequest Type: "
f
"
{
type
(
request
)
}
"
)
except
Exception
as
e
:
except
Exception
as
e
:
self
.
_set_errored
(
e
)
self
.
_set_errored
(
e
)
...
@@ -279,13 +298,32 @@ class MQLLMEngine:
...
@@ -279,13 +298,32 @@ class MQLLMEngine:
if
self
.
log_requests
:
if
self
.
log_requests
:
logger
.
info
(
"Aborted request %s."
,
request
.
request_id
)
logger
.
info
(
"Aborted request %s."
,
request
.
request_id
)
def
_handle_health_request
(
self
):
def
_heartbeat_loop
(
self
):
while
not
self
.
_heartbeat_stop_event
.
wait
(
timeout
=
self
.
heartbeat_interval_seconds
):
# Loops until the stop event is set
self
.
_heartbeat
()
logger
.
debug
(
"Exiting MQLLMEngine heartbeat thread"
)
def
_heartbeat
(
self
):
# Send unhealthy if engine has already errored
if
self
.
_errored_with
is
not
None
:
if
self
.
_errored_with
is
not
None
:
self
.
_send_unhealthy
(
self
.
_errored_with
)
self
.
_send_unhealthy
(
self
.
_errored_with
)
# Raises error if unhealthy.
# Check for life of the main loop
self
.
engine
.
check_health
()
elif
time
.
time
()
-
self
.
_last_alive_time
>
self
.
last_alive_threshold
:
self
.
_send_healthy
()
self
.
_send_unhealthy
(
RuntimeError
(
"Engine loop has died"
))
else
:
# Otherwise- check health of the engine
# self.engine.check_health() raises on unhealthy
try
:
self
.
engine
.
check_health
()
self
.
_send_healthy
()
except
Exception
as
e
:
self
.
_set_errored
(
e
)
self
.
_send_unhealthy
(
e
)
def
_send_outputs
(
self
,
outputs
:
REQUEST_OUTPUTS_T
):
def
_send_outputs
(
self
,
outputs
:
REQUEST_OUTPUTS_T
):
"""Send List of RequestOutput to RPCClient."""
"""Send List of RequestOutput to RPCClient."""
...
@@ -295,12 +333,14 @@ class MQLLMEngine:
...
@@ -295,12 +333,14 @@ class MQLLMEngine:
def
_send_healthy
(
self
):
def
_send_healthy
(
self
):
"""Send HEALTHY message to RPCClient."""
"""Send HEALTHY message to RPCClient."""
self
.
health_socket
.
send_multipart
(
HEALTHY_RESPONSE
,
copy
=
False
)
if
not
self
.
heartbeat_socket
.
closed
:
self
.
heartbeat_socket
.
send_multipart
(
HEALTHY_RESPONSE
,
copy
=
False
)
def
_send_unhealthy
(
self
,
error
:
BaseException
):
def
_send_unhealthy
(
self
,
error
:
BaseException
):
"""Send UNHEALTHY message to RPCClient."""
"""Send UNHEALTHY message to RPCClient."""
error_bytes
=
pickle
.
dumps
(
error
)
if
not
self
.
heartbeat_socket
.
closed
:
self
.
health_socket
.
send_multipart
((
error_bytes
,
),
copy
=
False
)
error_bytes
=
pickle
.
dumps
(
error
)
self
.
heartbeat_socket
.
send_multipart
((
error_bytes
,
),
copy
=
False
)
def
_async_socket_engine_callback
(
self
,
def
_async_socket_engine_callback
(
self
,
request_outputs
:
REQUEST_OUTPUTS_T
):
request_outputs
:
REQUEST_OUTPUTS_T
):
...
@@ -313,6 +353,9 @@ class MQLLMEngine:
...
@@ -313,6 +353,9 @@ class MQLLMEngine:
if
self
.
_errored_with
is
None
:
if
self
.
_errored_with
is
None
:
self
.
_errored_with
=
e
self
.
_errored_with
=
e
def
_alive
(
self
):
self
.
_last_alive_time
=
time
.
time
()
def
run_mp_engine
(
engine_args
:
AsyncEngineArgs
,
usage_context
:
UsageContext
,
def
run_mp_engine
(
engine_args
:
AsyncEngineArgs
,
usage_context
:
UsageContext
,
ipc_path
:
str
):
ipc_path
:
str
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment