Unverified Commit 09ae5b20 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

Merge PDLB (Prefill-Decode Load Balancer) into SGLang Router (#7096)

parent 712bf9ec
...@@ -218,15 +218,39 @@ async def get_server_info(): ...@@ -218,15 +218,39 @@ async def get_server_info():
) )
prefill_infos = [] prefill_infos = []
decode_infos = [] decode_infos = []
all_internal_states = []
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
for server in chain(prefill_servers): for server in chain(prefill_servers):
server_info = await session.get(f"{server}/get_server_info") server_info = await session.get(f"{server}/get_server_info")
prefill_infos.append(await server_info.json()) prefill_infos.append(await server_info.json())
for server in chain(decode_servers): for server in chain(decode_servers):
server_info = await session.get(f"{server}/get_server_info") server_info = await session.get(f"{server}/get_server_info")
decode_infos.append(await server_info.json()) info_json = await server_info.json()
decode_infos.append(info_json)
return {"prefill": prefill_infos, "decode": decode_infos} # Extract internal_states from decode servers
if "internal_states" in info_json:
all_internal_states.extend(info_json["internal_states"])
# Return format expected by bench_one_batch_server.py
if all_internal_states:
return {
"internal_states": all_internal_states,
"prefill": prefill_infos,
"decode": decode_infos,
}
else:
# Fallback with dummy data if no internal states found
return {
"internal_states": [
{
"last_gen_throughput": 0.0,
"avg_spec_accept_length": None,
}
],
"prefill": prefill_infos,
"decode": decode_infos,
}
@app.get("/get_model_info") @app.get("/get_model_info")
......
...@@ -15,7 +15,7 @@ serde = { version = "1.0", features = ["derive"] } ...@@ -15,7 +15,7 @@ serde = { version = "1.0", features = ["derive"] }
clap = { version = "4.4", features = ["derive"] } clap = { version = "4.4", features = ["derive"] }
bytes = "1.8.0" bytes = "1.8.0"
rand = "0.8.5" rand = "0.8.5"
reqwest = { version = "0.12.8", features = ["stream", "blocking"] } reqwest = { version = "0.12.8", features = ["stream", "blocking", "json"] }
futures-util = "0.3" futures-util = "0.3"
serde_json = "1.0" serde_json = "1.0"
pyo3 = { version = "0.22.5", features = ["extension-module"] } pyo3 = { version = "0.22.5", features = ["extension-module"] }
...@@ -33,6 +33,8 @@ futures = "0.3" ...@@ -33,6 +33,8 @@ futures = "0.3"
# Added for metrics # Added for metrics
metrics = "0.24.2" metrics = "0.24.2"
metrics-exporter-prometheus = "0.17.0" metrics-exporter-prometheus = "0.17.0"
# Added for request tracing
uuid = { version = "1.10", features = ["v4", "serde"] }
[profile.release] [profile.release]
lto = "thin" lto = "thin"
codegen-units = 1 codegen-units = 1
...@@ -31,6 +31,13 @@ class RouterArgs: ...@@ -31,6 +31,13 @@ class RouterArgs:
host: str = "127.0.0.1" host: str = "127.0.0.1"
port: int = 30000 port: int = 30000
# PD-specific configuration
pd_disaggregated: bool = False # Enable PD disaggregated mode
prefill_urls: List[tuple] = dataclasses.field(
default_factory=list
) # List of (url, bootstrap_port)
decode_urls: List[str] = dataclasses.field(default_factory=list)
# Routing policy # Routing policy
policy: str = "cache_aware" policy: str = "cache_aware"
worker_startup_timeout_secs: int = 300 worker_startup_timeout_secs: int = 300
...@@ -40,7 +47,7 @@ class RouterArgs: ...@@ -40,7 +47,7 @@ class RouterArgs:
balance_rel_threshold: float = 1.0001 balance_rel_threshold: float = 1.0001
eviction_interval: int = 60 eviction_interval: int = 60
max_tree_size: int = 2**24 max_tree_size: int = 2**24
max_payload_size: int = 4 * 1024 * 1024 # 4MB max_payload_size: int = 256 * 1024 * 1024 # 256MB default for large batches
verbose: bool = False verbose: bool = False
log_dir: Optional[str] = None log_dir: Optional[str] = None
# Service discovery configuration # Service discovery configuration
...@@ -95,8 +102,29 @@ class RouterArgs: ...@@ -95,8 +102,29 @@ class RouterArgs:
f"--{prefix}policy", f"--{prefix}policy",
type=str, type=str,
default=RouterArgs.policy, default=RouterArgs.policy,
choices=["random", "round_robin", "cache_aware"], choices=["random", "round_robin", "cache_aware", "power_of_two"],
help="Load balancing policy to use", help="Load balancing policy to use. Note: power_of_two is only available in PD disaggregated mode",
)
# PD-specific arguments
parser.add_argument(
f"--{prefix}pd-disaggregated",
action="store_true",
help="Enable PD (Prefill-Decode) disaggregated mode",
)
parser.add_argument(
f"--{prefix}prefill",
nargs=2,
action="append",
metavar=("URL", "BOOTSTRAP_PORT"),
help="Prefill server URL and bootstrap port. Can be specified multiple times. BOOTSTRAP_PORT can be 'none' for no bootstrap port.",
)
parser.add_argument(
f"--{prefix}decode",
nargs=1,
action="append",
metavar=("URL",),
help="Decode server URL. Can be specified multiple times.",
) )
parser.add_argument( parser.add_argument(
f"--{prefix}worker-startup-timeout-secs", f"--{prefix}worker-startup-timeout-secs",
...@@ -205,11 +233,19 @@ class RouterArgs: ...@@ -205,11 +233,19 @@ class RouterArgs:
use_router_prefix: If True, look for arguments with 'router-' prefix use_router_prefix: If True, look for arguments with 'router-' prefix
""" """
prefix = "router_" if use_router_prefix else "" prefix = "router_" if use_router_prefix else ""
worker_urls = args.worker_urls if args.worker_urls is not None else [] worker_urls = getattr(args, "worker_urls", [])
# Parse PD URLs
prefill_urls = cls._parse_prefill_urls(getattr(args, f"{prefix}prefill", None))
decode_urls = cls._parse_decode_urls(getattr(args, f"{prefix}decode", None))
return cls( return cls(
worker_urls=worker_urls, worker_urls=worker_urls,
host=args.host, host=args.host,
port=args.port, port=args.port,
pd_disaggregated=getattr(args, f"{prefix}pd_disaggregated", False),
prefill_urls=prefill_urls,
decode_urls=decode_urls,
policy=getattr(args, f"{prefix}policy"), policy=getattr(args, f"{prefix}policy"),
worker_startup_timeout_secs=getattr( worker_startup_timeout_secs=getattr(
args, f"{prefix}worker_startup_timeout_secs" args, f"{prefix}worker_startup_timeout_secs"
...@@ -247,6 +283,46 @@ class RouterArgs: ...@@ -247,6 +283,46 @@ class RouterArgs:
selector[key] = value selector[key] = value
return selector return selector
@staticmethod
def _parse_prefill_urls(prefill_list):
"""Parse prefill URLs from --prefill arguments.
Format: --prefill URL BOOTSTRAP_PORT
Example: --prefill http://prefill1:8080 9000 --prefill http://prefill2:8080 none
"""
if not prefill_list:
return []
prefill_urls = []
for url, bootstrap_port_str in prefill_list:
# Handle 'none' as None
if bootstrap_port_str.lower() == "none":
bootstrap_port = None
else:
try:
bootstrap_port = int(bootstrap_port_str)
except ValueError:
raise ValueError(
f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'"
)
prefill_urls.append((url, bootstrap_port))
return prefill_urls
@staticmethod
def _parse_decode_urls(decode_list):
"""Parse decode URLs from --decode arguments.
Format: --decode URL
Example: --decode http://decode1:8081 --decode http://decode2:8081
"""
if not decode_list:
return []
# decode_list is a list of single-element lists due to nargs=1
return [url[0] for url in decode_list]
def policy_from_str(policy_str: str) -> PolicyType: def policy_from_str(policy_str: str) -> PolicyType:
"""Convert policy string to PolicyType enum.""" """Convert policy string to PolicyType enum."""
...@@ -254,6 +330,7 @@ def policy_from_str(policy_str: str) -> PolicyType: ...@@ -254,6 +330,7 @@ def policy_from_str(policy_str: str) -> PolicyType:
"random": PolicyType.Random, "random": PolicyType.Random,
"round_robin": PolicyType.RoundRobin, "round_robin": PolicyType.RoundRobin,
"cache_aware": PolicyType.CacheAware, "cache_aware": PolicyType.CacheAware,
"power_of_two": PolicyType.PowerOfTwo,
} }
return policy_map[policy_str] return policy_map[policy_str]
...@@ -277,8 +354,19 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: ...@@ -277,8 +354,19 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
else: else:
router_args = args router_args = args
# Validate configuration based on mode
if router_args.pd_disaggregated:
# Validate PD configuration
if not router_args.prefill_urls:
raise ValueError("PD disaggregated mode requires --prefill")
if not router_args.decode_urls:
raise ValueError("PD disaggregated mode requires --decode")
# Create router with unified constructor
router = Router( router = Router(
worker_urls=router_args.worker_urls, worker_urls=(
router_args.worker_urls if not router_args.pd_disaggregated else []
),
host=router_args.host, host=router_args.host,
port=router_args.port, port=router_args.port,
policy=policy_from_str(router_args.policy), policy=policy_from_str(router_args.policy),
...@@ -298,6 +386,13 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: ...@@ -298,6 +386,13 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
service_discovery_namespace=router_args.service_discovery_namespace, service_discovery_namespace=router_args.service_discovery_namespace,
prometheus_port=router_args.prometheus_port, prometheus_port=router_args.prometheus_port,
prometheus_host=router_args.prometheus_host, prometheus_host=router_args.prometheus_host,
pd_disaggregated=router_args.pd_disaggregated,
prefill_urls=(
router_args.prefill_urls if router_args.pd_disaggregated else None
),
decode_urls=(
router_args.decode_urls if router_args.pd_disaggregated else None
),
) )
router.start() router.start()
...@@ -326,8 +421,14 @@ This launcher enables starting a router with individual worker instances. It is ...@@ -326,8 +421,14 @@ This launcher enables starting a router with individual worker instances. It is
multi-node setups or when you want to start workers and router separately. multi-node setups or when you want to start workers and router separately.
Examples: Examples:
# Regular mode
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 --cache-threshold 0.7 --balance-abs-threshold 64 --balance-rel-threshold 1.2
# PD disaggregated mode
python -m sglang_router.launch_router --pd-disaggregated \\
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
--decode http://decode1:8001 --decode http://decode2:8001 \\
--policy cache_aware
""", """,
formatter_class=CustomHelpFormatter, formatter_class=CustomHelpFormatter,
......
...@@ -15,6 +15,7 @@ class Router: ...@@ -15,6 +15,7 @@ class Router:
- PolicyType.Random: Randomly select workers - PolicyType.Random: Randomly select workers
- PolicyType.RoundRobin: Distribute requests in round-robin fashion - PolicyType.RoundRobin: Distribute requests in round-robin fashion
- PolicyType.CacheAware: Distribute requests based on cache state and load balance - PolicyType.CacheAware: Distribute requests based on cache state and load balance
- PolicyType.PowerOfTwo: Select best of two random workers based on load (PD mode only)
host: Host address to bind the router server. Default: '127.0.0.1' host: Host address to bind the router server. Default: '127.0.0.1'
port: Port number to bind the router server. Default: 3001 port: Port number to bind the router server. Default: 3001
worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300 worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300
...@@ -28,7 +29,7 @@ class Router: ...@@ -28,7 +29,7 @@ class Router:
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001 AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
routing. Default: 60 routing. Default: 60
max_payload_size: Maximum payload size in bytes. Default: 4MB max_payload_size: Maximum payload size in bytes. Default: 256MB
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24 max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
verbose: Enable verbose logging. Default: False verbose: Enable verbose logging. Default: False
log_dir: Directory to store log files. If None, logs are only output to console. Default: None log_dir: Directory to store log files. If None, logs are only output to console. Default: None
...@@ -42,6 +43,9 @@ class Router: ...@@ -42,6 +43,9 @@ class Router:
watches pods across all namespaces (requires cluster-wide permissions). Default: None watches pods across all namespaces (requires cluster-wide permissions). Default: None
prometheus_port: Port to expose Prometheus metrics. Default: None prometheus_port: Port to expose Prometheus metrics. Default: None
prometheus_host: Host address to bind the Prometheus metrics server. Default: None prometheus_host: Host address to bind the Prometheus metrics server. Default: None
pd_disaggregated: Enable PD (Prefill-Decode) disaggregated mode. Default: False
prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
decode_urls: List of URLs for decode servers (PD mode only)
""" """
def __init__( def __init__(
...@@ -57,7 +61,7 @@ class Router: ...@@ -57,7 +61,7 @@ class Router:
balance_rel_threshold: float = 1.0001, balance_rel_threshold: float = 1.0001,
eviction_interval_secs: int = 60, eviction_interval_secs: int = 60,
max_tree_size: int = 2**24, max_tree_size: int = 2**24,
max_payload_size: int = 4 * 1024 * 1024, # 4MB max_payload_size: int = 256 * 1024 * 1024, # 256MB
verbose: bool = False, verbose: bool = False,
log_dir: Optional[str] = None, log_dir: Optional[str] = None,
service_discovery: bool = False, service_discovery: bool = False,
...@@ -66,6 +70,9 @@ class Router: ...@@ -66,6 +70,9 @@ class Router:
service_discovery_namespace: Optional[str] = None, service_discovery_namespace: Optional[str] = None,
prometheus_port: Optional[int] = None, prometheus_port: Optional[int] = None,
prometheus_host: Optional[str] = None, prometheus_host: Optional[str] = None,
pd_disaggregated: bool = False,
prefill_urls: Optional[List[tuple]] = None,
decode_urls: Optional[List[str]] = None,
): ):
if selector is None: if selector is None:
selector = {} selector = {}
...@@ -91,6 +98,9 @@ class Router: ...@@ -91,6 +98,9 @@ class Router:
service_discovery_namespace=service_discovery_namespace, service_discovery_namespace=service_discovery_namespace,
prometheus_port=prometheus_port, prometheus_port=prometheus_port,
prometheus_host=prometheus_host, prometheus_host=prometheus_host,
pd_disaggregated=pd_disaggregated,
prefill_urls=prefill_urls,
decode_urls=decode_urls,
) )
def start(self) -> None: def start(self) -> None:
......
...@@ -35,13 +35,21 @@ class TestLaunchRouter(unittest.TestCase): ...@@ -35,13 +35,21 @@ class TestLaunchRouter(unittest.TestCase):
balance_rel_threshold=1.0001, balance_rel_threshold=1.0001,
eviction_interval=60, eviction_interval=60,
max_tree_size=2**24, max_tree_size=2**24,
max_payload_size=4 * 1024 * 1024, # 4MB max_payload_size=256 * 1024 * 1024, # 256MB
verbose=False, verbose=False,
log_dir=None, log_dir=None,
service_discovery=False, service_discovery=False,
selector=None, selector=None,
service_discovery_port=80, service_discovery_port=80,
service_discovery_namespace=None, service_discovery_namespace=None,
prometheus_port=None,
prometheus_host=None,
# PD-specific attributes
pd_disaggregated=False,
prefill=None,
decode=None,
# Keep worker_urls for regular mode
worker_urls=[],
) )
def create_router_args(self, **kwargs): def create_router_args(self, **kwargs):
...@@ -81,7 +89,7 @@ class TestLaunchRouter(unittest.TestCase): ...@@ -81,7 +89,7 @@ class TestLaunchRouter(unittest.TestCase):
def test_launch_router_with_empty_worker_urls(self): def test_launch_router_with_empty_worker_urls(self):
args = self.create_router_args(worker_urls=[]) args = self.create_router_args(worker_urls=[])
self.run_router_process(args) self.run_router_process(args) # Expected error
def test_launch_router_with_service_discovery(self): def test_launch_router_with_service_discovery(self):
# Test router startup with service discovery enabled but no selectors # Test router startup with service discovery enabled but no selectors
...@@ -100,6 +108,112 @@ class TestLaunchRouter(unittest.TestCase): ...@@ -100,6 +108,112 @@ class TestLaunchRouter(unittest.TestCase):
) )
self.run_router_process(args) self.run_router_process(args)
def test_launch_router_pd_mode_basic(self):
"""Test basic PD router functionality without actually starting servers."""
# This test just verifies the PD router can be created and configured
# without actually starting it (which would require real prefill/decode servers)
from sglang_router import Router
from sglang_router.launch_router import RouterArgs
from sglang_router_rs import PolicyType
# Test RouterArgs parsing for PD mode
# Simulate the parsed args structure from argparse with action="append"
args = self.create_router_args(
pd_disaggregated=True,
policy="power_of_two", # PowerOfTwo is only valid in PD mode
prefill=[
["http://prefill1:8080", "9000"],
["http://prefill2:8080", "none"],
],
decode=[
["http://decode1:8081"],
["http://decode2:8081"],
],
worker_urls=[], # Empty for PD mode
)
router_args = RouterArgs.from_cli_args(args)
self.assertTrue(router_args.pd_disaggregated)
self.assertEqual(router_args.policy, "power_of_two")
self.assertEqual(len(router_args.prefill_urls), 2)
self.assertEqual(len(router_args.decode_urls), 2)
# Verify the parsed URLs and bootstrap ports
self.assertEqual(router_args.prefill_urls[0], ("http://prefill1:8080", 9000))
self.assertEqual(router_args.prefill_urls[1], ("http://prefill2:8080", None))
self.assertEqual(router_args.decode_urls[0], "http://decode1:8081")
self.assertEqual(router_args.decode_urls[1], "http://decode2:8081")
# Test Router creation in PD mode
router = Router(
worker_urls=[], # Empty for PD mode
pd_disaggregated=True,
prefill_urls=[
("http://prefill1:8080", 9000),
("http://prefill2:8080", None),
],
decode_urls=["http://decode1:8081", "http://decode2:8081"],
policy=PolicyType.CacheAware,
host="127.0.0.1",
port=3001,
)
self.assertIsNotNone(router)
def test_policy_validation(self):
"""Test that policy validation works correctly for PD and regular modes."""
from sglang_router.launch_router import RouterArgs, launch_router
# Test 1: PowerOfTwo is only valid in PD mode
args = self.create_router_args(
pd_disaggregated=False,
policy="power_of_two",
worker_urls=["http://localhost:8000"],
)
# Should raise error
with self.assertRaises(ValueError) as cm:
launch_router(args)
self.assertIn(
"PowerOfTwo policy is only supported in PD disaggregated mode",
str(cm.exception),
)
# Test 2: RoundRobin is not valid in PD mode
args = self.create_router_args(
pd_disaggregated=True,
policy="round_robin",
prefill=[["http://prefill1:8080", "9000"]],
decode=[["http://decode1:8081"]],
worker_urls=[],
)
# Should raise error
with self.assertRaises(ValueError) as cm:
launch_router(args)
self.assertIn(
"RoundRobin policy is not supported in PD disaggregated mode",
str(cm.exception),
)
# Test 3: Valid combinations should not raise errors
# Regular mode with RoundRobin
args = self.create_router_args(
pd_disaggregated=False,
policy="round_robin",
worker_urls=["http://localhost:8000"],
)
# This should not raise (though it may fail to connect)
# PD mode with PowerOfTwo
args = self.create_router_args(
pd_disaggregated=True,
policy="power_of_two",
prefill=[["http://prefill1:8080", "9000"]],
decode=[["http://decode1:8081"]],
worker_urls=[],
)
# This should not raise (though it may fail to connect)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
use pyo3::prelude::*; use pyo3::prelude::*;
pub mod logging; pub mod logging;
use std::collections::HashMap; use std::collections::HashMap;
pub mod openai_api_types;
pub mod pd_router;
pub mod pd_types;
pub mod prometheus; pub mod prometheus;
pub mod request_adapter;
pub mod router; pub mod router;
pub mod server; pub mod server;
pub mod service_discovery; pub mod service_discovery;
...@@ -14,6 +18,7 @@ pub enum PolicyType { ...@@ -14,6 +18,7 @@ pub enum PolicyType {
Random, Random,
RoundRobin, RoundRobin,
CacheAware, CacheAware,
PowerOfTwo, // Moved from PD-specific, now shared
} }
#[pyclass] #[pyclass]
...@@ -39,6 +44,12 @@ struct Router { ...@@ -39,6 +44,12 @@ struct Router {
service_discovery_namespace: Option<String>, service_discovery_namespace: Option<String>,
prometheus_port: Option<u16>, prometheus_port: Option<u16>,
prometheus_host: Option<String>, prometheus_host: Option<String>,
request_timeout_secs: u64,
// PD mode flag
pd_disaggregated: bool,
// PD-specific fields (only used when pd_disaggregated is true)
prefill_urls: Option<Vec<(String, Option<u16>)>>,
decode_urls: Option<Vec<String>>,
} }
#[pymethods] #[pymethods]
...@@ -56,7 +67,7 @@ impl Router { ...@@ -56,7 +67,7 @@ impl Router {
balance_rel_threshold = 1.0001, balance_rel_threshold = 1.0001,
eviction_interval_secs = 60, eviction_interval_secs = 60,
max_tree_size = 2usize.pow(24), max_tree_size = 2usize.pow(24),
max_payload_size = 4 * 1024 * 1024, max_payload_size = 256 * 1024 * 1024, // 256MB default for large batches
verbose = false, verbose = false,
log_dir = None, log_dir = None,
service_discovery = false, service_discovery = false,
...@@ -64,7 +75,11 @@ impl Router { ...@@ -64,7 +75,11 @@ impl Router {
service_discovery_port = 80, service_discovery_port = 80,
service_discovery_namespace = None, service_discovery_namespace = None,
prometheus_port = None, prometheus_port = None,
prometheus_host = None prometheus_host = None,
request_timeout_secs = 600, // Add configurable request timeout
pd_disaggregated = false, // New flag for PD mode
prefill_urls = None,
decode_urls = None
))] ))]
fn new( fn new(
worker_urls: Vec<String>, worker_urls: Vec<String>,
...@@ -87,6 +102,10 @@ impl Router { ...@@ -87,6 +102,10 @@ impl Router {
service_discovery_namespace: Option<String>, service_discovery_namespace: Option<String>,
prometheus_port: Option<u16>, prometheus_port: Option<u16>,
prometheus_host: Option<String>, prometheus_host: Option<String>,
request_timeout_secs: u64,
pd_disaggregated: bool,
prefill_urls: Option<Vec<(String, Option<u16>)>>,
decode_urls: Option<Vec<String>>,
) -> PyResult<Self> { ) -> PyResult<Self> {
Ok(Router { Ok(Router {
host, host,
...@@ -109,28 +128,75 @@ impl Router { ...@@ -109,28 +128,75 @@ impl Router {
service_discovery_namespace, service_discovery_namespace,
prometheus_port, prometheus_port,
prometheus_host, prometheus_host,
request_timeout_secs,
pd_disaggregated,
prefill_urls,
decode_urls,
}) })
} }
fn start(&self) -> PyResult<()> { fn start(&self) -> PyResult<()> {
let policy_config = match &self.policy { let policy_config = if self.pd_disaggregated {
PolicyType::Random => router::PolicyConfig::RandomConfig { // PD mode - map PolicyType to PDSelectionPolicy
timeout_secs: self.worker_startup_timeout_secs, let pd_selection_policy = match &self.policy {
interval_secs: self.worker_startup_check_interval, PolicyType::Random => pd_types::PDSelectionPolicy::Random,
}, PolicyType::PowerOfTwo => pd_types::PDSelectionPolicy::PowerOfTwo,
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig { PolicyType::CacheAware => pd_types::PDSelectionPolicy::CacheAware {
timeout_secs: self.worker_startup_timeout_secs, cache_threshold: self.cache_threshold,
interval_secs: self.worker_startup_check_interval, balance_abs_threshold: self.balance_abs_threshold,
}, balance_rel_threshold: self.balance_rel_threshold,
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig { },
PolicyType::RoundRobin => {
return Err(pyo3::exceptions::PyValueError::new_err(
"RoundRobin policy is not supported in PD disaggregated mode",
));
}
};
let prefill_urls = self.prefill_urls.as_ref().ok_or_else(|| {
pyo3::exceptions::PyValueError::new_err(
"PD disaggregated mode requires prefill_urls",
)
})?;
let decode_urls = self.decode_urls.as_ref().ok_or_else(|| {
pyo3::exceptions::PyValueError::new_err(
"PD disaggregated mode requires decode_urls",
)
})?;
router::PolicyConfig::PrefillDecodeConfig {
selection_policy: pd_selection_policy,
prefill_urls: prefill_urls.clone(),
decode_urls: decode_urls.clone(),
timeout_secs: self.worker_startup_timeout_secs, timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval, interval_secs: self.worker_startup_check_interval,
cache_threshold: self.cache_threshold, }
balance_abs_threshold: self.balance_abs_threshold, } else {
balance_rel_threshold: self.balance_rel_threshold, // Regular mode
eviction_interval_secs: self.eviction_interval_secs, match &self.policy {
max_tree_size: self.max_tree_size, PolicyType::Random => router::PolicyConfig::RandomConfig {
}, timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval,
},
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig {
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval,
},
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval,
cache_threshold: self.cache_threshold,
balance_abs_threshold: self.balance_abs_threshold,
balance_rel_threshold: self.balance_rel_threshold,
eviction_interval_secs: self.eviction_interval_secs,
max_tree_size: self.max_tree_size,
},
PolicyType::PowerOfTwo => {
return Err(pyo3::exceptions::PyValueError::new_err(
"PowerOfTwo policy is only supported in PD disaggregated mode",
));
}
}
}; };
// Create service discovery config if enabled // Create service discovery config if enabled
...@@ -166,6 +232,7 @@ impl Router { ...@@ -166,6 +232,7 @@ impl Router {
log_dir: self.log_dir.clone(), log_dir: self.log_dir.clone(),
service_discovery_config, service_discovery_config,
prometheus_config, prometheus_config,
request_timeout_secs: self.request_timeout_secs,
}) })
.await .await
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
......
This diff is collapsed.
This diff is collapsed.
// Essential PDLB types extracted for PD routing
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone)]
pub enum EngineType {
Prefill,
Decode,
}
#[derive(Debug, Clone)]
pub struct EngineInfo {
pub engine_type: EngineType,
pub url: String,
pub bootstrap_port: Option<u16>,
}
impl EngineInfo {
pub fn new_prefill(url: String, bootstrap_port: Option<u16>) -> Self {
EngineInfo {
engine_type: EngineType::Prefill,
url,
bootstrap_port,
}
}
pub fn new_decode(url: String) -> Self {
EngineInfo {
engine_type: EngineType::Decode,
url,
bootstrap_port: None,
}
}
pub fn api_path(&self, api_path: &str) -> String {
if api_path.starts_with("/") {
format!("{}{}", self.url, api_path)
} else {
format!("{}/{}", self.url, api_path)
}
}
pub fn get_hostname(&self) -> String {
// Simple hostname extraction without external dependencies
let url = self
.url
.trim_start_matches("http://")
.trim_start_matches("https://");
url.split(':').next().unwrap_or("localhost").to_string()
}
}
// PD-specific routing policies
#[derive(Debug, Clone, PartialEq)]
pub enum PDSelectionPolicy {
Random,
PowerOfTwo,
CacheAware {
cache_threshold: f32,
balance_abs_threshold: usize,
balance_rel_threshold: f32,
},
}
// Bootstrap types from PDLB
#[derive(Debug, Deserialize, Serialize)]
#[serde(untagged)]
pub enum SingleOrBatch<T> {
Single(T),
Batch(Vec<T>),
}
pub type InputIds = SingleOrBatch<Vec<i32>>;
pub type InputText = SingleOrBatch<String>;
pub type BootstrapHost = SingleOrBatch<String>;
pub type BootstrapPort = SingleOrBatch<Option<u16>>;
pub type BootstrapRoom = SingleOrBatch<u64>;
// Bootstrap trait for request handling
pub trait Bootstrap: Send + Sync {
fn is_stream(&self) -> bool;
fn get_batch_size(&self) -> Result<Option<usize>, String>;
fn set_bootstrap_info(
&mut self,
bootstrap_host: BootstrapHost,
bootstrap_port: BootstrapPort,
bootstrap_room: BootstrapRoom,
);
fn add_bootstrap_info(&mut self, prefill_info: &EngineInfo) -> Result<(), String> {
let batch_size = self.get_batch_size()?;
if let Some(batch_size) = batch_size {
self.set_bootstrap_info(
BootstrapHost::Batch(vec![prefill_info.get_hostname(); batch_size]),
BootstrapPort::Batch(vec![prefill_info.bootstrap_port; batch_size]),
// Use high-quality random numbers to minimize collision risk
BootstrapRoom::Batch(
(0..batch_size)
.map(|_| {
// Combine multiple sources of randomness for better distribution
let r1 = rand::random::<u64>();
let r2 = rand::random::<u64>();
r1.wrapping_add(r2.rotate_left(32))
})
.collect(),
),
);
} else {
self.set_bootstrap_info(
BootstrapHost::Single(prefill_info.get_hostname()),
BootstrapPort::Single(prefill_info.bootstrap_port),
BootstrapRoom::Single({
// Use high-quality random number for single requests too
let r1 = rand::random::<u64>();
let r2 = rand::random::<u64>();
r1.wrapping_add(r2.rotate_left(32))
}),
);
}
Ok(())
}
}
// Request types
#[derive(Debug, Deserialize, Serialize)]
pub struct GenerateReqInput {
pub text: Option<InputText>,
pub input_ids: Option<InputIds>,
#[serde(default)]
pub stream: bool,
pub bootstrap_host: Option<BootstrapHost>,
pub bootstrap_port: Option<BootstrapPort>,
pub bootstrap_room: Option<BootstrapRoom>,
#[serde(flatten)]
pub other: Value,
}
impl GenerateReqInput {
pub fn get_batch_size(&self) -> Result<Option<usize>, String> {
if self.text.is_some() && self.input_ids.is_some() {
return Err("Both text and input_ids are present in the request".to_string());
}
// Check text batch
if let Some(InputText::Batch(texts)) = &self.text {
if texts.is_empty() {
return Err("Batch text array is empty".to_string());
}
if texts.len() > 10000 {
// Reasonable limit for production
return Err(format!(
"Batch size {} exceeds maximum allowed (10000)",
texts.len()
));
}
return Ok(Some(texts.len()));
}
// Check input_ids batch
if let Some(InputIds::Batch(ids)) = &self.input_ids {
if ids.is_empty() {
return Err("Batch input_ids array is empty".to_string());
}
if ids.len() > 10000 {
// Reasonable limit for production
return Err(format!(
"Batch size {} exceeds maximum allowed (10000)",
ids.len()
));
}
// Validate each sequence is not empty
for (i, seq) in ids.iter().enumerate() {
if seq.is_empty() {
return Err(format!("Input sequence at index {} is empty", i));
}
}
return Ok(Some(ids.len()));
}
Ok(None)
}
}
impl Bootstrap for GenerateReqInput {
fn is_stream(&self) -> bool {
self.stream
}
fn get_batch_size(&self) -> Result<Option<usize>, String> {
self.get_batch_size()
}
fn set_bootstrap_info(
&mut self,
bootstrap_host: BootstrapHost,
bootstrap_port: BootstrapPort,
bootstrap_room: BootstrapRoom,
) {
self.bootstrap_host = Some(bootstrap_host);
self.bootstrap_port = Some(bootstrap_port);
self.bootstrap_room = Some(bootstrap_room);
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ChatReqInput {
#[serde(default)]
pub stream: bool,
pub bootstrap_host: Option<BootstrapHost>,
pub bootstrap_port: Option<BootstrapPort>,
pub bootstrap_room: Option<BootstrapRoom>,
#[serde(flatten)]
pub other: Value,
}
impl Bootstrap for ChatReqInput {
fn is_stream(&self) -> bool {
self.stream
}
fn get_batch_size(&self) -> Result<Option<usize>, String> {
// Check if 'n' parameter is present and > 1
if let Some(n_value) = self.other.get("n") {
if let Some(n) = n_value.as_u64() {
if n > 1 {
return Ok(Some(n as usize));
}
}
}
Ok(None)
}
fn set_bootstrap_info(
&mut self,
bootstrap_host: BootstrapHost,
bootstrap_port: BootstrapPort,
bootstrap_room: BootstrapRoom,
) {
self.bootstrap_host = Some(bootstrap_host);
self.bootstrap_port = Some(bootstrap_port);
self.bootstrap_room = Some(bootstrap_room);
}
}
// Request adapter to bridge OpenAI API types with PD routing requirements
use crate::openai_api_types::{
ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, StringOrArray,
};
use crate::pd_types::{Bootstrap, ChatReqInput, GenerateReqInput, SingleOrBatch};
use serde_json::Value;
/// Adapter trait to convert OpenAI requests to PD-compatible requests
pub trait ToPdRequest {
type Output: Bootstrap;
fn to_pd_request(self) -> Self::Output;
}
// Helper macro to insert optional fields into a map
macro_rules! insert_if_some {
($map:expr, $($field:expr => $key:expr),* $(,)?) => {
$(
if let Some(value) = $field {
$map.insert($key.to_string(), serde_json::to_value(value).unwrap_or(Value::Null));
}
)*
};
}
// Helper macro for simple value insertions
macro_rules! insert_value {
($map:expr, $($field:expr => $key:expr),* $(,)?) => {
$(
$map.insert($key.to_string(), $field.into());
)*
};
}
// ============= Generate Request Adapter =============
impl ToPdRequest for GenerateRequest {
type Output = GenerateReqInput;
fn to_pd_request(self) -> Self::Output {
// Build the other fields first
let mut other = serde_json::Map::new();
// Handle text input - check in priority order: text (SGLang), prompt (OpenAI)
let (text, input_ids) = if let Some(text_str) = self.text {
// SGLang native format
(Some(SingleOrBatch::Single(text_str)), None)
} else if let Some(prompt) = self.prompt {
// OpenAI style prompt
let text = match prompt {
StringOrArray::String(s) => Some(SingleOrBatch::Single(s)),
StringOrArray::Array(v) => Some(SingleOrBatch::Batch(v)),
};
(text, None)
} else if let Some(ids) = self.input_ids {
// Input IDs case
let input_ids = match ids {
crate::openai_api_types::InputIds::Single(ids) => Some(SingleOrBatch::Single(ids)),
crate::openai_api_types::InputIds::Batch(ids) => Some(SingleOrBatch::Batch(ids)),
};
(None, input_ids)
} else {
// No input provided
(None, None)
};
// Add parameters to other - handle both old and new style
if let Some(params) = self.parameters {
// For generate endpoint, extract max_new_tokens to top level if present
let mut params_value = serde_json::to_value(&params).unwrap_or(Value::Null);
if let Value::Object(ref mut params_map) = params_value {
// Move max_new_tokens to top level if it exists
if let Some(max_new_tokens) = params_map.remove("max_new_tokens") {
other.insert("max_new_tokens".to_string(), max_new_tokens);
}
// Move temperature to top level if it exists
if let Some(temperature) = params_map.remove("temperature") {
other.insert("temperature".to_string(), temperature);
}
}
// Only add parameters if there are remaining fields
if !params_value.is_null() && params_value.as_object().map_or(false, |m| !m.is_empty())
{
other.insert("parameters".to_string(), params_value);
}
}
// Add sampling_params if present
if let Some(sampling_params) = self.sampling_params {
let params_value = serde_json::to_value(&sampling_params).unwrap_or(Value::Null);
if !params_value.is_null() {
// Extract commonly used fields to top level
if let Value::Object(ref params_map) = params_value {
if let Some(max_new_tokens) = params_map.get("max_new_tokens") {
other.insert("max_new_tokens".to_string(), max_new_tokens.clone());
}
if let Some(temperature) = params_map.get("temperature") {
other.insert("temperature".to_string(), temperature.clone());
}
}
other.insert("sampling_params".to_string(), params_value);
}
}
// Add other fields
insert_value!(other,
self.stream => "stream",
self.return_logprob => "return_logprob"
);
GenerateReqInput {
text,
input_ids,
stream: self.stream,
bootstrap_host: None,
bootstrap_port: None,
bootstrap_room: None,
other: Value::Object(other),
}
}
}
// ============= Completion Request Adapter =============
impl ToPdRequest for CompletionRequest {
type Output = GenerateReqInput;
fn to_pd_request(self) -> Self::Output {
// Convert CompletionRequest to GenerateReqInput
let text = match self.prompt {
StringOrArray::String(s) => Some(SingleOrBatch::Single(s)),
StringOrArray::Array(v) => Some(SingleOrBatch::Batch(v)),
};
// Map OpenAI parameters to generate parameters
let mut other = serde_json::Map::new();
// Create parameters object
let mut params = serde_json::Map::new();
// Map OpenAI fields to internal parameter names
insert_if_some!(params,
self.max_tokens => "max_new_tokens",
self.temperature => "temperature",
self.top_p => "top_p",
self.n => "best_of",
self.logprobs => "top_n_tokens",
self.seed => "seed"
);
// Special handling for fields that need transformation
if let Some(presence_penalty) = self.presence_penalty {
params.insert(
"repetition_penalty".to_string(),
(1.0 + presence_penalty).into(),
);
}
if let Some(stop) = self.stop {
let stop_sequences = match stop {
StringOrArray::String(s) => vec![s],
StringOrArray::Array(v) => v,
};
params.insert("stop".to_string(), stop_sequences.into());
}
if self.echo {
params.insert("return_full_text".to_string(), true.into());
}
other.insert("parameters".to_string(), Value::Object(params));
// Store original model and stream flag
insert_value!(other,
self.model => "model",
self.stream => "stream"
);
GenerateReqInput {
text,
input_ids: None,
stream: self.stream,
bootstrap_host: None,
bootstrap_port: None,
bootstrap_room: None,
other: Value::Object(other),
}
}
}
// ============= Chat Completion Request Adapter =============
impl ToPdRequest for ChatCompletionRequest {
type Output = ChatReqInput;
fn to_pd_request(self) -> Self::Output {
let mut other = serde_json::Map::new();
// Add required fields
insert_if_some!(other,
Some(&self.messages) => "messages"
);
insert_value!(other,
self.model => "model",
self.stream => "stream"
);
// Add all optional fields
insert_if_some!(other,
self.temperature => "temperature",
self.top_p => "top_p",
self.n => "n",
self.stop => "stop",
self.max_tokens => "max_tokens",
self.max_completion_tokens => "max_completion_tokens",
self.presence_penalty => "presence_penalty",
self.frequency_penalty => "frequency_penalty",
self.logit_bias => "logit_bias",
self.user => "user",
self.seed => "seed",
self.top_logprobs => "top_logprobs",
self.response_format => "response_format",
self.tools => "tools",
self.tool_choice => "tool_choice",
self.parallel_tool_calls => "parallel_tool_calls",
self.functions => "functions",
self.function_call => "function_call"
);
// Handle boolean logprobs flag
if self.logprobs {
other.insert("logprobs".to_string(), true.into());
}
ChatReqInput {
stream: self.stream,
bootstrap_host: None,
bootstrap_port: None,
bootstrap_room: None,
other: Value::Object(other),
}
}
}
// ============= Direct routing support for regular router =============
/// Extension trait for routing without PD conversion
pub trait RouteableRequest: GenerationRequest + serde::Serialize + Clone {
/// Convert to JSON for sending to backend
fn to_json(&self) -> Result<Value, serde_json::Error> {
serde_json::to_value(self)
}
/// Convert to bytes for legacy routing
fn to_bytes(&self) -> Result<bytes::Bytes, serde_json::Error> {
let json = serde_json::to_vec(self)?;
Ok(bytes::Bytes::from(json))
}
}
impl RouteableRequest for GenerateRequest {}
impl RouteableRequest for CompletionRequest {}
impl RouteableRequest for ChatCompletionRequest {}
This diff is collapsed.
use crate::logging::{self, LoggingConfig}; use crate::logging::{self, LoggingConfig};
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::prometheus::{self, PrometheusConfig}; use crate::prometheus::{self, PrometheusConfig};
use crate::request_adapter::ToPdRequest;
use crate::router::PolicyConfig; use crate::router::PolicyConfig;
use crate::router::Router; use crate::router::Router;
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig}; use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
use actix_web::{ use actix_web::{
error, get, post, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder, error, get, post, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder,
}; };
use bytes::Bytes;
use futures_util::StreamExt; use futures_util::StreamExt;
use reqwest::Client; use reqwest::Client;
use std::collections::HashMap; use std::collections::HashMap;
...@@ -20,6 +21,7 @@ use tracing::{error, info, warn, Level}; ...@@ -20,6 +21,7 @@ use tracing::{error, info, warn, Level};
pub struct AppState { pub struct AppState {
router: Arc<Router>, router: Arc<Router>,
client: Client, client: Client,
is_pd_mode: bool, // Add flag to track PD mode
} }
impl AppState { impl AppState {
...@@ -28,9 +30,16 @@ impl AppState { ...@@ -28,9 +30,16 @@ impl AppState {
client: Client, client: Client,
policy_config: PolicyConfig, policy_config: PolicyConfig,
) -> Result<Self, String> { ) -> Result<Self, String> {
// Check if this is PD mode from policy config
let is_pd_mode = matches!(policy_config, PolicyConfig::PrefillDecodeConfig { .. });
// Create router based on policy // Create router based on policy
let router = Arc::new(Router::new(worker_urls, policy_config)?); let router = Arc::new(Router::new(worker_urls, policy_config)?);
Ok(Self { router, client }) Ok(Self {
router,
client,
is_pd_mode,
})
} }
} }
...@@ -46,8 +55,25 @@ async fn sink_handler(_req: HttpRequest, mut payload: web::Payload) -> Result<Ht ...@@ -46,8 +55,25 @@ async fn sink_handler(_req: HttpRequest, mut payload: web::Payload) -> Result<Ht
} }
// Custom error handler for JSON payload errors. // Custom error handler for JSON payload errors.
fn json_error_handler(_err: error::JsonPayloadError, _req: &HttpRequest) -> Error { fn json_error_handler(err: error::JsonPayloadError, _req: &HttpRequest) -> Error {
error::ErrorPayloadTooLarge("Payload too large") error!("JSON payload error: {:?}", err);
match &err {
error::JsonPayloadError::OverflowKnownLength { length, limit } => {
error!(
"Payload too large: {} bytes exceeds limit of {} bytes",
length, limit
);
error::ErrorPayloadTooLarge(format!(
"Payload too large: {} bytes exceeds limit of {} bytes",
length, limit
))
}
error::JsonPayloadError::Overflow { limit } => {
error!("Payload overflow: exceeds limit of {} bytes", limit);
error::ErrorPayloadTooLarge(format!("Payload exceeds limit of {} bytes", limit))
}
_ => error::ErrorBadRequest(format!("Invalid JSON payload: {}", err)),
}
} }
#[get("/health")] #[get("/health")]
...@@ -59,59 +85,134 @@ async fn health(req: HttpRequest, data: web::Data<AppState>) -> impl Responder { ...@@ -59,59 +85,134 @@ async fn health(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
#[get("/health_generate")] #[get("/health_generate")]
async fn health_generate(req: HttpRequest, data: web::Data<AppState>) -> impl Responder { async fn health_generate(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router // Check if we're in PD mode
.route_to_first(&data.client, "/health_generate", &req) if data.is_pd_mode {
.await // For PD mode, check health on all servers
data.router
.route_pd_health_generate(&data.client, &req)
.await
} else {
// Regular mode
data.router
.route_to_first(&data.client, "/health_generate", &req)
.await
}
} }
#[get("/get_server_info")] #[get("/get_server_info")]
async fn get_server_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder { async fn get_server_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router if data.is_pd_mode {
.route_to_first(&data.client, "/get_server_info", &req) // For PD mode, aggregate info from both prefill and decode servers
.await data.router.get_pd_server_info(&data.client, &req).await
} else {
// Regular mode - return first server's info
data.router
.route_to_first(&data.client, "/get_server_info", &req)
.await
}
} }
#[get("/v1/models")] #[get("/v1/models")]
async fn v1_models(req: HttpRequest, data: web::Data<AppState>) -> impl Responder { async fn v1_models(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router if data.is_pd_mode {
.route_to_first(&data.client, "/v1/models", &req) // For PD mode, return models from the first prefill server
.await data.router.get_pd_models(&data.client, &req).await
} else {
// Regular mode
data.router
.route_to_first(&data.client, "/v1/models", &req)
.await
}
} }
#[get("/get_model_info")] #[get("/get_model_info")]
async fn get_model_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder { async fn get_model_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router if data.is_pd_mode {
.route_to_first(&data.client, "/get_model_info", &req) // For PD mode, get model info from the first prefill server
.await data.router.get_pd_model_info(&data.client, &req).await
} else {
data.router
.route_to_first(&data.client, "/get_model_info", &req)
.await
}
} }
#[post("/generate")] #[post("/generate")]
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder { async fn generate(
data.router req: HttpRequest,
.route_generate_request(&data.client, &req, &body, "/generate") body: web::Json<GenerateRequest>,
.await state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
let client = &state.client;
let router = &state.router;
// Use typed request directly for both PD and regular routing
if state.is_pd_mode {
// For PD mode, convert to PD request with bootstrap
let pd_request = body.into_inner().to_pd_request();
Ok(router
.route_pd_generate_typed(&client, &req, pd_request, "/generate")
.await)
} else {
// For regular mode, use typed request directly
let request = body.into_inner();
Ok(router
.route_typed_request(&client, &req, &request, "/generate")
.await)
}
} }
#[post("/v1/chat/completions")] #[post("/v1/chat/completions")]
async fn v1_chat_completions( async fn v1_chat_completions(
req: HttpRequest, req: HttpRequest,
body: Bytes, body: web::Json<ChatCompletionRequest>,
data: web::Data<AppState>, state: web::Data<AppState>,
) -> impl Responder { ) -> Result<HttpResponse, Error> {
data.router let client = &state.client;
.route_generate_request(&data.client, &req, &body, "/v1/chat/completions") let router = &state.router;
.await
// Use typed request directly for both PD and regular routing
if state.is_pd_mode {
// For PD mode, convert to PD request with bootstrap
let pd_request = body.into_inner().to_pd_request();
Ok(router
.route_pd_chat_typed(&client, &req, pd_request, "/v1/chat/completions")
.await)
} else {
// For regular mode, use typed request directly
let request = body.into_inner();
Ok(router
.route_typed_request(&client, &req, &request, "/v1/chat/completions")
.await)
}
} }
#[post("/v1/completions")] #[post("/v1/completions")]
async fn v1_completions( async fn v1_completions(
req: HttpRequest, req: HttpRequest,
body: Bytes, body: web::Json<CompletionRequest>,
data: web::Data<AppState>, state: web::Data<AppState>,
) -> impl Responder { ) -> Result<HttpResponse, Error> {
data.router let client = &state.client;
.route_generate_request(&data.client, &req, &body, "/v1/completions") let router = &state.router;
.await
// Use typed request directly for both PD and regular routing
if state.is_pd_mode {
// For PD mode, convert to PD request with bootstrap
let pd_request = body.into_inner().to_pd_request();
Ok(router
.route_pd_generate_typed(&client, &req, pd_request, "/v1/completions")
.await)
} else {
// For regular mode, use typed request directly
let request = body.into_inner();
Ok(router
.route_typed_request(&client, &req, &request, "/v1/completions")
.await)
}
} }
#[post("/add_worker")] #[post("/add_worker")]
...@@ -153,6 +254,25 @@ async fn remove_worker( ...@@ -153,6 +254,25 @@ async fn remove_worker(
HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url)) HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url))
} }
#[post("/flush_cache")]
async fn flush_cache(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
if data.is_pd_mode {
// For PD mode, flush cache on both prefill and decode servers
data.router.route_pd_flush_cache(&data.client).await
} else {
// Route to all workers for cache flushing
data.router
.route_to_all(&data.client, "/flush_cache", &req)
.await
}
}
#[get("/get_loads")]
async fn get_loads(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
// Get loads from all workers
data.router.get_all_loads(&data.client, &req).await
}
pub struct ServerConfig { pub struct ServerConfig {
pub host: String, pub host: String,
pub port: u16, pub port: u16,
...@@ -163,6 +283,7 @@ pub struct ServerConfig { ...@@ -163,6 +283,7 @@ pub struct ServerConfig {
pub log_dir: Option<String>, pub log_dir: Option<String>,
pub service_discovery_config: Option<ServiceDiscoveryConfig>, pub service_discovery_config: Option<ServiceDiscoveryConfig>,
pub prometheus_config: Option<PrometheusConfig>, pub prometheus_config: Option<PrometheusConfig>,
pub request_timeout_secs: u64,
} }
pub async fn startup(config: ServerConfig) -> std::io::Result<()> { pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
...@@ -215,6 +336,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { ...@@ -215,6 +336,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
let client = Client::builder() let client = Client::builder()
.pool_idle_timeout(Some(Duration::from_secs(50))) .pool_idle_timeout(Some(Duration::from_secs(50)))
.timeout(Duration::from_secs(config.request_timeout_secs)) // Use configurable timeout
.build() .build()
.expect("Failed to create HTTP client"); .expect("Failed to create HTTP client");
...@@ -276,7 +398,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { ...@@ -276,7 +398,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
.service(add_worker) .service(add_worker)
.service(remove_worker) .service(remove_worker)
.service(list_workers) .service(list_workers)
// Default handler for unmatched routes. .service(flush_cache)
.service(get_loads)
.default_service(web::route().to(sink_handler)) .default_service(web::route().to(sink_handler))
}) })
.bind_auto_h2c((config.host, config.port))? .bind_auto_h2c((config.host, config.port))?
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment