Unverified Commit 6e0c9d6b authored by Joe Runde's avatar Joe Runde Committed by GitHub
Browse files

[Bugfix] Use heartbeats instead of health checks (#8583)

parent 6da1ab6b
...@@ -153,27 +153,20 @@ async def test_failed_abort(tmp_socket): ...@@ -153,27 +153,20 @@ async def test_failed_abort(tmp_socket):
await client.check_health() await client.check_health()
# Trigger an abort on the client side. # Trigger an abort on the client side.
async def bad_abort_after_2s(): # This request ID does not exist, and will cause the engine to error
await asyncio.sleep(2.0)
await client.abort(request_id="foo") await client.abort(request_id="foo")
# Trigger an abort in 2s from now. # Future generation requests will now fail
abort_task = asyncio.create_task(bad_abort_after_2s())
# Exception in abort() will happen during this generation.
# This will kill the engine and should return ENGINE_DEAD_ERROR
# with reference to the original KeyError("foo") # with reference to the original KeyError("foo")
with pytest.raises(MQEngineDeadError) as execinfo: with pytest.raises(MQEngineDeadError) as execinfo:
async for _ in client.generate( async for _ in client.generate(
inputs="Hello my name is", inputs="Hello my name is",
sampling_params=SamplingParams(max_tokens=2000), sampling_params=SamplingParams(max_tokens=10),
request_id=uuid.uuid4()): request_id=uuid.uuid4()):
pass pass
assert "KeyError" in repr(execinfo.value) assert "KeyError" in repr(execinfo.value)
assert client.errored assert client.errored
await abort_task
# This should raise the original error. # This should raise the original error.
with pytest.raises(RAISED_ERROR): with pytest.raises(RAISED_ERROR):
await client.check_health() await client.check_health()
......
...@@ -43,10 +43,6 @@ class RPCAbortRequest: ...@@ -43,10 +43,6 @@ class RPCAbortRequest:
request_id: str request_id: str
class RPCHealthRequest:
pass
class RPCStartupRequest(Enum): class RPCStartupRequest(Enum):
IS_SERVER_READY = 1 IS_SERVER_READY = 1
...@@ -56,8 +52,7 @@ class RPCStartupResponse: ...@@ -56,8 +52,7 @@ class RPCStartupResponse:
tracing_enabled: bool tracing_enabled: bool
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCHealthRequest, RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest]
RPCStartupRequest]
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError] REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]
......
...@@ -20,9 +20,8 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, ...@@ -20,9 +20,8 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T, IPC_OUTPUT_EXT, RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest, VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCHealthRequest, RPCError, RPCProcessRequest,
RPCProcessRequest, RPCStartupRequest, RPCStartupRequest, RPCStartupResponse)
RPCStartupResponse)
# yapf: enable # yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptInputs from vllm.inputs import PromptInputs
...@@ -95,9 +94,9 @@ class MQLLMEngineClient: ...@@ -95,9 +94,9 @@ class MQLLMEngineClient:
self.output_socket: Socket = self.context.socket(zmq.constants.PULL) self.output_socket: Socket = self.context.socket(zmq.constants.PULL)
self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}") self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}")
# IPC path for ack of check_health requests. # IPC path for acking heartbeats.
self.health_socket: Socket = self.context.socket(zmq.constants.PULL) self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
self.health_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
# IPC path for the data socket. # IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
...@@ -124,34 +123,28 @@ class MQLLMEngineClient: ...@@ -124,34 +123,28 @@ class MQLLMEngineClient:
finally: finally:
socket.close(linger=0) socket.close(linger=0)
async def run_check_health_loop(self, timeout: int): async def run_heartbeat_loop(self, timeout: int):
"""Background loop that continually probes the RPCServer for health. """Background loop that continually listens to the RPCServer for
heartbeats.
The loop sends CHECK_HEALTH requests to the INPUT_SOCKET, which
the MQLLMEngine server is blocking on.
The Server replies on the HEALTH_SOCKET (rather than on the
OUTPUT_SOCKET such that the messages are not intermingled with
output streaming).
""" """
try: try:
while True: while True:
if await self.health_socket.poll(timeout=timeout) == 0: if await self.heartbeat_socket.poll(timeout=timeout) == 0:
# Wakeup every N seconds and do a health probe. # No heartbeat was received. Set error and exit the loop
await self._send_one_way_rpc_request( self._set_errored(
RPCHealthRequest(), self.input_socket) TimeoutError("No heartbeat received "
"from MQLLMEngine"))
logger.debug("Shutting down MQLLMEngineClient check "
"health loop due to timeout")
break
# Wait for ack from the health socket.
await self._await_ack(error_message="Health check failed.",
socket=self.health_socket)
else: else:
# Server sent a health status message unprompted. # Heartbeat received- check the message
await self._check_success( await self._check_success(
error_message="Health check failed.", error_message="Heartbeat failed.",
socket=self.health_socket) socket=self.heartbeat_socket)
logger.debug("Health probe successful.") logger.debug("Heartbeat successful.")
except asyncio.CancelledError: except asyncio.CancelledError:
logger.debug("Shutting down MQLLMEngineClient check health loop.") logger.debug("Shutting down MQLLMEngineClient check health loop.")
...@@ -234,7 +227,7 @@ class MQLLMEngineClient: ...@@ -234,7 +227,7 @@ class MQLLMEngineClient:
# Start health_loop. # Start health_loop.
self.health_loop = asyncio.create_task( self.health_loop = asyncio.create_task(
self.run_check_health_loop(timeout=VLLM_RPC_TIMEOUT)) self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT))
def close(self): def close(self):
"""Destroy the ZeroMQ Context.""" """Destroy the ZeroMQ Context."""
......
import pickle import pickle
import signal import signal
import threading
import time
from contextlib import contextmanager from contextlib import contextmanager
from typing import Iterator, List, Optional, Union from typing import Iterator, List, Optional, Union
...@@ -15,10 +17,10 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, ...@@ -15,10 +17,10 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest, VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCHealthRequest, RPCError, RPCProcessRequest,
RPCProcessRequest, RPCStartupRequest, RPCStartupRequest, RPCStartupResponse)
RPCStartupResponse)
# yapf: enable # yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
...@@ -91,9 +93,9 @@ class MQLLMEngine: ...@@ -91,9 +93,9 @@ class MQLLMEngine:
self.output_socket = self.ctx.socket(zmq.constants.PUSH) self.output_socket = self.ctx.socket(zmq.constants.PUSH)
self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}") self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")
# Send health status back to client. # Send heartbeats back to client.
self.health_socket = self.ctx.socket(zmq.constants.PUSH) self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
self.health_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
# IPC path for the data socket. # IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
...@@ -101,6 +103,20 @@ class MQLLMEngine: ...@@ -101,6 +103,20 @@ class MQLLMEngine:
# Error state. # Error state.
self._errored_with: Optional[BaseException] = None self._errored_with: Optional[BaseException] = None
# Heartbeat thread
self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop,
daemon=True)
self._heartbeat_stop_event = threading.Event()
# The heartbeat needs to be faster than what the client will wait for
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self.heartbeat_interval_seconds = VLLM_RPC_TIMEOUT / 5000.0
self._last_alive_time = time.time()
# The heartbeats can tolerate a long period of the engine chugging
# away at a generation request.
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self.last_alive_threshold = VLLM_RPC_TIMEOUT * 3.0 / 1000.0
@property @property
def dead_error(self) -> BaseException: def dead_error(self) -> BaseException:
if self._errored_with is not None: if self._errored_with is not None:
...@@ -131,6 +147,8 @@ class MQLLMEngine: ...@@ -131,6 +147,8 @@ class MQLLMEngine:
try: try:
logger.debug("Starting Startup Loop.") logger.debug("Starting Startup Loop.")
self.run_startup_loop() self.run_startup_loop()
logger.debug("Starting heartbeat thread")
self.heartbeat_thread.start()
logger.debug("Starting Engine Loop.") logger.debug("Starting Engine Loop.")
self.run_engine_loop() self.run_engine_loop()
except Exception as e: except Exception as e:
...@@ -144,6 +162,7 @@ class MQLLMEngine: ...@@ -144,6 +162,7 @@ class MQLLMEngine:
def cleanup(self): def cleanup(self):
"""Cleanup zeromq state on shutdown.""" """Cleanup zeromq state on shutdown."""
# Closes all sockets and destroys context. # Closes all sockets and destroys context.
self._heartbeat_stop_event.set()
self.ctx.destroy(linger=0) self.ctx.destroy(linger=0)
del self.engine del self.engine
...@@ -182,9 +201,11 @@ class MQLLMEngine: ...@@ -182,9 +201,11 @@ class MQLLMEngine:
"""Core busy loop of the LLMEngine.""" """Core busy loop of the LLMEngine."""
while True: while True:
self._alive()
if not self.engine.has_unfinished_requests(): if not self.engine.has_unfinished_requests():
# Poll until there is work to do. # Poll until there is work to do.
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
self._alive()
self.engine.do_log_stats() self.engine.do_log_stats()
logger.debug("Waiting for new requests in engine loop.") logger.debug("Waiting for new requests in engine loop.")
...@@ -200,7 +221,6 @@ class MQLLMEngine: ...@@ -200,7 +221,6 @@ class MQLLMEngine:
def engine_step(self) -> List[RequestOutput]: def engine_step(self) -> List[RequestOutput]:
"""Engine step wrapper with error handling.""" """Engine step wrapper with error handling."""
try: try:
return self.engine.step() return self.engine.step()
except SystemExit: except SystemExit:
...@@ -229,10 +249,9 @@ class MQLLMEngine: ...@@ -229,10 +249,9 @@ class MQLLMEngine:
self._handle_process_request(request) self._handle_process_request(request)
elif isinstance(request, RPCAbortRequest): elif isinstance(request, RPCAbortRequest):
self._handle_abort_request(request) self._handle_abort_request(request)
elif isinstance(request, RPCHealthRequest):
self._handle_health_request()
else: else:
raise ValueError("Unknown RPCRequest Type: {request}") raise ValueError("Unknown RPCRequest Type: "
f"{type(request)}")
except Exception as e: except Exception as e:
self._set_errored(e) self._set_errored(e)
...@@ -279,13 +298,32 @@ class MQLLMEngine: ...@@ -279,13 +298,32 @@ class MQLLMEngine:
if self.log_requests: if self.log_requests:
logger.info("Aborted request %s.", request.request_id) logger.info("Aborted request %s.", request.request_id)
def _handle_health_request(self): def _heartbeat_loop(self):
while not self._heartbeat_stop_event.wait(
timeout=self.heartbeat_interval_seconds):
# Loops until the stop event is set
self._heartbeat()
logger.debug("Exiting MQLLMEngine heartbeat thread")
def _heartbeat(self):
# Send unhealthy if engine has already errored
if self._errored_with is not None: if self._errored_with is not None:
self._send_unhealthy(self._errored_with) self._send_unhealthy(self._errored_with)
# Raises error if unhealthy. # Check for life of the main loop
elif time.time() - self._last_alive_time > self.last_alive_threshold:
self._send_unhealthy(RuntimeError("Engine loop has died"))
else:
# Otherwise- check health of the engine
# self.engine.check_health() raises on unhealthy
try:
self.engine.check_health() self.engine.check_health()
self._send_healthy() self._send_healthy()
except Exception as e:
self._set_errored(e)
self._send_unhealthy(e)
def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
"""Send List of RequestOutput to RPCClient.""" """Send List of RequestOutput to RPCClient."""
...@@ -295,12 +333,14 @@ class MQLLMEngine: ...@@ -295,12 +333,14 @@ class MQLLMEngine:
def _send_healthy(self): def _send_healthy(self):
"""Send HEALTHY message to RPCClient.""" """Send HEALTHY message to RPCClient."""
self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False) if not self.heartbeat_socket.closed:
self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
def _send_unhealthy(self, error: BaseException): def _send_unhealthy(self, error: BaseException):
"""Send UNHEALTHY message to RPCClient.""" """Send UNHEALTHY message to RPCClient."""
if not self.heartbeat_socket.closed:
error_bytes = pickle.dumps(error) error_bytes = pickle.dumps(error)
self.health_socket.send_multipart((error_bytes, ), copy=False) self.heartbeat_socket.send_multipart((error_bytes, ), copy=False)
def _async_socket_engine_callback(self, def _async_socket_engine_callback(self,
request_outputs: REQUEST_OUTPUTS_T): request_outputs: REQUEST_OUTPUTS_T):
...@@ -313,6 +353,9 @@ class MQLLMEngine: ...@@ -313,6 +353,9 @@ class MQLLMEngine:
if self._errored_with is None: if self._errored_with is None:
self._errored_with = e self._errored_with = e
def _alive(self):
self._last_alive_time = time.time()
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
ipc_path: str): ipc_path: str):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment