Unverified Commit f7e3b0c5 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[Bugfix][Frontend] Fix Issues Under High Load With `zeromq` Frontend (#7394)


Co-authored-by: default avatarNick Hill <nickhill@us.ibm.com>
parent d3c002ea
...@@ -86,6 +86,7 @@ steps: ...@@ -86,6 +86,7 @@ steps:
- vllm/ - vllm/
commands: commands:
- pip install -e ./plugins/vllm_add_dummy_model - pip install -e ./plugins/vllm_add_dummy_model
- pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@a4987bba6e9e9b3f22bd3a6c1ecf0abd04fd5622#egg=lm_eval[api]
- pytest -v -s entrypoints/llm - pytest -v -s entrypoints/llm
- pytest -v -s entrypoints/openai - pytest -v -s entrypoints/openai
......
"""
This file test accuracy of the vLLM server via LMEval.
It uses local-completions, which interacts with vLLM
through the OAI API with N concurrent connections.
This simulates real work usage of the API and makes
sure that the zmq frontend mp RPC message passing and
AsyncLLMEngine are working correctly.
"""
import lm_eval
import pytest
from ...utils import RemoteOpenAIServer
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"
NUM_CONCURRENT = 500
TASK = "gsm8k"
FILTER = "exact_match,strict-match"
RTOL = 0.03
EXPECTED_VALUE = 0.58
@pytest.fixture(scope="module")
def server():
args = [
"--max-model-len", "4096", "--enable-chunked-prefill",
"--disable-log-requests", "--enforce-eager"
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.fixture(scope="module")
def server_data(server):
return {
"url": f"{server.url_for('v1')}/completions",
}
def test_lm_eval_accuracy(server_data):
model_args = (f"model={MODEL_NAME},"
f"base_url={server_data['url']},"
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False")
results = lm_eval.simple_evaluate(
model="local-completions",
model_args=model_args,
tasks=TASK,
)
measured_value = results["results"][TASK][FILTER]
assert (measured_value - RTOL < EXPECTED_VALUE
and measured_value + RTOL > EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
...@@ -766,6 +766,11 @@ class AsyncLLMEngine: ...@@ -766,6 +766,11 @@ class AsyncLLMEngine:
def errored(self) -> bool: def errored(self) -> bool:
return self._errored_with is not None return self._errored_with is not None
@property
def limit_concurrency(self) -> Optional[int]:
"""Maximum number of concurrently running requests."""
return None
def set_errored(self, exc: Exception) -> None: def set_errored(self, exc: Exception) -> None:
self._errored_with = exc self._errored_with = exc
......
...@@ -29,6 +29,10 @@ class AsyncEngineClient(Protocol): ...@@ -29,6 +29,10 @@ class AsyncEngineClient(Protocol):
def errored(self) -> bool: def errored(self) -> bool:
... ...
@property
def limit_concurrency(self) -> Optional[int]:
"""Maximum number of concurrently running requests."""
def generate( def generate(
self, self,
inputs: PromptInputs, inputs: PromptInputs,
......
...@@ -27,6 +27,15 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient, ...@@ -27,6 +27,15 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,
logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) logger.info("Route: %s, Methods: %s", path, ', '.join(methods))
# Set concurrency limits in uvicorn if running in multiprocessing mode
# since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536).
if engine.limit_concurrency is not None:
logger.info(
"Launching Uvicorn with --limit_concurrency %s. To avoid this "
"limit at the expense of performance run with "
"--disable-frontend-multiprocessing", engine.limit_concurrency)
uvicorn_kwargs["limit_concurrency"] = engine.limit_concurrency
config = uvicorn.Config(app, **uvicorn_kwargs) config = uvicorn.Config(app, **uvicorn_kwargs)
server = uvicorn.Server(config) server = uvicorn.Server(config)
_add_shutdown_handlers(app, server, engine) _add_shutdown_handlers(app, server, engine)
......
...@@ -135,6 +135,12 @@ async def build_async_engine_client( ...@@ -135,6 +135,12 @@ async def build_async_engine_client(
logger.info("Multiprocessing frontend to use %s for RPC Path.", logger.info("Multiprocessing frontend to use %s for RPC Path.",
rpc_path) rpc_path)
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
rpc_client = AsyncEngineRPCClient(rpc_path)
async_engine_client = rpc_client # type: ignore
# Start RPCServer in separate process (holds the AsyncLLMEngine). # Start RPCServer in separate process (holds the AsyncLLMEngine).
context = multiprocessing.get_context("spawn") context = multiprocessing.get_context("spawn")
# the current process might have CUDA context, # the current process might have CUDA context,
...@@ -145,11 +151,6 @@ async def build_async_engine_client( ...@@ -145,11 +151,6 @@ async def build_async_engine_client(
rpc_server_process.start() rpc_server_process.start()
logger.info("Started engine process with PID %d", logger.info("Started engine process with PID %d",
rpc_server_process.pid) rpc_server_process.pid)
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
rpc_client = AsyncEngineRPCClient(rpc_path)
async_engine_client = rpc_client # type: ignore
try: try:
while True: while True:
......
...@@ -7,8 +7,18 @@ from vllm.lora.request import LoRARequest ...@@ -7,8 +7,18 @@ from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
# Success string used for RPC instructions.
VLLM_RPC_SUCCESS_STR = "SUCCESS" VLLM_RPC_SUCCESS_STR = "SUCCESS"
VLLM_RPC_HEALTHY_STR = "HEALTHY"
# Timeouts.
VLLM_RPC_SERVER_START_TIMEOUT_MS = 1000
VLLM_RPC_HEALTH_TIMEOUT_MS = 10000
# Minimum value of ZMQ.SOCKET_LIMIT to run mp.
VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000
# HWM is set to Infinity.
VLLM_RPC_ZMQ_HWM = 0
@dataclass @dataclass
...@@ -34,7 +44,7 @@ class RPCUtilityRequest(Enum): ...@@ -34,7 +44,7 @@ class RPCUtilityRequest(Enum):
GET_SCHEDULER_CONFIG = 5 GET_SCHEDULER_CONFIG = 5
GET_LORA_CONFIG = 6 GET_LORA_CONFIG = 6
DO_LOG_STATS = 7 DO_LOG_STATS = 7
CHECK_HEALTH = 8 IS_SERVER_HEALTHY = 8
IS_TRACING_ENABLED = 9 IS_TRACING_ENABLED = 9
......
import asyncio
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, AsyncGenerator, Mapping, Optional from typing import Any, AsyncGenerator, Mapping, Optional
from uuid import uuid4
import cloudpickle import cloudpickle
import zmq import zmq
...@@ -7,32 +9,140 @@ import zmq.asyncio ...@@ -7,32 +9,140 @@ import zmq.asyncio
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig) ParallelConfig, SchedulerConfig)
# yapf: disable
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
VLLM_RPC_HEALTHY_STR, VLLM_RPC_HEALTH_TIMEOUT_MS,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest, VLLM_RPC_SERVER_START_TIMEOUT_MS,
VLLM_RPC_SOCKET_LIMIT_CUTOFF,
VLLM_RPC_SUCCESS_STR,
VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest) RPCGenerateRequest, RPCUtilityRequest)
# yapf: enable
from vllm.inputs import PromptInputs from vllm.inputs import PromptInputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
# Time to wait before checking it the server process is alive. logger = init_logger(__name__)
SERVER_START_TIMEOUT_MS = 1000
# Path used for inprocess proxy.
INPROC_PROXY_PATH = f"inproc://{uuid4()}"
class AsyncEngineRPCClient: class AsyncEngineRPCClient:
"""
RPCClient that connects to the RPCServer wrapping AsyncLLMEngine.
The overall design mirrors the Asynchronous Client Server Pattern
https://zguide.zeromq.org/docs/chapter3/#The-Asynchronous-Client-Server-Pattern
On startup, the RPCClient:
- makes DEALER socket (to_rpc_server) that connects to the RPCServer
via ipc, which uses unix sockets under the hood
(https://libzmq.readthedocs.io/en/zeromq4-1/zmq_ipc.html)
- makes ROUTER socket (from_api_server) that binds to a random
inproc address, which uses memory under the hood
(https://libzmq.readthedocs.io/en/zeromq3-x/zmq_inproc.html)
- runs a proxy in a background asyncio task between
from_api_server (ROUTER, inproc) and to_rpc_server (DEALER ipc, )
Each request handled by the asyncio api_server calls generate():
- make a DEALER socket that connects to from_api_server via inproc
- send a RCPGenerateRequest to the inproc socket
- background proxy forwards the request from inproc -> ipc
- RPCServer responds to the request one token at a time over ipc
- background proxy forwards the response from ipc -> inproc
The connection looks like this:
DEALER <- inproc -> [ ROUTER | DEALER ] <- ipc -> DEALER
Message routing is performed via identities that are managed by the
ROUTER socket. ROUTER sockets track every connection it has and
tells the caller about these. The way it tells the caller is to stick
the connection identity in front of each message received. When we
send the message via a ROUTER, we first send an identity frame.
See https://zguide.zeromq.org/docs/chapter3/#The-Extended-Reply-Envelope
for more details on connection identities.
This proxy design enables us to use a single unix socket, which
improves performance by avoiding syscalls (~5%) and avoids resource limits
such as ulimit, which defaults to 1024 on ubuntu.
Note: we run set_hwm(0) on each socket, which sets the HWM to inf,
which is required to avoid dropping messages under high load.
This is generally not advisable. However, since we are in control
of both sides of the connection + failure on either side is
catastrophic to the overall system health and memory profiling
suggests limited memory overhead relative to asyncio, we will
proceed for now.
See https://zguide.zeromq.org/docs/chapter2/#High-Water-Marks
for more details on high water marks.
"""
def __init__(self, rpc_path: str): def __init__(self, rpc_path: str):
self.context = zmq.asyncio.Context() self.context = zmq.asyncio.Context()
self.rpc_path = rpc_path
# Maximum number of sockets that can be opened (typically 65536).
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT)
if socket_limit < VLLM_RPC_SOCKET_LIMIT_CUTOFF:
raise ValueError(
f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps "
"the number of concurrent requests vLLM can process. Launch "
"vLLM with --disable-frontend-multiprocessing and open a "
"GitHub issue so we can investigate.")
# We only have 1 ipc connection that uses unix sockets, so
# safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will
# not run into ulimit issues)
self.context.set(zmq.constants.MAX_SOCKETS, socket_limit)
# IPC connection to RPC Server (uses unix sockets).
self.to_rpc_server = self.context.socket(zmq.constants.DEALER)
self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM)
self.to_rpc_server.bind(rpc_path)
# In process proxy to RPC Server (uses memory-based messaging).
self.from_api_server = self.context.socket(zmq.constants.ROUTER)
self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM)
self.from_api_server.bind(INPROC_PROXY_PATH)
# Asyncio background task for the proxy.
self.proxy_task = asyncio.create_task(
self.run_proxy(self.from_api_server, self.to_rpc_server))
# Since we open 1 inproc socket per request, we have a hard cap on
# the number of requests that can run in vLLM w. frontend
# mulitprocessing. This value is used uvicorn to launch
# with --limit-concurrency to return 503 when server is overloaded.
# We need 2 sockets per request - 2:
# 1 for generate(), 1 for abort(), do_log_stats(), check_health()
self.limit_concurrency = socket_limit // 2 - 2
async def run_proxy(self, socket_from, socket_to):
"""Background task that runs a proxy"""
poller = zmq.asyncio.Poller()
poller.register(socket_from, zmq.constants.POLLIN)
poller.register(socket_to, zmq.constants.POLLIN)
while True:
events = await poller.poll()
events = dict(events)
if socket_from in events:
identity, msg = await socket_from.recv_multipart()
await socket_to.send_multipart([identity, msg])
if socket_to in events:
identity, msg = await socket_to.recv_multipart()
await socket_from.send_multipart([identity, msg])
async def setup(self): async def setup(self):
"""Setup the client before it starts sending server requests.""" """Setup the client before it starts sending server requests."""
# Wait until server is ready. # Wait until server is ready.
await self.wait_for_server() await self._wait_for_server_rpc()
self._errored = False self._errored = False
# Get the configs. # Get the configs.
...@@ -51,29 +161,23 @@ class AsyncEngineRPCClient: ...@@ -51,29 +161,23 @@ class AsyncEngineRPCClient:
def close(self): def close(self):
"""Destroy the ZeroMQ Context.""" """Destroy the ZeroMQ Context."""
# Close all sockets associated with this context and
# then terminate the context.
self.from_api_server.close()
self.to_rpc_server.close()
self.context.destroy() self.context.destroy()
@contextmanager @contextmanager
def socket(self): def to_proxy_socket(self):
# Ensure client sockets are always closed after use # Connect to the RPCServer via the proxy.
# Connect to RPC socket for Request-Reply pattern,
# Note that we use DEALER to enable asynchronous communication # Note that we use DEALER to enable asynchronous communication
# to enable streaming. # to enable streaming.
socket = self.context.socket(zmq.constants.DEALER) socket = self.context.socket(zmq.constants.DEALER)
socket.set_hwm(VLLM_RPC_ZMQ_HWM)
try: try:
socket.connect(self.rpc_path) socket.connect(INPROC_PROXY_PATH)
yield socket yield socket
finally: finally:
# linger == 0 means discard unsent messages
# when the socket is closed. This is necessary
# because otherwise self.context.destroy() will
# wait for 30 seconds until unsent messages are
# received, which is impossible if the server
# crashed. In the absence of a server crash we
# always expect a response before closing the
# socket anyway.
# Reference: http://api.zeromq.org/4-2:zmq-setsockopt#toc24
socket.close(linger=0) socket.close(linger=0)
async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
...@@ -81,10 +185,9 @@ class AsyncEngineRPCClient: ...@@ -81,10 +185,9 @@ class AsyncEngineRPCClient:
error_message: str) -> Any: error_message: str) -> Any:
"""Send an RPC request that is expecting data back.""" """Send an RPC request that is expecting data back."""
with self.socket() as socket: with self.to_proxy_socket() as socket:
# Ping RPCServer with a request. # Ping RPCServer with a request.
await socket.send(cloudpickle.dumps(request)) await socket.send_multipart([cloudpickle.dumps(request)])
# Await the data from the Server. # Await the data from the Server.
data = cloudpickle.loads(await socket.recv()) data = cloudpickle.loads(await socket.recv())
...@@ -93,31 +196,48 @@ class AsyncEngineRPCClient: ...@@ -93,31 +196,48 @@ class AsyncEngineRPCClient:
# LoRAConfig can be None. # LoRAConfig can be None.
if expected_type == LoRAConfig and data is None: if expected_type == LoRAConfig and data is None:
pass pass
elif isinstance(data, Exception):
logger.error(error_message)
raise data
else: else:
raise ValueError(error_message) raise ValueError(error_message)
return data return data
async def _send_one_way_rpc_request(self, async def _send_one_way_rpc_request(
request: RPC_REQUEST_TYPE, self,
error_message: str, request: RPC_REQUEST_TYPE,
timeout: Optional[int] = None): error_message: str,
timeout: Optional[int] = None,
socket: Optional[zmq.asyncio.Socket] = None):
"""Send one-way RPC request to trigger an action.""" """Send one-way RPC request to trigger an action."""
with self.socket() as socket:
# Ping RPC Server with request.
await socket.send(cloudpickle.dumps(request))
# Await acknowledgement from RPCServer. async def do_rpc_call(socket: zmq.asyncio.Socket,
request: RPC_REQUEST_TYPE,
timeout=None):
await socket.send_multipart([cloudpickle.dumps(request)])
if timeout is not None and await socket.poll(timeout=timeout) == 0: if timeout is not None and await socket.poll(timeout=timeout) == 0:
raise TimeoutError(f"server didn't reply within {timeout} ms") raise TimeoutError(f"Server didn't reply within {timeout} ms")
return cloudpickle.loads(await socket.recv())
response = cloudpickle.loads(await socket.recv()) # Make a new socket connection.
if socket is None:
with self.to_proxy_socket() as socket:
response = await do_rpc_call(socket, request, timeout)
# Use existing socket connection.
else:
response = await do_rpc_call(socket, request, timeout)
if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
if isinstance(response, Exception):
logger.error(error_message)
raise response
raise ValueError(error_message) raise ValueError(error_message)
return response
async def get_tokenizer(self, lora_request: LoRARequest): async def get_tokenizer(self, lora_request: LoRARequest):
return await self.tokenizer.get_lora_tokenizer_async(lora_request) return await self.tokenizer.get_lora_tokenizer_async(lora_request)
...@@ -130,13 +250,13 @@ class AsyncEngineRPCClient: ...@@ -130,13 +250,13 @@ class AsyncEngineRPCClient:
async def is_tracing_enabled(self) -> bool: async def is_tracing_enabled(self) -> bool:
return self.tracing_flag return self.tracing_flag
async def wait_for_server(self): async def _wait_for_server_rpc(self):
"""Wait for the RPCServer to start up.""" """Wait for the RPCServer to start up."""
await self._send_one_way_rpc_request( await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_READY, request=RPCUtilityRequest.IS_SERVER_READY,
error_message="Unable to start RPC Server.", error_message="Unable to start RPC Server",
timeout=SERVER_START_TIMEOUT_MS) timeout=VLLM_RPC_SERVER_START_TIMEOUT_MS)
async def _get_model_config_rpc(self) -> ModelConfig: async def _get_model_config_rpc(self) -> ModelConfig:
"""Get the ModelConfig object from the RPC Server""" """Get the ModelConfig object from the RPC Server"""
...@@ -184,8 +304,7 @@ class AsyncEngineRPCClient: ...@@ -184,8 +304,7 @@ class AsyncEngineRPCClient:
return await self._send_get_data_rpc_request( return await self._send_get_data_rpc_request(
RPCUtilityRequest.IS_TRACING_ENABLED, RPCUtilityRequest.IS_TRACING_ENABLED,
expected_type=bool, expected_type=bool,
error_message="Could not get is_tracing_enabled flag from RPC " error_message="Could not get is_tracing_enabled from RPC Server")
"Server")
async def abort(self, request_id: str): async def abort(self, request_id: str):
"""Send an ABORT_REQUEST signal to the RPC Server""" """Send an ABORT_REQUEST signal to the RPC Server"""
...@@ -226,8 +345,7 @@ class AsyncEngineRPCClient: ...@@ -226,8 +345,7 @@ class AsyncEngineRPCClient:
finished = False finished = False
try: try:
with self.socket() as socket: with self.to_proxy_socket() as socket:
# Send RPCGenerateRequest to the RPCServer. # Send RPCGenerateRequest to the RPCServer.
await socket.send_multipart([ await socket.send_multipart([
cloudpickle.dumps( cloudpickle.dumps(
...@@ -246,43 +364,37 @@ class AsyncEngineRPCClient: ...@@ -246,43 +364,37 @@ class AsyncEngineRPCClient:
request_output = cloudpickle.loads(message) request_output = cloudpickle.loads(message)
if isinstance(request_output, Exception): if isinstance(request_output, Exception):
# On exception, check if the server is still healthy. # On exception, check if the server is still healthy
# Use this to set the sync `is_running` and `errored` # possibly setting the `errored` property.
# properties. if not self._errored:
try: try:
await self.check_health() await self.check_health(socket=socket)
except Exception: except Exception as e:
self._errored = True self._errored = True
logger.exception(repr(e))
# NB: do before raising here so that the flag is set # NB: do before raising here so that the flag is set
# by the time the caller receives this exception # by the time the caller receives this exception
raise request_output raise request_output
finished = request_output.finished finished = request_output.finished
yield request_output yield request_output
finally: finally:
if not finished: # Request was canceled by the client.
if not finished and not self._errored:
await self.abort(request_id) await self.abort(request_id)
async def check_health(self) -> None: async def check_health(self,
socket: Optional[zmq.asyncio.Socket] = None
) -> None:
"""Raise if unhealthy""" """Raise if unhealthy"""
with self.socket() as socket: await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_HEALTHY,
# Ping RPCServer with CHECK_HEALTH request. error_message="Got Unhealthy response from RPC Server",
await socket.send(cloudpickle.dumps(RPCUtilityRequest.CHECK_HEALTH) timeout=VLLM_RPC_HEALTH_TIMEOUT_MS,
) socket=socket)
# Await the reply from the server.
# TODO: do we need an internal timeout here?
# Or do we expect the external probe to timeout and let this chill?
health_message = cloudpickle.loads(await socket.recv())
if isinstance(health_message, Exception):
raise health_message
if health_message != VLLM_RPC_HEALTHY_STR:
raise ValueError("Expected healthy response from backend but got "
"f{health_message}")
async def encode(self, *args, async def encode(self, *args,
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
......
import asyncio import asyncio
import signal import signal
from typing import Any, Coroutine from typing import Any, Coroutine, Union
import cloudpickle import cloudpickle
import uvloop import uvloop
...@@ -9,14 +9,19 @@ import zmq.asyncio ...@@ -9,14 +9,19 @@ import zmq.asyncio
from typing_extensions import Never from typing_extensions import Never
from vllm import AsyncEngineArgs, AsyncLLMEngine from vllm import AsyncEngineArgs, AsyncLLMEngine
from vllm.entrypoints.openai.rpc import (VLLM_RPC_HEALTHY_STR, from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest, ParallelConfig, SchedulerConfig)
from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR,
VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest) RPCGenerateRequest, RPCUtilityRequest)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
logger = init_logger(__name__) logger = init_logger(__name__)
CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
SchedulerConfig, LoRAConfig]
class AsyncEngineRPCServer: class AsyncEngineRPCServer:
...@@ -29,9 +34,10 @@ class AsyncEngineRPCServer: ...@@ -29,9 +34,10 @@ class AsyncEngineRPCServer:
# Initialize context. # Initialize context.
self.context = zmq.asyncio.Context() self.context = zmq.asyncio.Context()
# Init socket for readiness state. # Init socket.
self.socket = self.context.socket(zmq.constants.ROUTER) self.socket = self.context.socket(zmq.constants.DEALER)
self.socket.bind(rpc_path) self.socket.set_hwm(VLLM_RPC_ZMQ_HWM)
self.socket.connect(rpc_path)
def cleanup(self): def cleanup(self):
"""Cleanup all resources.""" """Cleanup all resources."""
...@@ -41,39 +47,27 @@ class AsyncEngineRPCServer: ...@@ -41,39 +47,27 @@ class AsyncEngineRPCServer:
# Clear the engine reference so that it can be GC'ed. # Clear the engine reference so that it can be GC'ed.
del self.engine del self.engine
async def get_model_config(self, identity): async def get_config(self, identity, request):
"""Send the ModelConfig""" try:
model_config = await self.engine.get_model_config() config: CONFIG_TYPE
if request == RPCUtilityRequest.GET_MODEL_CONFIG:
await self.socket.send_multipart( config = await self.engine.get_model_config()
[identity, cloudpickle.dumps(model_config)]) elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
config = await self.engine.get_decoding_config()
async def get_decoding_config(self, identity): elif request == RPCUtilityRequest.GET_LORA_CONFIG:
"""Send the DecodingConfig""" config = await self.engine.get_lora_config()
decoding_config = await self.engine.get_decoding_config() elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
config = await self.engine.get_scheduler_config()
await self.socket.send_multipart( elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
[identity, cloudpickle.dumps(decoding_config)]) config = await self.engine.get_parallel_config()
else:
async def get_lora_config(self, identity): raise ValueError("Unknown Config Request: %s", request)
lora_config = await self.engine.get_lora_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(lora_config)])
async def get_scheduler_config(self, identity):
"""Send the SchedulerConfig"""
parallel_config = await self.engine.get_scheduler_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(parallel_config)])
async def get_parallel_config(self, identity): await self.socket.send_multipart(
"""Send the ParallelConfig""" [identity, cloudpickle.dumps(config)])
parallel_config = await self.engine.get_parallel_config()
await self.socket.send_multipart( except Exception as e:
[identity, cloudpickle.dumps(parallel_config)]) await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
async def is_tracing_enabled(self, identity): async def is_tracing_enabled(self, identity):
"""Send the is_tracing_enabled flag""" """Send the is_tracing_enabled flag"""
...@@ -86,31 +80,23 @@ class AsyncEngineRPCServer: ...@@ -86,31 +80,23 @@ class AsyncEngineRPCServer:
"""Log stats and confirm success.""" """Log stats and confirm success."""
await self.engine.do_log_stats() await self.engine.do_log_stats()
await self.socket.send_multipart([ await self.socket.send_multipart(
identity, [identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
async def is_server_ready(self, identity): async def is_server_ready(self, identity):
"""Notify the client that we are ready.""" """Notify the client that we are ready."""
await self.socket.send_multipart([ await self.socket.send_multipart(
identity, [identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
async def abort(self, identity, request: RPCAbortRequest): async def abort(self, identity, request: RPCAbortRequest):
"""Abort request and notify the client of success.""" """Abort request and notify the client of success."""
try: try:
# Abort the request in the llm engine. # Abort the request in the llm engine.
await self.engine.abort(request.request_id) await self.engine.abort(request.request_id)
except Exception: result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR
logger.warning("Failed to abort request %s", request.request_id) except Exception as e:
finally: result = e
# Send confirmation to the client. await self.socket.send_multipart([identity, cloudpickle.dumps(result)])
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
async def generate(self, identity, generate_request: RPCGenerateRequest): async def generate(self, identity, generate_request: RPCGenerateRequest):
try: try:
...@@ -127,14 +113,14 @@ class AsyncEngineRPCServer: ...@@ -127,14 +113,14 @@ class AsyncEngineRPCServer:
[identity, cloudpickle.dumps(request_output)]) [identity, cloudpickle.dumps(request_output)])
except Exception as e: except Exception as e:
### Notify client of all failures
await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
async def check_health(self, identity): async def check_health(self, identity):
try: try:
await self.engine.check_health() await self.engine.check_health()
await self.socket.send_multipart( await self.socket.send_multipart(
[identity, cloudpickle.dumps(VLLM_RPC_HEALTHY_STR)]) [identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])
except Exception as e: except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
...@@ -151,21 +137,19 @@ class AsyncEngineRPCServer: ...@@ -151,21 +137,19 @@ class AsyncEngineRPCServer:
return self.abort(identity, request) return self.abort(identity, request)
elif isinstance(request, RPCUtilityRequest): elif isinstance(request, RPCUtilityRequest):
if request == RPCUtilityRequest.GET_MODEL_CONFIG: if request in [
return self.get_model_config(identity) RPCUtilityRequest.GET_MODEL_CONFIG,
elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: RPCUtilityRequest.GET_PARALLEL_CONFIG,
return self.get_parallel_config(identity) RPCUtilityRequest.GET_DECODING_CONFIG,
elif request == RPCUtilityRequest.GET_DECODING_CONFIG: RPCUtilityRequest.GET_SCHEDULER_CONFIG,
return self.get_decoding_config(identity) RPCUtilityRequest.GET_LORA_CONFIG
elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: ]:
return self.get_scheduler_config(identity) return self.get_config(identity, request)
elif request == RPCUtilityRequest.GET_LORA_CONFIG:
return self.get_lora_config(identity)
elif request == RPCUtilityRequest.DO_LOG_STATS: elif request == RPCUtilityRequest.DO_LOG_STATS:
return self.do_log_stats(identity) return self.do_log_stats(identity)
elif request == RPCUtilityRequest.IS_SERVER_READY: elif request == RPCUtilityRequest.IS_SERVER_READY:
return self.is_server_ready(identity) return self.is_server_ready(identity)
elif request == RPCUtilityRequest.CHECK_HEALTH: elif request == RPCUtilityRequest.IS_SERVER_HEALTHY:
return self.check_health(identity) return self.check_health(identity)
elif request == RPCUtilityRequest.IS_TRACING_ENABLED: elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
return self.is_tracing_enabled(identity) return self.is_tracing_enabled(identity)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment