Unverified Commit 21b9c49a authored by Joe Runde's avatar Joe Runde Committed by GitHub
Browse files

[Frontend] Kill the server on engine death (#6594)


Signed-off-by: default avatarJoe Runde <joe@joerun.de>
Signed-off-by: default avatarJoe Runde <Joseph.Runde@ibm.com>
parent 5fb4a3f6
import json
import os
import openai
import pytest
from ...utils import RemoteOpenAIServer
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@pytest.mark.asyncio
async def test_shutdown_on_engine_failure(tmp_path):
# Use a bad adapter to crash the engine
# (This test will fail when that bug is fixed)
adapter_path = tmp_path / "bad_adapter"
os.mkdir(adapter_path)
with open(adapter_path / "adapter_model_config.json", "w") as f:
json.dump({"not": "real"}, f)
with open(adapter_path / "adapter_model.safetensors", "wb") as f:
f.write(b"this is fake")
# dtype, max-len etc set so that this can run in CI
args = [
"--dtype",
"bfloat16",
"--max-model-len",
"8192",
"--enforce-eager",
"--max-num-seqs",
"128",
"--enable-lora",
"--lora-modules",
f"bad-adapter={tmp_path / 'bad_adapter'}",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
client = remote_server.get_async_client()
with pytest.raises(openai.APIConnectionError):
# This crashes the engine
await client.completions.create(model="bad-adapter",
prompt="Hello, my name is")
# Now the server should shut down
return_code = remote_server.proc.wait(timeout=1)
assert return_code is not None
...@@ -58,7 +58,7 @@ def _log_task_completion(task: asyncio.Task, ...@@ -58,7 +58,7 @@ def _log_task_completion(task: asyncio.Task,
error_callback(exception) error_callback(exception)
raise AsyncEngineDeadError( raise AsyncEngineDeadError(
"Task finished unexpectedly. This should never happen! " "Task finished unexpectedly. This should never happen! "
"Please open an issue on Github. See stack trace above for the" "Please open an issue on Github. See stack trace above for the "
"actual cause.") from e "actual cause.") from e
...@@ -132,7 +132,9 @@ class RequestTracker: ...@@ -132,7 +132,9 @@ class RequestTracker:
self._request_streams[request_id].put(exc) self._request_streams[request_id].put(exc)
self.abort_request(request_id) self.abort_request(request_id)
else: else:
for rid, stream in self._request_streams.items(): # NB: list() used here because self.abort_request pops the stream
# out of self._request_streams, so we can't iterate on it directly
for rid, stream in list(self._request_streams.items()):
stream.put(exc) stream.put(exc)
self.abort_request(rid) self.abort_request(rid)
......
...@@ -118,6 +118,7 @@ async def run_server(args: Namespace, ...@@ -118,6 +118,7 @@ async def run_server(args: Namespace,
shutdown_task = await serve_http( shutdown_task = await serve_http(
app, app,
engine=engine,
host=args.host, host=args.host,
port=args.port, port=args.port,
log_level=args.log_level, log_level=args.log_level,
......
import asyncio import asyncio
import signal import signal
from http import HTTPStatus
from typing import Any from typing import Any
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI, Response
from vllm import envs
from vllm.engine.async_llm_engine import AsyncEngineDeadError
from vllm.engine.protocol import AsyncEngineClient
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
async def serve_http(app: FastAPI, **uvicorn_kwargs: Any): async def serve_http(app: FastAPI, engine: AsyncEngineClient,
**uvicorn_kwargs: Any):
logger.info("Available routes are:") logger.info("Available routes are:")
for route in app.routes: for route in app.routes:
methods = getattr(route, "methods", None) methods = getattr(route, "methods", None)
...@@ -23,6 +28,7 @@ async def serve_http(app: FastAPI, **uvicorn_kwargs: Any): ...@@ -23,6 +28,7 @@ async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
config = uvicorn.Config(app, **uvicorn_kwargs) config = uvicorn.Config(app, **uvicorn_kwargs)
server = uvicorn.Server(config) server = uvicorn.Server(config)
_add_shutdown_handlers(app, server, engine)
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
...@@ -44,3 +50,37 @@ async def serve_http(app: FastAPI, **uvicorn_kwargs: Any): ...@@ -44,3 +50,37 @@ async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info("Gracefully stopping http server") logger.info("Gracefully stopping http server")
return server.shutdown() return server.shutdown()
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server,
engine: AsyncEngineClient) -> None:
"""Adds handlers for fatal errors that should crash the server"""
@app.exception_handler(RuntimeError)
async def runtime_error_handler(_, __):
"""On generic runtime error, check to see if the engine has died.
It probably has, in which case the server will no longer be able to
handle requests. Trigger a graceful shutdown with a SIGTERM."""
if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored
and not engine.is_running):
logger.fatal("AsyncLLMEngine has failed, terminating server "
"process")
# See discussions here on shutting down a uvicorn server
# https://github.com/encode/uvicorn/discussions/1103
# In this case we cannot await the server shutdown here because
# this handler must first return to close the connection for
# this request.
server.should_exit = True
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
@app.exception_handler(AsyncEngineDeadError)
async def engine_dead_handler(_, __):
"""Kill the server if the async engine is already dead. It will
not handle any further requests."""
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
logger.fatal("AsyncLLMEngine is already dead, terminating server "
"process")
server.should_exit = True
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
...@@ -357,6 +357,7 @@ async def run_server(args, **uvicorn_kwargs) -> None: ...@@ -357,6 +357,7 @@ async def run_server(args, **uvicorn_kwargs) -> None:
shutdown_task = await serve_http( shutdown_task = await serve_http(
app, app,
engine=async_engine_client,
host=args.host, host=args.host,
port=args.port, port=args.port,
log_level=args.uvicorn_log_level, log_level=args.uvicorn_log_level,
......
...@@ -33,6 +33,7 @@ class AsyncEngineRPCClient: ...@@ -33,6 +33,7 @@ class AsyncEngineRPCClient:
# Wait until server is ready. # Wait until server is ready.
await self.wait_for_server() await self.wait_for_server()
self._errored = False
# Get the configs. # Get the configs.
self.model_config = await self._get_model_config_rpc() self.model_config = await self._get_model_config_rpc()
...@@ -169,7 +170,7 @@ class AsyncEngineRPCClient: ...@@ -169,7 +170,7 @@ class AsyncEngineRPCClient:
expected_type=SchedulerConfig, expected_type=SchedulerConfig,
error_message="Could not get SchedulerConfig from RPC Server") error_message="Could not get SchedulerConfig from RPC Server")
async def _get_lora_config_rpc(self): async def _get_lora_config_rpc(self) -> LoRAConfig:
"""Get LoRAConfig from the RPCServer""" """Get LoRAConfig from the RPCServer"""
return await self._send_get_data_rpc_request( return await self._send_get_data_rpc_request(
...@@ -177,7 +178,7 @@ class AsyncEngineRPCClient: ...@@ -177,7 +178,7 @@ class AsyncEngineRPCClient:
expected_type=LoRAConfig, expected_type=LoRAConfig,
error_message="Could not get LoRAConfig from RPC Server") error_message="Could not get LoRAConfig from RPC Server")
async def _is_tracing_enabled_rpc(self) -> ParallelConfig: async def _is_tracing_enabled_rpc(self) -> bool:
"""Get is_tracing_enabled flag from the RPCServer""" """Get is_tracing_enabled flag from the RPCServer"""
return await self._send_get_data_rpc_request( return await self._send_get_data_rpc_request(
...@@ -200,6 +201,18 @@ class AsyncEngineRPCClient: ...@@ -200,6 +201,18 @@ class AsyncEngineRPCClient:
request=RPCUtilityRequest.DO_LOG_STATS, request=RPCUtilityRequest.DO_LOG_STATS,
error_message="RPCRequest DO_LOG_STATS failed.") error_message="RPCRequest DO_LOG_STATS failed.")
@property
def is_running(self) -> bool:
return not self._errored
@property
def is_stopped(self) -> bool:
return self._errored
@property
def errored(self) -> bool:
return self._errored
async def generate( async def generate(
self, self,
inputs: PromptInputs, inputs: PromptInputs,
...@@ -233,6 +246,15 @@ class AsyncEngineRPCClient: ...@@ -233,6 +246,15 @@ class AsyncEngineRPCClient:
request_output = cloudpickle.loads(message) request_output = cloudpickle.loads(message)
if isinstance(request_output, Exception): if isinstance(request_output, Exception):
# On exception, check if the server is still healthy.
# Use this to set the sync `is_running` and `errored`
# properties.
try:
await self.check_health()
except Exception:
self._errored = True
# NB: do before raising here so that the flag is set
# by the time the caller receives this exception
raise request_output raise request_output
finished = request_output.finished finished = request_output.finished
......
...@@ -96,14 +96,17 @@ class AsyncEngineRPCServer: ...@@ -96,14 +96,17 @@ class AsyncEngineRPCServer:
async def abort(self, identity, request: RPCAbortRequest): async def abort(self, identity, request: RPCAbortRequest):
"""Abort request and notify the client of success.""" """Abort request and notify the client of success."""
# Abort the request in the llm engine. try:
await self.engine.abort(request.request_id) # Abort the request in the llm engine.
await self.engine.abort(request.request_id)
# Send confirmation to the client. except Exception:
await self.socket.send_multipart([ logger.warning("Failed to abort request %s", request.request_id)
identity, finally:
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), # Send confirmation to the client.
]) await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
async def generate(self, identity, generate_request: RPCGenerateRequest): async def generate(self, identity, generate_request: RPCGenerateRequest):
try: try:
......
...@@ -49,6 +49,7 @@ if TYPE_CHECKING: ...@@ -49,6 +49,7 @@ if TYPE_CHECKING:
NVCC_THREADS: Optional[str] = None NVCC_THREADS: Optional[str] = None
VLLM_USE_PRECOMPILED: bool = False VLLM_USE_PRECOMPILED: bool = False
VLLM_NO_DEPRECATION_WARNING: bool = False VLLM_NO_DEPRECATION_WARNING: bool = False
VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: bool = False
CMAKE_BUILD_TYPE: Optional[str] = None CMAKE_BUILD_TYPE: Optional[str] = None
VERBOSE: bool = False VERBOSE: bool = False
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
...@@ -335,6 +336,11 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -335,6 +336,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_NO_DEPRECATION_WARNING": "VLLM_NO_DEPRECATION_WARNING":
lambda: bool(int(os.getenv("VLLM_NO_DEPRECATION_WARNING", "0"))), lambda: bool(int(os.getenv("VLLM_NO_DEPRECATION_WARNING", "0"))),
# If set, the OpenAI API server will stay alive even after the underlying
# AsyncLLMEngine errors and stops serving requests
"VLLM_KEEP_ALIVE_ON_ENGINE_DEATH":
lambda: bool(os.getenv("VLLM_KEEP_ALIVE_ON_ENGINE_DEATH", 0)),
# If the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN is set, it allows # If the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN is set, it allows
# the user to specify a max sequence length greater than # the user to specify a max sequence length greater than
# the max length derived from the model's config.json. # the max length derived from the model's config.json.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment