Unverified Commit 6e95f5e5 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Simplify `Router` arguments passing and build it in docker image (#9964)

parent 0e9387a9
......@@ -36,7 +36,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
ibverbs-providers infiniband-diags perftest \
libgoogle-glog-dev libgtest-dev libjsoncpp-dev libunwind-dev \
libboost-all-dev libssl-dev \
libgrpc-dev libgrpc++-dev libprotobuf-dev protobuf-compiler-grpc \
libgrpc-dev libgrpc++-dev libprotobuf-dev protobuf-compiler protobuf-compiler-grpc \
pybind11-dev \
libhiredis-dev libcurl4-openssl-dev \
libczmq4 libczmq-dev \
......@@ -218,6 +218,19 @@ RUN wget https://github.com/Kitware/CMake/releases/download/v3.31.1/cmake-3.31.1
&& cp -r cmake-3.31.1-linux-x86_64/share/* /usr/local/share/ \
&& rm -rf cmake-3.31.1-linux-x86_64 cmake-3.31.1-linux-x86_64.tar.gz
# Install Rust toolchain for sgl-router
ENV PATH="/root/.cargo/bin:${PATH}"
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y \
&& rustc --version && cargo --version
# Build and install sgl-router
RUN python3 -m pip install --no-cache-dir setuptools-rust \
&& cd /sgl-workspace/sglang/sgl-router \
&& cargo build --release \
&& python3 -m pip install --no-cache-dir . \
&& rm -rf /root/.cache
# Add yank script
COPY --chown=root:root <<-"EOF" /usr/local/bin/yank
#!/bin/bash
......
......@@ -36,7 +36,7 @@ uv pip install mooncake-transfer-engine
```bash
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-ib-device mlx5_roce0
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-ib-device mlx5_roce0
$ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
$ python -m sglang_router.launch_router --pd-disaggregation --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
```
### DeepSeek Multi-Node
......@@ -100,7 +100,7 @@ pip install . --config-settings=setup-args="-Ducx_path=/path/to/ucx"
```bash
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-transfer-backend nixl
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-transfer-backend nixl
$ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
$ python -m sglang_router.launch_router --pd-disaggregation --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
```
### DeepSeek Multi-Node
......@@ -137,7 +137,7 @@ export ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE=true
```bash
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-transfer-backend ascend
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-transfer-backend ascend
$ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
$ python -m sglang_router.launch_router --pd-disaggregation --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
```
### DeepSeek Multi-Node
......
......@@ -278,7 +278,7 @@ The most sophisticated policy that combines cache optimization with load balanci
3. **Cache Management**:
- Maintains approximate radix trees per worker
- Periodically evicts LRU entries based on `--eviction-interval` and `--max-tree-size`
- Periodically evicts LRU entries based on `--eviction-interval-secs` and `--max-tree-size`
### Data Parallelism Aware Routing
......@@ -296,7 +296,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
### Core Settings
| Parameter | Type | Default | Description |
|-----------------------------|------|-------------|-----------------------------------------------------------------|
| --------------------------- | ---- | ----------- | --------------------------------------------------------------- |
| `--host` | str | 127.0.0.1 | Router server host address |
| `--port` | int | 30000 | Router server port |
| `--worker-urls` | list | [] | Worker URLs for separate launch mode |
......@@ -307,18 +307,18 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
### Cache-Aware Routing Parameters
| Parameter | Type | Default | Description |
|---------------------------|-------|----------|--------------------------------------------------------|
| `--cache-threshold` | float | 0.5 | Minimum prefix match ratio for cache routing (0.0-1.0) |
| `--balance-abs-threshold` | int | 32 | Absolute load difference threshold |
| `--balance-rel-threshold` | float | 1.0001 | Relative load ratio threshold |
| `--eviction-interval` | int | 60 | Seconds between cache eviction cycles |
| `--max-tree-size` | int | 16777216 | Maximum nodes in routing tree |
| Parameter | Type | Default | Description |
| -------------------------- | ----- | -------- | ------------------------------------------------------ |
| `--cache-threshold` | float | 0.5 | Minimum prefix match ratio for cache routing (0.0-1.0) |
| `--balance-abs-threshold` | int | 32 | Absolute load difference threshold |
| `--balance-rel-threshold` | float | 1.0001 | Relative load ratio threshold |
| `--eviction-interval-secs` | int | 60 | Seconds between cache eviction cycles |
| `--max-tree-size` | int | 16777216 | Maximum nodes in routing tree |
### Fault Tolerance Parameters
| Parameter | Type | Default | Description |
|------------------------------|-------|---------|---------------------------------------|
| ---------------------------- | ----- | ------- | ------------------------------------- |
| `--retry-max-retries` | int | 3 | Maximum retry attempts per request |
| `--retry-initial-backoff-ms` | int | 100 | Initial retry backoff in milliseconds |
| `--retry-max-backoff-ms` | int | 10000 | Maximum retry backoff in milliseconds |
......@@ -334,7 +334,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
### Prefill-Decode Disaggregation Parameters
| Parameter | Type | Default | Description |
|-----------------------------------|------|---------|-------------------------------------------------------|
| --------------------------------- | ---- | ------- | ----------------------------------------------------- |
| `--pd-disaggregation` | flag | False | Enable PD disaggregated mode |
| `--prefill` | list | [] | Prefill server URLs with optional bootstrap ports |
| `--decode` | list | [] | Decode server URLs |
......@@ -346,7 +346,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
### Kubernetes Integration
| Parameter | Type | Default | Description |
|---------------------------------|------|--------------------------|------------------------------------------------------|
| ------------------------------- | ---- | ------------------------ | ---------------------------------------------------- |
| `--service-discovery` | flag | False | Enable Kubernetes service discovery |
| `--selector` | list | [] | Label selector for workers (key1=value1 key2=value2) |
| `--prefill-selector` | list | [] | Label selector for prefill servers in PD mode |
......@@ -358,7 +358,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
### Observability
| Parameter | Type | Default | Description |
|------------------------|------|-----------|-------------------------------------------------------|
| ---------------------- | ---- | --------- | ----------------------------------------------------- |
| `--prometheus-port` | int | 29000 | Prometheus metrics port |
| `--prometheus-host` | str | 127.0.0.1 | Prometheus metrics host |
| `--log-dir` | str | None | Directory for log files |
......@@ -368,7 +368,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
### CORS Configuration
| Parameter | Type | Default | Description |
|--------------------------|------|---------|----------------------|
| ------------------------ | ---- | ------- | -------------------- |
| `--cors-allowed-origins` | list | [] | Allowed CORS origins |
## Advanced Features
......@@ -429,7 +429,7 @@ python -m sglang_router.launch_router \
2. **High latency**: Check if cache-aware routing is causing imbalance. Try adjusting `--balance-abs-threshold` and `--balance-rel-threshold`.
3. **Memory growth**: Reduce `--max-tree-size` or decrease `--eviction-interval` for more aggressive cache cleanup.
3. **Memory growth**: Reduce `--max-tree-size` or decrease `--eviction-interval-secs` for more aggressive cache cleanup.
4. **Circuit breaker triggering frequently**: Increase `--cb-failure-threshold` or extend `--cb-window-duration-secs`.
......
......@@ -27,7 +27,8 @@ spec:
command:
- python
- -m
- sglang.srt.disaggregation.mini_lb
- sglang_router.launch_router
- --pd-disaggregation
- --prefill
- http://deepseekr10528-prefill-main:30000
- --decode
......
......@@ -714,7 +714,8 @@ spec:
command:
- python
- -m
- sglang.srt.disaggregation.mini_lb
- sglang_router.launch_router
- --pd-disaggregation
- --prefill
- http://deepseekr10528-prefill-main:30000
- --decode
......
import argparse
import dataclasses
from sglang.srt.disaggregation.mini_lb import PrefillConfig, run
@dataclasses.dataclass
class LBArgs:
host: str = "0.0.0.0"
port: int = 8000
policy: str = "random"
prefill_infos: list = dataclasses.field(default_factory=list)
decode_infos: list = dataclasses.field(default_factory=list)
log_interval: int = 5
timeout: int = 600
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--host",
type=str,
default=LBArgs.host,
help=f"Host to bind the server (default: {LBArgs.host})",
)
parser.add_argument(
"--port",
type=int,
default=LBArgs.port,
help=f"Port to bind the server (default: {LBArgs.port})",
)
parser.add_argument(
"--policy",
type=str,
default=LBArgs.policy,
choices=["random", "po2"],
help=f"Policy to use for load balancing (default: {LBArgs.policy})",
)
parser.add_argument(
"--prefill",
type=str,
default=[],
nargs="+",
help="URLs for prefill servers",
)
parser.add_argument(
"--decode",
type=str,
default=[],
nargs="+",
help="URLs for decode servers",
)
parser.add_argument(
"--prefill-bootstrap-ports",
type=int,
nargs="+",
help="Bootstrap ports for prefill servers",
)
parser.add_argument(
"--log-interval",
type=int,
default=LBArgs.log_interval,
help=f"Log interval in seconds (default: {LBArgs.log_interval})",
)
parser.add_argument(
"--timeout",
type=int,
default=LBArgs.timeout,
help=f"Timeout in seconds (default: {LBArgs.timeout})",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "LBArgs":
bootstrap_ports = args.prefill_bootstrap_ports
if bootstrap_ports is None:
bootstrap_ports = [None] * len(args.prefill)
elif len(bootstrap_ports) == 1:
bootstrap_ports = bootstrap_ports * len(args.prefill)
else:
if len(bootstrap_ports) != len(args.prefill):
raise ValueError(
"Number of prefill URLs must match number of bootstrap ports"
)
prefill_infos = [
(url, port) for url, port in zip(args.prefill, bootstrap_ports)
]
return cls(
host=args.host,
port=args.port,
policy=args.policy,
prefill_infos=prefill_infos,
decode_infos=args.decode,
log_interval=args.log_interval,
timeout=args.timeout,
)
def main():
parser = argparse.ArgumentParser(
description="PD Disaggregation Load Balancer Server"
)
LBArgs.add_cli_args(parser)
args = parser.parse_args()
lb_args = LBArgs.from_cli_args(args)
prefill_configs = [PrefillConfig(url, port) for url, port in lb_args.prefill_infos]
run(
prefill_configs,
lb_args.decode_infos,
lb_args.host,
lb_args.port,
lb_args.timeout,
)
if __name__ == "__main__":
main()
"""
Minimal HTTP load balancer for prefill and decode servers for testing.
"""
import asyncio
import dataclasses
import logging
import random
import urllib
from http import HTTPStatus
from itertools import chain
from typing import List, Optional
import aiohttp
import orjson
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from sglang.srt.disaggregation.utils import PDRegistryRequest
from sglang.srt.utils import maybe_wrap_ipv6_address
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
1024 * 64
) # 64KB, to prevent aiohttp's "Chunk too big" error
def setup_logger():
logger = logging.getLogger("pdlb")
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
"[PDLB (Python)] %(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
logger = setup_logger()
@dataclasses.dataclass
class PrefillConfig:
url: str
bootstrap_port: Optional[int] = None
class MiniLoadBalancer:
def __init__(
self,
prefill_configs: List[PrefillConfig],
decode_servers: List[str],
timeout: int,
):
self.prefill_configs = prefill_configs
self.prefill_servers = [p.url for p in prefill_configs]
self.decode_servers = decode_servers
self.timeout = timeout
def add_prefill_server(self, new_prefill_config: PrefillConfig):
self.prefill_configs.append(new_prefill_config)
self.prefill_servers.append(new_prefill_config.url)
def add_decode_server(self, new_decode_server: str):
self.decode_servers.append(new_decode_server)
def select_pair(self):
# TODO: return some message instead of panic
assert len(self.prefill_configs) > 0, "No prefill servers available"
assert len(self.decode_servers) > 0, "No decode servers available"
prefill_config = random.choice(self.prefill_configs)
decode_server = random.choice(self.decode_servers)
return prefill_config.url, prefill_config.bootstrap_port, decode_server
async def generate(
self, modified_request, prefill_server, decode_server, endpoint
) -> ORJSONResponse:
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=self.timeout
) # Add timeout for request reliability
) as session:
tasks = [
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
session.post(f"{decode_server}/{endpoint}", json=modified_request),
]
# Wait for both responses to complete. Prefill should end first.
prefill_response, decode_response = await asyncio.gather(*tasks)
if "return_logprob" in modified_request:
prefill_json = await prefill_response.json()
ret_json = await decode_response.json()
# merge `meta_info.input_token_logprobs` from prefill to decode
if "meta_info" in ret_json:
if "input_token_logprobs" in ret_json["meta_info"]:
ret_json["meta_info"]["input_token_logprobs"] = (
prefill_json["meta_info"]["input_token_logprobs"]
+ ret_json["meta_info"]["input_token_logprobs"]
)
else:
ret_json = await decode_response.json()
return ORJSONResponse(
content=ret_json,
status_code=decode_response.status,
)
async def generate_stream(
self, modified_request, prefill_server, decode_server, endpoint="generate"
):
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
async def stream_results():
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=self.timeout
) # Add timeout for request reliability
) as session:
# Create the tasks for both prefill and decode requests
tasks = [
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
session.post(f"{decode_server}/{endpoint}", json=modified_request),
]
# Wait for both responses to complete. Since this is streaming, they return immediately.
prefill_response, decode_response = await asyncio.gather(*tasks)
if modified_request.get("return_logprob", False):
prefill_chunks = []
async for chunk in prefill_response.content:
prefill_chunks.append(chunk)
first_prefill_chunk = (
prefill_chunks[0].decode("utf-8")[5:].strip("\n")
)
first_prefill_chunk_json = orjson.loads(first_prefill_chunk)
async for chunk in decode_response.content:
# Note: This is inefficient
# merge prefill input_token_logprobs, output_token_logprobs to decode
decoded_chunk = chunk.decode("utf-8")
if (
decoded_chunk
and decoded_chunk.startswith("data:")
and "[DONE]" not in decoded_chunk
):
ret_json = orjson.loads(decoded_chunk[5:].strip("\n"))
ret_json["meta_info"]["input_token_logprobs"] = (
first_prefill_chunk_json["meta_info"][
"input_token_logprobs"
]
+ ret_json["meta_info"]["input_token_logprobs"]
)
yield b"data: " + orjson.dumps(ret_json) + b"\n\n"
else:
yield chunk
else:
async for chunk in decode_response.content.iter_chunked(
AIOHTTP_STREAM_READ_CHUNK_SIZE
):
yield chunk
return StreamingResponse(
stream_results(),
media_type="text/event-stream",
)
app = FastAPI()
load_balancer: Optional[MiniLoadBalancer] = None
@app.get("/health")
async def health_check():
return Response(status_code=200)
@app.get("/health_generate")
async def health_generate():
prefill_servers, decode_servers = (
load_balancer.prefill_servers,
load_balancer.decode_servers,
)
async with aiohttp.ClientSession() as session:
# Create the tasks
tasks = []
for server in chain(prefill_servers, decode_servers):
tasks.append(session.get(f"{server}/health_generate"))
for i, response in enumerate(asyncio.as_completed(tasks)):
await response
return Response(status_code=200)
@app.post("/flush_cache")
async def flush_cache():
prefill_servers, decode_servers = (
load_balancer.prefill_servers,
load_balancer.decode_servers,
)
async with aiohttp.ClientSession() as session:
# Create the tasks
tasks = []
for server in chain(prefill_servers, decode_servers):
tasks.append(session.post(f"{server}/flush_cache"))
for i, response in enumerate(asyncio.as_completed(tasks)):
await response
return Response(status_code=200)
@app.get("/get_server_info")
async def get_server_info():
prefill_servers, decode_servers = (
load_balancer.prefill_servers,
load_balancer.decode_servers,
)
prefill_infos = []
decode_infos = []
all_internal_states = []
async with aiohttp.ClientSession() as session:
for server in chain(prefill_servers):
server_info = await session.get(f"{server}/get_server_info")
prefill_infos.append(await server_info.json())
for server in chain(decode_servers):
server_info = await session.get(f"{server}/get_server_info")
info_json = await server_info.json()
decode_infos.append(info_json)
# Extract internal_states from decode servers
if "internal_states" in info_json:
all_internal_states.extend(info_json["internal_states"])
# Return format expected by bench_one_batch_server.py
if all_internal_states:
return {
"internal_states": all_internal_states,
"prefill": prefill_infos,
"decode": decode_infos,
}
else:
# Fallback with dummy data if no internal states found
return {
"internal_states": [
{
"last_gen_throughput": 0.0,
"avg_spec_accept_length": None,
}
],
"prefill": prefill_infos,
"decode": decode_infos,
}
@app.get("/get_model_info")
async def get_model_info():
global load_balancer
if not load_balancer or not load_balancer.prefill_servers:
raise HTTPException(
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
detail="There is no server registered",
)
target_server_url = load_balancer.prefill_servers[0]
endpoint_url = f"{target_server_url}/get_model_info"
async with aiohttp.ClientSession() as session:
try:
async with session.get(endpoint_url) as response:
if response.status != 200:
error_text = await response.text()
raise HTTPException(
status_code=HTTPStatus.BAD_GATEWAY,
detail=(
f"Failed to get model info from {target_server_url}"
f"Status: {response.status}, Response: {error_text}"
),
)
model_info_json = await response.json()
return ORJSONResponse(content=model_info_json)
except aiohttp.ClientError as e:
raise HTTPException(
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
detail=f"Failed to get model info from backend",
)
@app.post("/generate")
async def handle_generate_request(request_data: dict):
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
# Parse and transform prefill_server for bootstrap data
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
modified_request = request_data.copy()
batch_size = _get_request_batch_size(modified_request)
if batch_size is not None:
modified_request.update(
{
"bootstrap_host": [hostname] * batch_size,
"bootstrap_port": [bootstrap_port] * batch_size,
"bootstrap_room": [
_generate_bootstrap_room() for _ in range(batch_size)
],
}
)
else:
modified_request.update(
{
"bootstrap_host": hostname,
"bootstrap_port": bootstrap_port,
"bootstrap_room": _generate_bootstrap_room(),
}
)
if request_data.get("stream", False):
return await load_balancer.generate_stream(
modified_request, prefill_server, decode_server, "generate"
)
else:
return await load_balancer.generate(
modified_request, prefill_server, decode_server, "generate"
)
async def _forward_to_backend(request_data: dict, endpoint_name: str):
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
# Parse and transform prefill_server for bootstrap data
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
modified_request = request_data.copy()
modified_request.update(
{
"bootstrap_host": hostname,
"bootstrap_port": bootstrap_port,
"bootstrap_room": _generate_bootstrap_room(),
}
)
if request_data.get("stream", False):
return await load_balancer.generate_stream(
modified_request,
prefill_server,
decode_server,
endpoint=endpoint_name,
)
else:
return await load_balancer.generate(
modified_request,
prefill_server,
decode_server,
endpoint=endpoint_name,
)
@app.post("/v1/chat/completions")
async def handle_chat_completion_request(request_data: dict):
return await _forward_to_backend(request_data, "v1/chat/completions")
@app.post("/v1/completions")
async def handle_completion_request(request_data: dict):
return await _forward_to_backend(request_data, "v1/completions")
def _generate_bootstrap_room():
return random.randint(0, 2**63 - 1)
# We may utilize `GenerateReqInput`'s logic later
def _get_request_batch_size(request):
if (text := request.get("text")) is not None:
return None if isinstance(text, str) else len(text)
if (input_ids := request.get("input_ids")) is not None:
return None if isinstance(input_ids[0], int) else len(input_ids)
return None
@app.get("/v1/models")
async def get_models():
prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server
async with aiohttp.ClientSession() as session:
try:
response = await session.get(f"{prefill_server}/v1/models")
if response.status != 200:
raise HTTPException(
status_code=response.status,
detail=f"Prefill server error: Status {response.status}",
)
return ORJSONResponse(content=await response.json())
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/register")
async def register(obj: PDRegistryRequest):
if obj.mode == "prefill":
load_balancer.add_prefill_server(
PrefillConfig(obj.registry_url, obj.bootstrap_port)
)
logger.info(
f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}"
)
elif obj.mode == "decode":
load_balancer.add_decode_server(obj.registry_url)
logger.info(f"Registered decode server: {obj.registry_url}")
else:
raise HTTPException(
status_code=400,
detail="Invalid mode. Must be either PREFILL or DECODE.",
)
logger.info(
f"#Prefill servers: {len(load_balancer.prefill_configs)}, "
f"#Decode servers: {len(load_balancer.decode_servers)}"
)
return Response(status_code=200)
def run(prefill_configs, decode_addrs, host, port, timeout):
global load_balancer
load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs, timeout=timeout)
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
# FIXME: remove this, use the unified entry point: sglang.srt.disaggregation.launch_lb
from sglang.srt.disaggregation.launch_lb import main
main()
raise RuntimeError(
"""The 'mini_lb' module has been relocated to the 'sglang_router' package.
We recommend installing 'sglang-router' with Rust support for optimal performance.
If you encounter issues building the router with Rust, set the environment variable
'SGLANG_ROUTER_BUILD_NO_RUST=1' and add '--mini-lb' to the command line to use the Python version of 'mini_lb'."""
)
from __future__ import annotations
import dataclasses
import os
import random
import threading
import warnings
from collections import deque
from contextlib import nullcontext
from enum import Enum
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional
import numpy as np
import requests
import torch
import torch.distributed as dist
from sglang.srt.utils import get_ip, is_npu
from sglang.srt.utils import is_npu
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
......@@ -305,49 +301,6 @@ def kv_to_page_num(num_kv_indices: int, page_size: int):
return (num_kv_indices + page_size - 1) // page_size
#########################
# PDLB Registry
#########################
@dataclasses.dataclass
class PDRegistryRequest:
"""A request to register a machine itself to the LB."""
mode: str
registry_url: str
bootstrap_port: Optional[int] = None
def __post_init__(self):
if self.mode == "prefill" and self.bootstrap_port is None:
raise ValueError("Bootstrap port must be set in PREFILL mode.")
elif self.mode == "decode" and self.bootstrap_port is not None:
raise ValueError("Bootstrap port must not be set in DECODE mode.")
elif self.mode not in ["prefill", "decode"]:
raise ValueError(
f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
)
def register_disaggregation_server(
mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
):
boostrap_port = bootstrap_port if mode == "prefill" else None
registry_request = PDRegistryRequest(
mode=mode,
registry_url=f"http://{get_ip()}:{server_port}",
bootstrap_port=boostrap_port,
)
res = requests.post(
f"{pdlb_url}/register",
json=dataclasses.asdict(registry_request),
)
if res.status_code != 200:
warnings.warn(
f"Failed to register disaggregation server: {res.status_code} {res.text}"
)
#########################
# Misc
#########################
......
......@@ -47,11 +47,7 @@ from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from sglang.srt.disaggregation.utils import (
FAKE_BOOTSTRAP_HOST,
DisaggregationMode,
register_disaggregation_server,
)
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode
from sglang.srt.entrypoints.engine import _launch_subprocesses
from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest,
......@@ -1405,13 +1401,5 @@ def _wait_and_warmup(
if server_args.debug_tensor_dump_input_file:
kill_process_tree(os.getpid())
if server_args.pdlb_url is not None:
register_disaggregation_server(
server_args.disaggregation_mode,
server_args.port,
server_args.disaggregation_bootstrap_port,
server_args.pdlb_url,
)
if launch_callback is not None:
launch_callback()
......@@ -367,7 +367,6 @@ class ServerArgs:
disaggregation_prefill_pp: Optional[int] = 1
disaggregation_ib_device: Optional[str] = None
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
pdlb_url: Optional[str] = None
# For model weight update
custom_weight_loader: Optional[List[str]] = None
......@@ -2071,12 +2070,6 @@ class ServerArgs:
default=ServerArgs.num_reserved_decode_tokens,
help="Number of decode tokens that will have memory reserved when adding new request to the running batch.",
)
parser.add_argument(
"--pdlb-url",
type=str,
default=None,
help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
)
# Custom weight loader
parser.add_argument(
......
......@@ -466,6 +466,25 @@ def try_cached_model(model_repo: str):
return model_dir if model_dir else model_repo
def popen_with_error_check(command: list[str], allow_exit: bool = False):
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
def _run_and_check():
stdout, stderr = process.communicate()
while process.poll() is None:
time.sleep(5)
if not allow_exit or process.returncode != 0:
raise Exception(
f"{command} exited with code {process.returncode}\n{stdout=}\n{stderr=}"
)
t = threading.Thread(target=_run_and_check)
t.start()
return process
def popen_launch_server(
model: str,
base_url: str,
......
......@@ -45,6 +45,10 @@ fi
# Install the main package
$PIP_CMD install -e "python[dev]" --extra-index-url https://download.pytorch.org/whl/${CU_VERSION} $PIP_INSTALL_SUFFIX
# Install router for pd-disagg test
SGLANG_ROUTER_BUILD_NO_RUST=1 $PIP_CMD install -e "sgl-router" $PIP_INSTALL_SUFFIX
if [ "$IS_BLACKWELL" = "1" ]; then
# TODO auto determine sgl-kernel version
SGL_KERNEL_VERSION=0.3.8
......
# a lightweihgt wrapper on router with argument type and comments
# no wrapper on policy type => direct export
from sglang_router.router import Router
from sglang_router.version import __version__
from sglang_router_rs import PolicyType
__all__ = ["Router", "PolicyType", "__version__"]
__all__ = ["__version__"]
import argparse
import dataclasses
import logging
import sys
from typing import Dict, List, Optional
from typing import List, Optional
from sglang_router import Router
from sglang_router_rs import PolicyType
import setproctitle
from sglang_router.mini_lb import MiniLoadBalancer
from sglang_router.router_args import RouterArgs
logger = logging.getLogger("router")
def setup_logger():
logger = logging.getLogger("router")
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
"[Router (Python)] %(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
try:
from sglang_router.router import Router
except ImportError:
Router = None
logger.warning(
"Rust Router is not installed, only python MiniLB (debugging only) is available"
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
@dataclasses.dataclass
class RouterArgs:
# Worker configuration
worker_urls: List[str] = dataclasses.field(default_factory=list)
host: str = "127.0.0.1"
port: int = 30000
# PD-specific configuration
pd_disaggregation: bool = False # Enable PD disaggregated mode
prefill_urls: List[tuple] = dataclasses.field(
default_factory=list
) # List of (url, bootstrap_port)
decode_urls: List[str] = dataclasses.field(default_factory=list)
# Routing policy
policy: str = "cache_aware"
prefill_policy: Optional[str] = None # Specific policy for prefill nodes in PD mode
decode_policy: Optional[str] = None # Specific policy for decode nodes in PD mode
worker_startup_timeout_secs: int = 600
worker_startup_check_interval: int = 30
cache_threshold: float = 0.3
balance_abs_threshold: int = 64
balance_rel_threshold: float = 1.5
eviction_interval: int = 120
max_tree_size: int = 2**26
max_payload_size: int = 512 * 1024 * 1024 # 512MB default for large batches
dp_aware: bool = False
api_key: Optional[str] = None
log_dir: Optional[str] = None
log_level: Optional[str] = None
# Service discovery configuration
service_discovery: bool = False
selector: Dict[str, str] = dataclasses.field(default_factory=dict)
service_discovery_port: int = 80
service_discovery_namespace: Optional[str] = None
# PD service discovery configuration
prefill_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
decode_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port"
# Prometheus configuration
prometheus_port: Optional[int] = None
prometheus_host: Optional[str] = None
# Request ID headers configuration
request_id_headers: Optional[List[str]] = None
# Request timeout in seconds
request_timeout_secs: int = 1800
# Max concurrent requests for rate limiting
max_concurrent_requests: int = 256
# Queue size for pending requests when max concurrent limit reached
queue_size: int = 100
# Maximum time (in seconds) a request can wait in queue before timing out
queue_timeout_secs: int = 60
# Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests
rate_limit_tokens_per_second: Optional[int] = None
# CORS allowed origins
cors_allowed_origins: List[str] = dataclasses.field(default_factory=list)
# Retry configuration
retry_max_retries: int = 5
retry_initial_backoff_ms: int = 50
retry_max_backoff_ms: int = 30_000
retry_backoff_multiplier: float = 1.5
retry_jitter_factor: float = 0.2
disable_retries: bool = False
# Health check configuration
health_failure_threshold: int = 3
health_success_threshold: int = 2
health_check_timeout_secs: int = 5
health_check_interval_secs: int = 60
health_check_endpoint: str = "/health"
# Circuit breaker configuration
cb_failure_threshold: int = 10
cb_success_threshold: int = 3
cb_timeout_duration_secs: int = 60
cb_window_duration_secs: int = 120
disable_circuit_breaker: bool = False
# Tokenizer configuration
model_path: Optional[str] = None
tokenizer_path: Optional[str] = None
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser,
use_router_prefix: bool = False,
exclude_host_port: bool = False,
):
"""
Add router-specific arguments to an argument parser.
Args:
parser: The argument parser to add arguments to
use_router_prefix: If True, prefix all arguments with 'router-' to avoid conflicts
exclude_host_port: If True, don't add host and port arguments (used when inheriting from server)
"""
prefix = "router-" if use_router_prefix else ""
# Worker configuration
if not exclude_host_port:
parser.add_argument(
"--host",
type=str,
default=RouterArgs.host,
help="Host address to bind the router server",
)
parser.add_argument(
"--port",
type=int,
default=RouterArgs.port,
help="Port number to bind the router server",
)
parser.add_argument(
"--worker-urls",
type=str,
nargs="*",
default=[],
help="List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)",
)
# Routing policy configuration
parser.add_argument(
f"--{prefix}policy",
type=str,
default=RouterArgs.policy,
choices=["random", "round_robin", "cache_aware", "power_of_two"],
help="Load balancing policy to use. In PD mode, this is used for both prefill and decode unless overridden",
)
parser.add_argument(
f"--{prefix}prefill-policy",
type=str,
default=None,
choices=["random", "round_robin", "cache_aware", "power_of_two"],
help="Specific policy for prefill nodes in PD mode. If not specified, uses the main policy",
)
parser.add_argument(
f"--{prefix}decode-policy",
type=str,
default=None,
choices=["random", "round_robin", "cache_aware", "power_of_two"],
help="Specific policy for decode nodes in PD mode. If not specified, uses the main policy",
)
# PD-specific arguments
parser.add_argument(
f"--{prefix}pd-disaggregation",
action="store_true",
help="Enable PD (Prefill-Decode) disaggregated mode",
)
parser.add_argument(
f"--{prefix}prefill",
nargs="+",
action="append",
help="Prefill server URL and optional bootstrap port. Can be specified multiple times. "
"Format: --prefill URL [BOOTSTRAP_PORT]. "
"BOOTSTRAP_PORT can be a port number, 'none', or omitted (defaults to none).",
)
parser.add_argument(
f"--{prefix}decode",
nargs=1,
action="append",
metavar=("URL",),
help="Decode server URL. Can be specified multiple times.",
)
parser.add_argument(
f"--{prefix}worker-startup-timeout-secs",
type=int,
default=RouterArgs.worker_startup_timeout_secs,
help="Timeout in seconds for worker startup",
)
parser.add_argument(
f"--{prefix}worker-startup-check-interval",
type=int,
default=RouterArgs.worker_startup_check_interval,
help="Interval in seconds between checks for worker startup",
)
parser.add_argument(
f"--{prefix}cache-threshold",
type=float,
default=RouterArgs.cache_threshold,
help="Cache threshold (0.0-1.0) for cache-aware routing",
)
parser.add_argument(
f"--{prefix}balance-abs-threshold",
type=int,
default=RouterArgs.balance_abs_threshold,
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
)
parser.add_argument(
f"--{prefix}balance-rel-threshold",
type=float,
default=RouterArgs.balance_rel_threshold,
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
)
parser.add_argument(
f"--{prefix}eviction-interval",
type=int,
default=RouterArgs.eviction_interval,
help="Interval in seconds between cache eviction operations",
)
parser.add_argument(
f"--{prefix}max-tree-size",
type=int,
default=RouterArgs.max_tree_size,
help="Maximum size of the approximation tree for cache-aware routing",
)
parser.add_argument(
f"--{prefix}max-payload-size",
type=int,
default=RouterArgs.max_payload_size,
help="Maximum payload size in bytes",
)
parser.add_argument(
f"--{prefix}dp-aware",
action="store_true",
help="Enable data parallelism aware schedule",
)
parser.add_argument(
f"--{prefix}api-key",
type=str,
default=None,
help="The api key used for the authorization with the worker. Useful when the dp aware scheduling strategy is enaled.",
)
parser.add_argument(
f"--{prefix}log-dir",
type=str,
default=None,
help="Directory to store log files. If not specified, logs are only output to console.",
)
parser.add_argument(
f"--{prefix}log-level",
type=str,
default="info",
choices=["debug", "info", "warning", "error", "critical"],
help="Set the logging level. If not specified, defaults to INFO.",
)
parser.add_argument(
f"--{prefix}service-discovery",
action="store_true",
help="Enable Kubernetes service discovery",
)
parser.add_argument(
f"--{prefix}selector",
type=str,
nargs="+",
help="Label selector for Kubernetes service discovery (format: key1=value1 key2=value2)",
)
parser.add_argument(
f"--{prefix}service-discovery-port",
type=int,
default=RouterArgs.service_discovery_port,
help="Port to use for discovered worker pods",
)
parser.add_argument(
f"--{prefix}service-discovery-namespace",
type=str,
help="Kubernetes namespace to watch for pods. If not provided, watches all namespaces (requires cluster-wide permissions)",
)
parser.add_argument(
f"--{prefix}prefill-selector",
type=str,
nargs="+",
help="Label selector for prefill server pods in PD mode (format: key1=value1 key2=value2)",
)
parser.add_argument(
f"--{prefix}decode-selector",
type=str,
nargs="+",
help="Label selector for decode server pods in PD mode (format: key1=value1 key2=value2)",
)
# Prometheus configuration
parser.add_argument(
f"--{prefix}prometheus-port",
type=int,
default=29000,
help="Port to expose Prometheus metrics. If not specified, Prometheus metrics are disabled",
)
parser.add_argument(
f"--{prefix}prometheus-host",
type=str,
default="127.0.0.1",
help="Host address to bind the Prometheus metrics server",
)
parser.add_argument(
f"--{prefix}request-id-headers",
type=str,
nargs="*",
help="Custom HTTP headers to check for request IDs (e.g., x-request-id x-trace-id). If not specified, uses common defaults.",
)
parser.add_argument(
f"--{prefix}request-timeout-secs",
type=int,
default=RouterArgs.request_timeout_secs,
help="Request timeout in seconds",
)
# Retry configuration
parser.add_argument(
f"--{prefix}retry-max-retries",
type=int,
default=RouterArgs.retry_max_retries,
)
parser.add_argument(
f"--{prefix}retry-initial-backoff-ms",
type=int,
default=RouterArgs.retry_initial_backoff_ms,
)
parser.add_argument(
f"--{prefix}retry-max-backoff-ms",
type=int,
default=RouterArgs.retry_max_backoff_ms,
)
parser.add_argument(
f"--{prefix}retry-backoff-multiplier",
type=float,
default=RouterArgs.retry_backoff_multiplier,
)
parser.add_argument(
f"--{prefix}retry-jitter-factor",
type=float,
default=RouterArgs.retry_jitter_factor,
)
parser.add_argument(
f"--{prefix}disable-retries",
action="store_true",
help="Disable retries (equivalent to setting retry_max_retries=1)",
)
# Circuit breaker configuration
parser.add_argument(
f"--{prefix}cb-failure-threshold",
type=int,
default=RouterArgs.cb_failure_threshold,
)
parser.add_argument(
f"--{prefix}cb-success-threshold",
type=int,
default=RouterArgs.cb_success_threshold,
)
parser.add_argument(
f"--{prefix}cb-timeout-duration-secs",
type=int,
default=RouterArgs.cb_timeout_duration_secs,
)
parser.add_argument(
f"--{prefix}cb-window-duration-secs",
type=int,
default=RouterArgs.cb_window_duration_secs,
)
parser.add_argument(
f"--{prefix}disable-circuit-breaker",
action="store_true",
help="Disable circuit breaker (equivalent to setting cb_failure_threshold to u32::MAX)",
)
# Health check configuration
parser.add_argument(
f"--{prefix}health-failure-threshold",
type=int,
default=RouterArgs.health_failure_threshold,
help="Number of consecutive health check failures before marking worker unhealthy",
)
parser.add_argument(
f"--{prefix}health-success-threshold",
type=int,
default=RouterArgs.health_success_threshold,
help="Number of consecutive health check successes before marking worker healthy",
)
parser.add_argument(
f"--{prefix}health-check-timeout-secs",
type=int,
default=RouterArgs.health_check_timeout_secs,
help="Timeout in seconds for health check requests",
)
parser.add_argument(
f"--{prefix}health-check-interval-secs",
type=int,
default=RouterArgs.health_check_interval_secs,
help="Interval in seconds between runtime health checks",
)
parser.add_argument(
f"--{prefix}health-check-endpoint",
type=str,
default=RouterArgs.health_check_endpoint,
help="Health check endpoint path",
)
parser.add_argument(
f"--{prefix}max-concurrent-requests",
type=int,
default=RouterArgs.max_concurrent_requests,
help="Maximum number of concurrent requests allowed (for rate limiting)",
)
parser.add_argument(
f"--{prefix}queue-size",
type=int,
default=RouterArgs.queue_size,
help="Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately)",
)
parser.add_argument(
f"--{prefix}queue-timeout-secs",
type=int,
default=RouterArgs.queue_timeout_secs,
help="Maximum time (in seconds) a request can wait in queue before timing out",
)
parser.add_argument(
f"--{prefix}rate-limit-tokens-per-second",
type=int,
default=RouterArgs.rate_limit_tokens_per_second,
help="Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests",
)
parser.add_argument(
f"--{prefix}cors-allowed-origins",
type=str,
nargs="*",
default=[],
help="CORS allowed origins (e.g., http://localhost:3000 https://example.com)",
)
# Tokenizer configuration
parser.add_argument(
f"--{prefix}model-path",
type=str,
default=None,
help="Model path for loading tokenizer (HuggingFace model ID or local path)",
)
parser.add_argument(
f"--{prefix}tokenizer-path",
type=str,
default=None,
help="Explicit tokenizer path (overrides model_path tokenizer if provided)",
)
@classmethod
def from_cli_args(
cls, args: argparse.Namespace, use_router_prefix: bool = False
) -> "RouterArgs":
"""
Create RouterArgs instance from parsed command line arguments.
Args:
args: Parsed command line arguments
use_router_prefix: If True, look for arguments with 'router-' prefix
"""
prefix = "router_" if use_router_prefix else ""
worker_urls = getattr(args, "worker_urls", [])
# Parse PD URLs
prefill_urls = cls._parse_prefill_urls(getattr(args, f"{prefix}prefill", None))
decode_urls = cls._parse_decode_urls(getattr(args, f"{prefix}decode", None))
return cls(
worker_urls=worker_urls,
host=args.host,
port=args.port,
pd_disaggregation=getattr(args, f"{prefix}pd_disaggregation", False),
prefill_urls=prefill_urls,
decode_urls=decode_urls,
policy=getattr(args, f"{prefix}policy"),
prefill_policy=getattr(args, f"{prefix}prefill_policy", None),
decode_policy=getattr(args, f"{prefix}decode_policy", None),
worker_startup_timeout_secs=getattr(
args, f"{prefix}worker_startup_timeout_secs"
),
worker_startup_check_interval=getattr(
args, f"{prefix}worker_startup_check_interval"
),
cache_threshold=getattr(args, f"{prefix}cache_threshold"),
balance_abs_threshold=getattr(args, f"{prefix}balance_abs_threshold"),
balance_rel_threshold=getattr(args, f"{prefix}balance_rel_threshold"),
eviction_interval=getattr(args, f"{prefix}eviction_interval"),
max_tree_size=getattr(args, f"{prefix}max_tree_size"),
max_payload_size=getattr(args, f"{prefix}max_payload_size"),
dp_aware=getattr(args, f"{prefix}dp_aware", False),
api_key=getattr(args, f"{prefix}api_key", None),
log_dir=getattr(args, f"{prefix}log_dir", None),
log_level=getattr(args, f"{prefix}log_level", None),
service_discovery=getattr(args, f"{prefix}service_discovery", False),
selector=cls._parse_selector(getattr(args, f"{prefix}selector", None)),
service_discovery_port=getattr(args, f"{prefix}service_discovery_port"),
service_discovery_namespace=getattr(
args, f"{prefix}service_discovery_namespace", None
),
prefill_selector=cls._parse_selector(
getattr(args, f"{prefix}prefill_selector", None)
),
decode_selector=cls._parse_selector(
getattr(args, f"{prefix}decode_selector", None)
),
bootstrap_port_annotation="sglang.ai/bootstrap-port", # Mooncake-specific annotation
prometheus_port=getattr(args, f"{prefix}prometheus_port", None),
prometheus_host=getattr(args, f"{prefix}prometheus_host", None),
request_id_headers=getattr(args, f"{prefix}request_id_headers", None),
request_timeout_secs=getattr(
args, f"{prefix}request_timeout_secs", RouterArgs.request_timeout_secs
),
max_concurrent_requests=getattr(
args,
f"{prefix}max_concurrent_requests",
RouterArgs.max_concurrent_requests,
),
queue_size=getattr(
args,
f"{prefix}queue_size",
RouterArgs.queue_size,
),
queue_timeout_secs=getattr(
args,
f"{prefix}queue_timeout_secs",
RouterArgs.queue_timeout_secs,
),
rate_limit_tokens_per_second=getattr(
args,
f"{prefix}rate_limit_tokens_per_second",
RouterArgs.rate_limit_tokens_per_second,
),
cors_allowed_origins=getattr(args, f"{prefix}cors_allowed_origins", []),
retry_max_retries=getattr(args, f"{prefix}retry_max_retries"),
retry_initial_backoff_ms=getattr(args, f"{prefix}retry_initial_backoff_ms"),
retry_max_backoff_ms=getattr(args, f"{prefix}retry_max_backoff_ms"),
retry_backoff_multiplier=getattr(args, f"{prefix}retry_backoff_multiplier"),
retry_jitter_factor=getattr(args, f"{prefix}retry_jitter_factor"),
cb_failure_threshold=getattr(args, f"{prefix}cb_failure_threshold"),
cb_success_threshold=getattr(args, f"{prefix}cb_success_threshold"),
cb_timeout_duration_secs=getattr(args, f"{prefix}cb_timeout_duration_secs"),
cb_window_duration_secs=getattr(args, f"{prefix}cb_window_duration_secs"),
disable_retries=getattr(args, f"{prefix}disable_retries", False),
disable_circuit_breaker=getattr(
args, f"{prefix}disable_circuit_breaker", False
),
health_failure_threshold=getattr(
args,
f"{prefix}health_failure_threshold",
RouterArgs.health_failure_threshold,
),
health_success_threshold=getattr(
args,
f"{prefix}health_success_threshold",
RouterArgs.health_success_threshold,
),
health_check_timeout_secs=getattr(
args,
f"{prefix}health_check_timeout_secs",
RouterArgs.health_check_timeout_secs,
),
health_check_interval_secs=getattr(
args,
f"{prefix}health_check_interval_secs",
RouterArgs.health_check_interval_secs,
),
health_check_endpoint=getattr(
args, f"{prefix}health_check_endpoint", RouterArgs.health_check_endpoint
),
model_path=getattr(args, f"{prefix}model_path", None),
tokenizer_path=getattr(args, f"{prefix}tokenizer_path", None),
)
@staticmethod
def _parse_selector(selector_list):
if not selector_list:
return {}
selector = {}
for item in selector_list:
if "=" in item:
key, value = item.split("=", 1)
selector[key] = value
return selector
@staticmethod
def _parse_prefill_urls(prefill_list):
"""Parse prefill URLs from --prefill arguments.
Format: --prefill URL [BOOTSTRAP_PORT]
Example:
--prefill http://prefill1:8080 9000 # With bootstrap port
--prefill http://prefill2:8080 none # Explicitly no bootstrap port
--prefill http://prefill3:8080 # Defaults to no bootstrap port
"""
if not prefill_list:
return []
prefill_urls = []
for prefill_args in prefill_list:
url = prefill_args[0]
# Handle optional bootstrap port
if len(prefill_args) >= 2:
bootstrap_port_str = prefill_args[1]
# Handle 'none' as None
if bootstrap_port_str.lower() == "none":
bootstrap_port = None
else:
try:
bootstrap_port = int(bootstrap_port_str)
except ValueError:
raise ValueError(
f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'"
)
else:
# No bootstrap port specified, default to None
bootstrap_port = None
prefill_urls.append((url, bootstrap_port))
return prefill_urls
@staticmethod
def _parse_decode_urls(decode_list):
"""Parse decode URLs from --decode arguments.
Format: --decode URL
Example: --decode http://decode1:8081 --decode http://decode2:8081
"""
if not decode_list:
return []
# decode_list is a list of single-element lists due to nargs=1
return [url[0] for url in decode_list]
def policy_from_str(policy_str: str) -> PolicyType:
"""Convert policy string to PolicyType enum."""
policy_map = {
"random": PolicyType.Random,
"round_robin": PolicyType.RoundRobin,
"cache_aware": PolicyType.CacheAware,
"power_of_two": PolicyType.PowerOfTwo,
}
return policy_map[policy_str]
def launch_router(args: argparse.Namespace) -> Optional[Router]:
"""
......@@ -661,7 +29,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
Returns:
Router instance if successful, None if failed
"""
logger = logging.getLogger("router")
setproctitle.setproctitle("sglang::router")
try:
# Convert to RouterArgs if needed
if not isinstance(args, RouterArgs):
......@@ -669,120 +37,15 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
else:
router_args = args
# Validate configuration based on mode
if router_args.pd_disaggregation:
# Validate PD configuration - skip URL requirements if using service discovery
if not router_args.service_discovery:
if not router_args.prefill_urls:
raise ValueError("PD disaggregation mode requires --prefill")
if not router_args.decode_urls:
raise ValueError("PD disaggregation mode requires --decode")
# Warn about policy usage in PD mode
if (
router_args.prefill_policy
and router_args.decode_policy
and router_args.policy
):
logger.warning(
"Both --prefill-policy and --decode-policy are specified. "
"The main --policy flag will be ignored for PD mode."
)
elif (
router_args.prefill_policy
and not router_args.decode_policy
and router_args.policy
):
logger.info(
f"Using --prefill-policy '{router_args.prefill_policy}' for prefill nodes "
f"and --policy '{router_args.policy}' for decode nodes."
)
elif (
router_args.decode_policy
and not router_args.prefill_policy
and router_args.policy
):
logger.info(
f"Using --policy '{router_args.policy}' for prefill nodes "
f"and --decode-policy '{router_args.decode_policy}' for decode nodes."
)
# Create router with unified constructor
router = Router(
worker_urls=(
[]
if router_args.service_discovery or router_args.pd_disaggregation
else router_args.worker_urls
),
host=router_args.host,
port=router_args.port,
policy=policy_from_str(router_args.policy),
worker_startup_timeout_secs=router_args.worker_startup_timeout_secs,
worker_startup_check_interval=router_args.worker_startup_check_interval,
cache_threshold=router_args.cache_threshold,
balance_abs_threshold=router_args.balance_abs_threshold,
balance_rel_threshold=router_args.balance_rel_threshold,
eviction_interval_secs=router_args.eviction_interval,
max_tree_size=router_args.max_tree_size,
max_payload_size=router_args.max_payload_size,
dp_aware=router_args.dp_aware,
api_key=router_args.api_key,
log_dir=router_args.log_dir,
log_level=router_args.log_level,
service_discovery=router_args.service_discovery,
selector=router_args.selector,
service_discovery_port=router_args.service_discovery_port,
service_discovery_namespace=router_args.service_discovery_namespace,
prefill_selector=router_args.prefill_selector,
decode_selector=router_args.decode_selector,
prometheus_port=router_args.prometheus_port,
prometheus_host=router_args.prometheus_host,
request_timeout_secs=router_args.request_timeout_secs,
pd_disaggregation=router_args.pd_disaggregation,
prefill_urls=(
router_args.prefill_urls if router_args.pd_disaggregation else None
),
decode_urls=(
router_args.decode_urls if router_args.pd_disaggregation else None
),
prefill_policy=(
policy_from_str(router_args.prefill_policy)
if router_args.prefill_policy
else None
),
decode_policy=(
policy_from_str(router_args.decode_policy)
if router_args.decode_policy
else None
),
request_id_headers=router_args.request_id_headers,
max_concurrent_requests=router_args.max_concurrent_requests,
queue_size=router_args.queue_size,
queue_timeout_secs=router_args.queue_timeout_secs,
rate_limit_tokens_per_second=router_args.rate_limit_tokens_per_second,
cors_allowed_origins=router_args.cors_allowed_origins,
retry_max_retries=router_args.retry_max_retries,
retry_initial_backoff_ms=router_args.retry_initial_backoff_ms,
retry_max_backoff_ms=router_args.retry_max_backoff_ms,
retry_backoff_multiplier=router_args.retry_backoff_multiplier,
retry_jitter_factor=router_args.retry_jitter_factor,
cb_failure_threshold=router_args.cb_failure_threshold,
cb_success_threshold=router_args.cb_success_threshold,
cb_timeout_duration_secs=router_args.cb_timeout_duration_secs,
cb_window_duration_secs=router_args.cb_window_duration_secs,
disable_retries=router_args.disable_retries,
disable_circuit_breaker=router_args.disable_circuit_breaker,
health_failure_threshold=router_args.health_failure_threshold,
health_success_threshold=router_args.health_success_threshold,
health_check_timeout_secs=router_args.health_check_timeout_secs,
health_check_interval_secs=router_args.health_check_interval_secs,
health_check_endpoint=router_args.health_check_endpoint,
model_path=router_args.model_path,
tokenizer_path=router_args.tokenizer_path,
)
router.start()
return router
if router_args.mini_lb:
mini_lb = MiniLoadBalancer(router_args)
mini_lb.start()
else:
if Router is None:
raise RuntimeError("Rust Router is not installed")
router_args._validate_router_args()
router = Router.from_args(router_args)
router.start()
except Exception as e:
logger.error(f"Error starting router: {e}")
......
"""
Minimal HTTP load balancer for prefill and decode servers for testing.
"""
import asyncio
import ipaddress
import logging
import random
import urllib
from http import HTTPStatus
from itertools import chain
from typing import Optional
import aiohttp
import orjson
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from sglang_router.router_args import RouterArgs
logger = logging.getLogger(__name__)
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
1024 * 64
) # 64KB, to prevent aiohttp's "Chunk too big" error
def maybe_wrap_ipv6_address(address: str) -> str:
try:
ipaddress.IPv6Address(address)
return f"[{address}]"
except ValueError:
return address
class MiniLoadBalancer:
def __init__(
self,
router_args: RouterArgs,
):
self._validate_router_args(router_args)
self.host = router_args.host
self.port = router_args.port
self.timeout = router_args.request_timeout_secs
self.prefill_urls = [url[0] for url in router_args.prefill_urls]
self.prefill_bootstrap_ports = [url[1] for url in router_args.prefill_urls]
self.decode_urls = router_args.decode_urls
def _validate_router_args(self, router_args: RouterArgs):
logger.warning(
"\x1b[33mMiniLB is only for debugging purposes, it only supports random policy!\033[0m"
)
# NOTE: too many arguments unsupported, just validate some important ones
if router_args.policy != "random":
logger.warning("[MiniLB] Overriding policy to random")
router_args.policy = "random"
if not router_args.pd_disaggregation:
raise ValueError("MiniLB only supports PD disaggregation mode")
if len(router_args.prefill_urls) == 0 or len(router_args.decode_urls) == 0:
raise ValueError(
"MiniLB requires at least one prefill and one decode server"
)
def start(self):
global lb
lb = self
uvicorn.run(app, host=self.host, port=self.port)
def select_pair(self):
assert len(self.prefill_urls) > 0, "No prefill servers available"
assert len(self.decode_urls) > 0, "No decode servers available"
pidx = random.randint(0, len(self.prefill_urls) - 1)
didx = random.randint(0, len(self.decode_urls) - 1)
return (
self.prefill_urls[pidx],
self.prefill_bootstrap_ports[pidx],
self.decode_urls[didx],
)
async def generate(
self, modified_request, prefill_server, decode_server, endpoint
) -> ORJSONResponse:
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=self.timeout
) # Add timeout for request reliability
) as session:
tasks = [
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
session.post(f"{decode_server}/{endpoint}", json=modified_request),
]
# Wait for both responses to complete. Prefill should end first.
prefill_response, decode_response = await asyncio.gather(*tasks)
if "return_logprob" in modified_request:
prefill_json = await prefill_response.json()
ret_json = await decode_response.json()
# merge `meta_info.input_token_logprobs` from prefill to decode
if "meta_info" in ret_json:
if "input_token_logprobs" in ret_json["meta_info"]:
ret_json["meta_info"]["input_token_logprobs"] = (
prefill_json["meta_info"]["input_token_logprobs"]
+ ret_json["meta_info"]["input_token_logprobs"]
)
else:
ret_json = await decode_response.json()
return ORJSONResponse(
content=ret_json,
status_code=decode_response.status,
)
async def generate_stream(
self, modified_request, prefill_server, decode_server, endpoint="generate"
):
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
async def stream_results():
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=self.timeout
) # Add timeout for request reliability
) as session:
# Create the tasks for both prefill and decode requests
tasks = [
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
session.post(f"{decode_server}/{endpoint}", json=modified_request),
]
# Wait for both responses to complete. Since this is streaming, they return immediately.
prefill_response, decode_response = await asyncio.gather(*tasks)
if modified_request.get("return_logprob", False):
prefill_chunks = []
async for chunk in prefill_response.content:
prefill_chunks.append(chunk)
first_prefill_chunk = (
prefill_chunks[0].decode("utf-8")[5:].strip("\n")
)
first_prefill_chunk_json = orjson.loads(first_prefill_chunk)
async for chunk in decode_response.content:
# Note: This is inefficient
# merge prefill input_token_logprobs, output_token_logprobs to decode
decoded_chunk = chunk.decode("utf-8")
if (
decoded_chunk
and decoded_chunk.startswith("data:")
and "[DONE]" not in decoded_chunk
):
ret_json = orjson.loads(decoded_chunk[5:].strip("\n"))
ret_json["meta_info"]["input_token_logprobs"] = (
first_prefill_chunk_json["meta_info"][
"input_token_logprobs"
]
+ ret_json["meta_info"]["input_token_logprobs"]
)
yield b"data: " + orjson.dumps(ret_json) + b"\n\n"
else:
yield chunk
else:
async for chunk in decode_response.content.iter_chunked(
AIOHTTP_STREAM_READ_CHUNK_SIZE
):
yield chunk
return StreamingResponse(
stream_results(),
media_type="text/event-stream",
)
app = FastAPI()
lb: Optional[MiniLoadBalancer] = None
@app.get("/health")
async def health_check():
return Response(status_code=200)
@app.get("/health_generate")
async def health_generate():
async with aiohttp.ClientSession() as session:
# Create the tasks
tasks = []
for server in chain(lb.prefill_urls, lb.decode_urls):
tasks.append(session.get(f"{server}/health_generate"))
for i, response in enumerate(asyncio.as_completed(tasks)):
await response
return Response(status_code=200)
@app.post("/flush_cache")
async def flush_cache():
async with aiohttp.ClientSession() as session:
# Create the tasks
tasks = []
for server in chain(lb.prefill_urls, lb.decode_urls):
tasks.append(session.post(f"{server}/flush_cache"))
for i, response in enumerate(asyncio.as_completed(tasks)):
await response
return Response(status_code=200)
@app.get("/get_server_info")
async def get_server_info():
prefill_infos = []
decode_infos = []
all_internal_states = []
async with aiohttp.ClientSession() as session:
for server in lb.prefill_urls:
server_info = await session.get(f"{server}/get_server_info")
prefill_infos.append(await server_info.json())
for server in lb.decode_urls:
server_info = await session.get(f"{server}/get_server_info")
info_json = await server_info.json()
decode_infos.append(info_json)
# Extract internal_states from decode servers
if "internal_states" in info_json:
all_internal_states.extend(info_json["internal_states"])
# Return format expected by bench_one_batch_server.py
if all_internal_states:
return {
"internal_states": all_internal_states,
"prefill": prefill_infos,
"decode": decode_infos,
}
else:
# Fallback with dummy data if no internal states found
return {
"internal_states": [
{
"last_gen_throughput": 0.0,
"avg_spec_accept_length": None,
}
],
"prefill": prefill_infos,
"decode": decode_infos,
}
@app.get("/get_model_info")
async def get_model_info():
if not lb or not lb.prefill_urls:
raise HTTPException(
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
detail="There is no server registered",
)
target_server_url = lb.prefill_urls[0]
endpoint_url = f"{target_server_url}/get_model_info"
async with aiohttp.ClientSession() as session:
try:
async with session.get(endpoint_url) as response:
if response.status != 200:
error_text = await response.text()
raise HTTPException(
status_code=HTTPStatus.BAD_GATEWAY,
detail=(
f"Failed to get model info from {target_server_url}"
f"Status: {response.status}, Response: {error_text}"
),
)
model_info_json = await response.json()
return ORJSONResponse(content=model_info_json)
except aiohttp.ClientError as e:
raise HTTPException(
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
detail=f"Failed to get model info from backend",
)
@app.post("/generate")
async def handle_generate_request(request_data: dict):
prefill_server, bootstrap_port, decode_server = lb.select_pair()
# Parse and transform prefill_server for bootstrap data
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
modified_request = request_data.copy()
batch_size = _get_request_batch_size(modified_request)
if batch_size is not None:
modified_request.update(
{
"bootstrap_host": [hostname] * batch_size,
"bootstrap_port": [bootstrap_port] * batch_size,
"bootstrap_room": [
_generate_bootstrap_room() for _ in range(batch_size)
],
}
)
else:
modified_request.update(
{
"bootstrap_host": hostname,
"bootstrap_port": bootstrap_port,
"bootstrap_room": _generate_bootstrap_room(),
}
)
if request_data.get("stream", False):
return await lb.generate_stream(
modified_request, prefill_server, decode_server, "generate"
)
else:
return await lb.generate(
modified_request, prefill_server, decode_server, "generate"
)
async def _forward_to_backend(request_data: dict, endpoint_name: str):
prefill_server, bootstrap_port, decode_server = lb.select_pair()
# Parse and transform prefill_server for bootstrap data
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
modified_request = request_data.copy()
modified_request.update(
{
"bootstrap_host": hostname,
"bootstrap_port": bootstrap_port,
"bootstrap_room": _generate_bootstrap_room(),
}
)
if request_data.get("stream", False):
return await lb.generate_stream(
modified_request,
prefill_server,
decode_server,
endpoint=endpoint_name,
)
else:
return await lb.generate(
modified_request,
prefill_server,
decode_server,
endpoint=endpoint_name,
)
@app.post("/v1/chat/completions")
async def handle_chat_completion_request(request_data: dict):
return await _forward_to_backend(request_data, "v1/chat/completions")
@app.post("/v1/completions")
async def handle_completion_request(request_data: dict):
return await _forward_to_backend(request_data, "v1/completions")
def _generate_bootstrap_room():
return random.randint(0, 2**63 - 1)
# We may utilize `GenerateReqInput`'s logic later
def _get_request_batch_size(request):
if (text := request.get("text")) is not None:
return None if isinstance(text, str) else len(text)
if (input_ids := request.get("input_ids")) is not None:
return None if isinstance(input_ids[0], int) else len(input_ids)
return None
@app.get("/v1/models")
async def get_models():
prefill_server = lb.prefill_urls[0] # Get the first prefill server
async with aiohttp.ClientSession() as session:
try:
response = await session.get(f"{prefill_server}/v1/models")
if response.status != 200:
raise HTTPException(
status_code=response.status,
detail=f"Prefill server error: Status {response.status}",
)
return ORJSONResponse(content=await response.json())
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
from typing import Dict, List, Optional
from typing import Optional
from sglang_router.router_args import RouterArgs
from sglang_router_rs import PolicyType
from sglang_router_rs import Router as _Router
def policy_from_str(policy_str: Optional[str]) -> PolicyType:
"""Convert policy string to PolicyType enum."""
if policy_str is None:
return None
policy_map = {
"random": PolicyType.Random,
"round_robin": PolicyType.RoundRobin,
"cache_aware": PolicyType.CacheAware,
"power_of_two": PolicyType.PowerOfTwo,
}
return policy_map[policy_str]
class Router:
"""
A high-performance router for distributing requests across worker nodes.
......@@ -78,130 +92,34 @@ class Router:
tokenizer_path: Explicit tokenizer path (overrides model_path tokenizer if provided). Default: None
"""
def __init__(
self,
worker_urls: List[str],
policy: PolicyType = PolicyType.RoundRobin,
host: str = "127.0.0.1",
port: int = 3001,
worker_startup_timeout_secs: int = 600,
worker_startup_check_interval: int = 30,
cache_threshold: float = 0.3,
balance_abs_threshold: int = 64,
balance_rel_threshold: float = 1.5,
eviction_interval_secs: int = 120,
max_tree_size: int = 2**26,
max_payload_size: int = 512 * 1024 * 1024, # 512MB
dp_aware: bool = False,
api_key: Optional[str] = None,
log_dir: Optional[str] = None,
log_level: Optional[str] = None,
service_discovery: bool = False,
selector: Dict[str, str] = None,
service_discovery_port: int = 80,
service_discovery_namespace: Optional[str] = None,
prefill_selector: Dict[str, str] = None,
decode_selector: Dict[str, str] = None,
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port",
prometheus_port: Optional[int] = None,
prometheus_host: Optional[str] = None,
request_timeout_secs: int = 1800,
request_id_headers: Optional[List[str]] = None,
pd_disaggregation: bool = False,
prefill_urls: Optional[List[tuple]] = None,
decode_urls: Optional[List[str]] = None,
prefill_policy: Optional[PolicyType] = None,
decode_policy: Optional[PolicyType] = None,
max_concurrent_requests: int = 256,
queue_size: int = 100,
queue_timeout_secs: int = 60,
rate_limit_tokens_per_second: Optional[int] = None,
cors_allowed_origins: List[str] = None,
retry_max_retries: int = 5,
retry_initial_backoff_ms: int = 50,
retry_max_backoff_ms: int = 30_000,
retry_backoff_multiplier: float = 1.5,
retry_jitter_factor: float = 0.2,
cb_failure_threshold: int = 10,
cb_success_threshold: int = 3,
cb_timeout_duration_secs: int = 60,
cb_window_duration_secs: int = 120,
disable_retries: bool = False,
disable_circuit_breaker: bool = False,
health_failure_threshold: int = 3,
health_success_threshold: int = 2,
health_check_timeout_secs: int = 5,
health_check_interval_secs: int = 60,
health_check_endpoint: str = "/health",
model_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
):
if selector is None:
selector = {}
if prefill_selector is None:
prefill_selector = {}
if decode_selector is None:
decode_selector = {}
if cors_allowed_origins is None:
cors_allowed_origins = []
def __init__(self, router: _Router):
self._router = router
self._router = _Router(
worker_urls=worker_urls,
policy=policy,
host=host,
port=port,
worker_startup_timeout_secs=worker_startup_timeout_secs,
worker_startup_check_interval=worker_startup_check_interval,
cache_threshold=cache_threshold,
balance_abs_threshold=balance_abs_threshold,
balance_rel_threshold=balance_rel_threshold,
eviction_interval_secs=eviction_interval_secs,
max_tree_size=max_tree_size,
max_payload_size=max_payload_size,
dp_aware=dp_aware,
api_key=api_key,
log_dir=log_dir,
log_level=log_level,
service_discovery=service_discovery,
selector=selector,
service_discovery_port=service_discovery_port,
service_discovery_namespace=service_discovery_namespace,
prefill_selector=prefill_selector,
decode_selector=decode_selector,
bootstrap_port_annotation=bootstrap_port_annotation,
prometheus_port=prometheus_port,
prometheus_host=prometheus_host,
request_timeout_secs=request_timeout_secs,
request_id_headers=request_id_headers,
pd_disaggregation=pd_disaggregation,
prefill_urls=prefill_urls,
decode_urls=decode_urls,
prefill_policy=prefill_policy,
decode_policy=decode_policy,
max_concurrent_requests=max_concurrent_requests,
queue_size=queue_size,
queue_timeout_secs=queue_timeout_secs,
rate_limit_tokens_per_second=rate_limit_tokens_per_second,
cors_allowed_origins=cors_allowed_origins,
retry_max_retries=retry_max_retries,
retry_initial_backoff_ms=retry_initial_backoff_ms,
retry_max_backoff_ms=retry_max_backoff_ms,
retry_backoff_multiplier=retry_backoff_multiplier,
retry_jitter_factor=retry_jitter_factor,
cb_failure_threshold=cb_failure_threshold,
cb_success_threshold=cb_success_threshold,
cb_timeout_duration_secs=cb_timeout_duration_secs,
cb_window_duration_secs=cb_window_duration_secs,
disable_retries=disable_retries,
disable_circuit_breaker=disable_circuit_breaker,
health_failure_threshold=health_failure_threshold,
health_success_threshold=health_success_threshold,
health_check_timeout_secs=health_check_timeout_secs,
health_check_interval_secs=health_check_interval_secs,
health_check_endpoint=health_check_endpoint,
model_path=model_path,
tokenizer_path=tokenizer_path,
@staticmethod
def from_args(args: RouterArgs) -> "Router":
"""Create a router from a RouterArgs instance."""
args_dict = vars(args)
# Convert RouterArgs to _Router parameters
args_dict["worker_urls"] = (
[]
if args_dict["service_discovery"] or args_dict["pd_disaggregation"]
else args_dict["worker_urls"]
)
args_dict["policy"] = policy_from_str(args_dict["policy"])
args_dict["prefill_urls"] = (
args_dict["prefill_urls"] if args_dict["pd_disaggregation"] else None
)
args_dict["decode_urls"] = (
args_dict["decode_urls"] if args_dict["pd_disaggregation"] else None
)
args_dict["prefill_policy"] = policy_from_str(args_dict["prefill_policy"])
args_dict["decode_policy"] = policy_from_str(args_dict["decode_policy"])
# remoge mini_lb parameter
args_dict.pop("mini_lb")
return Router(_Router(**args_dict))
def start(self) -> None:
"""Start the router server.
......
import argparse
import dataclasses
import logging
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class RouterArgs:
# Worker configuration
worker_urls: List[str] = dataclasses.field(default_factory=list)
host: str = "127.0.0.1"
port: int = 30000
# PD-specific configuration
mini_lb: bool = False
pd_disaggregation: bool = False # Enable PD disaggregated mode
prefill_urls: List[tuple] = dataclasses.field(
default_factory=list
) # List of (url, bootstrap_port)
decode_urls: List[str] = dataclasses.field(default_factory=list)
# Routing policy
policy: str = "cache_aware"
prefill_policy: Optional[str] = None # Specific policy for prefill nodes in PD mode
decode_policy: Optional[str] = None # Specific policy for decode nodes in PD mode
worker_startup_timeout_secs: int = 600
worker_startup_check_interval: int = 30
cache_threshold: float = 0.3
balance_abs_threshold: int = 64
balance_rel_threshold: float = 1.5
eviction_interval_secs: int = 120
max_tree_size: int = 2**26
max_payload_size: int = 512 * 1024 * 1024 # 512MB default for large batches
dp_aware: bool = False
api_key: Optional[str] = None
log_dir: Optional[str] = None
log_level: Optional[str] = None
# Service discovery configuration
service_discovery: bool = False
selector: Dict[str, str] = dataclasses.field(default_factory=dict)
service_discovery_port: int = 80
service_discovery_namespace: Optional[str] = None
# PD service discovery configuration
prefill_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
decode_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port"
# Prometheus configuration
prometheus_port: Optional[int] = None
prometheus_host: Optional[str] = None
# Request ID headers configuration
request_id_headers: Optional[List[str]] = None
# Request timeout in seconds
request_timeout_secs: int = 1800
# Max concurrent requests for rate limiting
max_concurrent_requests: int = 256
# Queue size for pending requests when max concurrent limit reached
queue_size: int = 100
# Maximum time (in seconds) a request can wait in queue before timing out
queue_timeout_secs: int = 60
# Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests
rate_limit_tokens_per_second: Optional[int] = None
# CORS allowed origins
cors_allowed_origins: List[str] = dataclasses.field(default_factory=list)
# Retry configuration
retry_max_retries: int = 5
retry_initial_backoff_ms: int = 50
retry_max_backoff_ms: int = 30_000
retry_backoff_multiplier: float = 1.5
retry_jitter_factor: float = 0.2
disable_retries: bool = False
# Health check configuration
health_failure_threshold: int = 3
health_success_threshold: int = 2
health_check_timeout_secs: int = 5
health_check_interval_secs: int = 60
health_check_endpoint: str = "/health"
# Circuit breaker configuration
cb_failure_threshold: int = 10
cb_success_threshold: int = 3
cb_timeout_duration_secs: int = 60
cb_window_duration_secs: int = 120
disable_circuit_breaker: bool = False
# Tokenizer configuration
model_path: Optional[str] = None
tokenizer_path: Optional[str] = None
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser,
use_router_prefix: bool = False,
exclude_host_port: bool = False,
):
"""
Add router-specific arguments to an argument parser.
Args:
parser: The argument parser to add arguments to
use_router_prefix: If True, prefix all arguments with 'router-' to avoid conflicts
exclude_host_port: If True, don't add host and port arguments (used when inheriting from server)
"""
prefix = "router-" if use_router_prefix else ""
# Worker configuration
if not exclude_host_port:
parser.add_argument(
"--host",
type=str,
default=RouterArgs.host,
help="Host address to bind the router server",
)
parser.add_argument(
"--port",
type=int,
default=RouterArgs.port,
help="Port number to bind the router server",
)
parser.add_argument(
"--worker-urls",
type=str,
nargs="*",
default=[],
help="List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)",
)
# Routing policy configuration
parser.add_argument(
f"--{prefix}policy",
type=str,
default=RouterArgs.policy,
choices=["random", "round_robin", "cache_aware", "power_of_two"],
help="Load balancing policy to use. In PD mode, this is used for both prefill and decode unless overridden",
)
parser.add_argument(
f"--{prefix}prefill-policy",
type=str,
default=None,
choices=["random", "round_robin", "cache_aware", "power_of_two"],
help="Specific policy for prefill nodes in PD mode. If not specified, uses the main policy",
)
parser.add_argument(
f"--{prefix}decode-policy",
type=str,
default=None,
choices=["random", "round_robin", "cache_aware", "power_of_two"],
help="Specific policy for decode nodes in PD mode. If not specified, uses the main policy",
)
# PD-specific arguments
parser.add_argument(
f"--{prefix}mini-lb",
action="store_true",
help="Enable MiniLB",
)
parser.add_argument(
f"--{prefix}pd-disaggregation",
action="store_true",
help="Enable PD (Prefill-Decode) disaggregated mode",
)
parser.add_argument(
f"--{prefix}prefill",
nargs="+",
action="append",
help="Prefill server URL and optional bootstrap port. Can be specified multiple times. "
"Format: --prefill URL [BOOTSTRAP_PORT]. "
"BOOTSTRAP_PORT can be a port number, 'none', or omitted (defaults to none).",
)
parser.add_argument(
f"--{prefix}decode",
nargs=1,
action="append",
metavar=("URL",),
help="Decode server URL. Can be specified multiple times.",
)
parser.add_argument(
f"--{prefix}worker-startup-timeout-secs",
type=int,
default=RouterArgs.worker_startup_timeout_secs,
help="Timeout in seconds for worker startup",
)
parser.add_argument(
f"--{prefix}worker-startup-check-interval",
type=int,
default=RouterArgs.worker_startup_check_interval,
help="Interval in seconds between checks for worker startup",
)
parser.add_argument(
f"--{prefix}cache-threshold",
type=float,
default=RouterArgs.cache_threshold,
help="Cache threshold (0.0-1.0) for cache-aware routing",
)
parser.add_argument(
f"--{prefix}balance-abs-threshold",
type=int,
default=RouterArgs.balance_abs_threshold,
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
)
parser.add_argument(
f"--{prefix}balance-rel-threshold",
type=float,
default=RouterArgs.balance_rel_threshold,
help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware",
)
parser.add_argument(
f"--{prefix}eviction-interval-secs",
type=int,
default=RouterArgs.eviction_interval_secs,
help="Interval in seconds between cache eviction operations",
)
parser.add_argument(
f"--{prefix}max-tree-size",
type=int,
default=RouterArgs.max_tree_size,
help="Maximum size of the approximation tree for cache-aware routing",
)
parser.add_argument(
f"--{prefix}max-payload-size",
type=int,
default=RouterArgs.max_payload_size,
help="Maximum payload size in bytes",
)
parser.add_argument(
f"--{prefix}dp-aware",
action="store_true",
help="Enable data parallelism aware schedule",
)
parser.add_argument(
f"--{prefix}api-key",
type=str,
default=None,
help="The api key used for the authorization with the worker. Useful when the dp aware scheduling strategy is enaled.",
)
parser.add_argument(
f"--{prefix}log-dir",
type=str,
default=None,
help="Directory to store log files. If not specified, logs are only output to console.",
)
parser.add_argument(
f"--{prefix}log-level",
type=str,
default="info",
choices=["debug", "info", "warning", "error", "critical"],
help="Set the logging level. If not specified, defaults to INFO.",
)
parser.add_argument(
f"--{prefix}service-discovery",
action="store_true",
help="Enable Kubernetes service discovery",
)
parser.add_argument(
f"--{prefix}selector",
type=str,
nargs="+",
default={},
help="Label selector for Kubernetes service discovery (format: key1=value1 key2=value2)",
)
parser.add_argument(
f"--{prefix}service-discovery-port",
type=int,
default=RouterArgs.service_discovery_port,
help="Port to use for discovered worker pods",
)
parser.add_argument(
f"--{prefix}service-discovery-namespace",
type=str,
help="Kubernetes namespace to watch for pods. If not provided, watches all namespaces (requires cluster-wide permissions)",
)
parser.add_argument(
f"--{prefix}prefill-selector",
type=str,
nargs="+",
default={},
help="Label selector for prefill server pods in PD mode (format: key1=value1 key2=value2)",
)
parser.add_argument(
f"--{prefix}decode-selector",
type=str,
nargs="+",
default={},
help="Label selector for decode server pods in PD mode (format: key1=value1 key2=value2)",
)
# Prometheus configuration
parser.add_argument(
f"--{prefix}prometheus-port",
type=int,
default=29000,
help="Port to expose Prometheus metrics. If not specified, Prometheus metrics are disabled",
)
parser.add_argument(
f"--{prefix}prometheus-host",
type=str,
default="127.0.0.1",
help="Host address to bind the Prometheus metrics server",
)
parser.add_argument(
f"--{prefix}request-id-headers",
type=str,
nargs="*",
help="Custom HTTP headers to check for request IDs (e.g., x-request-id x-trace-id). If not specified, uses common defaults.",
)
parser.add_argument(
f"--{prefix}request-timeout-secs",
type=int,
default=RouterArgs.request_timeout_secs,
help="Request timeout in seconds",
)
# Retry configuration
parser.add_argument(
f"--{prefix}retry-max-retries",
type=int,
default=RouterArgs.retry_max_retries,
)
parser.add_argument(
f"--{prefix}retry-initial-backoff-ms",
type=int,
default=RouterArgs.retry_initial_backoff_ms,
)
parser.add_argument(
f"--{prefix}retry-max-backoff-ms",
type=int,
default=RouterArgs.retry_max_backoff_ms,
)
parser.add_argument(
f"--{prefix}retry-backoff-multiplier",
type=float,
default=RouterArgs.retry_backoff_multiplier,
)
parser.add_argument(
f"--{prefix}retry-jitter-factor",
type=float,
default=RouterArgs.retry_jitter_factor,
)
parser.add_argument(
f"--{prefix}disable-retries",
action="store_true",
help="Disable retries (equivalent to setting retry_max_retries=1)",
)
# Circuit breaker configuration
parser.add_argument(
f"--{prefix}cb-failure-threshold",
type=int,
default=RouterArgs.cb_failure_threshold,
)
parser.add_argument(
f"--{prefix}cb-success-threshold",
type=int,
default=RouterArgs.cb_success_threshold,
)
parser.add_argument(
f"--{prefix}cb-timeout-duration-secs",
type=int,
default=RouterArgs.cb_timeout_duration_secs,
)
parser.add_argument(
f"--{prefix}cb-window-duration-secs",
type=int,
default=RouterArgs.cb_window_duration_secs,
)
parser.add_argument(
f"--{prefix}disable-circuit-breaker",
action="store_true",
help="Disable circuit breaker (equivalent to setting cb_failure_threshold to u32::MAX)",
)
# Health check configuration
parser.add_argument(
f"--{prefix}health-failure-threshold",
type=int,
default=RouterArgs.health_failure_threshold,
help="Number of consecutive health check failures before marking worker unhealthy",
)
parser.add_argument(
f"--{prefix}health-success-threshold",
type=int,
default=RouterArgs.health_success_threshold,
help="Number of consecutive health check successes before marking worker healthy",
)
parser.add_argument(
f"--{prefix}health-check-timeout-secs",
type=int,
default=RouterArgs.health_check_timeout_secs,
help="Timeout in seconds for health check requests",
)
parser.add_argument(
f"--{prefix}health-check-interval-secs",
type=int,
default=RouterArgs.health_check_interval_secs,
help="Interval in seconds between runtime health checks",
)
parser.add_argument(
f"--{prefix}health-check-endpoint",
type=str,
default=RouterArgs.health_check_endpoint,
help="Health check endpoint path",
)
parser.add_argument(
f"--{prefix}max-concurrent-requests",
type=int,
default=RouterArgs.max_concurrent_requests,
help="Maximum number of concurrent requests allowed (for rate limiting)",
)
parser.add_argument(
f"--{prefix}queue-size",
type=int,
default=RouterArgs.queue_size,
help="Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately)",
)
parser.add_argument(
f"--{prefix}queue-timeout-secs",
type=int,
default=RouterArgs.queue_timeout_secs,
help="Maximum time (in seconds) a request can wait in queue before timing out",
)
parser.add_argument(
f"--{prefix}rate-limit-tokens-per-second",
type=int,
default=RouterArgs.rate_limit_tokens_per_second,
help="Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests",
)
parser.add_argument(
f"--{prefix}cors-allowed-origins",
type=str,
nargs="*",
default=[],
help="CORS allowed origins (e.g., http://localhost:3000 https://example.com)",
)
# Tokenizer configuration
parser.add_argument(
f"--{prefix}model-path",
type=str,
default=None,
help="Model path for loading tokenizer (HuggingFace model ID or local path)",
)
parser.add_argument(
f"--{prefix}tokenizer-path",
type=str,
default=None,
help="Explicit tokenizer path (overrides model_path tokenizer if provided)",
)
@classmethod
def from_cli_args(
cls, args: argparse.Namespace, use_router_prefix: bool = False
) -> "RouterArgs":
"""
Create RouterArgs instance from parsed command line arguments.
Args:
args: Parsed command line arguments
use_router_prefix: If True, look for arguments with 'router-' prefix
"""
prefix = "router_" if use_router_prefix else ""
cli_args_dict = vars(args)
args_dict = {}
for attr in dataclasses.fields(cls):
# Auto strip prefix from args
if f"{prefix}{attr.name}" in cli_args_dict:
args_dict[attr.name] = cli_args_dict[f"{prefix}{attr.name}"]
elif attr.name in cli_args_dict:
args_dict[attr.name] = cli_args_dict[attr.name]
# parse special arguments and remove "--prefill" and "--decode" from cli_args_dict
args_dict["prefill_urls"] = cls._parse_prefill_urls(
cli_args_dict.get(f"{prefix}prefill", None)
)
args_dict["decode_urls"] = cls._parse_decode_urls(
cli_args_dict.get(f"{prefix}decode", None)
)
args_dict["selector"] = cls._parse_selector(
cli_args_dict.get(f"{prefix}selector", None)
)
args_dict["prefill_selector"] = cls._parse_selector(
cli_args_dict.get(f"{prefix}prefill_selector", None)
)
args_dict["decode_selector"] = cls._parse_selector(
cli_args_dict.get(f"{prefix}decode_selector", None)
)
# Mooncake-specific annotation
args_dict["bootstrap_port_annotation"] = "sglang.ai/bootstrap-port"
return cls(**args_dict)
def _validate_router_args(self):
# Validate configuration based on mode
if self.pd_disaggregation:
# Validate PD configuration - skip URL requirements if using service discovery
if not self.service_discovery:
if not self.prefill_urls:
raise ValueError("PD disaggregation mode requires --prefill")
if not self.decode_urls:
raise ValueError("PD disaggregation mode requires --decode")
# Warn about policy usage in PD mode
if self.prefill_policy and self.decode_policy and self.policy:
logger.warning(
"Both --prefill-policy and --decode-policy are specified. "
"The main --policy flag will be ignored for PD mode."
)
elif self.prefill_policy and not self.decode_policy and self.policy:
logger.info(
f"Using --prefill-policy '{self.prefill_policy}' for prefill nodes "
f"and --policy '{self.policy}' for decode nodes."
)
elif self.decode_policy and not self.prefill_policy and self.policy:
logger.info(
f"Using --policy '{self.policy}' for prefill nodes "
f"and --decode-policy '{self.decode_policy}' for decode nodes."
)
@staticmethod
def _parse_selector(selector_list):
if not selector_list:
return {}
selector = {}
for item in selector_list:
if "=" in item:
key, value = item.split("=", 1)
selector[key] = value
return selector
@staticmethod
def _parse_prefill_urls(prefill_list):
"""Parse prefill URLs from --prefill arguments.
Format: --prefill URL [BOOTSTRAP_PORT]
Example:
--prefill http://prefill1:8080 9000 # With bootstrap port
--prefill http://prefill2:8080 none # Explicitly no bootstrap port
--prefill http://prefill3:8080 # Defaults to no bootstrap port
"""
if not prefill_list:
return []
prefill_urls = []
for prefill_args in prefill_list:
url = prefill_args[0]
# Handle optional bootstrap port
if len(prefill_args) >= 2:
bootstrap_port_str = prefill_args[1]
# Handle 'none' as None
if bootstrap_port_str.lower() == "none":
bootstrap_port = None
else:
try:
bootstrap_port = int(bootstrap_port_str)
except ValueError:
raise ValueError(
f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'"
)
else:
# No bootstrap port specified, default to None
bootstrap_port = None
prefill_urls.append((url, bootstrap_port))
return prefill_urls
@staticmethod
def _parse_decode_urls(decode_list):
"""Parse decode URLs from --decode arguments.
Format: --decode URL
Example: --decode http://decode1:8081 --decode http://decode2:8081
"""
if not decode_list:
return []
# decode_list is a list of single-element lists due to nargs=1
return [url[0] for url in decode_list]
......@@ -33,7 +33,7 @@ class TestLaunchRouter(unittest.TestCase):
cache_threshold=0.5,
balance_abs_threshold=32,
balance_rel_threshold=1.0001,
eviction_interval=60,
eviction_interval_secs=60,
max_tree_size=2**24,
max_payload_size=256 * 1024 * 1024, # 256MB
verbose=False,
......@@ -176,9 +176,8 @@ class TestLaunchRouter(unittest.TestCase):
"""Test basic PD router functionality without actually starting servers."""
# This test just verifies the PD router can be created and configured
# without actually starting it (which would require real prefill/decode servers)
from sglang_router import Router
from sglang_router.launch_router import RouterArgs
from sglang_router_rs import PolicyType
from sglang_router.router import PolicyType, Router
# Test RouterArgs parsing for PD mode
# Simulate the parsed args structure from argparse with action="append"
......@@ -209,18 +208,7 @@ class TestLaunchRouter(unittest.TestCase):
self.assertEqual(router_args.decode_urls[1], "http://decode2:8081")
# Test Router creation in PD mode
router = Router(
worker_urls=[], # Empty for PD mode
pd_disaggregation=True,
prefill_urls=[
("http://prefill1:8080", 9000),
("http://prefill2:8080", None),
],
decode_urls=["http://decode1:8081", "http://decode2:8081"],
policy=PolicyType.CacheAware,
host="127.0.0.1",
port=3001,
)
router = Router.from_args(router_args)
self.assertIsNotNone(router)
def test_policy_validation(self):
......
......@@ -77,7 +77,7 @@ def popen_launch_router(
port,
"--dp",
str(dp_size),
"--router-eviction-interval",
"--router-eviction-interval-secs",
"5",
"--router-policy",
policy,
......
......@@ -28,8 +28,3 @@ find = { where = ["py_src"] }
# workaround for https://github.com/pypa/twine/issues/1216
[tool.setuptools]
license-files = []
[[tool.setuptools-rust.ext-modules]]
target = "sglang_router_rs"
path = "Cargo.toml"
binding = "PyO3"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment