"tests/vscode:/vscode.git/clone" did not exist on "65d136e067350dcea8cf9f72da1df79599e4bbb8"
Unverified Commit 09ae5b20 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

Merge PDLB (Prefill-Decode Load Balancer) into SGLang Router (#7096)

parent 712bf9ec
...@@ -218,15 +218,39 @@ async def get_server_info(): ...@@ -218,15 +218,39 @@ async def get_server_info():
) )
prefill_infos = [] prefill_infos = []
decode_infos = [] decode_infos = []
all_internal_states = []
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
for server in chain(prefill_servers): for server in chain(prefill_servers):
server_info = await session.get(f"{server}/get_server_info") server_info = await session.get(f"{server}/get_server_info")
prefill_infos.append(await server_info.json()) prefill_infos.append(await server_info.json())
for server in chain(decode_servers): for server in chain(decode_servers):
server_info = await session.get(f"{server}/get_server_info") server_info = await session.get(f"{server}/get_server_info")
decode_infos.append(await server_info.json()) info_json = await server_info.json()
decode_infos.append(info_json)
return {"prefill": prefill_infos, "decode": decode_infos} # Extract internal_states from decode servers
if "internal_states" in info_json:
all_internal_states.extend(info_json["internal_states"])
# Return format expected by bench_one_batch_server.py
if all_internal_states:
return {
"internal_states": all_internal_states,
"prefill": prefill_infos,
"decode": decode_infos,
}
else:
# Fallback with dummy data if no internal states found
return {
"internal_states": [
{
"last_gen_throughput": 0.0,
"avg_spec_accept_length": None,
}
],
"prefill": prefill_infos,
"decode": decode_infos,
}
@app.get("/get_model_info") @app.get("/get_model_info")
......
...@@ -15,7 +15,7 @@ serde = { version = "1.0", features = ["derive"] } ...@@ -15,7 +15,7 @@ serde = { version = "1.0", features = ["derive"] }
clap = { version = "4.4", features = ["derive"] } clap = { version = "4.4", features = ["derive"] }
bytes = "1.8.0" bytes = "1.8.0"
rand = "0.8.5" rand = "0.8.5"
reqwest = { version = "0.12.8", features = ["stream", "blocking"] } reqwest = { version = "0.12.8", features = ["stream", "blocking", "json"] }
futures-util = "0.3" futures-util = "0.3"
serde_json = "1.0" serde_json = "1.0"
pyo3 = { version = "0.22.5", features = ["extension-module"] } pyo3 = { version = "0.22.5", features = ["extension-module"] }
...@@ -33,6 +33,8 @@ futures = "0.3" ...@@ -33,6 +33,8 @@ futures = "0.3"
# Added for metrics # Added for metrics
metrics = "0.24.2" metrics = "0.24.2"
metrics-exporter-prometheus = "0.17.0" metrics-exporter-prometheus = "0.17.0"
# Added for request tracing
uuid = { version = "1.10", features = ["v4", "serde"] }
[profile.release] [profile.release]
lto = "thin" lto = "thin"
codegen-units = 1 codegen-units = 1
...@@ -31,6 +31,13 @@ class RouterArgs: ...@@ -31,6 +31,13 @@ class RouterArgs:
host: str = "127.0.0.1" host: str = "127.0.0.1"
port: int = 30000 port: int = 30000
# PD-specific configuration
pd_disaggregated: bool = False # Enable PD disaggregated mode
prefill_urls: List[tuple] = dataclasses.field(
default_factory=list
) # List of (url, bootstrap_port)
decode_urls: List[str] = dataclasses.field(default_factory=list)
# Routing policy # Routing policy
policy: str = "cache_aware" policy: str = "cache_aware"
worker_startup_timeout_secs: int = 300 worker_startup_timeout_secs: int = 300
...@@ -40,7 +47,7 @@ class RouterArgs: ...@@ -40,7 +47,7 @@ class RouterArgs:
balance_rel_threshold: float = 1.0001 balance_rel_threshold: float = 1.0001
eviction_interval: int = 60 eviction_interval: int = 60
max_tree_size: int = 2**24 max_tree_size: int = 2**24
max_payload_size: int = 4 * 1024 * 1024 # 4MB max_payload_size: int = 256 * 1024 * 1024 # 256MB default for large batches
verbose: bool = False verbose: bool = False
log_dir: Optional[str] = None log_dir: Optional[str] = None
# Service discovery configuration # Service discovery configuration
...@@ -95,8 +102,29 @@ class RouterArgs: ...@@ -95,8 +102,29 @@ class RouterArgs:
f"--{prefix}policy", f"--{prefix}policy",
type=str, type=str,
default=RouterArgs.policy, default=RouterArgs.policy,
choices=["random", "round_robin", "cache_aware"], choices=["random", "round_robin", "cache_aware", "power_of_two"],
help="Load balancing policy to use", help="Load balancing policy to use. Note: power_of_two is only available in PD disaggregated mode",
)
# PD-specific arguments
parser.add_argument(
f"--{prefix}pd-disaggregated",
action="store_true",
help="Enable PD (Prefill-Decode) disaggregated mode",
)
parser.add_argument(
f"--{prefix}prefill",
nargs=2,
action="append",
metavar=("URL", "BOOTSTRAP_PORT"),
help="Prefill server URL and bootstrap port. Can be specified multiple times. BOOTSTRAP_PORT can be 'none' for no bootstrap port.",
)
parser.add_argument(
f"--{prefix}decode",
nargs=1,
action="append",
metavar=("URL",),
help="Decode server URL. Can be specified multiple times.",
) )
parser.add_argument( parser.add_argument(
f"--{prefix}worker-startup-timeout-secs", f"--{prefix}worker-startup-timeout-secs",
...@@ -205,11 +233,19 @@ class RouterArgs: ...@@ -205,11 +233,19 @@ class RouterArgs:
use_router_prefix: If True, look for arguments with 'router-' prefix use_router_prefix: If True, look for arguments with 'router-' prefix
""" """
prefix = "router_" if use_router_prefix else "" prefix = "router_" if use_router_prefix else ""
worker_urls = args.worker_urls if args.worker_urls is not None else [] worker_urls = getattr(args, "worker_urls", [])
# Parse PD URLs
prefill_urls = cls._parse_prefill_urls(getattr(args, f"{prefix}prefill", None))
decode_urls = cls._parse_decode_urls(getattr(args, f"{prefix}decode", None))
return cls( return cls(
worker_urls=worker_urls, worker_urls=worker_urls,
host=args.host, host=args.host,
port=args.port, port=args.port,
pd_disaggregated=getattr(args, f"{prefix}pd_disaggregated", False),
prefill_urls=prefill_urls,
decode_urls=decode_urls,
policy=getattr(args, f"{prefix}policy"), policy=getattr(args, f"{prefix}policy"),
worker_startup_timeout_secs=getattr( worker_startup_timeout_secs=getattr(
args, f"{prefix}worker_startup_timeout_secs" args, f"{prefix}worker_startup_timeout_secs"
...@@ -247,6 +283,46 @@ class RouterArgs: ...@@ -247,6 +283,46 @@ class RouterArgs:
selector[key] = value selector[key] = value
return selector return selector
@staticmethod
def _parse_prefill_urls(prefill_list):
"""Parse prefill URLs from --prefill arguments.
Format: --prefill URL BOOTSTRAP_PORT
Example: --prefill http://prefill1:8080 9000 --prefill http://prefill2:8080 none
"""
if not prefill_list:
return []
prefill_urls = []
for url, bootstrap_port_str in prefill_list:
# Handle 'none' as None
if bootstrap_port_str.lower() == "none":
bootstrap_port = None
else:
try:
bootstrap_port = int(bootstrap_port_str)
except ValueError:
raise ValueError(
f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'"
)
prefill_urls.append((url, bootstrap_port))
return prefill_urls
@staticmethod
def _parse_decode_urls(decode_list):
"""Parse decode URLs from --decode arguments.
Format: --decode URL
Example: --decode http://decode1:8081 --decode http://decode2:8081
"""
if not decode_list:
return []
# decode_list is a list of single-element lists due to nargs=1
return [url[0] for url in decode_list]
def policy_from_str(policy_str: str) -> PolicyType: def policy_from_str(policy_str: str) -> PolicyType:
"""Convert policy string to PolicyType enum.""" """Convert policy string to PolicyType enum."""
...@@ -254,6 +330,7 @@ def policy_from_str(policy_str: str) -> PolicyType: ...@@ -254,6 +330,7 @@ def policy_from_str(policy_str: str) -> PolicyType:
"random": PolicyType.Random, "random": PolicyType.Random,
"round_robin": PolicyType.RoundRobin, "round_robin": PolicyType.RoundRobin,
"cache_aware": PolicyType.CacheAware, "cache_aware": PolicyType.CacheAware,
"power_of_two": PolicyType.PowerOfTwo,
} }
return policy_map[policy_str] return policy_map[policy_str]
...@@ -277,8 +354,19 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: ...@@ -277,8 +354,19 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
else: else:
router_args = args router_args = args
# Validate configuration based on mode
if router_args.pd_disaggregated:
# Validate PD configuration
if not router_args.prefill_urls:
raise ValueError("PD disaggregated mode requires --prefill")
if not router_args.decode_urls:
raise ValueError("PD disaggregated mode requires --decode")
# Create router with unified constructor
router = Router( router = Router(
worker_urls=router_args.worker_urls, worker_urls=(
router_args.worker_urls if not router_args.pd_disaggregated else []
),
host=router_args.host, host=router_args.host,
port=router_args.port, port=router_args.port,
policy=policy_from_str(router_args.policy), policy=policy_from_str(router_args.policy),
...@@ -298,6 +386,13 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: ...@@ -298,6 +386,13 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
service_discovery_namespace=router_args.service_discovery_namespace, service_discovery_namespace=router_args.service_discovery_namespace,
prometheus_port=router_args.prometheus_port, prometheus_port=router_args.prometheus_port,
prometheus_host=router_args.prometheus_host, prometheus_host=router_args.prometheus_host,
pd_disaggregated=router_args.pd_disaggregated,
prefill_urls=(
router_args.prefill_urls if router_args.pd_disaggregated else None
),
decode_urls=(
router_args.decode_urls if router_args.pd_disaggregated else None
),
) )
router.start() router.start()
...@@ -326,8 +421,14 @@ This launcher enables starting a router with individual worker instances. It is ...@@ -326,8 +421,14 @@ This launcher enables starting a router with individual worker instances. It is
multi-node setups or when you want to start workers and router separately. multi-node setups or when you want to start workers and router separately.
Examples: Examples:
# Regular mode
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 --cache-threshold 0.7 --balance-abs-threshold 64 --balance-rel-threshold 1.2
# PD disaggregated mode
python -m sglang_router.launch_router --pd-disaggregated \\
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
--decode http://decode1:8001 --decode http://decode2:8001 \\
--policy cache_aware
""", """,
formatter_class=CustomHelpFormatter, formatter_class=CustomHelpFormatter,
......
...@@ -15,6 +15,7 @@ class Router: ...@@ -15,6 +15,7 @@ class Router:
- PolicyType.Random: Randomly select workers - PolicyType.Random: Randomly select workers
- PolicyType.RoundRobin: Distribute requests in round-robin fashion - PolicyType.RoundRobin: Distribute requests in round-robin fashion
- PolicyType.CacheAware: Distribute requests based on cache state and load balance - PolicyType.CacheAware: Distribute requests based on cache state and load balance
- PolicyType.PowerOfTwo: Select best of two random workers based on load (PD mode only)
host: Host address to bind the router server. Default: '127.0.0.1' host: Host address to bind the router server. Default: '127.0.0.1'
port: Port number to bind the router server. Default: 3001 port: Port number to bind the router server. Default: 3001
worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300 worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300
...@@ -28,7 +29,7 @@ class Router: ...@@ -28,7 +29,7 @@ class Router:
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001 AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
routing. Default: 60 routing. Default: 60
max_payload_size: Maximum payload size in bytes. Default: 4MB max_payload_size: Maximum payload size in bytes. Default: 256MB
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24 max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
verbose: Enable verbose logging. Default: False verbose: Enable verbose logging. Default: False
log_dir: Directory to store log files. If None, logs are only output to console. Default: None log_dir: Directory to store log files. If None, logs are only output to console. Default: None
...@@ -42,6 +43,9 @@ class Router: ...@@ -42,6 +43,9 @@ class Router:
watches pods across all namespaces (requires cluster-wide permissions). Default: None watches pods across all namespaces (requires cluster-wide permissions). Default: None
prometheus_port: Port to expose Prometheus metrics. Default: None prometheus_port: Port to expose Prometheus metrics. Default: None
prometheus_host: Host address to bind the Prometheus metrics server. Default: None prometheus_host: Host address to bind the Prometheus metrics server. Default: None
pd_disaggregated: Enable PD (Prefill-Decode) disaggregated mode. Default: False
prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
decode_urls: List of URLs for decode servers (PD mode only)
""" """
def __init__( def __init__(
...@@ -57,7 +61,7 @@ class Router: ...@@ -57,7 +61,7 @@ class Router:
balance_rel_threshold: float = 1.0001, balance_rel_threshold: float = 1.0001,
eviction_interval_secs: int = 60, eviction_interval_secs: int = 60,
max_tree_size: int = 2**24, max_tree_size: int = 2**24,
max_payload_size: int = 4 * 1024 * 1024, # 4MB max_payload_size: int = 256 * 1024 * 1024, # 256MB
verbose: bool = False, verbose: bool = False,
log_dir: Optional[str] = None, log_dir: Optional[str] = None,
service_discovery: bool = False, service_discovery: bool = False,
...@@ -66,6 +70,9 @@ class Router: ...@@ -66,6 +70,9 @@ class Router:
service_discovery_namespace: Optional[str] = None, service_discovery_namespace: Optional[str] = None,
prometheus_port: Optional[int] = None, prometheus_port: Optional[int] = None,
prometheus_host: Optional[str] = None, prometheus_host: Optional[str] = None,
pd_disaggregated: bool = False,
prefill_urls: Optional[List[tuple]] = None,
decode_urls: Optional[List[str]] = None,
): ):
if selector is None: if selector is None:
selector = {} selector = {}
...@@ -91,6 +98,9 @@ class Router: ...@@ -91,6 +98,9 @@ class Router:
service_discovery_namespace=service_discovery_namespace, service_discovery_namespace=service_discovery_namespace,
prometheus_port=prometheus_port, prometheus_port=prometheus_port,
prometheus_host=prometheus_host, prometheus_host=prometheus_host,
pd_disaggregated=pd_disaggregated,
prefill_urls=prefill_urls,
decode_urls=decode_urls,
) )
def start(self) -> None: def start(self) -> None:
......
...@@ -35,13 +35,21 @@ class TestLaunchRouter(unittest.TestCase): ...@@ -35,13 +35,21 @@ class TestLaunchRouter(unittest.TestCase):
balance_rel_threshold=1.0001, balance_rel_threshold=1.0001,
eviction_interval=60, eviction_interval=60,
max_tree_size=2**24, max_tree_size=2**24,
max_payload_size=4 * 1024 * 1024, # 4MB max_payload_size=256 * 1024 * 1024, # 256MB
verbose=False, verbose=False,
log_dir=None, log_dir=None,
service_discovery=False, service_discovery=False,
selector=None, selector=None,
service_discovery_port=80, service_discovery_port=80,
service_discovery_namespace=None, service_discovery_namespace=None,
prometheus_port=None,
prometheus_host=None,
# PD-specific attributes
pd_disaggregated=False,
prefill=None,
decode=None,
# Keep worker_urls for regular mode
worker_urls=[],
) )
def create_router_args(self, **kwargs): def create_router_args(self, **kwargs):
...@@ -81,7 +89,7 @@ class TestLaunchRouter(unittest.TestCase): ...@@ -81,7 +89,7 @@ class TestLaunchRouter(unittest.TestCase):
def test_launch_router_with_empty_worker_urls(self): def test_launch_router_with_empty_worker_urls(self):
args = self.create_router_args(worker_urls=[]) args = self.create_router_args(worker_urls=[])
self.run_router_process(args) self.run_router_process(args) # Expected error
def test_launch_router_with_service_discovery(self): def test_launch_router_with_service_discovery(self):
# Test router startup with service discovery enabled but no selectors # Test router startup with service discovery enabled but no selectors
...@@ -100,6 +108,112 @@ class TestLaunchRouter(unittest.TestCase): ...@@ -100,6 +108,112 @@ class TestLaunchRouter(unittest.TestCase):
) )
self.run_router_process(args) self.run_router_process(args)
def test_launch_router_pd_mode_basic(self):
"""Test basic PD router functionality without actually starting servers."""
# This test just verifies the PD router can be created and configured
# without actually starting it (which would require real prefill/decode servers)
from sglang_router import Router
from sglang_router.launch_router import RouterArgs
from sglang_router_rs import PolicyType
# Test RouterArgs parsing for PD mode
# Simulate the parsed args structure from argparse with action="append"
args = self.create_router_args(
pd_disaggregated=True,
policy="power_of_two", # PowerOfTwo is only valid in PD mode
prefill=[
["http://prefill1:8080", "9000"],
["http://prefill2:8080", "none"],
],
decode=[
["http://decode1:8081"],
["http://decode2:8081"],
],
worker_urls=[], # Empty for PD mode
)
router_args = RouterArgs.from_cli_args(args)
self.assertTrue(router_args.pd_disaggregated)
self.assertEqual(router_args.policy, "power_of_two")
self.assertEqual(len(router_args.prefill_urls), 2)
self.assertEqual(len(router_args.decode_urls), 2)
# Verify the parsed URLs and bootstrap ports
self.assertEqual(router_args.prefill_urls[0], ("http://prefill1:8080", 9000))
self.assertEqual(router_args.prefill_urls[1], ("http://prefill2:8080", None))
self.assertEqual(router_args.decode_urls[0], "http://decode1:8081")
self.assertEqual(router_args.decode_urls[1], "http://decode2:8081")
# Test Router creation in PD mode
router = Router(
worker_urls=[], # Empty for PD mode
pd_disaggregated=True,
prefill_urls=[
("http://prefill1:8080", 9000),
("http://prefill2:8080", None),
],
decode_urls=["http://decode1:8081", "http://decode2:8081"],
policy=PolicyType.CacheAware,
host="127.0.0.1",
port=3001,
)
self.assertIsNotNone(router)
def test_policy_validation(self):
"""Test that policy validation works correctly for PD and regular modes."""
from sglang_router.launch_router import RouterArgs, launch_router
# Test 1: PowerOfTwo is only valid in PD mode
args = self.create_router_args(
pd_disaggregated=False,
policy="power_of_two",
worker_urls=["http://localhost:8000"],
)
# Should raise error
with self.assertRaises(ValueError) as cm:
launch_router(args)
self.assertIn(
"PowerOfTwo policy is only supported in PD disaggregated mode",
str(cm.exception),
)
# Test 2: RoundRobin is not valid in PD mode
args = self.create_router_args(
pd_disaggregated=True,
policy="round_robin",
prefill=[["http://prefill1:8080", "9000"]],
decode=[["http://decode1:8081"]],
worker_urls=[],
)
# Should raise error
with self.assertRaises(ValueError) as cm:
launch_router(args)
self.assertIn(
"RoundRobin policy is not supported in PD disaggregated mode",
str(cm.exception),
)
# Test 3: Valid combinations should not raise errors
# Regular mode with RoundRobin
args = self.create_router_args(
pd_disaggregated=False,
policy="round_robin",
worker_urls=["http://localhost:8000"],
)
# This should not raise (though it may fail to connect)
# PD mode with PowerOfTwo
args = self.create_router_args(
pd_disaggregated=True,
policy="power_of_two",
prefill=[["http://prefill1:8080", "9000"]],
decode=[["http://decode1:8081"]],
worker_urls=[],
)
# This should not raise (though it may fail to connect)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
use pyo3::prelude::*; use pyo3::prelude::*;
pub mod logging; pub mod logging;
use std::collections::HashMap; use std::collections::HashMap;
pub mod openai_api_types;
pub mod pd_router;
pub mod pd_types;
pub mod prometheus; pub mod prometheus;
pub mod request_adapter;
pub mod router; pub mod router;
pub mod server; pub mod server;
pub mod service_discovery; pub mod service_discovery;
...@@ -14,6 +18,7 @@ pub enum PolicyType { ...@@ -14,6 +18,7 @@ pub enum PolicyType {
Random, Random,
RoundRobin, RoundRobin,
CacheAware, CacheAware,
PowerOfTwo, // Moved from PD-specific, now shared
} }
#[pyclass] #[pyclass]
...@@ -39,6 +44,12 @@ struct Router { ...@@ -39,6 +44,12 @@ struct Router {
service_discovery_namespace: Option<String>, service_discovery_namespace: Option<String>,
prometheus_port: Option<u16>, prometheus_port: Option<u16>,
prometheus_host: Option<String>, prometheus_host: Option<String>,
request_timeout_secs: u64,
// PD mode flag
pd_disaggregated: bool,
// PD-specific fields (only used when pd_disaggregated is true)
prefill_urls: Option<Vec<(String, Option<u16>)>>,
decode_urls: Option<Vec<String>>,
} }
#[pymethods] #[pymethods]
...@@ -56,7 +67,7 @@ impl Router { ...@@ -56,7 +67,7 @@ impl Router {
balance_rel_threshold = 1.0001, balance_rel_threshold = 1.0001,
eviction_interval_secs = 60, eviction_interval_secs = 60,
max_tree_size = 2usize.pow(24), max_tree_size = 2usize.pow(24),
max_payload_size = 4 * 1024 * 1024, max_payload_size = 256 * 1024 * 1024, // 256MB default for large batches
verbose = false, verbose = false,
log_dir = None, log_dir = None,
service_discovery = false, service_discovery = false,
...@@ -64,7 +75,11 @@ impl Router { ...@@ -64,7 +75,11 @@ impl Router {
service_discovery_port = 80, service_discovery_port = 80,
service_discovery_namespace = None, service_discovery_namespace = None,
prometheus_port = None, prometheus_port = None,
prometheus_host = None prometheus_host = None,
request_timeout_secs = 600, // Add configurable request timeout
pd_disaggregated = false, // New flag for PD mode
prefill_urls = None,
decode_urls = None
))] ))]
fn new( fn new(
worker_urls: Vec<String>, worker_urls: Vec<String>,
...@@ -87,6 +102,10 @@ impl Router { ...@@ -87,6 +102,10 @@ impl Router {
service_discovery_namespace: Option<String>, service_discovery_namespace: Option<String>,
prometheus_port: Option<u16>, prometheus_port: Option<u16>,
prometheus_host: Option<String>, prometheus_host: Option<String>,
request_timeout_secs: u64,
pd_disaggregated: bool,
prefill_urls: Option<Vec<(String, Option<u16>)>>,
decode_urls: Option<Vec<String>>,
) -> PyResult<Self> { ) -> PyResult<Self> {
Ok(Router { Ok(Router {
host, host,
...@@ -109,28 +128,75 @@ impl Router { ...@@ -109,28 +128,75 @@ impl Router {
service_discovery_namespace, service_discovery_namespace,
prometheus_port, prometheus_port,
prometheus_host, prometheus_host,
request_timeout_secs,
pd_disaggregated,
prefill_urls,
decode_urls,
}) })
} }
fn start(&self) -> PyResult<()> { fn start(&self) -> PyResult<()> {
let policy_config = match &self.policy { let policy_config = if self.pd_disaggregated {
PolicyType::Random => router::PolicyConfig::RandomConfig { // PD mode - map PolicyType to PDSelectionPolicy
timeout_secs: self.worker_startup_timeout_secs, let pd_selection_policy = match &self.policy {
interval_secs: self.worker_startup_check_interval, PolicyType::Random => pd_types::PDSelectionPolicy::Random,
}, PolicyType::PowerOfTwo => pd_types::PDSelectionPolicy::PowerOfTwo,
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig { PolicyType::CacheAware => pd_types::PDSelectionPolicy::CacheAware {
timeout_secs: self.worker_startup_timeout_secs, cache_threshold: self.cache_threshold,
interval_secs: self.worker_startup_check_interval, balance_abs_threshold: self.balance_abs_threshold,
}, balance_rel_threshold: self.balance_rel_threshold,
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig { },
PolicyType::RoundRobin => {
return Err(pyo3::exceptions::PyValueError::new_err(
"RoundRobin policy is not supported in PD disaggregated mode",
));
}
};
let prefill_urls = self.prefill_urls.as_ref().ok_or_else(|| {
pyo3::exceptions::PyValueError::new_err(
"PD disaggregated mode requires prefill_urls",
)
})?;
let decode_urls = self.decode_urls.as_ref().ok_or_else(|| {
pyo3::exceptions::PyValueError::new_err(
"PD disaggregated mode requires decode_urls",
)
})?;
router::PolicyConfig::PrefillDecodeConfig {
selection_policy: pd_selection_policy,
prefill_urls: prefill_urls.clone(),
decode_urls: decode_urls.clone(),
timeout_secs: self.worker_startup_timeout_secs, timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval, interval_secs: self.worker_startup_check_interval,
cache_threshold: self.cache_threshold, }
balance_abs_threshold: self.balance_abs_threshold, } else {
balance_rel_threshold: self.balance_rel_threshold, // Regular mode
eviction_interval_secs: self.eviction_interval_secs, match &self.policy {
max_tree_size: self.max_tree_size, PolicyType::Random => router::PolicyConfig::RandomConfig {
}, timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval,
},
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig {
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval,
},
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval,
cache_threshold: self.cache_threshold,
balance_abs_threshold: self.balance_abs_threshold,
balance_rel_threshold: self.balance_rel_threshold,
eviction_interval_secs: self.eviction_interval_secs,
max_tree_size: self.max_tree_size,
},
PolicyType::PowerOfTwo => {
return Err(pyo3::exceptions::PyValueError::new_err(
"PowerOfTwo policy is only supported in PD disaggregated mode",
));
}
}
}; };
// Create service discovery config if enabled // Create service discovery config if enabled
...@@ -166,6 +232,7 @@ impl Router { ...@@ -166,6 +232,7 @@ impl Router {
log_dir: self.log_dir.clone(), log_dir: self.log_dir.clone(),
service_discovery_config, service_discovery_config,
prometheus_config, prometheus_config,
request_timeout_secs: self.request_timeout_secs,
}) })
.await .await
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
......
// OpenAI-compatible API types for text generation
// Based on OpenAI's API specification: https://platform.openai.com/docs/api-reference
// Reference: Azure OpenAI API documentation which follows OpenAI's specification
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
/// Common trait for all generation requests
pub trait GenerationRequest: Send + Sync {
/// Check if the request is for streaming
fn is_stream(&self) -> bool;
/// Get the model name if specified
fn get_model(&self) -> Option<&str>;
/// Extract text content for routing decisions
fn extract_text_for_routing(&self) -> String;
}
// ============= Completions API (v1/completions) - DEPRECATED but still supported =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionRequest {
/// ID of the model to use (required for OpenAI, optional for some implementations, such as SGLang)
pub model: String,
/// The prompt(s) to generate completions for
pub prompt: StringOrArray,
/// The suffix that comes after a completion of inserted text
#[serde(skip_serializing_if = "Option::is_none")]
pub suffix: Option<String>,
/// The maximum number of tokens to generate
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
/// What sampling temperature to use, between 0 and 2
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
/// An alternative to sampling with temperature (nucleus sampling)
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
/// How many completions to generate for each prompt
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
/// Whether to stream back partial progress
#[serde(default)]
pub stream: bool,
/// Include the log probabilities on the logprobs most likely tokens
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<u32>,
/// Echo back the prompt in addition to the completion
#[serde(default)]
pub echo: bool,
/// Up to 4 sequences where the API will stop generating further tokens
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<StringOrArray>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
/// Generates best_of completions server-side and returns the "best"
#[serde(skip_serializing_if = "Option::is_none")]
pub best_of: Option<u32>,
/// Modify the likelihood of specified tokens appearing in the completion
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<String, f32>>,
/// A unique identifier representing your end-user
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
/// If specified, our system will make a best effort to sample deterministically
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
}
impl GenerationRequest for CompletionRequest {
fn is_stream(&self) -> bool {
self.stream
}
fn get_model(&self) -> Option<&str> {
Some(&self.model)
}
fn extract_text_for_routing(&self) -> String {
match &self.prompt {
StringOrArray::String(s) => s.clone(),
StringOrArray::Array(v) => v.join(" "),
}
}
}
// ============= Chat Completions API (v1/chat/completions) =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatCompletionRequest {
/// ID of the model to use
pub model: String,
/// A list of messages comprising the conversation so far
pub messages: Vec<ChatMessage>,
/// What sampling temperature to use, between 0 and 2
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
/// An alternative to sampling with temperature
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
/// How many chat completion choices to generate for each input message
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
/// If set, partial message deltas will be sent
#[serde(default)]
pub stream: bool,
/// Up to 4 sequences where the API will stop generating further tokens
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<StringOrArray>,
/// The maximum number of tokens to generate
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
/// An upper bound for the number of tokens that can be generated for a completion
#[serde(skip_serializing_if = "Option::is_none")]
pub max_completion_tokens: Option<u32>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
/// Modify the likelihood of specified tokens appearing in the completion
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<String, i32>>,
/// A unique identifier representing your end-user
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
/// If specified, our system will make a best effort to sample deterministically
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
/// Whether to return log probabilities of the output tokens
#[serde(default)]
pub logprobs: bool,
/// An integer between 0 and 20 specifying the number of most likely tokens to return
#[serde(skip_serializing_if = "Option::is_none")]
pub top_logprobs: Option<u32>,
/// An object specifying the format that the model must output
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
/// A list of tools the model may call
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
/// Controls which (if any) tool is called by the model
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
/// Whether to enable parallel function calling during tool use
#[serde(skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>,
/// Deprecated: use tools instead
#[serde(skip_serializing_if = "Option::is_none")]
pub functions: Option<Vec<Function>>,
/// Deprecated: use tool_choice instead
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCall>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum ChatMessage {
System {
role: String, // "system"
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
User {
role: String, // "user"
content: UserMessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
Assistant {
role: String, // "assistant"
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
function_call: Option<FunctionCallResponse>,
},
Tool {
role: String, // "tool"
content: String,
tool_call_id: String,
},
Function {
role: String, // "function"
content: String,
name: String,
},
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum UserMessageContent {
Text(String),
Parts(Vec<ContentPart>),
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
pub enum ContentPart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrl },
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ImageUrl {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>, // "auto", "low", or "high"
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
pub enum ResponseFormat {
#[serde(rename = "text")]
Text,
#[serde(rename = "json_object")]
JsonObject,
#[serde(rename = "json_schema")]
JsonSchema { json_schema: JsonSchemaFormat },
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct JsonSchemaFormat {
pub name: String,
pub schema: Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub strict: Option<bool>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String, // "function"
pub function: Function,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Function {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub parameters: Value, // JSON Schema
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum ToolChoice {
None,
Auto,
Required,
Function {
#[serde(rename = "type")]
tool_type: String, // "function"
function: FunctionChoice,
},
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FunctionChoice {
pub name: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub tool_type: String, // "function"
pub function: FunctionCallResponse,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum FunctionCall {
None,
Auto,
Function { name: String },
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FunctionCallResponse {
pub name: String,
pub arguments: String, // JSON string
}
impl GenerationRequest for ChatCompletionRequest {
fn is_stream(&self) -> bool {
self.stream
}
fn get_model(&self) -> Option<&str> {
Some(&self.model)
}
fn extract_text_for_routing(&self) -> String {
// Extract text from messages for routing decisions
self.messages
.iter()
.filter_map(|msg| match msg {
ChatMessage::System { content, .. } => Some(content.clone()),
ChatMessage::User { content, .. } => match content {
UserMessageContent::Text(text) => Some(text.clone()),
UserMessageContent::Parts(parts) => {
let texts: Vec<String> = parts
.iter()
.filter_map(|part| match part {
ContentPart::Text { text } => Some(text.clone()),
_ => None,
})
.collect();
Some(texts.join(" "))
}
},
ChatMessage::Assistant { content, .. } => content.clone(),
ChatMessage::Tool { content, .. } => Some(content.clone()),
ChatMessage::Function { content, .. } => Some(content.clone()),
})
.collect::<Vec<String>>()
.join(" ")
}
}
// ============= Generate API (/generate) =============
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct GenerateRequest {
/// The prompt to generate from (OpenAI style)
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt: Option<StringOrArray>,
/// Text input - SGLang native format
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
/// Input IDs for tokenized input
#[serde(skip_serializing_if = "Option::is_none")]
pub input_ids: Option<InputIds>,
/// Generation parameters
#[serde(default, skip_serializing_if = "Option::is_none")]
pub parameters: Option<GenerateParameters>,
/// Sampling parameters (sglang style)
#[serde(skip_serializing_if = "Option::is_none")]
pub sampling_params: Option<SamplingParams>,
/// Whether to stream the response
#[serde(default)]
pub stream: bool,
/// Whether to return logprobs
#[serde(default)]
pub return_logprob: bool,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum InputIds {
Single(Vec<i32>),
Batch(Vec<Vec<i32>>),
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct GenerateParameters {
#[serde(skip_serializing_if = "Option::is_none")]
pub best_of: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub decoder_input_details: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub do_sample: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_new_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub repetition_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub return_full_text: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub truncate: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub typical_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub watermark: Option<bool>,
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct SamplingParams {
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_new_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub repetition_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<StringOrArray>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ignore_eos: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub skip_special_tokens: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub json_schema: Option<String>,
}
impl GenerationRequest for GenerateRequest {
fn is_stream(&self) -> bool {
self.stream
}
fn get_model(&self) -> Option<&str> {
// Generate requests typically don't have a model field
None
}
fn extract_text_for_routing(&self) -> String {
// Check fields in priority order: text, prompt, inputs
if let Some(ref text) = self.text {
return text.clone();
}
if let Some(ref prompt) = self.prompt {
return match prompt {
StringOrArray::String(s) => s.clone(),
StringOrArray::Array(v) => v.join(" "),
};
}
if let Some(ref input_ids) = self.input_ids {
return match input_ids {
InputIds::Single(ids) => ids
.iter()
.map(|&id| id.to_string())
.collect::<Vec<String>>()
.join(" "),
InputIds::Batch(batches) => batches
.iter()
.flat_map(|batch| batch.iter().map(|&id| id.to_string()))
.collect::<Vec<String>>()
.join(" "),
};
}
// No text input found
String::new()
}
}
// ============= Helper Types =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum StringOrArray {
String(String),
Array(Vec<String>),
}
// ============= Response Types =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionResponse {
pub id: String,
pub object: String, // "text_completion"
pub created: u64,
pub model: String,
pub choices: Vec<CompletionChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionChoice {
pub text: String,
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<LogProbs>,
pub finish_reason: Option<String>, // "stop", "length", "content_filter", etc.
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LogProbs {
pub tokens: Vec<String>,
pub token_logprobs: Vec<Option<f32>>,
pub top_logprobs: Vec<Option<HashMap<String, f32>>>,
pub text_offset: Vec<u32>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String, // "chat.completion"
pub created: u64,
pub model: String,
pub choices: Vec<ChatChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatChoice {
pub index: u32,
pub message: ChatMessage,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChatLogProbs>,
pub finish_reason: Option<String>, // "stop", "length", "tool_calls", "content_filter", "function_call"
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatLogProbs {
pub content: Option<Vec<ChatLogProbsContent>>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatLogProbsContent {
pub token: String,
pub logprob: f32,
pub bytes: Option<Vec<u8>>,
pub top_logprobs: Vec<TopLogProb>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct TopLogProb {
pub token: String,
pub logprob: f32,
pub bytes: Option<Vec<u8>>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub completion_tokens_details: Option<CompletionTokensDetails>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionTokensDetails {
pub reasoning_tokens: Option<u32>,
}
// ============= Streaming Response Types =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionStreamResponse {
pub id: String,
pub object: String, // "text_completion"
pub created: u64,
pub choices: Vec<CompletionStreamChoice>,
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionStreamChoice {
pub text: String,
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<LogProbs>,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatCompletionStreamResponse {
pub id: String,
pub object: String, // "chat.completion.chunk"
pub created: u64,
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
pub choices: Vec<ChatStreamChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatStreamChoice {
pub index: u32,
pub delta: ChatMessageDelta,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChatLogProbs>,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatMessageDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallDelta>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCallDelta>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ToolCallDelta {
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "type")]
pub tool_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function: Option<FunctionCallDelta>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FunctionCallDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
}
// ============= Error Response Types =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ErrorResponse {
pub error: ErrorDetail,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ErrorDetail {
pub message: String,
#[serde(rename = "type")]
pub error_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub param: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub code: Option<String>,
}
// PD (Prefill-Decode) Router Implementation
// This module handles routing for disaggregated prefill-decode systems
use crate::pd_types::{Bootstrap, ChatReqInput, EngineInfo, GenerateReqInput, PDSelectionPolicy};
use crate::tree::Tree;
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse};
use futures_util::{StreamExt, TryStreamExt};
use metrics::{counter, histogram};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant};
use tracing::{debug, error, info, warn};
use uuid::Uuid;
// Removed over-engineered ProxyResponse - using HttpResponse directly
#[derive(Debug)]
pub struct PDRouter {
pub prefill_workers: Arc<RwLock<Vec<EngineInfo>>>,
pub decode_workers: Arc<RwLock<Vec<EngineInfo>>>,
pub selection_policy: PDSelectionPolicy,
pub load_tracking: Arc<dashmap::DashMap<String, Arc<AtomicUsize>>>,
pub prefill_tree: Option<Arc<Mutex<Tree>>>,
pub timeout_secs: u64,
pub interval_secs: u64,
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
pub http_client: reqwest::Client,
}
// RAII guard for load tracking to ensure cleanup even on panic
struct LoadGuard<'a> {
tracking: &'a Arc<dashmap::DashMap<String, Arc<AtomicUsize>>>,
urls: Vec<String>,
}
impl<'a> LoadGuard<'a> {
fn new(
tracking: &'a Arc<dashmap::DashMap<String, Arc<AtomicUsize>>>,
urls: Vec<String>,
) -> Self {
// Increment counters
for url in &urls {
let counter = tracking
.entry(url.clone())
.or_insert_with(|| Arc::new(AtomicUsize::new(0)));
counter.fetch_add(1, Ordering::Relaxed);
}
LoadGuard { tracking, urls }
}
}
impl Drop for LoadGuard<'_> {
fn drop(&mut self) {
// Guaranteed cleanup even on panic
for url in &self.urls {
if let Some(counter) = self.tracking.get(url) {
counter.fetch_sub(1, Ordering::Relaxed);
}
}
}
}
impl PDRouter {
// TODO: Add methods for dynamic worker management to support /register endpoint:
// - add_prefill_server(url: String, bootstrap_port: Option<u16>)
// - add_decode_server(url: String)
// - remove_prefill_server(url: &str)
// - remove_decode_server(url: &str)
// These methods will be used when service discovery is implemented for PD mode
pub fn new(
prefill_urls: Vec<(String, Option<u16>)>,
decode_urls: Vec<String>,
selection_policy: PDSelectionPolicy,
timeout_secs: u64,
interval_secs: u64,
) -> Result<Self, String> {
// Convert URLs to EngineInfo
let prefill_workers: Vec<EngineInfo> = prefill_urls
.into_iter()
.map(|(url, port)| EngineInfo::new_prefill(url, port))
.collect();
let decode_workers: Vec<EngineInfo> = decode_urls
.into_iter()
.map(EngineInfo::new_decode)
.collect();
// Wait for PD workers to be healthy
let all_urls: Vec<String> = prefill_workers
.iter()
.chain(decode_workers.iter())
.map(|engine| engine.url.clone())
.collect();
crate::router::Router::wait_for_healthy_workers(&all_urls, timeout_secs, interval_secs)?;
// Initialize load tracking with atomic counters
let load_tracking = Arc::new(dashmap::DashMap::new());
for engine in &prefill_workers {
load_tracking.insert(engine.url.clone(), Arc::new(AtomicUsize::new(0)));
}
for engine in &decode_workers {
load_tracking.insert(engine.url.clone(), Arc::new(AtomicUsize::new(0)));
}
// Initialize cache-aware components if needed
let prefill_tree = match &selection_policy {
PDSelectionPolicy::CacheAware { .. } => {
let tree = Arc::new(Mutex::new(Tree::new()));
// Initialize tree with prefill workers
for engine in &prefill_workers {
tree.lock().unwrap().insert("", &engine.url);
}
Some(tree)
}
_ => None,
};
// Set up background load monitoring for power-of-two selection
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
let worker_loads = Arc::new(rx);
// Create a shared HTTP client for all operations
let http_client = reqwest::Client::builder()
.timeout(Duration::from_secs(timeout_secs))
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
let load_monitor_handle = if matches!(selection_policy, PDSelectionPolicy::PowerOfTwo) {
let monitor_urls = all_urls.clone();
let monitor_interval = interval_secs;
let monitor_client = http_client.clone();
Some(Arc::new(tokio::spawn(async move {
Self::monitor_worker_loads_with_client(
monitor_urls,
tx,
monitor_interval,
monitor_client,
)
.await;
})))
} else {
None
};
Ok(PDRouter {
prefill_workers: Arc::new(RwLock::new(prefill_workers)),
decode_workers: Arc::new(RwLock::new(decode_workers)),
selection_policy,
load_tracking,
prefill_tree,
timeout_secs,
interval_secs,
worker_loads,
load_monitor_handle,
http_client,
})
}
// Route a typed generate request
pub async fn route_generate(
&self,
client: &reqwest::Client,
req: &HttpRequest,
mut typed_req: GenerateReqInput,
route: &str,
) -> HttpResponse {
let start = Instant::now();
let _request_id = Uuid::new_v4();
// Get stream flag and return_logprob flag before moving the request
let is_stream = typed_req.is_stream();
let return_logprob = typed_req
.other
.get("return_logprob")
.and_then(|v| v.as_bool())
.unwrap_or(false);
// Select servers
let (prefill, decode) = match self.select_pd_pair(client).await {
Ok(pair) => pair,
Err(e) => {
error!("Failed to select PD pair: {}", e);
counter!("sgl_router_pd_errors_total", "error" => "server_selection").increment(1);
return HttpResponse::ServiceUnavailable()
.body(format!("No available servers: {}", e));
}
};
// Log routing decision
info!(
"PD routing: {} -> prefill={}, decode={}",
route, prefill.url, decode.url
);
// Add bootstrap info using the trait method
if let Err(e) = typed_req.add_bootstrap_info(&prefill) {
error!("Failed to add bootstrap info: {}", e);
counter!("sgl_router_pd_errors_total", "error" => "bootstrap_injection").increment(1);
return HttpResponse::InternalServerError()
.body(format!("Bootstrap injection failed: {}", e));
}
// Convert to JSON after bootstrap injection
let json_with_bootstrap = match serde_json::to_value(&typed_req) {
Ok(json) => json,
Err(e) => {
error!("Failed to serialize request: {}", e);
return HttpResponse::InternalServerError().body("Failed to serialize request");
}
};
// Execute dual dispatch
self.execute_dual_dispatch(
client,
req,
json_with_bootstrap,
route,
&prefill,
&decode,
is_stream,
return_logprob,
start,
)
.await
}
// Route a typed chat request
pub async fn route_chat(
&self,
client: &reqwest::Client,
req: &HttpRequest,
mut typed_req: ChatReqInput,
route: &str,
) -> HttpResponse {
let start = Instant::now();
// Get stream flag and return_logprob flag before moving the request
let is_stream = typed_req.is_stream();
let return_logprob = typed_req
.other
.get("return_logprob")
.and_then(|v| v.as_bool())
.unwrap_or(false);
// Select servers
let (prefill, decode) = match self.select_pd_pair(client).await {
Ok(pair) => pair,
Err(e) => {
error!("Failed to select PD pair: {}", e);
counter!("sgl_router_pd_errors_total", "error" => "server_selection").increment(1);
return HttpResponse::ServiceUnavailable()
.body(format!("No available servers: {}", e));
}
};
// Log routing decision
info!(
"PD routing: {} -> prefill={}, decode={}",
route, prefill.url, decode.url
);
// Add bootstrap info using the trait method
if let Err(e) = typed_req.add_bootstrap_info(&prefill) {
error!("Failed to add bootstrap info: {}", e);
counter!("sgl_router_pd_errors_total", "error" => "bootstrap_injection").increment(1);
return HttpResponse::InternalServerError()
.body(format!("Bootstrap injection failed: {}", e));
}
// Convert to JSON after bootstrap injection
let json_with_bootstrap = match serde_json::to_value(&typed_req) {
Ok(json) => json,
Err(e) => {
error!("Failed to serialize request: {}", e);
return HttpResponse::InternalServerError().body("Failed to serialize request");
}
};
// Execute dual dispatch
self.execute_dual_dispatch(
client,
req,
json_with_bootstrap,
route,
&prefill,
&decode,
is_stream,
return_logprob,
start,
)
.await
}
// Execute the dual dispatch to prefill and decode servers
#[allow(clippy::too_many_arguments)]
async fn execute_dual_dispatch(
&self,
client: &reqwest::Client,
req: &HttpRequest,
json_request: serde_json::Value,
route: &str,
prefill: &EngineInfo,
decode: &EngineInfo,
is_stream: bool,
return_logprob: bool,
start_time: Instant,
) -> HttpResponse {
// Update load tracking for both workers
let _guard = LoadGuard::new(
&self.load_tracking,
vec![prefill.url.clone(), decode.url.clone()],
);
// Build requests using .json() method
let mut prefill_request = client.post(prefill.api_path(route)).json(&json_request);
let mut decode_request = client.post(decode.api_path(route)).json(&json_request);
// Copy headers from original request
for (name, value) in crate::router::copy_request_headers(req) {
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" {
prefill_request = prefill_request.header(&name, &value);
decode_request = decode_request.header(&name, &value);
}
}
// Send both requests concurrently
let (prefill_result, decode_result) =
tokio::join!(prefill_request.send(), decode_request.send());
// Update metrics
let duration = start_time.elapsed();
histogram!("sgl_router_pd_request_duration_seconds", "route" => route.to_string())
.record(duration.as_secs_f64());
counter!("sgl_router_pd_requests_total", "route" => route.to_string()).increment(1);
counter!("sgl_router_pd_prefill_requests_total", "worker" => prefill.url.to_string())
.increment(1);
counter!("sgl_router_pd_decode_requests_total", "worker" => decode.url.to_string())
.increment(1);
// Process decode response
match decode_result {
Ok(res) => {
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
if !status.is_success() {
counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url.to_string()).increment(1);
error!(
"Decode server {} returned error status: {}",
decode.url, status
);
// Return the error response from decode server
match res.bytes().await {
Ok(error_body) => {
return HttpResponse::build(status).body(error_body.to_vec());
}
Err(e) => {
return HttpResponse::build(status)
.body(format!("Decode server error: {}", e));
}
}
}
// Log prefill errors for debugging
if let Err(e) = &prefill_result {
error!(
"Prefill server {} failed (non-critical): {}",
prefill.url, e
);
counter!("sgl_router_pd_prefill_errors_total", "worker" => prefill.url.to_string()).increment(1);
}
if is_stream {
// Streaming response
if return_logprob {
// Get prefill logprobs for merging
let prefill_logprobs =
match prefill_result {
Ok(prefill_res) => match prefill_res.bytes().await {
Ok(body) => serde_json::from_slice::<Value>(&body)
.ok()
.and_then(|json| {
json.pointer("/meta_info/input_token_logprobs").cloned()
}),
Err(_) => None,
},
Err(_) => None,
};
// Stream with logprob merging
HttpResponse::build(status)
.insert_header((
CONTENT_TYPE,
HeaderValue::from_static("text/event-stream"),
))
.streaming(res.bytes_stream().map(move |chunk_result| {
match chunk_result {
Ok(chunk) => {
// Try to merge logprobs
if let Ok(merged) = Self::merge_streaming_logprobs(
prefill_logprobs.clone(),
&chunk,
) {
Ok(merged)
} else {
Ok(chunk)
}
}
Err(e) => Err(actix_web::error::ErrorInternalServerError(
format!("Stream error: {}", e),
)),
}
}))
} else {
// No logprob merging needed
HttpResponse::build(status)
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
.streaming({
let decode_url = decode.url.clone();
res.bytes_stream().map_err(move |e| {
error!("Stream error from decode server {}: {}", decode_url, e);
counter!("sgl_router_pd_stream_errors_total", "worker" => decode_url.to_string()).increment(1);
actix_web::error::ErrorInternalServerError(format!("Stream error: {}", e))
})
})
}
} else {
// Non-streaming response
match res.bytes().await {
Ok(decode_body) => {
if return_logprob {
self.merge_logprobs(prefill_result, decode_body, status)
.await
} else {
HttpResponse::build(status).body(decode_body.to_vec())
}
}
Err(e) => {
error!("Failed to read decode response: {}", e);
HttpResponse::InternalServerError().body("Failed to read response")
}
}
}
}
Err(e) => {
error!("Decode request failed: {}", e);
counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url.to_string())
.increment(1);
HttpResponse::BadGateway().body(format!("Decode server error: {}", e))
}
}
}
// Merge logprobs from prefill and decode responses
async fn merge_logprobs(
&self,
prefill_result: Result<reqwest::Response, reqwest::Error>,
decode_body: bytes::Bytes,
status: actix_web::http::StatusCode,
) -> HttpResponse {
match prefill_result {
Ok(prefill_res) => {
match prefill_res.bytes().await {
Ok(prefill_body) => {
match (
serde_json::from_slice::<Value>(&prefill_body),
serde_json::from_slice::<Value>(&decode_body),
) {
(Ok(prefill_json), Ok(mut decode_json)) => {
// Merge input_token_logprobs
if let (Some(prefill_meta), Some(decode_meta)) = (
prefill_json.get("meta_info"),
decode_json.get_mut("meta_info"),
) {
if let (Some(prefill_logprobs), Some(decode_logprobs)) = (
prefill_meta.get("input_token_logprobs"),
decode_meta.get_mut("input_token_logprobs"),
) {
if let (Some(p_arr), Some(d_arr)) = (
prefill_logprobs.as_array(),
decode_logprobs.as_array(),
) {
let mut merged = p_arr.clone();
merged.extend(d_arr.clone());
decode_meta["input_token_logprobs"] =
Value::Array(merged);
}
}
}
HttpResponse::build(status).json(&decode_json)
}
_ => {
warn!("Failed to parse responses for logprob merging");
HttpResponse::build(status).body(decode_body.to_vec())
}
}
}
Err(e) => {
warn!("Failed to read prefill response: {}", e);
HttpResponse::build(status).body(decode_body.to_vec())
}
}
}
Err(_) => HttpResponse::build(status).body(decode_body.to_vec()),
}
}
// Select a pair of prefill and decode servers
async fn select_pd_pair(
&self,
_client: &reqwest::Client,
) -> Result<(EngineInfo, EngineInfo), String> {
// Check we have workers
if self
.prefill_workers
.read()
.map_err(|e| format!("Failed to acquire prefill workers lock: {}", e))?
.is_empty()
{
return Err("No prefill workers available. Please check if prefill servers are configured and healthy.".to_string());
}
if self
.decode_workers
.read()
.map_err(|e| format!("Failed to acquire decode workers lock: {}", e))?
.is_empty()
{
return Err("No decode workers available. Please check if decode servers are configured and healthy.".to_string());
}
match &self.selection_policy {
PDSelectionPolicy::Random => self.select_random(),
PDSelectionPolicy::PowerOfTwo => self.select_power_of_two().await,
PDSelectionPolicy::CacheAware { .. } => {
// TODO: Implement cache-aware selection
self.select_power_of_two().await
}
}
}
fn select_random(&self) -> Result<(EngineInfo, EngineInfo), String> {
let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?;
let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?;
let prefill = prefill_list[rand::random::<usize>() % prefill_list.len()].clone();
let decode = decode_list[rand::random::<usize>() % decode_list.len()].clone();
Ok((prefill, decode))
}
async fn select_power_of_two(&self) -> Result<(EngineInfo, EngineInfo), String> {
let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?;
let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?;
let (p1_idx, p2_idx) = get_two_random_indices(prefill_list.len());
let (d1_idx, d2_idx) = get_two_random_indices(decode_list.len());
let loads = self.worker_loads.borrow();
let p1_load = loads.get(&prefill_list[p1_idx].url).copied().unwrap_or(0);
let p2_load = loads.get(&prefill_list[p2_idx].url).copied().unwrap_or(0);
let d1_load = loads.get(&decode_list[d1_idx].url).copied().unwrap_or(0);
let d2_load = loads.get(&decode_list[d2_idx].url).copied().unwrap_or(0);
info!(
"Power-of-two selection - Prefill: {}={} vs {}={} | Decode: {}={} vs {}={}",
prefill_list[p1_idx].url,
p1_load,
prefill_list[p2_idx].url,
p2_load,
decode_list[d1_idx].url,
d1_load,
decode_list[d2_idx].url,
d2_load
);
let selected_prefill = if p1_load <= p2_load {
prefill_list[p1_idx].clone()
} else {
prefill_list[p2_idx].clone()
};
let selected_decode = if d1_load <= d2_load {
decode_list[d1_idx].clone()
} else {
decode_list[d2_idx].clone()
};
Ok((selected_prefill, selected_decode))
}
// Background task to monitor worker loads with shared client
async fn monitor_worker_loads_with_client(
worker_urls: Vec<String>,
tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
interval_secs: u64,
client: reqwest::Client,
) {
loop {
let mut loads = HashMap::new();
let futures: Vec<_> = worker_urls
.iter()
.map(|url| {
let client = client.clone();
let url = url.clone();
async move {
let load = get_worker_load(&client, &url).await.unwrap_or(0);
(url, load)
}
})
.collect();
let results = futures_util::future::join_all(futures).await;
for (url, load) in results {
loads.insert(url, load);
}
debug!("Worker loads updated: {:?}", loads);
// Check if receiver is still active
if tx.send(loads).is_err() {
info!("Load monitor receiver dropped, shutting down monitor task");
break;
}
tokio::time::sleep(Duration::from_secs(interval_secs)).await;
}
}
// Simple helper to merge logprobs in streaming responses
fn merge_streaming_logprobs(
prefill_logprobs: Option<Value>,
decode_chunk: &[u8],
) -> Result<bytes::Bytes, ()> {
// Skip non-data chunks
let chunk_str = std::str::from_utf8(decode_chunk).map_err(|_| ())?;
if !chunk_str.starts_with("data: ") || chunk_str.contains("[DONE]") {
return Err(());
}
// Parse JSON from chunk
let json_str = chunk_str.trim_start_matches("data: ").trim();
let mut decode_json: Value = serde_json::from_str(json_str).map_err(|_| ())?;
// Merge prefill logprobs if available
if let Some(ref p_logprobs) = prefill_logprobs {
if let Some(meta) = decode_json.get_mut("meta_info") {
if let Some(d_logprobs) = meta.get_mut("input_token_logprobs") {
if let (Some(p_arr), Some(d_arr)) =
(p_logprobs.as_array(), d_logprobs.as_array())
{
let mut merged = p_arr.clone();
merged.extend(d_arr.clone());
*d_logprobs = Value::Array(merged);
}
}
}
}
// Re-serialize
let merged_str = format!(
"data: {}\n\n",
serde_json::to_string(&decode_json).unwrap_or_default()
);
Ok(bytes::Bytes::from(merged_str))
}
}
// Helper functions
fn get_two_random_indices(len: usize) -> (usize, usize) {
if len == 1 {
(0, 0)
} else {
let idx1 = rand::random::<usize>() % len;
let mut idx2 = rand::random::<usize>() % len;
while idx2 == idx1 {
idx2 = rand::random::<usize>() % len;
}
(idx1, idx2)
}
}
async fn get_worker_load(client: &reqwest::Client, worker_url: &str) -> Option<isize> {
match client.get(format!("{}/get_load", worker_url)).send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<Value>(&bytes) {
Ok(data) => data
.get("load")
.and_then(|v| v.as_i64())
.map(|v| v as isize),
Err(e) => {
debug!("Failed to parse load response from {}: {}", worker_url, e);
None
}
},
Err(e) => {
debug!("Failed to read load response from {}: {}", worker_url, e);
None
}
},
Ok(res) => {
debug!(
"Worker {} returned non-success status: {}",
worker_url,
res.status()
);
None
}
Err(e) => {
debug!("Failed to get load from {}: {}", worker_url, e);
None
}
}
}
// PD-specific endpoints
impl PDRouter {
pub async fn health_generate(&self, client: &reqwest::Client) -> HttpResponse {
let mut all_healthy = true;
let mut unhealthy_servers = Vec::new();
// Collect all worker URLs with their types
let mut worker_infos = Vec::new();
for worker in self.prefill_workers.read().unwrap().iter() {
worker_infos.push((worker.url.clone(), "prefill"));
}
for worker in self.decode_workers.read().unwrap().iter() {
worker_infos.push((worker.url.clone(), "decode"));
}
// Create tasks with URL tracking
let tasks: Vec<_> = worker_infos
.iter()
.map(|(url, _)| {
let health_url = format!("{}/health_generate", url);
client.get(&health_url).send()
})
.collect();
let results = futures_util::future::join_all(tasks).await;
for ((url, worker_type), result) in worker_infos.iter().zip(results.into_iter()) {
match result {
Ok(res) if res.status().is_success() => {
debug!("Health check passed for {} server: {}", worker_type, url);
}
Ok(res) => {
all_healthy = false;
let msg = format!(
"{} server {} returned status {}",
worker_type,
url,
res.status()
);
error!("{}", msg);
unhealthy_servers.push(msg);
}
Err(e) => {
all_healthy = false;
let msg = format!("{} server {} error: {}", worker_type, url, e);
error!("{}", msg);
unhealthy_servers.push(msg);
}
}
}
if all_healthy {
HttpResponse::Ok().body("Health check passed on all servers")
} else {
HttpResponse::ServiceUnavailable()
.body(format!("Health check failed: {:?}", unhealthy_servers))
}
}
pub async fn get_server_info(&self, client: &reqwest::Client) -> HttpResponse {
// Get info from all decode servers (where generation happens)
let mut all_internal_states = Vec::new();
let mut decode_infos = Vec::new();
// Clone URLs to avoid holding lock across await
let worker_urls: Vec<String> = self
.decode_workers
.read()
.unwrap()
.iter()
.map(|w| w.url.clone())
.collect();
for worker_url in worker_urls {
match client
.get(format!("{}/get_server_info", worker_url))
.send()
.await
{
Ok(res) if res.status().is_success() => {
match res.json::<Value>().await {
Ok(info) => {
// Extract internal_states from each decode server
if let Some(states) = info.get("internal_states") {
if let Some(states_array) = states.as_array() {
all_internal_states.extend(states_array.clone());
}
}
decode_infos.push(info);
}
Err(e) => error!("Failed to parse server info: {}", e),
}
}
_ => {}
}
}
// If we have internal states, return in the format expected by bench_one_batch_server.py
if !all_internal_states.is_empty() {
// Use the first decode server's internal state (they should all be similar)
HttpResponse::Ok().json(serde_json::json!({
"internal_states": all_internal_states,
// Include original format for compatibility
"decode_servers": decode_infos,
}))
} else {
// Fallback: create a dummy internal_states entry
HttpResponse::Ok().json(serde_json::json!({
"internal_states": [{
"last_gen_throughput": 0.0,
"avg_spec_accept_length": null,
}],
"decode_servers": decode_infos,
}))
}
}
pub async fn get_models(&self, client: &reqwest::Client, req: &HttpRequest) -> HttpResponse {
// Get first prefill worker URL to avoid holding lock across await
let first_worker_url = if let Ok(workers) = self.prefill_workers.read() {
workers.first().map(|w| w.url.clone())
} else {
return HttpResponse::InternalServerError().body("Failed to access prefill workers");
};
if let Some(worker_url) = first_worker_url {
// Send request directly without going through Router
let mut request_builder = client.get(format!("{}/v1/models", worker_url));
for (name, value) in crate::router::copy_request_headers(req) {
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
{
request_builder = request_builder.header(name, value);
}
}
match request_builder.send().await {
Ok(res) => {
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
match res.bytes().await {
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
Err(e) => HttpResponse::InternalServerError()
.body(format!("Failed to read response body: {}", e)),
}
}
Err(e) => HttpResponse::InternalServerError()
.body(format!("Failed to send request: {}", e)),
}
} else {
HttpResponse::ServiceUnavailable().body("No prefill servers available")
}
}
pub async fn get_loads(&self, client: &reqwest::Client) -> HttpResponse {
let p_urls: Vec<_> = self
.prefill_workers
.read()
.unwrap()
.iter()
.map(|w| w.url.clone())
.collect();
let d_urls: Vec<_> = self
.decode_workers
.read()
.unwrap()
.iter()
.map(|w| w.url.clone())
.collect();
let mut prefill_loads = Vec::new();
let mut decode_loads = Vec::new();
for url in &p_urls {
let load = get_worker_load(client, url).await.unwrap_or(-1);
prefill_loads.push(serde_json::json!({
"engine": format!("(Prefill@{})", url),
"load": load as i64
}));
}
for url in &d_urls {
let load = get_worker_load(client, url).await.unwrap_or(-1);
decode_loads.push(serde_json::json!({
"engine": format!("(Decode@{})", url),
"load": load as i64
}));
}
HttpResponse::Ok().json(serde_json::json!({
"prefill": prefill_loads,
"decode": decode_loads
}))
}
pub async fn get_model_info(
&self,
client: &reqwest::Client,
req: &HttpRequest,
) -> HttpResponse {
// Get model info from the first prefill server (matches original Rust PDLB behavior)
// Get first prefill worker URL to avoid holding lock across await
let first_worker_url = if let Ok(workers) = self.prefill_workers.read() {
workers.first().map(|w| w.url.clone())
} else {
return HttpResponse::InternalServerError().body("Failed to access prefill workers");
};
if let Some(worker_url) = first_worker_url {
let mut request_builder = client.get(format!("{}/get_model_info", worker_url));
for (name, value) in crate::router::copy_request_headers(req) {
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
{
request_builder = request_builder.header(name, value);
}
}
match request_builder.send().await {
Ok(res) => {
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
match res.bytes().await {
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
Err(e) => HttpResponse::InternalServerError()
.body(format!("Failed to read response body: {}", e)),
}
}
Err(e) => HttpResponse::InternalServerError()
.body(format!("Failed to send request: {}", e)),
}
} else {
HttpResponse::ServiceUnavailable().body("No prefill servers available")
}
}
pub async fn flush_cache(&self, client: &reqwest::Client) -> HttpResponse {
let mut tasks = Vec::new();
// Flush cache on all prefill servers
for worker in self.prefill_workers.read().unwrap().iter() {
let url = format!("{}/flush_cache", worker.url);
tasks.push(client.post(&url).send());
}
// Flush cache on all decode servers
for worker in self.decode_workers.read().unwrap().iter() {
let url = format!("{}/flush_cache", worker.url);
tasks.push(client.post(&url).send());
}
let results = futures_util::future::join_all(tasks).await;
let mut all_success = true;
for (i, result) in results.into_iter().enumerate() {
match result {
Ok(res) if res.status().is_success() => {}
Ok(res) => {
all_success = false;
warn!(
"Server {} returned status {} for flush_cache",
i,
res.status()
);
}
Err(e) => {
all_success = false;
error!("Server {} error during flush_cache: {}", i, e);
}
}
}
if all_success {
HttpResponse::Ok().body("Cache flushed on all servers")
} else {
HttpResponse::InternalServerError().body("Cache flush failed on one or more servers")
}
}
}
// Essential PDLB types extracted for PD routing
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone)]
pub enum EngineType {
Prefill,
Decode,
}
#[derive(Debug, Clone)]
pub struct EngineInfo {
pub engine_type: EngineType,
pub url: String,
pub bootstrap_port: Option<u16>,
}
impl EngineInfo {
pub fn new_prefill(url: String, bootstrap_port: Option<u16>) -> Self {
EngineInfo {
engine_type: EngineType::Prefill,
url,
bootstrap_port,
}
}
pub fn new_decode(url: String) -> Self {
EngineInfo {
engine_type: EngineType::Decode,
url,
bootstrap_port: None,
}
}
pub fn api_path(&self, api_path: &str) -> String {
if api_path.starts_with("/") {
format!("{}{}", self.url, api_path)
} else {
format!("{}/{}", self.url, api_path)
}
}
pub fn get_hostname(&self) -> String {
// Simple hostname extraction without external dependencies
let url = self
.url
.trim_start_matches("http://")
.trim_start_matches("https://");
url.split(':').next().unwrap_or("localhost").to_string()
}
}
// PD-specific routing policies
#[derive(Debug, Clone, PartialEq)]
pub enum PDSelectionPolicy {
Random,
PowerOfTwo,
CacheAware {
cache_threshold: f32,
balance_abs_threshold: usize,
balance_rel_threshold: f32,
},
}
// Bootstrap types from PDLB
#[derive(Debug, Deserialize, Serialize)]
#[serde(untagged)]
pub enum SingleOrBatch<T> {
Single(T),
Batch(Vec<T>),
}
pub type InputIds = SingleOrBatch<Vec<i32>>;
pub type InputText = SingleOrBatch<String>;
pub type BootstrapHost = SingleOrBatch<String>;
pub type BootstrapPort = SingleOrBatch<Option<u16>>;
pub type BootstrapRoom = SingleOrBatch<u64>;
// Bootstrap trait for request handling
pub trait Bootstrap: Send + Sync {
fn is_stream(&self) -> bool;
fn get_batch_size(&self) -> Result<Option<usize>, String>;
fn set_bootstrap_info(
&mut self,
bootstrap_host: BootstrapHost,
bootstrap_port: BootstrapPort,
bootstrap_room: BootstrapRoom,
);
fn add_bootstrap_info(&mut self, prefill_info: &EngineInfo) -> Result<(), String> {
let batch_size = self.get_batch_size()?;
if let Some(batch_size) = batch_size {
self.set_bootstrap_info(
BootstrapHost::Batch(vec![prefill_info.get_hostname(); batch_size]),
BootstrapPort::Batch(vec![prefill_info.bootstrap_port; batch_size]),
// Use high-quality random numbers to minimize collision risk
BootstrapRoom::Batch(
(0..batch_size)
.map(|_| {
// Combine multiple sources of randomness for better distribution
let r1 = rand::random::<u64>();
let r2 = rand::random::<u64>();
r1.wrapping_add(r2.rotate_left(32))
})
.collect(),
),
);
} else {
self.set_bootstrap_info(
BootstrapHost::Single(prefill_info.get_hostname()),
BootstrapPort::Single(prefill_info.bootstrap_port),
BootstrapRoom::Single({
// Use high-quality random number for single requests too
let r1 = rand::random::<u64>();
let r2 = rand::random::<u64>();
r1.wrapping_add(r2.rotate_left(32))
}),
);
}
Ok(())
}
}
// Request types
#[derive(Debug, Deserialize, Serialize)]
pub struct GenerateReqInput {
pub text: Option<InputText>,
pub input_ids: Option<InputIds>,
#[serde(default)]
pub stream: bool,
pub bootstrap_host: Option<BootstrapHost>,
pub bootstrap_port: Option<BootstrapPort>,
pub bootstrap_room: Option<BootstrapRoom>,
#[serde(flatten)]
pub other: Value,
}
impl GenerateReqInput {
pub fn get_batch_size(&self) -> Result<Option<usize>, String> {
if self.text.is_some() && self.input_ids.is_some() {
return Err("Both text and input_ids are present in the request".to_string());
}
// Check text batch
if let Some(InputText::Batch(texts)) = &self.text {
if texts.is_empty() {
return Err("Batch text array is empty".to_string());
}
if texts.len() > 10000 {
// Reasonable limit for production
return Err(format!(
"Batch size {} exceeds maximum allowed (10000)",
texts.len()
));
}
return Ok(Some(texts.len()));
}
// Check input_ids batch
if let Some(InputIds::Batch(ids)) = &self.input_ids {
if ids.is_empty() {
return Err("Batch input_ids array is empty".to_string());
}
if ids.len() > 10000 {
// Reasonable limit for production
return Err(format!(
"Batch size {} exceeds maximum allowed (10000)",
ids.len()
));
}
// Validate each sequence is not empty
for (i, seq) in ids.iter().enumerate() {
if seq.is_empty() {
return Err(format!("Input sequence at index {} is empty", i));
}
}
return Ok(Some(ids.len()));
}
Ok(None)
}
}
impl Bootstrap for GenerateReqInput {
fn is_stream(&self) -> bool {
self.stream
}
fn get_batch_size(&self) -> Result<Option<usize>, String> {
self.get_batch_size()
}
fn set_bootstrap_info(
&mut self,
bootstrap_host: BootstrapHost,
bootstrap_port: BootstrapPort,
bootstrap_room: BootstrapRoom,
) {
self.bootstrap_host = Some(bootstrap_host);
self.bootstrap_port = Some(bootstrap_port);
self.bootstrap_room = Some(bootstrap_room);
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ChatReqInput {
#[serde(default)]
pub stream: bool,
pub bootstrap_host: Option<BootstrapHost>,
pub bootstrap_port: Option<BootstrapPort>,
pub bootstrap_room: Option<BootstrapRoom>,
#[serde(flatten)]
pub other: Value,
}
impl Bootstrap for ChatReqInput {
fn is_stream(&self) -> bool {
self.stream
}
fn get_batch_size(&self) -> Result<Option<usize>, String> {
// Check if 'n' parameter is present and > 1
if let Some(n_value) = self.other.get("n") {
if let Some(n) = n_value.as_u64() {
if n > 1 {
return Ok(Some(n as usize));
}
}
}
Ok(None)
}
fn set_bootstrap_info(
&mut self,
bootstrap_host: BootstrapHost,
bootstrap_port: BootstrapPort,
bootstrap_room: BootstrapRoom,
) {
self.bootstrap_host = Some(bootstrap_host);
self.bootstrap_port = Some(bootstrap_port);
self.bootstrap_room = Some(bootstrap_room);
}
}
// Request adapter to bridge OpenAI API types with PD routing requirements
use crate::openai_api_types::{
ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, StringOrArray,
};
use crate::pd_types::{Bootstrap, ChatReqInput, GenerateReqInput, SingleOrBatch};
use serde_json::Value;
/// Adapter trait to convert OpenAI requests to PD-compatible requests
pub trait ToPdRequest {
type Output: Bootstrap;
fn to_pd_request(self) -> Self::Output;
}
// Helper macro to insert optional fields into a map
macro_rules! insert_if_some {
($map:expr, $($field:expr => $key:expr),* $(,)?) => {
$(
if let Some(value) = $field {
$map.insert($key.to_string(), serde_json::to_value(value).unwrap_or(Value::Null));
}
)*
};
}
// Helper macro for simple value insertions
macro_rules! insert_value {
($map:expr, $($field:expr => $key:expr),* $(,)?) => {
$(
$map.insert($key.to_string(), $field.into());
)*
};
}
// ============= Generate Request Adapter =============
impl ToPdRequest for GenerateRequest {
type Output = GenerateReqInput;
fn to_pd_request(self) -> Self::Output {
// Build the other fields first
let mut other = serde_json::Map::new();
// Handle text input - check in priority order: text (SGLang), prompt (OpenAI)
let (text, input_ids) = if let Some(text_str) = self.text {
// SGLang native format
(Some(SingleOrBatch::Single(text_str)), None)
} else if let Some(prompt) = self.prompt {
// OpenAI style prompt
let text = match prompt {
StringOrArray::String(s) => Some(SingleOrBatch::Single(s)),
StringOrArray::Array(v) => Some(SingleOrBatch::Batch(v)),
};
(text, None)
} else if let Some(ids) = self.input_ids {
// Input IDs case
let input_ids = match ids {
crate::openai_api_types::InputIds::Single(ids) => Some(SingleOrBatch::Single(ids)),
crate::openai_api_types::InputIds::Batch(ids) => Some(SingleOrBatch::Batch(ids)),
};
(None, input_ids)
} else {
// No input provided
(None, None)
};
// Add parameters to other - handle both old and new style
if let Some(params) = self.parameters {
// For generate endpoint, extract max_new_tokens to top level if present
let mut params_value = serde_json::to_value(&params).unwrap_or(Value::Null);
if let Value::Object(ref mut params_map) = params_value {
// Move max_new_tokens to top level if it exists
if let Some(max_new_tokens) = params_map.remove("max_new_tokens") {
other.insert("max_new_tokens".to_string(), max_new_tokens);
}
// Move temperature to top level if it exists
if let Some(temperature) = params_map.remove("temperature") {
other.insert("temperature".to_string(), temperature);
}
}
// Only add parameters if there are remaining fields
if !params_value.is_null() && params_value.as_object().map_or(false, |m| !m.is_empty())
{
other.insert("parameters".to_string(), params_value);
}
}
// Add sampling_params if present
if let Some(sampling_params) = self.sampling_params {
let params_value = serde_json::to_value(&sampling_params).unwrap_or(Value::Null);
if !params_value.is_null() {
// Extract commonly used fields to top level
if let Value::Object(ref params_map) = params_value {
if let Some(max_new_tokens) = params_map.get("max_new_tokens") {
other.insert("max_new_tokens".to_string(), max_new_tokens.clone());
}
if let Some(temperature) = params_map.get("temperature") {
other.insert("temperature".to_string(), temperature.clone());
}
}
other.insert("sampling_params".to_string(), params_value);
}
}
// Add other fields
insert_value!(other,
self.stream => "stream",
self.return_logprob => "return_logprob"
);
GenerateReqInput {
text,
input_ids,
stream: self.stream,
bootstrap_host: None,
bootstrap_port: None,
bootstrap_room: None,
other: Value::Object(other),
}
}
}
// ============= Completion Request Adapter =============
impl ToPdRequest for CompletionRequest {
type Output = GenerateReqInput;
fn to_pd_request(self) -> Self::Output {
// Convert CompletionRequest to GenerateReqInput
let text = match self.prompt {
StringOrArray::String(s) => Some(SingleOrBatch::Single(s)),
StringOrArray::Array(v) => Some(SingleOrBatch::Batch(v)),
};
// Map OpenAI parameters to generate parameters
let mut other = serde_json::Map::new();
// Create parameters object
let mut params = serde_json::Map::new();
// Map OpenAI fields to internal parameter names
insert_if_some!(params,
self.max_tokens => "max_new_tokens",
self.temperature => "temperature",
self.top_p => "top_p",
self.n => "best_of",
self.logprobs => "top_n_tokens",
self.seed => "seed"
);
// Special handling for fields that need transformation
if let Some(presence_penalty) = self.presence_penalty {
params.insert(
"repetition_penalty".to_string(),
(1.0 + presence_penalty).into(),
);
}
if let Some(stop) = self.stop {
let stop_sequences = match stop {
StringOrArray::String(s) => vec![s],
StringOrArray::Array(v) => v,
};
params.insert("stop".to_string(), stop_sequences.into());
}
if self.echo {
params.insert("return_full_text".to_string(), true.into());
}
other.insert("parameters".to_string(), Value::Object(params));
// Store original model and stream flag
insert_value!(other,
self.model => "model",
self.stream => "stream"
);
GenerateReqInput {
text,
input_ids: None,
stream: self.stream,
bootstrap_host: None,
bootstrap_port: None,
bootstrap_room: None,
other: Value::Object(other),
}
}
}
// ============= Chat Completion Request Adapter =============
impl ToPdRequest for ChatCompletionRequest {
type Output = ChatReqInput;
fn to_pd_request(self) -> Self::Output {
let mut other = serde_json::Map::new();
// Add required fields
insert_if_some!(other,
Some(&self.messages) => "messages"
);
insert_value!(other,
self.model => "model",
self.stream => "stream"
);
// Add all optional fields
insert_if_some!(other,
self.temperature => "temperature",
self.top_p => "top_p",
self.n => "n",
self.stop => "stop",
self.max_tokens => "max_tokens",
self.max_completion_tokens => "max_completion_tokens",
self.presence_penalty => "presence_penalty",
self.frequency_penalty => "frequency_penalty",
self.logit_bias => "logit_bias",
self.user => "user",
self.seed => "seed",
self.top_logprobs => "top_logprobs",
self.response_format => "response_format",
self.tools => "tools",
self.tool_choice => "tool_choice",
self.parallel_tool_calls => "parallel_tool_calls",
self.functions => "functions",
self.function_call => "function_call"
);
// Handle boolean logprobs flag
if self.logprobs {
other.insert("logprobs".to_string(), true.into());
}
ChatReqInput {
stream: self.stream,
bootstrap_host: None,
bootstrap_port: None,
bootstrap_room: None,
other: Value::Object(other),
}
}
}
// ============= Direct routing support for regular router =============
/// Extension trait for routing without PD conversion
pub trait RouteableRequest: GenerationRequest + serde::Serialize + Clone {
/// Convert to JSON for sending to backend
fn to_json(&self) -> Result<Value, serde_json::Error> {
serde_json::to_value(self)
}
/// Convert to bytes for legacy routing
fn to_bytes(&self) -> Result<bytes::Bytes, serde_json::Error> {
let json = serde_json::to_vec(self)?;
Ok(bytes::Bytes::from(json))
}
}
impl RouteableRequest for GenerateRequest {}
impl RouteableRequest for CompletionRequest {}
impl RouteableRequest for ChatCompletionRequest {}
use crate::pd_router::PDRouter;
use crate::pd_types::PDSelectionPolicy;
use crate::tree::Tree; use crate::tree::Tree;
use ::metrics::{counter, gauge, histogram}; use ::metrics::{counter, gauge, histogram};
use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse}; use actix_web::{HttpRequest, HttpResponse};
use bytes::Bytes;
use futures_util::{StreamExt, TryStreamExt}; use futures_util::{StreamExt, TryStreamExt};
use serde_json::Value;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Debug; use std::fmt::Debug;
use std::sync::atomic::AtomicUsize; use std::sync::atomic::AtomicUsize;
...@@ -15,7 +15,7 @@ use std::time::Instant; ...@@ -15,7 +15,7 @@ use std::time::Instant;
use tokio; use tokio;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> { pub fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
req.headers() req.headers()
.iter() .iter()
.filter_map(|(name, value)| { .filter_map(|(name, value)| {
...@@ -40,6 +40,9 @@ pub enum Router { ...@@ -40,6 +40,9 @@ pub enum Router {
timeout_secs: u64, timeout_secs: u64,
interval_secs: u64, interval_secs: u64,
}, },
PrefillDecode {
pd_router: Arc<PDRouter>,
},
CacheAware { CacheAware {
/* /*
Cache-Aware Load Balancing Router Cache-Aware Load Balancing Router
...@@ -133,6 +136,13 @@ pub enum PolicyConfig { ...@@ -133,6 +136,13 @@ pub enum PolicyConfig {
timeout_secs: u64, timeout_secs: u64,
interval_secs: u64, interval_secs: u64,
}, },
PrefillDecodeConfig {
selection_policy: PDSelectionPolicy,
prefill_urls: Vec<(String, Option<u16>)>, // (url, bootstrap_port)
decode_urls: Vec<String>,
timeout_secs: u64,
interval_secs: u64,
},
} }
impl Router { impl Router {
...@@ -155,10 +165,24 @@ impl Router { ...@@ -155,10 +165,24 @@ impl Router {
interval_secs, interval_secs,
.. ..
} => (*timeout_secs, *interval_secs), } => (*timeout_secs, *interval_secs),
PolicyConfig::PrefillDecodeConfig {
timeout_secs,
interval_secs,
..
} => (*timeout_secs, *interval_secs),
}; };
// Wait until all workers are healthy // For PrefillDecode, we need to handle workers differently
Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?; match &policy_config {
PolicyConfig::PrefillDecodeConfig { .. } => {
// PD mode doesn't use the worker_urls parameter
// We'll validate PD workers separately
}
_ => {
// Wait until all workers are healthy for regular modes
Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
}
}
// Create router based on policy... // Create router based on policy...
Ok(match policy_config { Ok(match policy_config {
...@@ -226,7 +250,7 @@ impl Router { ...@@ -226,7 +250,7 @@ impl Router {
}); });
for url in &worker_urls { for url in &worker_urls {
tree.lock().unwrap().insert(&"".to_string(), url); tree.lock().unwrap().insert("", url);
} }
Router::CacheAware { Router::CacheAware {
...@@ -242,6 +266,26 @@ impl Router { ...@@ -242,6 +266,26 @@ impl Router {
_eviction_thread: Some(eviction_thread), _eviction_thread: Some(eviction_thread),
} }
} }
PolicyConfig::PrefillDecodeConfig {
selection_policy,
prefill_urls,
decode_urls,
timeout_secs,
interval_secs,
} => {
// Create PDRouter instance
let pd_router = PDRouter::new(
prefill_urls,
decode_urls,
selection_policy,
timeout_secs,
interval_secs,
)?;
Router::PrefillDecode {
pd_router: Arc::new(pd_router),
}
}
}) })
} }
...@@ -251,16 +295,23 @@ impl Router { ...@@ -251,16 +295,23 @@ impl Router {
Router::RoundRobin { worker_urls, .. } => Arc::clone(worker_urls), Router::RoundRobin { worker_urls, .. } => Arc::clone(worker_urls),
Router::Random { worker_urls, .. } => Arc::clone(worker_urls), Router::Random { worker_urls, .. } => Arc::clone(worker_urls),
Router::CacheAware { worker_urls, .. } => Arc::clone(worker_urls), Router::CacheAware { worker_urls, .. } => Arc::clone(worker_urls),
Router::PrefillDecode { .. } => {
// For PD mode, return empty list since we manage workers differently
Arc::new(RwLock::new(Vec::new()))
}
} }
} }
fn wait_for_healthy_workers( pub fn wait_for_healthy_workers(
worker_urls: &[String], worker_urls: &[String],
timeout_secs: u64, timeout_secs: u64,
interval_secs: u64, interval_secs: u64,
) -> Result<(), String> { ) -> Result<(), String> {
let start_time = std::time::Instant::now(); let start_time = std::time::Instant::now();
let sync_client = reqwest::blocking::Client::new(); let sync_client = reqwest::blocking::Client::builder()
.timeout(Duration::from_secs(timeout_secs))
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
loop { loop {
if start_time.elapsed() > Duration::from_secs(timeout_secs) { if start_time.elapsed() > Duration::from_secs(timeout_secs) {
...@@ -323,10 +374,14 @@ impl Router { ...@@ -323,10 +374,14 @@ impl Router {
Ok(worker_urls.read().unwrap()[0].clone()) Ok(worker_urls.read().unwrap()[0].clone())
} }
} }
Router::PrefillDecode { .. } => {
// For PD mode, we don't need this method as routing is handled by PDRouter
Err("PrefillDecode mode doesn't use select_first_worker".to_string())
}
} }
} }
async fn send_request( pub async fn send_request(
&self, &self,
client: &reqwest::Client, client: &reqwest::Client,
worker_url: &str, worker_url: &str,
...@@ -339,7 +394,11 @@ impl Router { ...@@ -339,7 +394,11 @@ impl Router {
// Copy all headers from original request except for /health because it does not need authorization // Copy all headers from original request except for /health because it does not need authorization
if route != "/health" { if route != "/health" {
for (name, value) in copy_request_headers(req) { for (name, value) in copy_request_headers(req) {
request_builder = request_builder.header(name, value); // Skip Content-Type and Content-Length as .json() sets them
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
{
request_builder = request_builder.header(name, value);
}
} }
} }
...@@ -433,50 +492,193 @@ impl Router { ...@@ -433,50 +492,193 @@ impl Router {
HttpResponse::InternalServerError().body("All retry attempts failed") HttpResponse::InternalServerError().body("All retry attempts failed")
} }
fn get_text_from_request(&self, body: &Bytes, route: &str) -> String { pub async fn route_to_all(
// Convert body to JSON &self,
let json: Value = match serde_json::from_slice(body) { client: &reqwest::Client,
Ok(j) => j, route: &str,
Err(_) => { req: &HttpRequest,
warn!("Failed to parse JSON from request body."); ) -> HttpResponse {
return String::new(); // Get all worker URLs based on router type
let worker_urls = match self {
Router::PrefillDecode { .. } => {
// For PD mode, route_to_all is not supported directly
// It should be handled by PDRouter if needed
return HttpResponse::NotImplemented()
.body("route_to_all not implemented for PrefillDecode mode");
} }
_ => self.get_worker_urls().read().unwrap().clone(),
}; };
match route { // Send requests to all workers concurrently
"/generate" => { let mut tasks = Vec::new();
// For /generate, always use the "text" field. for worker_url in &worker_urls {
match json.get("text").and_then(Value::as_str) { let mut request_builder = client.post(format!("{}{}", worker_url, route));
Some(text) => text.to_string(),
None => { // Copy headers from original request
warn!("No 'text' field found in request body for route /generate."); for (name, value) in copy_request_headers(req) {
String::new() request_builder = request_builder.header(name, value);
}
}
} }
"/v1/chat/completions" | "/v1/completions" => {
// For these routes, try "messages", then "prompt", then "text". tasks.push(request_builder.send());
if let Some(messages) = json.get("messages") { }
serde_json::to_string(messages).unwrap_or_default()
} else if let Some(prompt) = json.get("prompt").and_then(Value::as_str) { // Wait for all responses
prompt.to_string() let results = futures_util::future::join_all(tasks).await;
} else {
warn!("Failed to find 'messages', 'prompt' in request body."); // Check if all succeeded
String::new() let all_success = results.iter().all(|r| {
} r.as_ref()
.map(|res| res.status().is_success())
.unwrap_or(false)
});
if all_success {
HttpResponse::Ok().body("Operation completed on all servers")
} else {
HttpResponse::InternalServerError().body("Operation failed on one or more servers")
}
}
pub async fn get_all_loads(
&self,
client: &reqwest::Client,
_req: &HttpRequest,
) -> HttpResponse {
// For PD mode, delegate to PDRouter
match self {
Router::PrefillDecode { pd_router } => {
return pd_router.get_loads(client).await;
} }
_ => { _ => {
warn!("Unknown route: {} - defaulting to fallback string", route); // For non-PD routers, handle normally
String::new()
} }
} }
let urls = self.get_worker_urls().read().unwrap().clone();
let prefill_urls: Vec<String> = Vec::new();
let decode_urls = urls;
// Collect loads from all servers
let mut prefill_loads = Vec::new();
let mut decode_loads = Vec::new();
// Get prefill loads
for url in &prefill_urls {
let load = self.get_worker_load(client, url).await.unwrap_or(-1);
prefill_loads.push(serde_json::json!({
"engine": format!("(Prefill@{})", url),
"load": load as i64
}));
}
// Get decode loads
for url in &decode_urls {
let load = self.get_worker_load(client, url).await.unwrap_or(-1);
decode_loads.push(serde_json::json!({
"engine": format!("(Decode@{})", url),
"load": load as i64
}));
}
HttpResponse::Ok().json(serde_json::json!({
"prefill": prefill_loads,
"decode": decode_loads
}))
} }
// TODO: return Result<String, String> instead of panicking // New method to route typed requests directly
fn select_generate_worker(&self, body: &Bytes, route: &str) -> String { pub async fn route_typed_request<
let text = self.get_text_from_request(&body, route); T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone,
>(
&self,
client: &reqwest::Client,
req: &HttpRequest,
typed_req: &T,
route: &str,
) -> HttpResponse {
match self {
Router::PrefillDecode { .. } => HttpResponse::InternalServerError()
.body("PD routing should use specialized typed handlers"),
_ => {
// Handle retries like the original implementation
let start = Instant::now();
const MAX_REQUEST_RETRIES: u32 = 3;
const MAX_TOTAL_RETRIES: u32 = 6;
let mut total_retries = 0;
while total_retries < MAX_TOTAL_RETRIES {
// Extract routing text directly from typed request
let text = typed_req.extract_text_for_routing();
let is_stream = typed_req.is_stream();
// Select worker based on text
let worker_url = self.select_generate_worker_from_text(&text);
let mut request_retries = 0;
// Try the same worker multiple times
while request_retries < MAX_REQUEST_RETRIES {
if total_retries >= 1 {
info!("Retrying request after {} failed attempts", total_retries);
counter!("sgl_router_retries_total", "route" => route.to_string())
.increment(1);
}
// Send typed request directly
let response = self
.send_typed_request(
client,
req,
typed_req,
route,
&worker_url,
is_stream,
)
.await;
if response.status().is_success() {
let duration = start.elapsed();
histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string())
.record(duration.as_secs_f64());
return response;
} else {
// if the worker is healthy, it means the request is bad, so return the error response
let health_response =
self.send_request(client, &worker_url, "/health", req).await;
if health_response.status().is_success() {
counter!("sgl_router_request_errors_total", "route" => route.to_string())
.increment(1);
return response;
}
}
warn!(
"Generate request to {} failed (attempt {}/{})",
worker_url,
request_retries + 1,
MAX_REQUEST_RETRIES
);
request_retries += 1;
total_retries += 1;
let worker_url = match self { if request_retries == MAX_REQUEST_RETRIES {
warn!("Removing failed worker: {}", worker_url);
self.remove_worker(&worker_url);
break;
}
}
}
counter!("sgl_router_request_errors_total", "route" => route.to_string())
.increment(1);
HttpResponse::InternalServerError().body("All retry attempts failed")
}
}
}
// Helper method to select worker from text
fn select_generate_worker_from_text(&self, text: &str) -> String {
match self {
Router::RoundRobin { Router::RoundRobin {
worker_urls, worker_urls,
current_index, current_index,
...@@ -506,8 +708,6 @@ impl Router { ...@@ -506,8 +708,6 @@ impl Router {
balance_rel_threshold, balance_rel_threshold,
.. ..
} => { } => {
// TODO: delay scheduling if cache hit rate is high because it may cause imbalance. prioritize low hit rate ones
let tree = tree.lock().unwrap(); let tree = tree.lock().unwrap();
let mut running_queue = running_queue.lock().unwrap(); let mut running_queue = running_queue.lock().unwrap();
...@@ -572,35 +772,48 @@ impl Router { ...@@ -572,35 +772,48 @@ impl Router {
selected_url selected_url
} }
}; Router::PrefillDecode { .. } => {
// For PD mode, we don't use this method
worker_url return "PD_MODE_ERROR".to_string();
}
}
} }
async fn send_generate_request( // Send typed request directly without conversion
async fn send_typed_request<T: serde::Serialize>(
&self, &self,
client: &reqwest::Client, client: &reqwest::Client,
req: &HttpRequest, req: &HttpRequest,
body: &Bytes, typed_req: &T,
route: &str, route: &str,
worker_url: &str, worker_url: &str,
is_stream: bool,
) -> HttpResponse { ) -> HttpResponse {
let is_stream = serde_json::from_slice::<serde_json::Value>(&body) let start = Instant::now();
.map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
.unwrap_or(false); // Debug: Log what we're sending
if let Ok(json_str) = serde_json::to_string_pretty(typed_req) {
debug!("Sending request to {}: {}", route, json_str);
}
let mut request_builder = client let mut request_builder = client
.post(format!("{}{}", worker_url, route)) .post(format!("{}{}", worker_url, route))
.body(body.to_vec()); .json(typed_req); // Use json() directly with typed request
// Copy all headers from original request // Copy all headers from original request
for (name, value) in copy_request_headers(req) { for (name, value) in copy_request_headers(req) {
request_builder = request_builder.header(name, value); // Skip Content-Type and Content-Length as .json() sets them
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" {
request_builder = request_builder.header(&name, &value);
}
} }
let res = match request_builder.send().await { let res = match request_builder.send().await {
Ok(res) => res, Ok(res) => res,
Err(_) => return HttpResponse::InternalServerError().finish(), Err(e) => {
error!("Failed to send request to {}: {}", worker_url, e);
return HttpResponse::InternalServerError().body(format!("Request failed: {}", e));
}
}; };
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
...@@ -625,6 +838,12 @@ impl Router { ...@@ -625,6 +838,12 @@ impl Router {
} }
} }
// Record metrics
let duration = start.elapsed();
histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string())
.record(duration.as_secs_f64());
counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1);
response response
} else if let Router::CacheAware { running_queue, .. } = self { } else if let Router::CacheAware { running_queue, .. } = self {
let running_queue = Arc::clone(running_queue); let running_queue = Arc::clone(running_queue);
...@@ -660,70 +879,6 @@ impl Router { ...@@ -660,70 +879,6 @@ impl Router {
} }
} }
pub async fn route_generate_request(
&self,
client: &reqwest::Client,
req: &HttpRequest,
body: &Bytes,
route: &str,
) -> HttpResponse {
let start = Instant::now();
const MAX_REQUEST_RETRIES: u32 = 3;
const MAX_TOTAL_RETRIES: u32 = 6;
let mut total_retries = 0;
while total_retries < MAX_TOTAL_RETRIES {
let worker_url = self.select_generate_worker(body, route);
let mut request_retries = 0;
// Try the same worker multiple times
while request_retries < MAX_REQUEST_RETRIES {
if total_retries >= 1 {
info!("Retrying request after {} failed attempts", total_retries);
counter!("sgl_router_retries_total", "route" => route.to_string()).increment(1);
}
let response = self
.send_generate_request(client, req, body, route, &worker_url)
.await;
if response.status().is_success() {
let duration = start.elapsed();
histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string()).record(duration.as_secs_f64());
return response;
} else {
// if the worker is healthy, it means the request is bad, so return the error response
let health_response =
self.send_request(client, &worker_url, "/health", req).await;
if health_response.status().is_success() {
counter!("sgl_router_request_errors_total", "route" => route.to_string())
.increment(1);
return response;
}
}
warn!(
"Generate request to {} failed (attempt {}/{})",
worker_url,
request_retries + 1,
MAX_REQUEST_RETRIES
);
request_retries += 1;
total_retries += 1;
if request_retries == MAX_REQUEST_RETRIES {
warn!("Removing failed worker: {}", worker_url);
self.remove_worker(&worker_url);
break;
}
}
}
counter!("sgl_router_request_errors_total", "route" => route.to_string()).increment(1);
HttpResponse::InternalServerError().body("All retry attempts failed")
}
pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> { pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
let (timeout_secs, interval_secs) = match self { let (timeout_secs, interval_secs) = match self {
Router::Random { Router::Random {
...@@ -741,10 +896,17 @@ impl Router { ...@@ -741,10 +896,17 @@ impl Router {
interval_secs, interval_secs,
.. ..
} => (*timeout_secs, *interval_secs), } => (*timeout_secs, *interval_secs),
Router::PrefillDecode { .. } => {
// For PD mode, we don't support adding workers via this method
return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string());
}
}; };
let start_time = std::time::Instant::now(); let start_time = std::time::Instant::now();
let client = reqwest::Client::new(); let client = reqwest::Client::builder()
.timeout(Duration::from_secs(timeout_secs))
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
loop { loop {
if start_time.elapsed() > Duration::from_secs(timeout_secs) { if start_time.elapsed() > Duration::from_secs(timeout_secs) {
...@@ -774,6 +936,9 @@ impl Router { ...@@ -774,6 +936,9 @@ impl Router {
urls.push(worker_url.to_string()); urls.push(worker_url.to_string());
gauge!("sgl_router_active_workers").set(urls.len() as f64); gauge!("sgl_router_active_workers").set(urls.len() as f64);
} }
Router::PrefillDecode { .. } => {
return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string());
}
} }
// If cache aware, initialize the queues for the new worker // If cache aware, initialize the queues for the new worker
...@@ -797,7 +962,7 @@ impl Router { ...@@ -797,7 +962,7 @@ impl Router {
.insert(worker_url.to_string(), 0); .insert(worker_url.to_string(), 0);
// Add worker to tree // Add worker to tree
tree.lock().unwrap().insert(&"".to_string(), &worker_url); tree.lock().unwrap().insert("", worker_url);
} }
return Ok(format!("Successfully added worker: {}", worker_url)); return Ok(format!("Successfully added worker: {}", worker_url));
...@@ -850,6 +1015,10 @@ impl Router { ...@@ -850,6 +1015,10 @@ impl Router {
return; return;
} }
} }
Router::PrefillDecode { .. } => {
warn!("Removing workers from PrefillDecode router not supported via remove_worker. Use dedicated PD management methods.");
return;
}
} }
// if cache aware, remove the worker from the tree // if cache aware, remove the worker from the tree
...@@ -875,4 +1044,133 @@ impl Router { ...@@ -875,4 +1044,133 @@ impl Router {
); );
} }
} }
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
match client.get(&format!("{}/get_load", worker_url)).send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
Ok(data) => data
.get("load")
.and_then(|v| v.as_i64())
.map(|v| v as isize),
Err(e) => {
debug!("Failed to parse load response from {}: {}", worker_url, e);
None
}
},
Err(e) => {
debug!("Failed to read load response from {}: {}", worker_url, e);
None
}
},
Ok(res) => {
debug!(
"Worker {} returned non-success status: {}",
worker_url,
res.status()
);
None
}
Err(e) => {
debug!("Failed to get load from {}: {}", worker_url, e);
None
}
}
}
// PD-specific wrapper methods that delegate to PDRouter
pub async fn route_pd_health_generate(
&self,
_client: &reqwest::Client,
_req: &HttpRequest,
) -> HttpResponse {
match self {
Router::PrefillDecode { pd_router } => {
pd_router.health_generate(&pd_router.http_client).await
}
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
}
}
pub async fn route_pd_generate_typed(
&self,
_client: &reqwest::Client,
req: &HttpRequest,
typed_req: crate::pd_types::GenerateReqInput,
route: &str,
) -> HttpResponse {
match self {
Router::PrefillDecode { pd_router } => {
pd_router
.route_generate(&pd_router.http_client, req, typed_req, route)
.await
}
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
}
}
pub async fn route_pd_chat_typed(
&self,
_client: &reqwest::Client,
req: &HttpRequest,
typed_req: crate::pd_types::ChatReqInput,
route: &str,
) -> HttpResponse {
match self {
Router::PrefillDecode { pd_router } => {
pd_router
.route_chat(&pd_router.http_client, req, typed_req, route)
.await
}
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
}
}
pub async fn get_pd_server_info(
&self,
_client: &reqwest::Client,
_req: &HttpRequest,
) -> HttpResponse {
match self {
Router::PrefillDecode { pd_router } => {
pd_router.get_server_info(&pd_router.http_client).await
}
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
}
}
pub async fn get_pd_models(
&self,
_client: &reqwest::Client,
req: &HttpRequest,
) -> HttpResponse {
match self {
Router::PrefillDecode { pd_router } => {
pd_router.get_models(&pd_router.http_client, req).await
}
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
}
}
pub async fn route_pd_flush_cache(&self, _client: &reqwest::Client) -> HttpResponse {
match self {
Router::PrefillDecode { pd_router } => {
pd_router.flush_cache(&pd_router.http_client).await
}
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
}
}
pub async fn get_pd_model_info(
&self,
_client: &reqwest::Client,
req: &HttpRequest,
) -> HttpResponse {
match self {
Router::PrefillDecode { pd_router } => {
pd_router.get_model_info(&pd_router.http_client, req).await
}
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
}
}
} }
use crate::logging::{self, LoggingConfig}; use crate::logging::{self, LoggingConfig};
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::prometheus::{self, PrometheusConfig}; use crate::prometheus::{self, PrometheusConfig};
use crate::request_adapter::ToPdRequest;
use crate::router::PolicyConfig; use crate::router::PolicyConfig;
use crate::router::Router; use crate::router::Router;
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig}; use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
use actix_web::{ use actix_web::{
error, get, post, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder, error, get, post, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder,
}; };
use bytes::Bytes;
use futures_util::StreamExt; use futures_util::StreamExt;
use reqwest::Client; use reqwest::Client;
use std::collections::HashMap; use std::collections::HashMap;
...@@ -20,6 +21,7 @@ use tracing::{error, info, warn, Level}; ...@@ -20,6 +21,7 @@ use tracing::{error, info, warn, Level};
pub struct AppState { pub struct AppState {
router: Arc<Router>, router: Arc<Router>,
client: Client, client: Client,
is_pd_mode: bool, // Add flag to track PD mode
} }
impl AppState { impl AppState {
...@@ -28,9 +30,16 @@ impl AppState { ...@@ -28,9 +30,16 @@ impl AppState {
client: Client, client: Client,
policy_config: PolicyConfig, policy_config: PolicyConfig,
) -> Result<Self, String> { ) -> Result<Self, String> {
// Check if this is PD mode from policy config
let is_pd_mode = matches!(policy_config, PolicyConfig::PrefillDecodeConfig { .. });
// Create router based on policy // Create router based on policy
let router = Arc::new(Router::new(worker_urls, policy_config)?); let router = Arc::new(Router::new(worker_urls, policy_config)?);
Ok(Self { router, client }) Ok(Self {
router,
client,
is_pd_mode,
})
} }
} }
...@@ -46,8 +55,25 @@ async fn sink_handler(_req: HttpRequest, mut payload: web::Payload) -> Result<Ht ...@@ -46,8 +55,25 @@ async fn sink_handler(_req: HttpRequest, mut payload: web::Payload) -> Result<Ht
} }
// Custom error handler for JSON payload errors. // Custom error handler for JSON payload errors.
fn json_error_handler(_err: error::JsonPayloadError, _req: &HttpRequest) -> Error { fn json_error_handler(err: error::JsonPayloadError, _req: &HttpRequest) -> Error {
error::ErrorPayloadTooLarge("Payload too large") error!("JSON payload error: {:?}", err);
match &err {
error::JsonPayloadError::OverflowKnownLength { length, limit } => {
error!(
"Payload too large: {} bytes exceeds limit of {} bytes",
length, limit
);
error::ErrorPayloadTooLarge(format!(
"Payload too large: {} bytes exceeds limit of {} bytes",
length, limit
))
}
error::JsonPayloadError::Overflow { limit } => {
error!("Payload overflow: exceeds limit of {} bytes", limit);
error::ErrorPayloadTooLarge(format!("Payload exceeds limit of {} bytes", limit))
}
_ => error::ErrorBadRequest(format!("Invalid JSON payload: {}", err)),
}
} }
#[get("/health")] #[get("/health")]
...@@ -59,59 +85,134 @@ async fn health(req: HttpRequest, data: web::Data<AppState>) -> impl Responder { ...@@ -59,59 +85,134 @@ async fn health(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
#[get("/health_generate")] #[get("/health_generate")]
async fn health_generate(req: HttpRequest, data: web::Data<AppState>) -> impl Responder { async fn health_generate(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router // Check if we're in PD mode
.route_to_first(&data.client, "/health_generate", &req) if data.is_pd_mode {
.await // For PD mode, check health on all servers
data.router
.route_pd_health_generate(&data.client, &req)
.await
} else {
// Regular mode
data.router
.route_to_first(&data.client, "/health_generate", &req)
.await
}
} }
#[get("/get_server_info")] #[get("/get_server_info")]
async fn get_server_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder { async fn get_server_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router if data.is_pd_mode {
.route_to_first(&data.client, "/get_server_info", &req) // For PD mode, aggregate info from both prefill and decode servers
.await data.router.get_pd_server_info(&data.client, &req).await
} else {
// Regular mode - return first server's info
data.router
.route_to_first(&data.client, "/get_server_info", &req)
.await
}
} }
#[get("/v1/models")] #[get("/v1/models")]
async fn v1_models(req: HttpRequest, data: web::Data<AppState>) -> impl Responder { async fn v1_models(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router if data.is_pd_mode {
.route_to_first(&data.client, "/v1/models", &req) // For PD mode, return models from the first prefill server
.await data.router.get_pd_models(&data.client, &req).await
} else {
// Regular mode
data.router
.route_to_first(&data.client, "/v1/models", &req)
.await
}
} }
#[get("/get_model_info")] #[get("/get_model_info")]
async fn get_model_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder { async fn get_model_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router if data.is_pd_mode {
.route_to_first(&data.client, "/get_model_info", &req) // For PD mode, get model info from the first prefill server
.await data.router.get_pd_model_info(&data.client, &req).await
} else {
data.router
.route_to_first(&data.client, "/get_model_info", &req)
.await
}
} }
#[post("/generate")] #[post("/generate")]
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder { async fn generate(
data.router req: HttpRequest,
.route_generate_request(&data.client, &req, &body, "/generate") body: web::Json<GenerateRequest>,
.await state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
let client = &state.client;
let router = &state.router;
// Use typed request directly for both PD and regular routing
if state.is_pd_mode {
// For PD mode, convert to PD request with bootstrap
let pd_request = body.into_inner().to_pd_request();
Ok(router
.route_pd_generate_typed(&client, &req, pd_request, "/generate")
.await)
} else {
// For regular mode, use typed request directly
let request = body.into_inner();
Ok(router
.route_typed_request(&client, &req, &request, "/generate")
.await)
}
} }
#[post("/v1/chat/completions")] #[post("/v1/chat/completions")]
async fn v1_chat_completions( async fn v1_chat_completions(
req: HttpRequest, req: HttpRequest,
body: Bytes, body: web::Json<ChatCompletionRequest>,
data: web::Data<AppState>, state: web::Data<AppState>,
) -> impl Responder { ) -> Result<HttpResponse, Error> {
data.router let client = &state.client;
.route_generate_request(&data.client, &req, &body, "/v1/chat/completions") let router = &state.router;
.await
// Use typed request directly for both PD and regular routing
if state.is_pd_mode {
// For PD mode, convert to PD request with bootstrap
let pd_request = body.into_inner().to_pd_request();
Ok(router
.route_pd_chat_typed(&client, &req, pd_request, "/v1/chat/completions")
.await)
} else {
// For regular mode, use typed request directly
let request = body.into_inner();
Ok(router
.route_typed_request(&client, &req, &request, "/v1/chat/completions")
.await)
}
} }
#[post("/v1/completions")] #[post("/v1/completions")]
async fn v1_completions( async fn v1_completions(
req: HttpRequest, req: HttpRequest,
body: Bytes, body: web::Json<CompletionRequest>,
data: web::Data<AppState>, state: web::Data<AppState>,
) -> impl Responder { ) -> Result<HttpResponse, Error> {
data.router let client = &state.client;
.route_generate_request(&data.client, &req, &body, "/v1/completions") let router = &state.router;
.await
// Use typed request directly for both PD and regular routing
if state.is_pd_mode {
// For PD mode, convert to PD request with bootstrap
let pd_request = body.into_inner().to_pd_request();
Ok(router
.route_pd_generate_typed(&client, &req, pd_request, "/v1/completions")
.await)
} else {
// For regular mode, use typed request directly
let request = body.into_inner();
Ok(router
.route_typed_request(&client, &req, &request, "/v1/completions")
.await)
}
} }
#[post("/add_worker")] #[post("/add_worker")]
...@@ -153,6 +254,25 @@ async fn remove_worker( ...@@ -153,6 +254,25 @@ async fn remove_worker(
HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url)) HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url))
} }
#[post("/flush_cache")]
async fn flush_cache(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
if data.is_pd_mode {
// For PD mode, flush cache on both prefill and decode servers
data.router.route_pd_flush_cache(&data.client).await
} else {
// Route to all workers for cache flushing
data.router
.route_to_all(&data.client, "/flush_cache", &req)
.await
}
}
#[get("/get_loads")]
async fn get_loads(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
// Get loads from all workers
data.router.get_all_loads(&data.client, &req).await
}
pub struct ServerConfig { pub struct ServerConfig {
pub host: String, pub host: String,
pub port: u16, pub port: u16,
...@@ -163,6 +283,7 @@ pub struct ServerConfig { ...@@ -163,6 +283,7 @@ pub struct ServerConfig {
pub log_dir: Option<String>, pub log_dir: Option<String>,
pub service_discovery_config: Option<ServiceDiscoveryConfig>, pub service_discovery_config: Option<ServiceDiscoveryConfig>,
pub prometheus_config: Option<PrometheusConfig>, pub prometheus_config: Option<PrometheusConfig>,
pub request_timeout_secs: u64,
} }
pub async fn startup(config: ServerConfig) -> std::io::Result<()> { pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
...@@ -215,6 +336,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { ...@@ -215,6 +336,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
let client = Client::builder() let client = Client::builder()
.pool_idle_timeout(Some(Duration::from_secs(50))) .pool_idle_timeout(Some(Duration::from_secs(50)))
.timeout(Duration::from_secs(config.request_timeout_secs)) // Use configurable timeout
.build() .build()
.expect("Failed to create HTTP client"); .expect("Failed to create HTTP client");
...@@ -276,7 +398,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { ...@@ -276,7 +398,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
.service(add_worker) .service(add_worker)
.service(remove_worker) .service(remove_worker)
.service(list_workers) .service(list_workers)
// Default handler for unmatched routes. .service(flush_cache)
.service(get_loads)
.default_service(web::route().to(sink_handler)) .default_service(web::route().to(sink_handler))
}) })
.bind_auto_h2c((config.host, config.port))? .bind_auto_h2c((config.host, config.port))?
......
//! Comprehensive tests for PrefillDecode (PD) routing functionality
//!
//! This test suite covers:
//! - Phase 1: Basic PD router creation and configuration
//! - Phase 2: Bootstrap injection and request handling
//! - Phase 3: Cache-aware selection (when implemented)
//!
//! Note: PD mode is enabled via the pd_disaggregated flag, not as a policy type.
//! The policy type (Random, PowerOfTwo, CacheAware) determines the selection algorithm within PD mode.
#[cfg(test)]
mod test_pd_routing {
use rand::Rng;
use serde_json::json;
use sglang_router_rs::pd_types::{EngineInfo, EngineType, PDSelectionPolicy};
use sglang_router_rs::router::{PolicyConfig, Router};
// Test-only struct to help validate PD request parsing
#[derive(Debug)]
struct PDRequest {
pub is_stream: bool,
pub batch_size: Option<usize>,
}
impl PDRequest {
// Extract PD-relevant info from JSON for testing
pub fn from_json(json: &serde_json::Value) -> Self {
let is_stream = json
.get("stream")
.and_then(|v| v.as_bool())
.unwrap_or(false);
// Detect batch size from text or input_ids
let batch_size = if let Some(text) = json.get("text") {
text.as_array().map(|arr| arr.len())
} else if let Some(input_ids) = json.get("input_ids") {
input_ids.as_array().map(|arr| arr.len())
} else {
None
};
PDRequest {
is_stream,
batch_size,
}
}
}
// ========================================================================
// Phase 1: Basic PD Components and Router Creation
// ========================================================================
#[test]
fn test_engine_info_creation() {
// Test EngineInfo creation for prefill servers
let prefill_engine = EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000));
match prefill_engine.engine_type {
EngineType::Prefill => (),
_ => panic!("Expected Prefill engine type"),
}
assert_eq!(prefill_engine.url, "http://prefill:8080");
assert_eq!(prefill_engine.bootstrap_port, Some(9000));
assert_eq!(prefill_engine.get_hostname(), "prefill");
// Test EngineInfo creation for decode servers
let decode_engine = EngineInfo::new_decode("http://decode:8080".to_string());
match decode_engine.engine_type {
EngineType::Decode => (),
_ => panic!("Expected Decode engine type"),
}
assert_eq!(decode_engine.url, "http://decode:8080");
assert_eq!(decode_engine.bootstrap_port, None);
assert_eq!(decode_engine.get_hostname(), "decode");
// Test API path generation
assert_eq!(
prefill_engine.api_path("/generate"),
"http://prefill:8080/generate"
);
assert_eq!(
prefill_engine.api_path("health"),
"http://prefill:8080/health"
);
assert_eq!(
decode_engine.api_path("/v1/chat/completions"),
"http://decode:8080/v1/chat/completions"
);
}
#[test]
fn test_pd_selection_policies() {
// Test all PD selection policy variants
// Note: These policies are only used when pd_disaggregated=true
let policies = vec![
PDSelectionPolicy::Random,
PDSelectionPolicy::PowerOfTwo,
PDSelectionPolicy::CacheAware {
cache_threshold: 0.5,
balance_abs_threshold: 32,
balance_rel_threshold: 1.1,
},
];
for policy in policies {
// Verify each policy can be created and matched
match &policy {
PDSelectionPolicy::Random => {
assert!(matches!(policy, PDSelectionPolicy::Random));
}
PDSelectionPolicy::PowerOfTwo => {
assert!(matches!(policy, PDSelectionPolicy::PowerOfTwo));
}
PDSelectionPolicy::CacheAware {
cache_threshold, ..
} => {
assert!(*cache_threshold >= 0.0 && *cache_threshold <= 1.0);
}
}
}
}
#[test]
fn test_pd_router_configuration() {
// Test PrefillDecodeConfig creation with various policies
// This config is used when pd_disaggregated=true
let configs = vec![
PolicyConfig::PrefillDecodeConfig {
selection_policy: PDSelectionPolicy::Random,
prefill_urls: vec![
("http://prefill1:8080".to_string(), Some(9000)),
("http://prefill2:8080".to_string(), None),
],
decode_urls: vec![
"http://decode1:8080".to_string(),
"http://decode2:8080".to_string(),
],
timeout_secs: 10,
interval_secs: 1,
},
PolicyConfig::PrefillDecodeConfig {
selection_policy: PDSelectionPolicy::PowerOfTwo,
prefill_urls: vec![("http://prefill:8080".to_string(), Some(9000))],
decode_urls: vec!["http://decode:8080".to_string()],
timeout_secs: 5,
interval_secs: 1,
},
PolicyConfig::PrefillDecodeConfig {
selection_policy: PDSelectionPolicy::CacheAware {
cache_threshold: 0.7,
balance_abs_threshold: 20,
balance_rel_threshold: 1.2,
},
prefill_urls: vec![
("http://p1:8080".to_string(), Some(9000)),
("http://p2:8080".to_string(), Some(9001)),
("http://p3:8080".to_string(), Some(9002)),
],
decode_urls: vec!["http://d1:8080".to_string(), "http://d2:8080".to_string()],
timeout_secs: 10,
interval_secs: 2,
},
];
for config in configs {
// Router creation will fail due to health checks, but config should be valid
let result = Router::new(vec![], config);
assert!(result.is_err());
let error_msg = result.unwrap_err();
// Error should be about health/timeout, not configuration
assert!(
error_msg.contains("healthy") || error_msg.contains("timeout"),
"Unexpected error: {}",
error_msg
);
}
}
// ========================================================================
// Phase 2: Bootstrap Injection and Request Handling
// ========================================================================
#[test]
fn test_pd_request_from_json() {
// Test PDRequest parsing from single text request
let single_json = json!({
"text": "Hello world",
"stream": false,
"temperature": 0.7,
"max_tokens": 100
});
let pd_req = PDRequest::from_json(&single_json);
assert!(!pd_req.is_stream);
assert_eq!(pd_req.batch_size, None);
// Test PDRequest parsing from batch text request
let batch_json = json!({
"text": ["Hello", "World", "Test"],
"stream": true,
"temperature": 0.5
});
let pd_req = PDRequest::from_json(&batch_json);
assert!(pd_req.is_stream);
assert_eq!(pd_req.batch_size, Some(3));
// Test PDRequest parsing from input_ids request
let ids_json = json!({
"input_ids": [[1, 2, 3], [4, 5, 6]],
"stream": false
});
let pd_req = PDRequest::from_json(&ids_json);
assert!(!pd_req.is_stream);
assert_eq!(pd_req.batch_size, Some(2));
// Test PDRequest parsing from chat request
let chat_json = json!({
"messages": [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"}
],
"stream": true
});
let pd_req = PDRequest::from_json(&chat_json);
assert!(pd_req.is_stream);
assert_eq!(pd_req.batch_size, None);
}
#[test]
fn test_bootstrap_injection_simulation() {
// Since we can't test the actual inject_bootstrap_fields function here
// (it's private in the router module), we'll test the expected behavior
// Simulate bootstrap injection for single request
let mut single_json = json!({
"text": "Hello world",
"stream": false,
"temperature": 0.7
});
// Simulate what inject_bootstrap_fields would do
let prefill_info = EngineInfo::new_prefill("http://prefill1:8080".to_string(), Some(9000));
single_json["bootstrap_host"] = json!(prefill_info.get_hostname());
single_json["bootstrap_port"] = json!(prefill_info.bootstrap_port);
single_json["bootstrap_room"] = json!(12345u64); // Random room ID
// Verify bootstrap fields are added correctly
assert_eq!(single_json["bootstrap_host"], "prefill1");
assert_eq!(single_json["bootstrap_port"], 9000);
assert!(single_json["bootstrap_room"].is_u64());
assert_eq!(single_json["temperature"], 0.7); // Original field preserved
// Simulate bootstrap injection for batch request
let mut batch_json = json!({
"text": ["Hello", "World", "Test"],
"stream": true
});
let batch_size = 3;
batch_json["bootstrap_host"] = json!(vec![prefill_info.get_hostname(); batch_size]);
batch_json["bootstrap_port"] = json!(vec![prefill_info.bootstrap_port; batch_size]);
batch_json["bootstrap_room"] = json!(vec![111u64, 222u64, 333u64]);
// Verify batch bootstrap fields
assert!(batch_json["bootstrap_host"].is_array());
assert_eq!(
batch_json["bootstrap_host"].as_array().unwrap().len(),
batch_size
);
assert!(batch_json["bootstrap_port"].is_array());
assert!(batch_json["bootstrap_room"].is_array());
assert_eq!(batch_json["stream"], true); // Original field preserved
}
#[test]
fn test_request_serialization() {
// Test that requests can be properly serialized and deserialized
let request = json!({
"text": "Test prompt",
"stream": false,
"temperature": 0.7,
"max_tokens": 100,
"top_p": 0.9,
"frequency_penalty": 0.5,
"bootstrap_host": "prefill1",
"bootstrap_port": 9000,
"bootstrap_room": 12345u64
});
// Convert to bytes (as would happen in the router)
let bytes = serde_json::to_vec(&request).unwrap();
// Parse back from bytes
let parsed: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
// Verify all fields are preserved
assert_eq!(parsed["text"], "Test prompt");
assert_eq!(parsed["stream"], false);
assert_eq!(parsed["temperature"], 0.7);
assert_eq!(parsed["max_tokens"], 100);
assert_eq!(parsed["bootstrap_host"], "prefill1");
assert_eq!(parsed["bootstrap_port"], 9000);
assert_eq!(parsed["bootstrap_room"], 12345);
}
#[test]
fn test_engine_info_hostname_extraction() {
// Test various URL formats
let test_cases = vec![
("http://localhost:8080", "localhost"),
("http://10.0.0.1:8080", "10.0.0.1"),
("https://api.example.com:443", "api.example.com"),
("http://prefill-server", "prefill-server"),
("http://[::1]:8080", "["), // IPv6 edge case
("prefill:8080", "prefill"), // No protocol
];
for (url, expected_hostname) in test_cases {
let engine = EngineInfo::new_prefill(url.to_string(), None);
assert_eq!(engine.get_hostname(), expected_hostname);
}
}
#[test]
fn test_pd_request_edge_cases() {
// Test empty request
let empty_json = json!({});
let pd_req = PDRequest::from_json(&empty_json);
assert!(!pd_req.is_stream);
assert_eq!(pd_req.batch_size, None);
// Test request with only stream field
let stream_only = json!({
"stream": true
});
let pd_req = PDRequest::from_json(&stream_only);
assert!(pd_req.is_stream);
assert_eq!(pd_req.batch_size, None);
// Test request with empty text array
let empty_batch = json!({
"text": []
});
let pd_req = PDRequest::from_json(&empty_batch);
assert_eq!(pd_req.batch_size, Some(0));
// Test request with non-array text (should be None)
let non_array_text = json!({
"text": "single string"
});
let pd_req = PDRequest::from_json(&non_array_text);
assert_eq!(pd_req.batch_size, None);
}
// ========================================================================
// Phase 2: Background Load Monitoring Tests
// ========================================================================
#[tokio::test]
async fn test_background_load_monitoring() {
use std::collections::HashMap;
use tokio::sync::watch;
// Create a watch channel for testing
let (tx, rx) = watch::channel(HashMap::new());
// Simulate load updates
let mut loads = HashMap::new();
loads.insert("http://prefill1:8080".to_string(), 10);
loads.insert("http://prefill2:8080".to_string(), 20);
loads.insert("http://decode1:8080".to_string(), 5);
loads.insert("http://decode2:8080".to_string(), 15);
// Send the loads
tx.send(loads.clone()).unwrap();
// Verify receiver gets the update
let received_loads = rx.borrow();
assert_eq!(received_loads.get("http://prefill1:8080"), Some(&10));
assert_eq!(received_loads.get("http://prefill2:8080"), Some(&20));
assert_eq!(received_loads.get("http://decode1:8080"), Some(&5));
assert_eq!(received_loads.get("http://decode2:8080"), Some(&15));
}
#[test]
fn test_power_of_two_load_selection() {
// Test the power-of-two selection logic with different load scenarios
// Scenario 1: Clear winner for both prefill and decode
let _loads = vec![
("prefill1", 100),
("prefill2", 10), // Should be selected
("decode1", 50),
("decode2", 5), // Should be selected
];
// In actual implementation, the lower load should be selected
assert!(10 < 100);
assert!(5 < 50);
// Scenario 2: Equal loads (should select first)
let _equal_loads = vec![
("prefill1", 20),
("prefill2", 20), // Either could be selected
("decode1", 30),
("decode2", 30), // Either could be selected
];
// When loads are equal, <= comparison means first is selected
assert!(20 <= 20);
assert!(30 <= 30);
// Scenario 3: Missing load data (should default to usize::MAX)
// This tests the unwrap_or(usize::MAX) behavior
let missing_load = usize::MAX;
assert!(10 < missing_load);
assert!(missing_load > 0);
}
#[test]
fn test_load_monitoring_configuration() {
// Test that load monitoring is only enabled for PowerOfTwo policy
let policies = vec![
(PDSelectionPolicy::Random, false),
(PDSelectionPolicy::PowerOfTwo, true),
(
PDSelectionPolicy::CacheAware {
cache_threshold: 0.5,
balance_abs_threshold: 32,
balance_rel_threshold: 1.1,
},
false,
),
];
for (policy, should_monitor) in policies {
match policy {
PDSelectionPolicy::PowerOfTwo => assert!(should_monitor),
_ => assert!(!should_monitor),
}
}
}
#[tokio::test]
async fn test_watch_channel_behavior() {
use std::collections::HashMap;
use tokio::sync::watch;
// Test watch channel's broadcast behavior
let (tx, rx1) = watch::channel(HashMap::new());
let rx2 = rx1.clone();
// Initial state - empty map
assert!(rx1.borrow().is_empty());
assert!(rx2.borrow().is_empty());
// Update 1
let mut loads = HashMap::new();
loads.insert("worker1".to_string(), 10);
tx.send(loads.clone()).unwrap();
// Both receivers see the update
assert_eq!(rx1.borrow().get("worker1"), Some(&10));
assert_eq!(rx2.borrow().get("worker1"), Some(&10));
// Update 2 - overwrites previous
loads.insert("worker1".to_string(), 20);
loads.insert("worker2".to_string(), 30);
tx.send(loads).unwrap();
// Both receivers see the latest state
assert_eq!(rx1.borrow().get("worker1"), Some(&20));
assert_eq!(rx2.borrow().get("worker2"), Some(&30));
}
// ========================================================================
// Tests based on bench_one_batch_server.py patterns
// ========================================================================
#[test]
fn test_generate_request_formats() {
// Based on bench_one_batch_server.py request patterns
// Test 1: Batch request with input_ids (most common in benchmarks)
let batch_request = json!({
"input_ids": [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": 16,
"ignore_eos": true,
},
"return_logprob": false,
"stream": true
});
let pd_req = PDRequest::from_json(&batch_request);
assert!(pd_req.is_stream);
assert_eq!(pd_req.batch_size, Some(3));
// Test 2: Request with return_logprob (critical for PD)
let logprob_request = json!({
"input_ids": [[1, 2, 3]],
"sampling_params": {
"temperature": 0.7,
"max_new_tokens": 8,
},
"return_logprob": true,
"stream": false
});
assert_eq!(logprob_request["return_logprob"], true);
assert_eq!(logprob_request["stream"], false);
// Test 3: Large batch sizes from benchmark
let batch_sizes = vec![1, 16, 64]; // From bench_one_batch_server.py
for bs in batch_sizes {
let request = json!({
"input_ids": vec![vec![1, 2, 3]; bs],
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": 16,
},
"stream": true
});
let pd_req = PDRequest::from_json(&request);
assert_eq!(pd_req.batch_size, Some(bs));
}
}
#[test]
fn test_sampling_params_handling() {
// Test various sampling parameters from bench_one_batch_server.py
let sampling_params_variations = vec![
json!({
"temperature": 0.0,
"max_new_tokens": 8,
"ignore_eos": true
}),
json!({
"temperature": 0.7,
"max_new_tokens": 16,
"ignore_eos": false,
"top_p": 0.9,
"frequency_penalty": 0.5
}),
json!({
"temperature": 1.0,
"max_new_tokens": 64,
"json_schema": "$$ANY$$" // Structured output
}),
];
for params in sampling_params_variations {
let request = json!({
"input_ids": [[1, 2, 3]],
"sampling_params": params.clone(),
"stream": false
});
// Verify params are preserved
assert_eq!(request["sampling_params"], params);
}
}
#[test]
fn test_streaming_response_parsing() {
// Test SSE format parsing from streaming responses
let sse_chunks = vec![
"data: {\"text\":\"Hello\",\"meta_info\":{\"completion_tokens\":1,\"finish_reason\":null}}",
"data: {\"text\":\" world\",\"meta_info\":{\"completion_tokens\":2,\"finish_reason\":null}}",
"data: {\"text\":\"!\",\"meta_info\":{\"completion_tokens\":3,\"finish_reason\":{\"type\":\"length\"}}}",
"data: [DONE]",
];
for chunk in &sse_chunks[..3] {
assert!(chunk.starts_with("data: "));
let json_str = &chunk[6..]; // Skip "data: "
let parsed: serde_json::Value = serde_json::from_str(json_str).unwrap();
assert!(parsed["meta_info"]["completion_tokens"].is_u64());
}
// Test [DONE] detection
assert_eq!(sse_chunks[3], "data: [DONE]");
}
#[test]
fn test_ttft_calculation() {
// Test Time To First Token calculation pattern
let first_token_response = json!({
"text": "Hello",
"meta_info": {
"completion_tokens": 1,
"finish_reason": null
}
});
// TTFT is calculated when completion_tokens == 1
assert_eq!(first_token_response["meta_info"]["completion_tokens"], 1);
assert!(first_token_response["meta_info"]["finish_reason"].is_null());
}
#[test]
fn test_throughput_metrics() {
// Test throughput calculation patterns from bench_one_batch_server.py
let batch_size = 16;
let input_len = 1024;
let output_len = 16;
let ttft = 0.5; // seconds
let total_latency = 2.0; // seconds
// Input throughput = batch_size * input_len / ttft
let input_throughput = (batch_size as f64) * (input_len as f64) / ttft;
assert!((input_throughput - 32768.0).abs() < 0.01);
// Output throughput = batch_size * output_len / (latency - ttft)
let output_throughput = (batch_size as f64) * (output_len as f64) / (total_latency - ttft);
assert!((output_throughput - 170.67).abs() < 0.01);
}
#[test]
fn test_error_response_handling() {
// Test error response format from bench_one_batch_server.py
let error_response = json!({
"error": "Request has failed. Invalid input format."
});
assert!(error_response.get("error").is_some());
assert!(error_response["error"].as_str().unwrap().contains("failed"));
}
#[test]
fn test_structured_output_request() {
// Test structured output format (json_schema)
let structured_request = json!({
"text": "What is the capital of France? Answer in JSON.",
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": 64,
"json_schema": "$$ANY$$"
},
"stream": false
});
assert_eq!(
structured_request["sampling_params"]["json_schema"],
"$$ANY$$"
);
}
#[test]
fn test_bootstrap_injection_with_benchmark_requests() {
// Test bootstrap injection with actual benchmark request patterns
let mut benchmark_request = json!({
"input_ids": vec![vec![1, 2, 3, 4]; 16], // Batch size 16
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": 8,
"ignore_eos": true
},
"return_logprob": true,
"stream": true
});
// Simulate bootstrap injection
let prefill_info = EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000));
let batch_size = 16;
benchmark_request["bootstrap_host"] = json!(vec![prefill_info.get_hostname(); batch_size]);
benchmark_request["bootstrap_port"] = json!(vec![prefill_info.bootstrap_port; batch_size]);
benchmark_request["bootstrap_room"] =
json!((0..batch_size).map(|_| 12345u64).collect::<Vec<_>>());
// Verify bootstrap fields match batch size
assert_eq!(
benchmark_request["bootstrap_host"]
.as_array()
.unwrap()
.len(),
batch_size
);
assert_eq!(
benchmark_request["bootstrap_port"]
.as_array()
.unwrap()
.len(),
batch_size
);
assert_eq!(
benchmark_request["bootstrap_room"]
.as_array()
.unwrap()
.len(),
batch_size
);
// Verify original fields are preserved
assert_eq!(benchmark_request["return_logprob"], true);
assert_eq!(benchmark_request["stream"], true);
}
#[test]
fn test_server_info_response_format() {
// Test server info format expected by bench_one_batch_server.py
let server_info = json!({
"internal_states": [{
"avg_spec_accept_length": 3.5,
"last_gen_throughput": 2048.5,
"load": 16
}],
"prefill": [
{"url": "http://prefill1:8080", "load": 10},
{"url": "http://prefill2:8080", "load": 20}
],
"decode": [
{"url": "http://decode1:8080", "load": 5},
{"url": "http://decode2:8080", "load": 15}
]
});
// Verify structure matches what benchmark expects
assert!(server_info["internal_states"][0]["avg_spec_accept_length"].is_f64());
assert!(server_info["internal_states"][0]["last_gen_throughput"].is_f64());
assert!(server_info["prefill"].is_array());
assert!(server_info["decode"].is_array());
}
// ========================================================================
// Comprehensive Endpoint Coverage Test
// ========================================================================
#[test]
fn test_pd_endpoints_coverage() {
// Document all endpoints from Python mini_lb.py and verify implementation status
let implemented_endpoints = vec![
("/health", "GET", true),
("/health_generate", "GET", true), // Note: Python uses POST, we use GET
("/get_server_info", "GET", true),
("/v1/models", "GET", true),
("/get_model_info", "GET", true),
("/generate", "POST", true),
("/v1/chat/completions", "POST", true),
("/v1/completions", "POST", true),
("/flush_cache", "POST", true),
("/get_loads", "GET", true),
("/register", "POST", false), // NOT IMPLEMENTED - needs dynamic worker management
];
let implemented_count = implemented_endpoints
.iter()
.filter(|(_, _, impl_status)| *impl_status)
.count();
let total_count = implemented_endpoints.len();
// We've implemented 10 out of 11 endpoints (register is not needed for Phase 1/2)
assert_eq!(implemented_count, 10);
assert_eq!(total_count, 11);
// Document the missing endpoint
let missing: Vec<_> = implemented_endpoints
.iter()
.filter(|(_, _, impl_status)| !impl_status)
.map(|(endpoint, method, _)| format!("{} {}", method, endpoint))
.collect();
assert_eq!(missing, vec!["POST /register"]);
}
#[test]
fn test_large_batch_bootstrap_injection() {
// Test bootstrap injection performance with very large batches
// This simulates the bench_one_batch_server.py scenario
let large_batch_sizes = vec![1024, 4096, 8192];
for batch_size in large_batch_sizes {
let start = std::time::Instant::now();
// Simulate a large batch request
let mut large_batch_request = json!({
"input_ids": vec![vec![1, 2, 3, 4]; batch_size],
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": 16,
},
"stream": true
});
// Simulate bootstrap injection
let prefill_info =
EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000));
large_batch_request["bootstrap_host"] =
json!(vec![prefill_info.get_hostname(); batch_size]);
large_batch_request["bootstrap_port"] =
json!(vec![prefill_info.bootstrap_port; batch_size]);
large_batch_request["bootstrap_room"] = json!((0..batch_size)
.map(|_| rand::thread_rng().gen::<u64>())
.collect::<Vec<_>>());
let elapsed = start.elapsed();
// Verify bootstrap fields are correctly sized
assert_eq!(
large_batch_request["bootstrap_host"]
.as_array()
.unwrap()
.len(),
batch_size
);
assert_eq!(
large_batch_request["bootstrap_port"]
.as_array()
.unwrap()
.len(),
batch_size
);
assert_eq!(
large_batch_request["bootstrap_room"]
.as_array()
.unwrap()
.len(),
batch_size
);
// Bootstrap injection should be reasonably fast even for large batches
println!(
"Bootstrap injection for batch_size {} took {:?}",
batch_size, elapsed
);
assert!(
elapsed.as_millis() < 1000,
"Bootstrap injection took too long for batch size {}",
batch_size
);
}
}
#[test]
fn test_payload_size_calculation() {
// Test payload size estimation for bench_one_batch_server.py scenarios
let test_cases = vec![
(1, 1024, 16), // Small batch
(16, 1024, 16), // Medium batch
(64, 1024, 16), // Large batch
(8192, 4096, 5), // Benchmark scenario
];
for (batch_size, input_len, _output_len) in test_cases {
// Estimate payload size (rough calculation)
// Each token is ~4 bytes (i32), plus JSON overhead
let tokens_size = batch_size * input_len * 4; // 4 bytes per token
let json_overhead = batch_size * 100; // ~100 bytes overhead per request
let total_size = tokens_size + json_overhead;
println!(
"Batch size: {}, Input len: {}, Estimated payload: {} MB",
batch_size,
input_len,
total_size / (1024 * 1024)
);
// For the benchmark case (8192, 4096), this should be ~134 MB
if batch_size == 8192 && input_len == 4096 {
assert!(
total_size > 100 * 1024 * 1024,
"Benchmark payload should be > 100MB"
);
assert!(
total_size < 200 * 1024 * 1024,
"Benchmark payload should be < 200MB"
);
}
}
}
#[test]
fn test_policy_type_to_pd_selection_policy_mapping() {
// Document the mapping from PolicyType to PDSelectionPolicy
// This mapping happens in lib.rs when pd_disaggregated=true
// PolicyType::Random -> PDSelectionPolicy::Random
// PolicyType::PowerOfTwo -> PDSelectionPolicy::PowerOfTwo
// PolicyType::CacheAware -> PDSelectionPolicy::CacheAware { ... }
// PolicyType::RoundRobin -> ERROR (not supported in PD mode)
// Test that PDSelectionPolicy doesn't include RoundRobin
let pd_policy_count = 3; // Random, PowerOfTwo, CacheAware
assert_eq!(
pd_policy_count, 3,
"PDSelectionPolicy should have exactly 3 variants"
);
// Verify that each PDSelectionPolicy variant can be created
let _random = PDSelectionPolicy::Random;
let _po2 = PDSelectionPolicy::PowerOfTwo;
let _cache_aware = PDSelectionPolicy::CacheAware {
cache_threshold: 0.5,
balance_abs_threshold: 32,
balance_rel_threshold: 1.1,
};
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment