Unverified Commit 6e95f5e5 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Simplify `Router` arguments passing and build it in docker image (#9964)

parent 0e9387a9
...@@ -36,7 +36,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ ...@@ -36,7 +36,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
ibverbs-providers infiniband-diags perftest \ ibverbs-providers infiniband-diags perftest \
libgoogle-glog-dev libgtest-dev libjsoncpp-dev libunwind-dev \ libgoogle-glog-dev libgtest-dev libjsoncpp-dev libunwind-dev \
libboost-all-dev libssl-dev \ libboost-all-dev libssl-dev \
libgrpc-dev libgrpc++-dev libprotobuf-dev protobuf-compiler-grpc \ libgrpc-dev libgrpc++-dev libprotobuf-dev protobuf-compiler protobuf-compiler-grpc \
pybind11-dev \ pybind11-dev \
libhiredis-dev libcurl4-openssl-dev \ libhiredis-dev libcurl4-openssl-dev \
libczmq4 libczmq-dev \ libczmq4 libczmq-dev \
...@@ -218,6 +218,19 @@ RUN wget https://github.com/Kitware/CMake/releases/download/v3.31.1/cmake-3.31.1 ...@@ -218,6 +218,19 @@ RUN wget https://github.com/Kitware/CMake/releases/download/v3.31.1/cmake-3.31.1
&& cp -r cmake-3.31.1-linux-x86_64/share/* /usr/local/share/ \ && cp -r cmake-3.31.1-linux-x86_64/share/* /usr/local/share/ \
&& rm -rf cmake-3.31.1-linux-x86_64 cmake-3.31.1-linux-x86_64.tar.gz && rm -rf cmake-3.31.1-linux-x86_64 cmake-3.31.1-linux-x86_64.tar.gz
# Install Rust toolchain for sgl-router
ENV PATH="/root/.cargo/bin:${PATH}"
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y \
&& rustc --version && cargo --version
# Build and install sgl-router
RUN python3 -m pip install --no-cache-dir setuptools-rust \
&& cd /sgl-workspace/sglang/sgl-router \
&& cargo build --release \
&& python3 -m pip install --no-cache-dir . \
&& rm -rf /root/.cache
# Add yank script # Add yank script
COPY --chown=root:root <<-"EOF" /usr/local/bin/yank COPY --chown=root:root <<-"EOF" /usr/local/bin/yank
#!/bin/bash #!/bin/bash
......
...@@ -36,7 +36,7 @@ uv pip install mooncake-transfer-engine ...@@ -36,7 +36,7 @@ uv pip install mooncake-transfer-engine
```bash ```bash
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-ib-device mlx5_roce0 $ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-ib-device mlx5_roce0
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-ib-device mlx5_roce0 $ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-ib-device mlx5_roce0
$ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000 $ python -m sglang_router.launch_router --pd-disaggregation --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
``` ```
### DeepSeek Multi-Node ### DeepSeek Multi-Node
...@@ -100,7 +100,7 @@ pip install . --config-settings=setup-args="-Ducx_path=/path/to/ucx" ...@@ -100,7 +100,7 @@ pip install . --config-settings=setup-args="-Ducx_path=/path/to/ucx"
```bash ```bash
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-transfer-backend nixl $ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-transfer-backend nixl
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-transfer-backend nixl $ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-transfer-backend nixl
$ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000 $ python -m sglang_router.launch_router --pd-disaggregation --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
``` ```
### DeepSeek Multi-Node ### DeepSeek Multi-Node
...@@ -137,7 +137,7 @@ export ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE=true ...@@ -137,7 +137,7 @@ export ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE=true
```bash ```bash
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-transfer-backend ascend $ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-transfer-backend ascend
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-transfer-backend ascend $ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-transfer-backend ascend
$ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000 $ python -m sglang_router.launch_router --pd-disaggregation --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
``` ```
### DeepSeek Multi-Node ### DeepSeek Multi-Node
......
...@@ -278,7 +278,7 @@ The most sophisticated policy that combines cache optimization with load balanci ...@@ -278,7 +278,7 @@ The most sophisticated policy that combines cache optimization with load balanci
3. **Cache Management**: 3. **Cache Management**:
- Maintains approximate radix trees per worker - Maintains approximate radix trees per worker
- Periodically evicts LRU entries based on `--eviction-interval` and `--max-tree-size` - Periodically evicts LRU entries based on `--eviction-interval-secs` and `--max-tree-size`
### Data Parallelism Aware Routing ### Data Parallelism Aware Routing
...@@ -296,7 +296,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu ...@@ -296,7 +296,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
### Core Settings ### Core Settings
| Parameter | Type | Default | Description | | Parameter | Type | Default | Description |
|-----------------------------|------|-------------|-----------------------------------------------------------------| | --------------------------- | ---- | ----------- | --------------------------------------------------------------- |
| `--host` | str | 127.0.0.1 | Router server host address | | `--host` | str | 127.0.0.1 | Router server host address |
| `--port` | int | 30000 | Router server port | | `--port` | int | 30000 | Router server port |
| `--worker-urls` | list | [] | Worker URLs for separate launch mode | | `--worker-urls` | list | [] | Worker URLs for separate launch mode |
...@@ -307,18 +307,18 @@ This mode coordinates with SGLang's DP controller for optimized request distribu ...@@ -307,18 +307,18 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
### Cache-Aware Routing Parameters ### Cache-Aware Routing Parameters
| Parameter | Type | Default | Description | | Parameter | Type | Default | Description |
|---------------------------|-------|----------|--------------------------------------------------------| | -------------------------- | ----- | -------- | ------------------------------------------------------ |
| `--cache-threshold` | float | 0.5 | Minimum prefix match ratio for cache routing (0.0-1.0) | | `--cache-threshold` | float | 0.5 | Minimum prefix match ratio for cache routing (0.0-1.0) |
| `--balance-abs-threshold` | int | 32 | Absolute load difference threshold | | `--balance-abs-threshold` | int | 32 | Absolute load difference threshold |
| `--balance-rel-threshold` | float | 1.0001 | Relative load ratio threshold | | `--balance-rel-threshold` | float | 1.0001 | Relative load ratio threshold |
| `--eviction-interval` | int | 60 | Seconds between cache eviction cycles | | `--eviction-interval-secs` | int | 60 | Seconds between cache eviction cycles |
| `--max-tree-size` | int | 16777216 | Maximum nodes in routing tree | | `--max-tree-size` | int | 16777216 | Maximum nodes in routing tree |
### Fault Tolerance Parameters ### Fault Tolerance Parameters
| Parameter | Type | Default | Description | | Parameter | Type | Default | Description |
|------------------------------|-------|---------|---------------------------------------| | ---------------------------- | ----- | ------- | ------------------------------------- |
| `--retry-max-retries` | int | 3 | Maximum retry attempts per request | | `--retry-max-retries` | int | 3 | Maximum retry attempts per request |
| `--retry-initial-backoff-ms` | int | 100 | Initial retry backoff in milliseconds | | `--retry-initial-backoff-ms` | int | 100 | Initial retry backoff in milliseconds |
| `--retry-max-backoff-ms` | int | 10000 | Maximum retry backoff in milliseconds | | `--retry-max-backoff-ms` | int | 10000 | Maximum retry backoff in milliseconds |
...@@ -334,7 +334,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu ...@@ -334,7 +334,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
### Prefill-Decode Disaggregation Parameters ### Prefill-Decode Disaggregation Parameters
| Parameter | Type | Default | Description | | Parameter | Type | Default | Description |
|-----------------------------------|------|---------|-------------------------------------------------------| | --------------------------------- | ---- | ------- | ----------------------------------------------------- |
| `--pd-disaggregation` | flag | False | Enable PD disaggregated mode | | `--pd-disaggregation` | flag | False | Enable PD disaggregated mode |
| `--prefill` | list | [] | Prefill server URLs with optional bootstrap ports | | `--prefill` | list | [] | Prefill server URLs with optional bootstrap ports |
| `--decode` | list | [] | Decode server URLs | | `--decode` | list | [] | Decode server URLs |
...@@ -346,7 +346,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu ...@@ -346,7 +346,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
### Kubernetes Integration ### Kubernetes Integration
| Parameter | Type | Default | Description | | Parameter | Type | Default | Description |
|---------------------------------|------|--------------------------|------------------------------------------------------| | ------------------------------- | ---- | ------------------------ | ---------------------------------------------------- |
| `--service-discovery` | flag | False | Enable Kubernetes service discovery | | `--service-discovery` | flag | False | Enable Kubernetes service discovery |
| `--selector` | list | [] | Label selector for workers (key1=value1 key2=value2) | | `--selector` | list | [] | Label selector for workers (key1=value1 key2=value2) |
| `--prefill-selector` | list | [] | Label selector for prefill servers in PD mode | | `--prefill-selector` | list | [] | Label selector for prefill servers in PD mode |
...@@ -358,7 +358,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu ...@@ -358,7 +358,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
### Observability ### Observability
| Parameter | Type | Default | Description | | Parameter | Type | Default | Description |
|------------------------|------|-----------|-------------------------------------------------------| | ---------------------- | ---- | --------- | ----------------------------------------------------- |
| `--prometheus-port` | int | 29000 | Prometheus metrics port | | `--prometheus-port` | int | 29000 | Prometheus metrics port |
| `--prometheus-host` | str | 127.0.0.1 | Prometheus metrics host | | `--prometheus-host` | str | 127.0.0.1 | Prometheus metrics host |
| `--log-dir` | str | None | Directory for log files | | `--log-dir` | str | None | Directory for log files |
...@@ -368,7 +368,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu ...@@ -368,7 +368,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
### CORS Configuration ### CORS Configuration
| Parameter | Type | Default | Description | | Parameter | Type | Default | Description |
|--------------------------|------|---------|----------------------| | ------------------------ | ---- | ------- | -------------------- |
| `--cors-allowed-origins` | list | [] | Allowed CORS origins | | `--cors-allowed-origins` | list | [] | Allowed CORS origins |
## Advanced Features ## Advanced Features
...@@ -429,7 +429,7 @@ python -m sglang_router.launch_router \ ...@@ -429,7 +429,7 @@ python -m sglang_router.launch_router \
2. **High latency**: Check if cache-aware routing is causing imbalance. Try adjusting `--balance-abs-threshold` and `--balance-rel-threshold`. 2. **High latency**: Check if cache-aware routing is causing imbalance. Try adjusting `--balance-abs-threshold` and `--balance-rel-threshold`.
3. **Memory growth**: Reduce `--max-tree-size` or decrease `--eviction-interval` for more aggressive cache cleanup. 3. **Memory growth**: Reduce `--max-tree-size` or decrease `--eviction-interval-secs` for more aggressive cache cleanup.
4. **Circuit breaker triggering frequently**: Increase `--cb-failure-threshold` or extend `--cb-window-duration-secs`. 4. **Circuit breaker triggering frequently**: Increase `--cb-failure-threshold` or extend `--cb-window-duration-secs`.
......
...@@ -27,7 +27,8 @@ spec: ...@@ -27,7 +27,8 @@ spec:
command: command:
- python - python
- -m - -m
- sglang.srt.disaggregation.mini_lb - sglang_router.launch_router
- --pd-disaggregation
- --prefill - --prefill
- http://deepseekr10528-prefill-main:30000 - http://deepseekr10528-prefill-main:30000
- --decode - --decode
......
...@@ -714,7 +714,8 @@ spec: ...@@ -714,7 +714,8 @@ spec:
command: command:
- python - python
- -m - -m
- sglang.srt.disaggregation.mini_lb - sglang_router.launch_router
- --pd-disaggregation
- --prefill - --prefill
- http://deepseekr10528-prefill-main:30000 - http://deepseekr10528-prefill-main:30000
- --decode - --decode
......
import argparse
import dataclasses
from sglang.srt.disaggregation.mini_lb import PrefillConfig, run
@dataclasses.dataclass
class LBArgs:
host: str = "0.0.0.0"
port: int = 8000
policy: str = "random"
prefill_infos: list = dataclasses.field(default_factory=list)
decode_infos: list = dataclasses.field(default_factory=list)
log_interval: int = 5
timeout: int = 600
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--host",
type=str,
default=LBArgs.host,
help=f"Host to bind the server (default: {LBArgs.host})",
)
parser.add_argument(
"--port",
type=int,
default=LBArgs.port,
help=f"Port to bind the server (default: {LBArgs.port})",
)
parser.add_argument(
"--policy",
type=str,
default=LBArgs.policy,
choices=["random", "po2"],
help=f"Policy to use for load balancing (default: {LBArgs.policy})",
)
parser.add_argument(
"--prefill",
type=str,
default=[],
nargs="+",
help="URLs for prefill servers",
)
parser.add_argument(
"--decode",
type=str,
default=[],
nargs="+",
help="URLs for decode servers",
)
parser.add_argument(
"--prefill-bootstrap-ports",
type=int,
nargs="+",
help="Bootstrap ports for prefill servers",
)
parser.add_argument(
"--log-interval",
type=int,
default=LBArgs.log_interval,
help=f"Log interval in seconds (default: {LBArgs.log_interval})",
)
parser.add_argument(
"--timeout",
type=int,
default=LBArgs.timeout,
help=f"Timeout in seconds (default: {LBArgs.timeout})",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "LBArgs":
bootstrap_ports = args.prefill_bootstrap_ports
if bootstrap_ports is None:
bootstrap_ports = [None] * len(args.prefill)
elif len(bootstrap_ports) == 1:
bootstrap_ports = bootstrap_ports * len(args.prefill)
else:
if len(bootstrap_ports) != len(args.prefill):
raise ValueError(
"Number of prefill URLs must match number of bootstrap ports"
)
prefill_infos = [
(url, port) for url, port in zip(args.prefill, bootstrap_ports)
]
return cls(
host=args.host,
port=args.port,
policy=args.policy,
prefill_infos=prefill_infos,
decode_infos=args.decode,
log_interval=args.log_interval,
timeout=args.timeout,
)
def main():
parser = argparse.ArgumentParser(
description="PD Disaggregation Load Balancer Server"
)
LBArgs.add_cli_args(parser)
args = parser.parse_args()
lb_args = LBArgs.from_cli_args(args)
prefill_configs = [PrefillConfig(url, port) for url, port in lb_args.prefill_infos]
run(
prefill_configs,
lb_args.decode_infos,
lb_args.host,
lb_args.port,
lb_args.timeout,
)
if __name__ == "__main__":
main()
""" raise RuntimeError(
Minimal HTTP load balancer for prefill and decode servers for testing. """The 'mini_lb' module has been relocated to the 'sglang_router' package.
""" We recommend installing 'sglang-router' with Rust support for optimal performance.
If you encounter issues building the router with Rust, set the environment variable
import asyncio 'SGLANG_ROUTER_BUILD_NO_RUST=1' and add '--mini-lb' to the command line to use the Python version of 'mini_lb'."""
import dataclasses )
import logging
import random
import urllib
from http import HTTPStatus
from itertools import chain
from typing import List, Optional
import aiohttp
import orjson
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from sglang.srt.disaggregation.utils import PDRegistryRequest
from sglang.srt.utils import maybe_wrap_ipv6_address
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
1024 * 64
) # 64KB, to prevent aiohttp's "Chunk too big" error
def setup_logger():
logger = logging.getLogger("pdlb")
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
"[PDLB (Python)] %(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
logger = setup_logger()
@dataclasses.dataclass
class PrefillConfig:
url: str
bootstrap_port: Optional[int] = None
class MiniLoadBalancer:
def __init__(
self,
prefill_configs: List[PrefillConfig],
decode_servers: List[str],
timeout: int,
):
self.prefill_configs = prefill_configs
self.prefill_servers = [p.url for p in prefill_configs]
self.decode_servers = decode_servers
self.timeout = timeout
def add_prefill_server(self, new_prefill_config: PrefillConfig):
self.prefill_configs.append(new_prefill_config)
self.prefill_servers.append(new_prefill_config.url)
def add_decode_server(self, new_decode_server: str):
self.decode_servers.append(new_decode_server)
def select_pair(self):
# TODO: return some message instead of panic
assert len(self.prefill_configs) > 0, "No prefill servers available"
assert len(self.decode_servers) > 0, "No decode servers available"
prefill_config = random.choice(self.prefill_configs)
decode_server = random.choice(self.decode_servers)
return prefill_config.url, prefill_config.bootstrap_port, decode_server
async def generate(
self, modified_request, prefill_server, decode_server, endpoint
) -> ORJSONResponse:
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=self.timeout
) # Add timeout for request reliability
) as session:
tasks = [
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
session.post(f"{decode_server}/{endpoint}", json=modified_request),
]
# Wait for both responses to complete. Prefill should end first.
prefill_response, decode_response = await asyncio.gather(*tasks)
if "return_logprob" in modified_request:
prefill_json = await prefill_response.json()
ret_json = await decode_response.json()
# merge `meta_info.input_token_logprobs` from prefill to decode
if "meta_info" in ret_json:
if "input_token_logprobs" in ret_json["meta_info"]:
ret_json["meta_info"]["input_token_logprobs"] = (
prefill_json["meta_info"]["input_token_logprobs"]
+ ret_json["meta_info"]["input_token_logprobs"]
)
else:
ret_json = await decode_response.json()
return ORJSONResponse(
content=ret_json,
status_code=decode_response.status,
)
async def generate_stream(
self, modified_request, prefill_server, decode_server, endpoint="generate"
):
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
async def stream_results():
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=self.timeout
) # Add timeout for request reliability
) as session:
# Create the tasks for both prefill and decode requests
tasks = [
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
session.post(f"{decode_server}/{endpoint}", json=modified_request),
]
# Wait for both responses to complete. Since this is streaming, they return immediately.
prefill_response, decode_response = await asyncio.gather(*tasks)
if modified_request.get("return_logprob", False):
prefill_chunks = []
async for chunk in prefill_response.content:
prefill_chunks.append(chunk)
first_prefill_chunk = (
prefill_chunks[0].decode("utf-8")[5:].strip("\n")
)
first_prefill_chunk_json = orjson.loads(first_prefill_chunk)
async for chunk in decode_response.content:
# Note: This is inefficient
# merge prefill input_token_logprobs, output_token_logprobs to decode
decoded_chunk = chunk.decode("utf-8")
if (
decoded_chunk
and decoded_chunk.startswith("data:")
and "[DONE]" not in decoded_chunk
):
ret_json = orjson.loads(decoded_chunk[5:].strip("\n"))
ret_json["meta_info"]["input_token_logprobs"] = (
first_prefill_chunk_json["meta_info"][
"input_token_logprobs"
]
+ ret_json["meta_info"]["input_token_logprobs"]
)
yield b"data: " + orjson.dumps(ret_json) + b"\n\n"
else:
yield chunk
else:
async for chunk in decode_response.content.iter_chunked(
AIOHTTP_STREAM_READ_CHUNK_SIZE
):
yield chunk
return StreamingResponse(
stream_results(),
media_type="text/event-stream",
)
app = FastAPI()
load_balancer: Optional[MiniLoadBalancer] = None
@app.get("/health")
async def health_check():
return Response(status_code=200)
@app.get("/health_generate")
async def health_generate():
prefill_servers, decode_servers = (
load_balancer.prefill_servers,
load_balancer.decode_servers,
)
async with aiohttp.ClientSession() as session:
# Create the tasks
tasks = []
for server in chain(prefill_servers, decode_servers):
tasks.append(session.get(f"{server}/health_generate"))
for i, response in enumerate(asyncio.as_completed(tasks)):
await response
return Response(status_code=200)
@app.post("/flush_cache")
async def flush_cache():
prefill_servers, decode_servers = (
load_balancer.prefill_servers,
load_balancer.decode_servers,
)
async with aiohttp.ClientSession() as session:
# Create the tasks
tasks = []
for server in chain(prefill_servers, decode_servers):
tasks.append(session.post(f"{server}/flush_cache"))
for i, response in enumerate(asyncio.as_completed(tasks)):
await response
return Response(status_code=200)
@app.get("/get_server_info")
async def get_server_info():
prefill_servers, decode_servers = (
load_balancer.prefill_servers,
load_balancer.decode_servers,
)
prefill_infos = []
decode_infos = []
all_internal_states = []
async with aiohttp.ClientSession() as session:
for server in chain(prefill_servers):
server_info = await session.get(f"{server}/get_server_info")
prefill_infos.append(await server_info.json())
for server in chain(decode_servers):
server_info = await session.get(f"{server}/get_server_info")
info_json = await server_info.json()
decode_infos.append(info_json)
# Extract internal_states from decode servers
if "internal_states" in info_json:
all_internal_states.extend(info_json["internal_states"])
# Return format expected by bench_one_batch_server.py
if all_internal_states:
return {
"internal_states": all_internal_states,
"prefill": prefill_infos,
"decode": decode_infos,
}
else:
# Fallback with dummy data if no internal states found
return {
"internal_states": [
{
"last_gen_throughput": 0.0,
"avg_spec_accept_length": None,
}
],
"prefill": prefill_infos,
"decode": decode_infos,
}
@app.get("/get_model_info")
async def get_model_info():
global load_balancer
if not load_balancer or not load_balancer.prefill_servers:
raise HTTPException(
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
detail="There is no server registered",
)
target_server_url = load_balancer.prefill_servers[0]
endpoint_url = f"{target_server_url}/get_model_info"
async with aiohttp.ClientSession() as session:
try:
async with session.get(endpoint_url) as response:
if response.status != 200:
error_text = await response.text()
raise HTTPException(
status_code=HTTPStatus.BAD_GATEWAY,
detail=(
f"Failed to get model info from {target_server_url}"
f"Status: {response.status}, Response: {error_text}"
),
)
model_info_json = await response.json()
return ORJSONResponse(content=model_info_json)
except aiohttp.ClientError as e:
raise HTTPException(
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
detail=f"Failed to get model info from backend",
)
@app.post("/generate")
async def handle_generate_request(request_data: dict):
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
# Parse and transform prefill_server for bootstrap data
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
modified_request = request_data.copy()
batch_size = _get_request_batch_size(modified_request)
if batch_size is not None:
modified_request.update(
{
"bootstrap_host": [hostname] * batch_size,
"bootstrap_port": [bootstrap_port] * batch_size,
"bootstrap_room": [
_generate_bootstrap_room() for _ in range(batch_size)
],
}
)
else:
modified_request.update(
{
"bootstrap_host": hostname,
"bootstrap_port": bootstrap_port,
"bootstrap_room": _generate_bootstrap_room(),
}
)
if request_data.get("stream", False):
return await load_balancer.generate_stream(
modified_request, prefill_server, decode_server, "generate"
)
else:
return await load_balancer.generate(
modified_request, prefill_server, decode_server, "generate"
)
async def _forward_to_backend(request_data: dict, endpoint_name: str):
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
# Parse and transform prefill_server for bootstrap data
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
modified_request = request_data.copy()
modified_request.update(
{
"bootstrap_host": hostname,
"bootstrap_port": bootstrap_port,
"bootstrap_room": _generate_bootstrap_room(),
}
)
if request_data.get("stream", False):
return await load_balancer.generate_stream(
modified_request,
prefill_server,
decode_server,
endpoint=endpoint_name,
)
else:
return await load_balancer.generate(
modified_request,
prefill_server,
decode_server,
endpoint=endpoint_name,
)
@app.post("/v1/chat/completions")
async def handle_chat_completion_request(request_data: dict):
return await _forward_to_backend(request_data, "v1/chat/completions")
@app.post("/v1/completions")
async def handle_completion_request(request_data: dict):
return await _forward_to_backend(request_data, "v1/completions")
def _generate_bootstrap_room():
return random.randint(0, 2**63 - 1)
# We may utilize `GenerateReqInput`'s logic later
def _get_request_batch_size(request):
if (text := request.get("text")) is not None:
return None if isinstance(text, str) else len(text)
if (input_ids := request.get("input_ids")) is not None:
return None if isinstance(input_ids[0], int) else len(input_ids)
return None
@app.get("/v1/models")
async def get_models():
prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server
async with aiohttp.ClientSession() as session:
try:
response = await session.get(f"{prefill_server}/v1/models")
if response.status != 200:
raise HTTPException(
status_code=response.status,
detail=f"Prefill server error: Status {response.status}",
)
return ORJSONResponse(content=await response.json())
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/register")
async def register(obj: PDRegistryRequest):
if obj.mode == "prefill":
load_balancer.add_prefill_server(
PrefillConfig(obj.registry_url, obj.bootstrap_port)
)
logger.info(
f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}"
)
elif obj.mode == "decode":
load_balancer.add_decode_server(obj.registry_url)
logger.info(f"Registered decode server: {obj.registry_url}")
else:
raise HTTPException(
status_code=400,
detail="Invalid mode. Must be either PREFILL or DECODE.",
)
logger.info(
f"#Prefill servers: {len(load_balancer.prefill_configs)}, "
f"#Decode servers: {len(load_balancer.decode_servers)}"
)
return Response(status_code=200)
def run(prefill_configs, decode_addrs, host, port, timeout):
global load_balancer
load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs, timeout=timeout)
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
# FIXME: remove this, use the unified entry point: sglang.srt.disaggregation.launch_lb
from sglang.srt.disaggregation.launch_lb import main
main()
from __future__ import annotations from __future__ import annotations
import dataclasses
import os import os
import random import random
import threading
import warnings
from collections import deque from collections import deque
from contextlib import nullcontext from contextlib import nullcontext
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, Optional
import numpy as np import numpy as np
import requests
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from sglang.srt.utils import get_ip, is_npu from sglang.srt.utils import is_npu
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.schedule_batch import Req
...@@ -305,49 +301,6 @@ def kv_to_page_num(num_kv_indices: int, page_size: int): ...@@ -305,49 +301,6 @@ def kv_to_page_num(num_kv_indices: int, page_size: int):
return (num_kv_indices + page_size - 1) // page_size return (num_kv_indices + page_size - 1) // page_size
#########################
# PDLB Registry
#########################
@dataclasses.dataclass
class PDRegistryRequest:
"""A request to register a machine itself to the LB."""
mode: str
registry_url: str
bootstrap_port: Optional[int] = None
def __post_init__(self):
if self.mode == "prefill" and self.bootstrap_port is None:
raise ValueError("Bootstrap port must be set in PREFILL mode.")
elif self.mode == "decode" and self.bootstrap_port is not None:
raise ValueError("Bootstrap port must not be set in DECODE mode.")
elif self.mode not in ["prefill", "decode"]:
raise ValueError(
f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
)
def register_disaggregation_server(
mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
):
boostrap_port = bootstrap_port if mode == "prefill" else None
registry_request = PDRegistryRequest(
mode=mode,
registry_url=f"http://{get_ip()}:{server_port}",
bootstrap_port=boostrap_port,
)
res = requests.post(
f"{pdlb_url}/register",
json=dataclasses.asdict(registry_request),
)
if res.status_code != 200:
warnings.warn(
f"Failed to register disaggregation server: {res.status_code} {res.text}"
)
######################### #########################
# Misc # Misc
######################### #########################
......
...@@ -47,11 +47,7 @@ from fastapi.exceptions import RequestValidationError ...@@ -47,11 +47,7 @@ from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse, Response, StreamingResponse from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from sglang.srt.disaggregation.utils import ( from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode
FAKE_BOOTSTRAP_HOST,
DisaggregationMode,
register_disaggregation_server,
)
from sglang.srt.entrypoints.engine import _launch_subprocesses from sglang.srt.entrypoints.engine import _launch_subprocesses
from sglang.srt.entrypoints.openai.protocol import ( from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
...@@ -1405,13 +1401,5 @@ def _wait_and_warmup( ...@@ -1405,13 +1401,5 @@ def _wait_and_warmup(
if server_args.debug_tensor_dump_input_file: if server_args.debug_tensor_dump_input_file:
kill_process_tree(os.getpid()) kill_process_tree(os.getpid())
if server_args.pdlb_url is not None:
register_disaggregation_server(
server_args.disaggregation_mode,
server_args.port,
server_args.disaggregation_bootstrap_port,
server_args.pdlb_url,
)
if launch_callback is not None: if launch_callback is not None:
launch_callback() launch_callback()
...@@ -367,7 +367,6 @@ class ServerArgs: ...@@ -367,7 +367,6 @@ class ServerArgs:
disaggregation_prefill_pp: Optional[int] = 1 disaggregation_prefill_pp: Optional[int] = 1
disaggregation_ib_device: Optional[str] = None disaggregation_ib_device: Optional[str] = None
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
pdlb_url: Optional[str] = None
# For model weight update # For model weight update
custom_weight_loader: Optional[List[str]] = None custom_weight_loader: Optional[List[str]] = None
...@@ -2071,12 +2070,6 @@ class ServerArgs: ...@@ -2071,12 +2070,6 @@ class ServerArgs:
default=ServerArgs.num_reserved_decode_tokens, default=ServerArgs.num_reserved_decode_tokens,
help="Number of decode tokens that will have memory reserved when adding new request to the running batch.", help="Number of decode tokens that will have memory reserved when adding new request to the running batch.",
) )
parser.add_argument(
"--pdlb-url",
type=str,
default=None,
help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
)
# Custom weight loader # Custom weight loader
parser.add_argument( parser.add_argument(
......
...@@ -466,6 +466,25 @@ def try_cached_model(model_repo: str): ...@@ -466,6 +466,25 @@ def try_cached_model(model_repo: str):
return model_dir if model_dir else model_repo return model_dir if model_dir else model_repo
def popen_with_error_check(command: list[str], allow_exit: bool = False):
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
def _run_and_check():
stdout, stderr = process.communicate()
while process.poll() is None:
time.sleep(5)
if not allow_exit or process.returncode != 0:
raise Exception(
f"{command} exited with code {process.returncode}\n{stdout=}\n{stderr=}"
)
t = threading.Thread(target=_run_and_check)
t.start()
return process
def popen_launch_server( def popen_launch_server(
model: str, model: str,
base_url: str, base_url: str,
......
...@@ -45,6 +45,10 @@ fi ...@@ -45,6 +45,10 @@ fi
# Install the main package # Install the main package
$PIP_CMD install -e "python[dev]" --extra-index-url https://download.pytorch.org/whl/${CU_VERSION} $PIP_INSTALL_SUFFIX $PIP_CMD install -e "python[dev]" --extra-index-url https://download.pytorch.org/whl/${CU_VERSION} $PIP_INSTALL_SUFFIX
# Install router for pd-disagg test
SGLANG_ROUTER_BUILD_NO_RUST=1 $PIP_CMD install -e "sgl-router" $PIP_INSTALL_SUFFIX
if [ "$IS_BLACKWELL" = "1" ]; then if [ "$IS_BLACKWELL" = "1" ]; then
# TODO auto determine sgl-kernel version # TODO auto determine sgl-kernel version
SGL_KERNEL_VERSION=0.3.8 SGL_KERNEL_VERSION=0.3.8
......
# a lightweihgt wrapper on router with argument type and comments
# no wrapper on policy type => direct export
from sglang_router.router import Router
from sglang_router.version import __version__ from sglang_router.version import __version__
from sglang_router_rs import PolicyType
__all__ = ["Router", "PolicyType", "__version__"] __all__ = ["__version__"]
"""
Minimal HTTP load balancer for prefill and decode servers for testing.
"""
import asyncio
import ipaddress
import logging
import random
import urllib
from http import HTTPStatus
from itertools import chain
from typing import Optional
import aiohttp
import orjson
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from sglang_router.router_args import RouterArgs
logger = logging.getLogger(__name__)
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
1024 * 64
) # 64KB, to prevent aiohttp's "Chunk too big" error
def maybe_wrap_ipv6_address(address: str) -> str:
try:
ipaddress.IPv6Address(address)
return f"[{address}]"
except ValueError:
return address
class MiniLoadBalancer:
def __init__(
self,
router_args: RouterArgs,
):
self._validate_router_args(router_args)
self.host = router_args.host
self.port = router_args.port
self.timeout = router_args.request_timeout_secs
self.prefill_urls = [url[0] for url in router_args.prefill_urls]
self.prefill_bootstrap_ports = [url[1] for url in router_args.prefill_urls]
self.decode_urls = router_args.decode_urls
def _validate_router_args(self, router_args: RouterArgs):
logger.warning(
"\x1b[33mMiniLB is only for debugging purposes, it only supports random policy!\033[0m"
)
# NOTE: too many arguments unsupported, just validate some important ones
if router_args.policy != "random":
logger.warning("[MiniLB] Overriding policy to random")
router_args.policy = "random"
if not router_args.pd_disaggregation:
raise ValueError("MiniLB only supports PD disaggregation mode")
if len(router_args.prefill_urls) == 0 or len(router_args.decode_urls) == 0:
raise ValueError(
"MiniLB requires at least one prefill and one decode server"
)
def start(self):
global lb
lb = self
uvicorn.run(app, host=self.host, port=self.port)
def select_pair(self):
assert len(self.prefill_urls) > 0, "No prefill servers available"
assert len(self.decode_urls) > 0, "No decode servers available"
pidx = random.randint(0, len(self.prefill_urls) - 1)
didx = random.randint(0, len(self.decode_urls) - 1)
return (
self.prefill_urls[pidx],
self.prefill_bootstrap_ports[pidx],
self.decode_urls[didx],
)
async def generate(
self, modified_request, prefill_server, decode_server, endpoint
) -> ORJSONResponse:
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=self.timeout
) # Add timeout for request reliability
) as session:
tasks = [
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
session.post(f"{decode_server}/{endpoint}", json=modified_request),
]
# Wait for both responses to complete. Prefill should end first.
prefill_response, decode_response = await asyncio.gather(*tasks)
if "return_logprob" in modified_request:
prefill_json = await prefill_response.json()
ret_json = await decode_response.json()
# merge `meta_info.input_token_logprobs` from prefill to decode
if "meta_info" in ret_json:
if "input_token_logprobs" in ret_json["meta_info"]:
ret_json["meta_info"]["input_token_logprobs"] = (
prefill_json["meta_info"]["input_token_logprobs"]
+ ret_json["meta_info"]["input_token_logprobs"]
)
else:
ret_json = await decode_response.json()
return ORJSONResponse(
content=ret_json,
status_code=decode_response.status,
)
async def generate_stream(
self, modified_request, prefill_server, decode_server, endpoint="generate"
):
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
async def stream_results():
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=self.timeout
) # Add timeout for request reliability
) as session:
# Create the tasks for both prefill and decode requests
tasks = [
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
session.post(f"{decode_server}/{endpoint}", json=modified_request),
]
# Wait for both responses to complete. Since this is streaming, they return immediately.
prefill_response, decode_response = await asyncio.gather(*tasks)
if modified_request.get("return_logprob", False):
prefill_chunks = []
async for chunk in prefill_response.content:
prefill_chunks.append(chunk)
first_prefill_chunk = (
prefill_chunks[0].decode("utf-8")[5:].strip("\n")
)
first_prefill_chunk_json = orjson.loads(first_prefill_chunk)
async for chunk in decode_response.content:
# Note: This is inefficient
# merge prefill input_token_logprobs, output_token_logprobs to decode
decoded_chunk = chunk.decode("utf-8")
if (
decoded_chunk
and decoded_chunk.startswith("data:")
and "[DONE]" not in decoded_chunk
):
ret_json = orjson.loads(decoded_chunk[5:].strip("\n"))
ret_json["meta_info"]["input_token_logprobs"] = (
first_prefill_chunk_json["meta_info"][
"input_token_logprobs"
]
+ ret_json["meta_info"]["input_token_logprobs"]
)
yield b"data: " + orjson.dumps(ret_json) + b"\n\n"
else:
yield chunk
else:
async for chunk in decode_response.content.iter_chunked(
AIOHTTP_STREAM_READ_CHUNK_SIZE
):
yield chunk
return StreamingResponse(
stream_results(),
media_type="text/event-stream",
)
app = FastAPI()
lb: Optional[MiniLoadBalancer] = None
@app.get("/health")
async def health_check():
return Response(status_code=200)
@app.get("/health_generate")
async def health_generate():
async with aiohttp.ClientSession() as session:
# Create the tasks
tasks = []
for server in chain(lb.prefill_urls, lb.decode_urls):
tasks.append(session.get(f"{server}/health_generate"))
for i, response in enumerate(asyncio.as_completed(tasks)):
await response
return Response(status_code=200)
@app.post("/flush_cache")
async def flush_cache():
async with aiohttp.ClientSession() as session:
# Create the tasks
tasks = []
for server in chain(lb.prefill_urls, lb.decode_urls):
tasks.append(session.post(f"{server}/flush_cache"))
for i, response in enumerate(asyncio.as_completed(tasks)):
await response
return Response(status_code=200)
@app.get("/get_server_info")
async def get_server_info():
prefill_infos = []
decode_infos = []
all_internal_states = []
async with aiohttp.ClientSession() as session:
for server in lb.prefill_urls:
server_info = await session.get(f"{server}/get_server_info")
prefill_infos.append(await server_info.json())
for server in lb.decode_urls:
server_info = await session.get(f"{server}/get_server_info")
info_json = await server_info.json()
decode_infos.append(info_json)
# Extract internal_states from decode servers
if "internal_states" in info_json:
all_internal_states.extend(info_json["internal_states"])
# Return format expected by bench_one_batch_server.py
if all_internal_states:
return {
"internal_states": all_internal_states,
"prefill": prefill_infos,
"decode": decode_infos,
}
else:
# Fallback with dummy data if no internal states found
return {
"internal_states": [
{
"last_gen_throughput": 0.0,
"avg_spec_accept_length": None,
}
],
"prefill": prefill_infos,
"decode": decode_infos,
}
@app.get("/get_model_info")
async def get_model_info():
if not lb or not lb.prefill_urls:
raise HTTPException(
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
detail="There is no server registered",
)
target_server_url = lb.prefill_urls[0]
endpoint_url = f"{target_server_url}/get_model_info"
async with aiohttp.ClientSession() as session:
try:
async with session.get(endpoint_url) as response:
if response.status != 200:
error_text = await response.text()
raise HTTPException(
status_code=HTTPStatus.BAD_GATEWAY,
detail=(
f"Failed to get model info from {target_server_url}"
f"Status: {response.status}, Response: {error_text}"
),
)
model_info_json = await response.json()
return ORJSONResponse(content=model_info_json)
except aiohttp.ClientError as e:
raise HTTPException(
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
detail=f"Failed to get model info from backend",
)
@app.post("/generate")
async def handle_generate_request(request_data: dict):
prefill_server, bootstrap_port, decode_server = lb.select_pair()
# Parse and transform prefill_server for bootstrap data
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
modified_request = request_data.copy()
batch_size = _get_request_batch_size(modified_request)
if batch_size is not None:
modified_request.update(
{
"bootstrap_host": [hostname] * batch_size,
"bootstrap_port": [bootstrap_port] * batch_size,
"bootstrap_room": [
_generate_bootstrap_room() for _ in range(batch_size)
],
}
)
else:
modified_request.update(
{
"bootstrap_host": hostname,
"bootstrap_port": bootstrap_port,
"bootstrap_room": _generate_bootstrap_room(),
}
)
if request_data.get("stream", False):
return await lb.generate_stream(
modified_request, prefill_server, decode_server, "generate"
)
else:
return await lb.generate(
modified_request, prefill_server, decode_server, "generate"
)
async def _forward_to_backend(request_data: dict, endpoint_name: str):
prefill_server, bootstrap_port, decode_server = lb.select_pair()
# Parse and transform prefill_server for bootstrap data
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
modified_request = request_data.copy()
modified_request.update(
{
"bootstrap_host": hostname,
"bootstrap_port": bootstrap_port,
"bootstrap_room": _generate_bootstrap_room(),
}
)
if request_data.get("stream", False):
return await lb.generate_stream(
modified_request,
prefill_server,
decode_server,
endpoint=endpoint_name,
)
else:
return await lb.generate(
modified_request,
prefill_server,
decode_server,
endpoint=endpoint_name,
)
@app.post("/v1/chat/completions")
async def handle_chat_completion_request(request_data: dict):
return await _forward_to_backend(request_data, "v1/chat/completions")
@app.post("/v1/completions")
async def handle_completion_request(request_data: dict):
return await _forward_to_backend(request_data, "v1/completions")
def _generate_bootstrap_room():
return random.randint(0, 2**63 - 1)
# We may utilize `GenerateReqInput`'s logic later
def _get_request_batch_size(request):
if (text := request.get("text")) is not None:
return None if isinstance(text, str) else len(text)
if (input_ids := request.get("input_ids")) is not None:
return None if isinstance(input_ids[0], int) else len(input_ids)
return None
@app.get("/v1/models")
async def get_models():
prefill_server = lb.prefill_urls[0] # Get the first prefill server
async with aiohttp.ClientSession() as session:
try:
response = await session.get(f"{prefill_server}/v1/models")
if response.status != 200:
raise HTTPException(
status_code=response.status,
detail=f"Prefill server error: Status {response.status}",
)
return ORJSONResponse(content=await response.json())
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
from typing import Dict, List, Optional from typing import Optional
from sglang_router.router_args import RouterArgs
from sglang_router_rs import PolicyType from sglang_router_rs import PolicyType
from sglang_router_rs import Router as _Router from sglang_router_rs import Router as _Router
def policy_from_str(policy_str: Optional[str]) -> PolicyType:
"""Convert policy string to PolicyType enum."""
if policy_str is None:
return None
policy_map = {
"random": PolicyType.Random,
"round_robin": PolicyType.RoundRobin,
"cache_aware": PolicyType.CacheAware,
"power_of_two": PolicyType.PowerOfTwo,
}
return policy_map[policy_str]
class Router: class Router:
""" """
A high-performance router for distributing requests across worker nodes. A high-performance router for distributing requests across worker nodes.
...@@ -78,130 +92,34 @@ class Router: ...@@ -78,130 +92,34 @@ class Router:
tokenizer_path: Explicit tokenizer path (overrides model_path tokenizer if provided). Default: None tokenizer_path: Explicit tokenizer path (overrides model_path tokenizer if provided). Default: None
""" """
def __init__( def __init__(self, router: _Router):
self, self._router = router
worker_urls: List[str],
policy: PolicyType = PolicyType.RoundRobin,
host: str = "127.0.0.1",
port: int = 3001,
worker_startup_timeout_secs: int = 600,
worker_startup_check_interval: int = 30,
cache_threshold: float = 0.3,
balance_abs_threshold: int = 64,
balance_rel_threshold: float = 1.5,
eviction_interval_secs: int = 120,
max_tree_size: int = 2**26,
max_payload_size: int = 512 * 1024 * 1024, # 512MB
dp_aware: bool = False,
api_key: Optional[str] = None,
log_dir: Optional[str] = None,
log_level: Optional[str] = None,
service_discovery: bool = False,
selector: Dict[str, str] = None,
service_discovery_port: int = 80,
service_discovery_namespace: Optional[str] = None,
prefill_selector: Dict[str, str] = None,
decode_selector: Dict[str, str] = None,
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port",
prometheus_port: Optional[int] = None,
prometheus_host: Optional[str] = None,
request_timeout_secs: int = 1800,
request_id_headers: Optional[List[str]] = None,
pd_disaggregation: bool = False,
prefill_urls: Optional[List[tuple]] = None,
decode_urls: Optional[List[str]] = None,
prefill_policy: Optional[PolicyType] = None,
decode_policy: Optional[PolicyType] = None,
max_concurrent_requests: int = 256,
queue_size: int = 100,
queue_timeout_secs: int = 60,
rate_limit_tokens_per_second: Optional[int] = None,
cors_allowed_origins: List[str] = None,
retry_max_retries: int = 5,
retry_initial_backoff_ms: int = 50,
retry_max_backoff_ms: int = 30_000,
retry_backoff_multiplier: float = 1.5,
retry_jitter_factor: float = 0.2,
cb_failure_threshold: int = 10,
cb_success_threshold: int = 3,
cb_timeout_duration_secs: int = 60,
cb_window_duration_secs: int = 120,
disable_retries: bool = False,
disable_circuit_breaker: bool = False,
health_failure_threshold: int = 3,
health_success_threshold: int = 2,
health_check_timeout_secs: int = 5,
health_check_interval_secs: int = 60,
health_check_endpoint: str = "/health",
model_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
):
if selector is None:
selector = {}
if prefill_selector is None:
prefill_selector = {}
if decode_selector is None:
decode_selector = {}
if cors_allowed_origins is None:
cors_allowed_origins = []
self._router = _Router( @staticmethod
worker_urls=worker_urls, def from_args(args: RouterArgs) -> "Router":
policy=policy, """Create a router from a RouterArgs instance."""
host=host,
port=port, args_dict = vars(args)
worker_startup_timeout_secs=worker_startup_timeout_secs, # Convert RouterArgs to _Router parameters
worker_startup_check_interval=worker_startup_check_interval, args_dict["worker_urls"] = (
cache_threshold=cache_threshold, []
balance_abs_threshold=balance_abs_threshold, if args_dict["service_discovery"] or args_dict["pd_disaggregation"]
balance_rel_threshold=balance_rel_threshold, else args_dict["worker_urls"]
eviction_interval_secs=eviction_interval_secs, )
max_tree_size=max_tree_size, args_dict["policy"] = policy_from_str(args_dict["policy"])
max_payload_size=max_payload_size, args_dict["prefill_urls"] = (
dp_aware=dp_aware, args_dict["prefill_urls"] if args_dict["pd_disaggregation"] else None
api_key=api_key,
log_dir=log_dir,
log_level=log_level,
service_discovery=service_discovery,
selector=selector,
service_discovery_port=service_discovery_port,
service_discovery_namespace=service_discovery_namespace,
prefill_selector=prefill_selector,
decode_selector=decode_selector,
bootstrap_port_annotation=bootstrap_port_annotation,
prometheus_port=prometheus_port,
prometheus_host=prometheus_host,
request_timeout_secs=request_timeout_secs,
request_id_headers=request_id_headers,
pd_disaggregation=pd_disaggregation,
prefill_urls=prefill_urls,
decode_urls=decode_urls,
prefill_policy=prefill_policy,
decode_policy=decode_policy,
max_concurrent_requests=max_concurrent_requests,
queue_size=queue_size,
queue_timeout_secs=queue_timeout_secs,
rate_limit_tokens_per_second=rate_limit_tokens_per_second,
cors_allowed_origins=cors_allowed_origins,
retry_max_retries=retry_max_retries,
retry_initial_backoff_ms=retry_initial_backoff_ms,
retry_max_backoff_ms=retry_max_backoff_ms,
retry_backoff_multiplier=retry_backoff_multiplier,
retry_jitter_factor=retry_jitter_factor,
cb_failure_threshold=cb_failure_threshold,
cb_success_threshold=cb_success_threshold,
cb_timeout_duration_secs=cb_timeout_duration_secs,
cb_window_duration_secs=cb_window_duration_secs,
disable_retries=disable_retries,
disable_circuit_breaker=disable_circuit_breaker,
health_failure_threshold=health_failure_threshold,
health_success_threshold=health_success_threshold,
health_check_timeout_secs=health_check_timeout_secs,
health_check_interval_secs=health_check_interval_secs,
health_check_endpoint=health_check_endpoint,
model_path=model_path,
tokenizer_path=tokenizer_path,
) )
args_dict["decode_urls"] = (
args_dict["decode_urls"] if args_dict["pd_disaggregation"] else None
)
args_dict["prefill_policy"] = policy_from_str(args_dict["prefill_policy"])
args_dict["decode_policy"] = policy_from_str(args_dict["decode_policy"])
# remoge mini_lb parameter
args_dict.pop("mini_lb")
return Router(_Router(**args_dict))
def start(self) -> None: def start(self) -> None:
"""Start the router server. """Start the router server.
......
This diff is collapsed.
...@@ -33,7 +33,7 @@ class TestLaunchRouter(unittest.TestCase): ...@@ -33,7 +33,7 @@ class TestLaunchRouter(unittest.TestCase):
cache_threshold=0.5, cache_threshold=0.5,
balance_abs_threshold=32, balance_abs_threshold=32,
balance_rel_threshold=1.0001, balance_rel_threshold=1.0001,
eviction_interval=60, eviction_interval_secs=60,
max_tree_size=2**24, max_tree_size=2**24,
max_payload_size=256 * 1024 * 1024, # 256MB max_payload_size=256 * 1024 * 1024, # 256MB
verbose=False, verbose=False,
...@@ -176,9 +176,8 @@ class TestLaunchRouter(unittest.TestCase): ...@@ -176,9 +176,8 @@ class TestLaunchRouter(unittest.TestCase):
"""Test basic PD router functionality without actually starting servers.""" """Test basic PD router functionality without actually starting servers."""
# This test just verifies the PD router can be created and configured # This test just verifies the PD router can be created and configured
# without actually starting it (which would require real prefill/decode servers) # without actually starting it (which would require real prefill/decode servers)
from sglang_router import Router
from sglang_router.launch_router import RouterArgs from sglang_router.launch_router import RouterArgs
from sglang_router_rs import PolicyType from sglang_router.router import PolicyType, Router
# Test RouterArgs parsing for PD mode # Test RouterArgs parsing for PD mode
# Simulate the parsed args structure from argparse with action="append" # Simulate the parsed args structure from argparse with action="append"
...@@ -209,18 +208,7 @@ class TestLaunchRouter(unittest.TestCase): ...@@ -209,18 +208,7 @@ class TestLaunchRouter(unittest.TestCase):
self.assertEqual(router_args.decode_urls[1], "http://decode2:8081") self.assertEqual(router_args.decode_urls[1], "http://decode2:8081")
# Test Router creation in PD mode # Test Router creation in PD mode
router = Router( router = Router.from_args(router_args)
worker_urls=[], # Empty for PD mode
pd_disaggregation=True,
prefill_urls=[
("http://prefill1:8080", 9000),
("http://prefill2:8080", None),
],
decode_urls=["http://decode1:8081", "http://decode2:8081"],
policy=PolicyType.CacheAware,
host="127.0.0.1",
port=3001,
)
self.assertIsNotNone(router) self.assertIsNotNone(router)
def test_policy_validation(self): def test_policy_validation(self):
......
...@@ -77,7 +77,7 @@ def popen_launch_router( ...@@ -77,7 +77,7 @@ def popen_launch_router(
port, port,
"--dp", "--dp",
str(dp_size), str(dp_size),
"--router-eviction-interval", "--router-eviction-interval-secs",
"5", "5",
"--router-policy", "--router-policy",
policy, policy,
......
...@@ -28,8 +28,3 @@ find = { where = ["py_src"] } ...@@ -28,8 +28,3 @@ find = { where = ["py_src"] }
# workaround for https://github.com/pypa/twine/issues/1216 # workaround for https://github.com/pypa/twine/issues/1216
[tool.setuptools] [tool.setuptools]
license-files = [] license-files = []
[[tool.setuptools-rust.ext-modules]]
target = "sglang_router_rs"
path = "Cargo.toml"
binding = "PyO3"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment