Unverified Commit baece8c3 authored by yyweiss's avatar yyweiss Committed by GitHub
Browse files

[Frontend] Add unix domain socket support (#18097)



Signed-off-by: <yyweiss@gmail.com>
Signed-off-by: default avataryyw <yyweiss@gmail.com>
parent 2fcf6b27
...@@ -29,6 +29,9 @@ Start the vLLM OpenAI Compatible API server. ...@@ -29,6 +29,9 @@ Start the vLLM OpenAI Compatible API server.
# Specify the port # Specify the port
vllm serve meta-llama/Llama-2-7b-hf --port 8100 vllm serve meta-llama/Llama-2-7b-hf --port 8100
# Serve over a Unix domain socket
vllm serve meta-llama/Llama-2-7b-hf --uds /tmp/vllm.sock
# Check with --help for more options # Check with --help for more options
# To list all groups # To list all groups
vllm serve --help=listgroup vllm serve --help=listgroup
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from tempfile import TemporaryDirectory
import httpx
import pytest
from vllm.version import __version__ as VLLM_VERSION
from ...utils import RemoteOpenAIServer
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@pytest.fixture(scope="module")
def server():
with TemporaryDirectory() as tmpdir:
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"8192",
"--enforce-eager",
"--max-num-seqs",
"128",
"--uds",
f"{tmpdir}/vllm.sock",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.mark.asyncio
async def test_show_version(server: RemoteOpenAIServer):
transport = httpx.HTTPTransport(uds=server.uds)
client = httpx.Client(transport=transport)
response = client.get(server.url_for("version"))
response.raise_for_status()
assert response.json() == {"version": VLLM_VERSION}
...@@ -17,6 +17,7 @@ from pathlib import Path ...@@ -17,6 +17,7 @@ from pathlib import Path
from typing import Any, Callable, Literal, Optional, Union from typing import Any, Callable, Literal, Optional, Union
import cloudpickle import cloudpickle
import httpx
import openai import openai
import pytest import pytest
import requests import requests
...@@ -88,10 +89,12 @@ class RemoteOpenAIServer: ...@@ -88,10 +89,12 @@ class RemoteOpenAIServer:
raise ValueError("You have manually specified the port " raise ValueError("You have manually specified the port "
"when `auto_port=True`.") "when `auto_port=True`.")
# Don't mutate the input args # No need for a port if using unix sockets
vllm_serve_args = vllm_serve_args + [ if "--uds" not in vllm_serve_args:
"--port", str(get_open_port()) # Don't mutate the input args
] vllm_serve_args = vllm_serve_args + [
"--port", str(get_open_port())
]
if seed is not None: if seed is not None:
if "--seed" in vllm_serve_args: if "--seed" in vllm_serve_args:
raise ValueError("You have manually specified the seed " raise ValueError("You have manually specified the seed "
...@@ -104,8 +107,13 @@ class RemoteOpenAIServer: ...@@ -104,8 +107,13 @@ class RemoteOpenAIServer:
subparsers = parser.add_subparsers(required=False, dest="subparser") subparsers = parser.add_subparsers(required=False, dest="subparser")
parser = ServeSubcommand().subparser_init(subparsers) parser = ServeSubcommand().subparser_init(subparsers)
args = parser.parse_args(["--model", model, *vllm_serve_args]) args = parser.parse_args(["--model", model, *vllm_serve_args])
self.host = str(args.host or 'localhost') self.uds = args.uds
self.port = int(args.port) if args.uds:
self.host = None
self.port = None
else:
self.host = str(args.host or 'localhost')
self.port = int(args.port)
self.show_hidden_metrics = \ self.show_hidden_metrics = \
args.show_hidden_metrics_for_version is not None args.show_hidden_metrics_for_version is not None
...@@ -150,9 +158,11 @@ class RemoteOpenAIServer: ...@@ -150,9 +158,11 @@ class RemoteOpenAIServer:
def _wait_for_server(self, *, url: str, timeout: float): def _wait_for_server(self, *, url: str, timeout: float):
# run health check # run health check
start = time.time() start = time.time()
client = (httpx.Client(transport=httpx.HTTPTransport(
uds=self.uds)) if self.uds else requests)
while True: while True:
try: try:
if requests.get(url).status_code == 200: if client.get(url).status_code == 200:
break break
except Exception: except Exception:
# this exception can only be raised by requests.get, # this exception can only be raised by requests.get,
...@@ -170,7 +180,8 @@ class RemoteOpenAIServer: ...@@ -170,7 +180,8 @@ class RemoteOpenAIServer:
@property @property
def url_root(self) -> str: def url_root(self) -> str:
return f"http://{self.host}:{self.port}" return (f"http://{self.uds.split('/')[-1]}"
if self.uds else f"http://{self.host}:{self.port}")
def url_for(self, *parts: str) -> str: def url_for(self, *parts: str) -> str:
return self.url_root + "/" + "/".join(parts) return self.url_root + "/" + "/".join(parts)
......
...@@ -1777,6 +1777,12 @@ def create_server_socket(addr: tuple[str, int]) -> socket.socket: ...@@ -1777,6 +1777,12 @@ def create_server_socket(addr: tuple[str, int]) -> socket.socket:
return sock return sock
def create_server_unix_socket(path: str) -> socket.socket:
sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
sock.bind(path)
return sock
def validate_api_server_args(args): def validate_api_server_args(args):
valid_tool_parses = ToolParserManager.tool_parsers.keys() valid_tool_parses = ToolParserManager.tool_parsers.keys()
if args.enable_auto_tool_choice \ if args.enable_auto_tool_choice \
...@@ -1807,8 +1813,11 @@ def setup_server(args): ...@@ -1807,8 +1813,11 @@ def setup_server(args):
# workaround to make sure that we bind the port before the engine is set up. # workaround to make sure that we bind the port before the engine is set up.
# This avoids race conditions with ray. # This avoids race conditions with ray.
# see https://github.com/vllm-project/vllm/issues/8204 # see https://github.com/vllm-project/vllm/issues/8204
sock_addr = (args.host or "", args.port) if args.uds:
sock = create_server_socket(sock_addr) sock = create_server_unix_socket(args.uds)
else:
sock_addr = (args.host or "", args.port)
sock = create_server_socket(sock_addr)
# workaround to avoid footguns where uvicorn drops requests with too # workaround to avoid footguns where uvicorn drops requests with too
# many concurrent requests active # many concurrent requests active
...@@ -1820,12 +1829,14 @@ def setup_server(args): ...@@ -1820,12 +1829,14 @@ def setup_server(args):
signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGTERM, signal_handler)
addr, port = sock_addr if args.uds:
is_ssl = args.ssl_keyfile and args.ssl_certfile listen_address = f"unix:{args.uds}"
host_part = f"[{addr}]" if is_valid_ipv6_address( else:
addr) else addr or "0.0.0.0" addr, port = sock_addr
listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}" is_ssl = args.ssl_keyfile and args.ssl_certfile
host_part = f"[{addr}]" if is_valid_ipv6_address(
addr) else addr or "0.0.0.0"
listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
return listen_address, sock return listen_address, sock
......
...@@ -72,6 +72,8 @@ class FrontendArgs: ...@@ -72,6 +72,8 @@ class FrontendArgs:
"""Host name.""" """Host name."""
port: int = 8000 port: int = 8000
"""Port number.""" """Port number."""
uds: Optional[str] = None
"""Unix domain socket path. If set, host and port arguments are ignored."""
uvicorn_log_level: Literal["debug", "info", "warning", "error", "critical", uvicorn_log_level: Literal["debug", "info", "warning", "error", "critical",
"trace"] = "info" "trace"] = "info"
"""Log level for uvicorn.""" """Log level for uvicorn."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment