Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3b3f1e74
Unverified
Commit
3b3f1e74
authored
Oct 30, 2024
by
Joe Runde
Committed by
GitHub
Oct 30, 2024
Browse files
[Bugfix][core] replace heartbeat with pid check (#9818)
Signed-off-by:
Joe Runde
<
Joseph.Runde@ibm.com
>
parent
9ff4511e
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
62 additions
and
62 deletions
+62
-62
tests/mq_llm_engine/test_error_handling.py
tests/mq_llm_engine/test_error_handling.py
+26
-1
tests/mq_llm_engine/utils.py
tests/mq_llm_engine/utils.py
+1
-1
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+19
-10
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+11
-48
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+5
-2
No files found.
tests/mq_llm_engine/test_error_handling.py
View file @
3b3f1e74
...
@@ -21,7 +21,7 @@ from vllm.usage.usage_lib import UsageContext
...
@@ -21,7 +21,7 @@ from vllm.usage.usage_lib import UsageContext
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
MODEL
=
"google/gemma-1.1-2b-it"
MODEL
=
"google/gemma-1.1-2b-it"
ENGINE_ARGS
=
AsyncEngineArgs
(
model
=
MODEL
)
ENGINE_ARGS
=
AsyncEngineArgs
(
model
=
MODEL
,
enforce_eager
=
True
)
RAISED_ERROR
=
KeyError
RAISED_ERROR
=
KeyError
RAISED_VALUE
=
"foo"
RAISED_VALUE
=
"foo"
...
@@ -266,3 +266,28 @@ async def test_mp_cuda_init():
...
@@ -266,3 +266,28 @@ async def test_mp_cuda_init():
async
with
build_async_engine_client
(
args
):
async
with
build_async_engine_client
(
args
):
pass
pass
@
pytest
.
mark
.
asyncio
async
def
test_engine_process_death
(
tmp_socket
):
with
RemoteMQLLMEngine
(
engine_args
=
ENGINE_ARGS
,
ipc_path
=
tmp_socket
)
as
engine
:
client
=
await
engine
.
make_client
()
assert
client
.
is_running
# kill the engine process
engine
.
proc
.
kill
()
# Generate call should fail
with
pytest
.
raises
(
MQEngineDeadError
):
async
for
_
in
client
.
generate
(
prompt
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(),
request_id
=
uuid
.
uuid4
()):
pass
# And the health check should show the engine is dead
with
pytest
.
raises
(
RuntimeError
,
match
=
"Engine process .* died"
):
await
client
.
check_health
()
client
.
close
()
tests/mq_llm_engine/utils.py
View file @
3b3f1e74
...
@@ -68,7 +68,7 @@ class RemoteMQLLMEngine:
...
@@ -68,7 +68,7 @@ class RemoteMQLLMEngine:
async
def
make_client
(
self
)
->
MQLLMEngineClient
:
async
def
make_client
(
self
)
->
MQLLMEngineClient
:
engine_config
=
self
.
engine_args
.
create_engine_config
()
engine_config
=
self
.
engine_args
.
create_engine_config
()
client
=
MQLLMEngineClient
(
self
.
ipc_path
,
engine_config
)
client
=
MQLLMEngineClient
(
self
.
ipc_path
,
engine_config
,
self
.
proc
.
pid
)
while
True
:
while
True
:
try
:
try
:
await
client
.
setup
()
await
client
.
setup
()
...
...
vllm/engine/multiprocessing/client.py
View file @
3b3f1e74
...
@@ -6,6 +6,7 @@ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
...
@@ -6,6 +6,7 @@ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
Optional
,
Union
,
cast
,
overload
)
Optional
,
Union
,
cast
,
overload
)
import
cloudpickle
import
cloudpickle
import
psutil
import
zmq
import
zmq
import
zmq.asyncio
import
zmq.asyncio
from
zmq
import
Frame
# type: ignore[attr-defined]
from
zmq
import
Frame
# type: ignore[attr-defined]
...
@@ -77,7 +78,8 @@ class MQLLMEngineClient(EngineClient):
...
@@ -77,7 +78,8 @@ class MQLLMEngineClient(EngineClient):
every N seconds, confirming the engine is healthy
every N seconds, confirming the engine is healthy
"""
"""
def
__init__
(
self
,
ipc_path
:
str
,
engine_config
:
EngineConfig
):
def
__init__
(
self
,
ipc_path
:
str
,
engine_config
:
EngineConfig
,
engine_pid
:
int
):
self
.
context
=
zmq
.
asyncio
.
Context
()
self
.
context
=
zmq
.
asyncio
.
Context
()
self
.
_errored_with
:
Optional
[
BaseException
]
=
None
self
.
_errored_with
:
Optional
[
BaseException
]
=
None
...
@@ -115,6 +117,7 @@ class MQLLMEngineClient(EngineClient):
...
@@ -115,6 +117,7 @@ class MQLLMEngineClient(EngineClient):
# Loop to check health of the LLMEngine periodically.
# Loop to check health of the LLMEngine periodically.
# Started after the MQLLMEngine is ready.
# Started after the MQLLMEngine is ready.
self
.
health_loop
:
Optional
[
asyncio
.
Task
]
=
None
self
.
health_loop
:
Optional
[
asyncio
.
Task
]
=
None
self
.
_engine_process
=
psutil
.
Process
(
engine_pid
)
@
staticmethod
@
staticmethod
def
is_unsupported_config
(
engine_args
:
AsyncEngineArgs
):
def
is_unsupported_config
(
engine_args
:
AsyncEngineArgs
):
...
@@ -131,21 +134,22 @@ class MQLLMEngineClient(EngineClient):
...
@@ -131,21 +134,22 @@ class MQLLMEngineClient(EngineClient):
socket
.
close
(
linger
=
0
)
socket
.
close
(
linger
=
0
)
async
def
run_heartbeat_loop
(
self
,
timeout
:
int
):
async
def
run_heartbeat_loop
(
self
,
timeout
:
int
):
"""Background loop that continually
listens to the RPCServer for
"""Background loop that continually
checks to ensure the engine process
heartbeats
.
is still alive
.
"""
"""
try
:
try
:
while
True
:
while
True
:
if
await
self
.
heartbeat_socket
.
poll
(
timeout
=
timeout
)
==
0
:
# Check if the engine process is running:
# No heartbeat was received. Set error and exit the loop
if
not
self
.
_engine_process
.
is_running
()
or
(
self
.
_engine_process
.
status
()
==
psutil
.
STATUS_ZOMBIE
):
# NB: is_running() returns True for zombies
self
.
_set_errored
(
self
.
_set_errored
(
TimeoutError
(
"No heartbeat received "
RuntimeError
(
"from MQLLMEngine"
))
f
"Engine process (pid
{
self
.
_engine_process
.
pid
}
) "
logger
.
debug
(
"Shutting down MQLLMEngineClient check "
"died."
))
"health loop due to timeout"
)
break
break
else
:
if
await
self
.
heartbeat_socket
.
poll
(
timeout
=
timeout
)
:
# Heartbeat received- check the message
# Heartbeat received- check the message
await
self
.
_check_success
(
await
self
.
_check_success
(
error_message
=
"Heartbeat failed."
,
error_message
=
"Heartbeat failed."
,
...
@@ -156,6 +160,11 @@ class MQLLMEngineClient(EngineClient):
...
@@ -156,6 +160,11 @@ class MQLLMEngineClient(EngineClient):
except
asyncio
.
CancelledError
:
except
asyncio
.
CancelledError
:
logger
.
debug
(
"Shutting down MQLLMEngineClient check health loop."
)
logger
.
debug
(
"Shutting down MQLLMEngineClient check health loop."
)
except
psutil
.
NoSuchProcess
:
self
.
_set_errored
(
RuntimeError
(
f
"Engine process (pid
{
self
.
_engine_process
.
pid
}
) died."
))
except
Exception
as
e
:
except
Exception
as
e
:
self
.
_set_errored
(
e
)
self
.
_set_errored
(
e
)
...
...
vllm/engine/multiprocessing/engine.py
View file @
3b3f1e74
import
pickle
import
pickle
import
signal
import
signal
import
threading
import
time
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Iterator
,
List
,
Optional
,
Union
from
typing
import
Iterator
,
List
,
Optional
,
Union
...
@@ -21,7 +19,7 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
...
@@ -21,7 +19,7 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCStartupRequest
,
RPCStartupResponse
,
RPCStartupRequest
,
RPCStartupResponse
,
RPCUProfileRequest
)
RPCUProfileRequest
)
# yapf: enable
# yapf: enable
from
vllm.envs
import
VLLM_RPC_TIMEOUT
,
VLLM_USE_V1
from
vllm.envs
import
VLLM_USE_V1
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
...
@@ -108,20 +106,6 @@ class MQLLMEngine:
...
@@ -108,20 +106,6 @@ class MQLLMEngine:
# Error state.
# Error state.
self
.
_errored_with
:
Optional
[
BaseException
]
=
None
self
.
_errored_with
:
Optional
[
BaseException
]
=
None
# Heartbeat thread
self
.
heartbeat_thread
=
threading
.
Thread
(
target
=
self
.
_heartbeat_loop
,
daemon
=
True
)
self
.
_heartbeat_stop_event
=
threading
.
Event
()
# The heartbeat needs to be faster than what the client will wait for
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self
.
heartbeat_interval_seconds
=
VLLM_RPC_TIMEOUT
/
5000.0
self
.
_last_alive_time
=
time
.
time
()
# The heartbeats can tolerate a long period of the engine chugging
# away at a generation request.
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self
.
last_alive_threshold
=
VLLM_RPC_TIMEOUT
*
3.0
/
1000.0
@
property
@
property
def
dead_error
(
self
)
->
BaseException
:
def
dead_error
(
self
)
->
BaseException
:
if
self
.
_errored_with
is
not
None
:
if
self
.
_errored_with
is
not
None
:
...
@@ -157,8 +141,6 @@ class MQLLMEngine:
...
@@ -157,8 +141,6 @@ class MQLLMEngine:
try
:
try
:
logger
.
debug
(
"Starting Startup Loop."
)
logger
.
debug
(
"Starting Startup Loop."
)
self
.
run_startup_loop
()
self
.
run_startup_loop
()
logger
.
debug
(
"Starting heartbeat thread"
)
self
.
heartbeat_thread
.
start
()
logger
.
debug
(
"Starting Engine Loop."
)
logger
.
debug
(
"Starting Engine Loop."
)
self
.
run_engine_loop
()
self
.
run_engine_loop
()
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -172,7 +154,6 @@ class MQLLMEngine:
...
@@ -172,7 +154,6 @@ class MQLLMEngine:
def
cleanup
(
self
):
def
cleanup
(
self
):
"""Cleanup zeromq state on shutdown."""
"""Cleanup zeromq state on shutdown."""
# Closes all sockets and destroys context.
# Closes all sockets and destroys context.
self
.
_heartbeat_stop_event
.
set
()
self
.
ctx
.
destroy
(
linger
=
0
)
self
.
ctx
.
destroy
(
linger
=
0
)
del
self
.
engine
del
self
.
engine
...
@@ -211,11 +192,12 @@ class MQLLMEngine:
...
@@ -211,11 +192,12 @@ class MQLLMEngine:
"""Core busy loop of the LLMEngine."""
"""Core busy loop of the LLMEngine."""
while
True
:
while
True
:
self
.
_alive
()
if
not
self
.
engine
.
has_unfinished_requests
():
if
not
self
.
engine
.
has_unfinished_requests
():
# Poll until there is work to do.
# Poll until there is work to do.
while
self
.
input_socket
.
poll
(
timeout
=
POLLING_TIMEOUT_MS
)
==
0
:
while
self
.
input_socket
.
poll
(
timeout
=
POLLING_TIMEOUT_MS
)
==
0
:
self
.
_alive
()
# When there's no work, check on engine health and send
# health status back to client
self
.
_health_check
()
self
.
engine
.
do_log_stats
()
self
.
engine
.
do_log_stats
()
logger
.
debug
(
"Waiting for new requests in engine loop."
)
logger
.
debug
(
"Waiting for new requests in engine loop."
)
...
@@ -314,26 +296,10 @@ class MQLLMEngine:
...
@@ -314,26 +296,10 @@ class MQLLMEngine:
if
self
.
log_requests
:
if
self
.
log_requests
:
logger
.
info
(
"Aborted request %s."
,
request
.
request_id
)
logger
.
info
(
"Aborted request %s."
,
request
.
request_id
)
def
_heartbeat_loop
(
self
):
def
_health_check
(
self
):
while
not
self
.
_heartbeat_stop_event
.
wait
(
timeout
=
self
.
heartbeat_interval_seconds
):
# Loops until the stop event is set
self
.
_heartbeat
()
logger
.
debug
(
"Exiting MQLLMEngine heartbeat thread"
)
def
_heartbeat
(
self
):
# Send unhealthy if engine has already errored
# Send unhealthy if engine has already errored
if
self
.
_errored_with
is
not
None
:
if
self
.
_errored_with
is
not
None
:
self
.
_send_unhealthy
(
self
.
_errored_with
)
self
.
_send_unhealthy
(
self
.
_errored_with
)
# Check for life of the main loop
elif
time
.
time
()
-
self
.
_last_alive_time
>
self
.
last_alive_threshold
:
self
.
_send_unhealthy
(
RuntimeError
(
"Engine loop has died"
))
else
:
# Otherwise- check health of the engine
# self.engine.check_health() raises on unhealthy
try
:
try
:
self
.
engine
.
check_health
()
self
.
engine
.
check_health
()
self
.
_send_healthy
()
self
.
_send_healthy
()
...
@@ -369,9 +335,6 @@ class MQLLMEngine:
...
@@ -369,9 +335,6 @@ class MQLLMEngine:
if
self
.
_errored_with
is
None
:
if
self
.
_errored_with
is
None
:
self
.
_errored_with
=
e
self
.
_errored_with
=
e
def
_alive
(
self
):
self
.
_last_alive_time
=
time
.
time
()
def
start_profile
(
self
)
->
None
:
def
start_profile
(
self
)
->
None
:
if
type
(
self
.
engine
.
model_executor
)
is
GPUExecutor
:
if
type
(
self
.
engine
.
model_executor
)
is
GPUExecutor
:
self
.
engine
.
model_executor
.
start_profile
()
self
.
engine
.
model_executor
.
start_profile
()
...
...
vllm/entrypoints/openai/api_server.py
View file @
3b3f1e74
...
@@ -176,13 +176,16 @@ async def build_async_engine_client_from_engine_args(
...
@@ -176,13 +176,16 @@ async def build_async_engine_client_from_engine_args(
UsageContext
.
OPENAI_API_SERVER
,
UsageContext
.
OPENAI_API_SERVER
,
ipc_path
))
ipc_path
))
engine_process
.
start
()
engine_process
.
start
()
logger
.
info
(
"Started engine process with PID %d"
,
engine_process
.
pid
)
engine_pid
=
engine_process
.
pid
assert
engine_pid
is
not
None
,
"Engine process failed to start"
logger
.
info
(
"Started engine process with PID %d"
,
engine_pid
)
# Build RPCClient, which conforms to EngineClient Protocol.
# Build RPCClient, which conforms to EngineClient Protocol.
# NOTE: Actually, this is not true yet. We still need to support
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
# embedding models via RPC (see TODO above)
engine_config
=
engine_args
.
create_engine_config
()
engine_config
=
engine_args
.
create_engine_config
()
mp_engine_client
=
MQLLMEngineClient
(
ipc_path
,
engine_config
)
mp_engine_client
=
MQLLMEngineClient
(
ipc_path
,
engine_config
,
engine_pid
)
try
:
try
:
while
True
:
while
True
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment