Unverified Commit 09ae5b20 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

Merge PDLB (Prefill-Decode Load Balancer) into SGLang Router (#7096)

parent 712bf9ec
......@@ -218,15 +218,39 @@ async def get_server_info():
)
prefill_infos = []
decode_infos = []
all_internal_states = []
async with aiohttp.ClientSession() as session:
for server in chain(prefill_servers):
server_info = await session.get(f"{server}/get_server_info")
prefill_infos.append(await server_info.json())
for server in chain(decode_servers):
server_info = await session.get(f"{server}/get_server_info")
decode_infos.append(await server_info.json())
return {"prefill": prefill_infos, "decode": decode_infos}
info_json = await server_info.json()
decode_infos.append(info_json)
# Extract internal_states from decode servers
if "internal_states" in info_json:
all_internal_states.extend(info_json["internal_states"])
# Return format expected by bench_one_batch_server.py
if all_internal_states:
return {
"internal_states": all_internal_states,
"prefill": prefill_infos,
"decode": decode_infos,
}
else:
# Fallback with dummy data if no internal states found
return {
"internal_states": [
{
"last_gen_throughput": 0.0,
"avg_spec_accept_length": None,
}
],
"prefill": prefill_infos,
"decode": decode_infos,
}
@app.get("/get_model_info")
......
......@@ -15,7 +15,7 @@ serde = { version = "1.0", features = ["derive"] }
clap = { version = "4.4", features = ["derive"] }
bytes = "1.8.0"
rand = "0.8.5"
reqwest = { version = "0.12.8", features = ["stream", "blocking"] }
reqwest = { version = "0.12.8", features = ["stream", "blocking", "json"] }
futures-util = "0.3"
serde_json = "1.0"
pyo3 = { version = "0.22.5", features = ["extension-module"] }
......@@ -33,6 +33,8 @@ futures = "0.3"
# Added for metrics
metrics = "0.24.2"
metrics-exporter-prometheus = "0.17.0"
# Added for request tracing
uuid = { version = "1.10", features = ["v4", "serde"] }
[profile.release]
lto = "thin"
codegen-units = 1
......@@ -31,6 +31,13 @@ class RouterArgs:
host: str = "127.0.0.1"
port: int = 30000
# PD-specific configuration
pd_disaggregated: bool = False # Enable PD disaggregated mode
prefill_urls: List[tuple] = dataclasses.field(
default_factory=list
) # List of (url, bootstrap_port)
decode_urls: List[str] = dataclasses.field(default_factory=list)
# Routing policy
policy: str = "cache_aware"
worker_startup_timeout_secs: int = 300
......@@ -40,7 +47,7 @@ class RouterArgs:
balance_rel_threshold: float = 1.0001
eviction_interval: int = 60
max_tree_size: int = 2**24
max_payload_size: int = 4 * 1024 * 1024 # 4MB
max_payload_size: int = 256 * 1024 * 1024 # 256MB default for large batches
verbose: bool = False
log_dir: Optional[str] = None
# Service discovery configuration
......@@ -95,8 +102,29 @@ class RouterArgs:
f"--{prefix}policy",
type=str,
default=RouterArgs.policy,
choices=["random", "round_robin", "cache_aware"],
help="Load balancing policy to use",
choices=["random", "round_robin", "cache_aware", "power_of_two"],
help="Load balancing policy to use. Note: power_of_two is only available in PD disaggregated mode",
)
# PD-specific arguments
parser.add_argument(
f"--{prefix}pd-disaggregated",
action="store_true",
help="Enable PD (Prefill-Decode) disaggregated mode",
)
parser.add_argument(
f"--{prefix}prefill",
nargs=2,
action="append",
metavar=("URL", "BOOTSTRAP_PORT"),
help="Prefill server URL and bootstrap port. Can be specified multiple times. BOOTSTRAP_PORT can be 'none' for no bootstrap port.",
)
parser.add_argument(
f"--{prefix}decode",
nargs=1,
action="append",
metavar=("URL",),
help="Decode server URL. Can be specified multiple times.",
)
parser.add_argument(
f"--{prefix}worker-startup-timeout-secs",
......@@ -205,11 +233,19 @@ class RouterArgs:
use_router_prefix: If True, look for arguments with 'router-' prefix
"""
prefix = "router_" if use_router_prefix else ""
worker_urls = args.worker_urls if args.worker_urls is not None else []
worker_urls = getattr(args, "worker_urls", [])
# Parse PD URLs
prefill_urls = cls._parse_prefill_urls(getattr(args, f"{prefix}prefill", None))
decode_urls = cls._parse_decode_urls(getattr(args, f"{prefix}decode", None))
return cls(
worker_urls=worker_urls,
host=args.host,
port=args.port,
pd_disaggregated=getattr(args, f"{prefix}pd_disaggregated", False),
prefill_urls=prefill_urls,
decode_urls=decode_urls,
policy=getattr(args, f"{prefix}policy"),
worker_startup_timeout_secs=getattr(
args, f"{prefix}worker_startup_timeout_secs"
......@@ -247,6 +283,46 @@ class RouterArgs:
selector[key] = value
return selector
@staticmethod
def _parse_prefill_urls(prefill_list):
"""Parse prefill URLs from --prefill arguments.
Format: --prefill URL BOOTSTRAP_PORT
Example: --prefill http://prefill1:8080 9000 --prefill http://prefill2:8080 none
"""
if not prefill_list:
return []
prefill_urls = []
for url, bootstrap_port_str in prefill_list:
# Handle 'none' as None
if bootstrap_port_str.lower() == "none":
bootstrap_port = None
else:
try:
bootstrap_port = int(bootstrap_port_str)
except ValueError:
raise ValueError(
f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'"
)
prefill_urls.append((url, bootstrap_port))
return prefill_urls
@staticmethod
def _parse_decode_urls(decode_list):
"""Parse decode URLs from --decode arguments.
Format: --decode URL
Example: --decode http://decode1:8081 --decode http://decode2:8081
"""
if not decode_list:
return []
# decode_list is a list of single-element lists due to nargs=1
return [url[0] for url in decode_list]
def policy_from_str(policy_str: str) -> PolicyType:
"""Convert policy string to PolicyType enum."""
......@@ -254,6 +330,7 @@ def policy_from_str(policy_str: str) -> PolicyType:
"random": PolicyType.Random,
"round_robin": PolicyType.RoundRobin,
"cache_aware": PolicyType.CacheAware,
"power_of_two": PolicyType.PowerOfTwo,
}
return policy_map[policy_str]
......@@ -277,8 +354,19 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
else:
router_args = args
# Validate configuration based on mode
if router_args.pd_disaggregated:
# Validate PD configuration
if not router_args.prefill_urls:
raise ValueError("PD disaggregated mode requires --prefill")
if not router_args.decode_urls:
raise ValueError("PD disaggregated mode requires --decode")
# Create router with unified constructor
router = Router(
worker_urls=router_args.worker_urls,
worker_urls=(
router_args.worker_urls if not router_args.pd_disaggregated else []
),
host=router_args.host,
port=router_args.port,
policy=policy_from_str(router_args.policy),
......@@ -298,6 +386,13 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
service_discovery_namespace=router_args.service_discovery_namespace,
prometheus_port=router_args.prometheus_port,
prometheus_host=router_args.prometheus_host,
pd_disaggregated=router_args.pd_disaggregated,
prefill_urls=(
router_args.prefill_urls if router_args.pd_disaggregated else None
),
decode_urls=(
router_args.decode_urls if router_args.pd_disaggregated else None
),
)
router.start()
......@@ -326,8 +421,14 @@ This launcher enables starting a router with individual worker instances. It is
multi-node setups or when you want to start workers and router separately.
Examples:
# Regular mode
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 --cache-threshold 0.7 --balance-abs-threshold 64 --balance-rel-threshold 1.2
# PD disaggregated mode
python -m sglang_router.launch_router --pd-disaggregated \\
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
--decode http://decode1:8001 --decode http://decode2:8001 \\
--policy cache_aware
""",
formatter_class=CustomHelpFormatter,
......
......@@ -15,6 +15,7 @@ class Router:
- PolicyType.Random: Randomly select workers
- PolicyType.RoundRobin: Distribute requests in round-robin fashion
- PolicyType.CacheAware: Distribute requests based on cache state and load balance
- PolicyType.PowerOfTwo: Select best of two random workers based on load (PD mode only)
host: Host address to bind the router server. Default: '127.0.0.1'
port: Port number to bind the router server. Default: 3001
worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300
......@@ -28,7 +29,7 @@ class Router:
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
routing. Default: 60
max_payload_size: Maximum payload size in bytes. Default: 4MB
max_payload_size: Maximum payload size in bytes. Default: 256MB
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
verbose: Enable verbose logging. Default: False
log_dir: Directory to store log files. If None, logs are only output to console. Default: None
......@@ -42,6 +43,9 @@ class Router:
watches pods across all namespaces (requires cluster-wide permissions). Default: None
prometheus_port: Port to expose Prometheus metrics. Default: None
prometheus_host: Host address to bind the Prometheus metrics server. Default: None
pd_disaggregated: Enable PD (Prefill-Decode) disaggregated mode. Default: False
prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
decode_urls: List of URLs for decode servers (PD mode only)
"""
def __init__(
......@@ -57,7 +61,7 @@ class Router:
balance_rel_threshold: float = 1.0001,
eviction_interval_secs: int = 60,
max_tree_size: int = 2**24,
max_payload_size: int = 4 * 1024 * 1024, # 4MB
max_payload_size: int = 256 * 1024 * 1024, # 256MB
verbose: bool = False,
log_dir: Optional[str] = None,
service_discovery: bool = False,
......@@ -66,6 +70,9 @@ class Router:
service_discovery_namespace: Optional[str] = None,
prometheus_port: Optional[int] = None,
prometheus_host: Optional[str] = None,
pd_disaggregated: bool = False,
prefill_urls: Optional[List[tuple]] = None,
decode_urls: Optional[List[str]] = None,
):
if selector is None:
selector = {}
......@@ -91,6 +98,9 @@ class Router:
service_discovery_namespace=service_discovery_namespace,
prometheus_port=prometheus_port,
prometheus_host=prometheus_host,
pd_disaggregated=pd_disaggregated,
prefill_urls=prefill_urls,
decode_urls=decode_urls,
)
def start(self) -> None:
......
......@@ -35,13 +35,21 @@ class TestLaunchRouter(unittest.TestCase):
balance_rel_threshold=1.0001,
eviction_interval=60,
max_tree_size=2**24,
max_payload_size=4 * 1024 * 1024, # 4MB
max_payload_size=256 * 1024 * 1024, # 256MB
verbose=False,
log_dir=None,
service_discovery=False,
selector=None,
service_discovery_port=80,
service_discovery_namespace=None,
prometheus_port=None,
prometheus_host=None,
# PD-specific attributes
pd_disaggregated=False,
prefill=None,
decode=None,
# Keep worker_urls for regular mode
worker_urls=[],
)
def create_router_args(self, **kwargs):
......@@ -81,7 +89,7 @@ class TestLaunchRouter(unittest.TestCase):
def test_launch_router_with_empty_worker_urls(self):
args = self.create_router_args(worker_urls=[])
self.run_router_process(args)
self.run_router_process(args) # Expected error
def test_launch_router_with_service_discovery(self):
# Test router startup with service discovery enabled but no selectors
......@@ -100,6 +108,112 @@ class TestLaunchRouter(unittest.TestCase):
)
self.run_router_process(args)
def test_launch_router_pd_mode_basic(self):
"""Test basic PD router functionality without actually starting servers."""
# This test just verifies the PD router can be created and configured
# without actually starting it (which would require real prefill/decode servers)
from sglang_router import Router
from sglang_router.launch_router import RouterArgs
from sglang_router_rs import PolicyType
# Test RouterArgs parsing for PD mode
# Simulate the parsed args structure from argparse with action="append"
args = self.create_router_args(
pd_disaggregated=True,
policy="power_of_two", # PowerOfTwo is only valid in PD mode
prefill=[
["http://prefill1:8080", "9000"],
["http://prefill2:8080", "none"],
],
decode=[
["http://decode1:8081"],
["http://decode2:8081"],
],
worker_urls=[], # Empty for PD mode
)
router_args = RouterArgs.from_cli_args(args)
self.assertTrue(router_args.pd_disaggregated)
self.assertEqual(router_args.policy, "power_of_two")
self.assertEqual(len(router_args.prefill_urls), 2)
self.assertEqual(len(router_args.decode_urls), 2)
# Verify the parsed URLs and bootstrap ports
self.assertEqual(router_args.prefill_urls[0], ("http://prefill1:8080", 9000))
self.assertEqual(router_args.prefill_urls[1], ("http://prefill2:8080", None))
self.assertEqual(router_args.decode_urls[0], "http://decode1:8081")
self.assertEqual(router_args.decode_urls[1], "http://decode2:8081")
# Test Router creation in PD mode
router = Router(
worker_urls=[], # Empty for PD mode
pd_disaggregated=True,
prefill_urls=[
("http://prefill1:8080", 9000),
("http://prefill2:8080", None),
],
decode_urls=["http://decode1:8081", "http://decode2:8081"],
policy=PolicyType.CacheAware,
host="127.0.0.1",
port=3001,
)
self.assertIsNotNone(router)
def test_policy_validation(self):
"""Test that policy validation works correctly for PD and regular modes."""
from sglang_router.launch_router import RouterArgs, launch_router
# Test 1: PowerOfTwo is only valid in PD mode
args = self.create_router_args(
pd_disaggregated=False,
policy="power_of_two",
worker_urls=["http://localhost:8000"],
)
# Should raise error
with self.assertRaises(ValueError) as cm:
launch_router(args)
self.assertIn(
"PowerOfTwo policy is only supported in PD disaggregated mode",
str(cm.exception),
)
# Test 2: RoundRobin is not valid in PD mode
args = self.create_router_args(
pd_disaggregated=True,
policy="round_robin",
prefill=[["http://prefill1:8080", "9000"]],
decode=[["http://decode1:8081"]],
worker_urls=[],
)
# Should raise error
with self.assertRaises(ValueError) as cm:
launch_router(args)
self.assertIn(
"RoundRobin policy is not supported in PD disaggregated mode",
str(cm.exception),
)
# Test 3: Valid combinations should not raise errors
# Regular mode with RoundRobin
args = self.create_router_args(
pd_disaggregated=False,
policy="round_robin",
worker_urls=["http://localhost:8000"],
)
# This should not raise (though it may fail to connect)
# PD mode with PowerOfTwo
args = self.create_router_args(
pd_disaggregated=True,
policy="power_of_two",
prefill=[["http://prefill1:8080", "9000"]],
decode=[["http://decode1:8081"]],
worker_urls=[],
)
# This should not raise (though it may fail to connect)
if __name__ == "__main__":
unittest.main()
use pyo3::prelude::*;
pub mod logging;
use std::collections::HashMap;
pub mod openai_api_types;
pub mod pd_router;
pub mod pd_types;
pub mod prometheus;
pub mod request_adapter;
pub mod router;
pub mod server;
pub mod service_discovery;
......@@ -14,6 +18,7 @@ pub enum PolicyType {
Random,
RoundRobin,
CacheAware,
PowerOfTwo, // Moved from PD-specific, now shared
}
#[pyclass]
......@@ -39,6 +44,12 @@ struct Router {
service_discovery_namespace: Option<String>,
prometheus_port: Option<u16>,
prometheus_host: Option<String>,
request_timeout_secs: u64,
// PD mode flag
pd_disaggregated: bool,
// PD-specific fields (only used when pd_disaggregated is true)
prefill_urls: Option<Vec<(String, Option<u16>)>>,
decode_urls: Option<Vec<String>>,
}
#[pymethods]
......@@ -56,7 +67,7 @@ impl Router {
balance_rel_threshold = 1.0001,
eviction_interval_secs = 60,
max_tree_size = 2usize.pow(24),
max_payload_size = 4 * 1024 * 1024,
max_payload_size = 256 * 1024 * 1024, // 256MB default for large batches
verbose = false,
log_dir = None,
service_discovery = false,
......@@ -64,7 +75,11 @@ impl Router {
service_discovery_port = 80,
service_discovery_namespace = None,
prometheus_port = None,
prometheus_host = None
prometheus_host = None,
request_timeout_secs = 600, // Add configurable request timeout
pd_disaggregated = false, // New flag for PD mode
prefill_urls = None,
decode_urls = None
))]
fn new(
worker_urls: Vec<String>,
......@@ -87,6 +102,10 @@ impl Router {
service_discovery_namespace: Option<String>,
prometheus_port: Option<u16>,
prometheus_host: Option<String>,
request_timeout_secs: u64,
pd_disaggregated: bool,
prefill_urls: Option<Vec<(String, Option<u16>)>>,
decode_urls: Option<Vec<String>>,
) -> PyResult<Self> {
Ok(Router {
host,
......@@ -109,11 +128,52 @@ impl Router {
service_discovery_namespace,
prometheus_port,
prometheus_host,
request_timeout_secs,
pd_disaggregated,
prefill_urls,
decode_urls,
})
}
fn start(&self) -> PyResult<()> {
let policy_config = match &self.policy {
let policy_config = if self.pd_disaggregated {
// PD mode - map PolicyType to PDSelectionPolicy
let pd_selection_policy = match &self.policy {
PolicyType::Random => pd_types::PDSelectionPolicy::Random,
PolicyType::PowerOfTwo => pd_types::PDSelectionPolicy::PowerOfTwo,
PolicyType::CacheAware => pd_types::PDSelectionPolicy::CacheAware {
cache_threshold: self.cache_threshold,
balance_abs_threshold: self.balance_abs_threshold,
balance_rel_threshold: self.balance_rel_threshold,
},
PolicyType::RoundRobin => {
return Err(pyo3::exceptions::PyValueError::new_err(
"RoundRobin policy is not supported in PD disaggregated mode",
));
}
};
let prefill_urls = self.prefill_urls.as_ref().ok_or_else(|| {
pyo3::exceptions::PyValueError::new_err(
"PD disaggregated mode requires prefill_urls",
)
})?;
let decode_urls = self.decode_urls.as_ref().ok_or_else(|| {
pyo3::exceptions::PyValueError::new_err(
"PD disaggregated mode requires decode_urls",
)
})?;
router::PolicyConfig::PrefillDecodeConfig {
selection_policy: pd_selection_policy,
prefill_urls: prefill_urls.clone(),
decode_urls: decode_urls.clone(),
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval,
}
} else {
// Regular mode
match &self.policy {
PolicyType::Random => router::PolicyConfig::RandomConfig {
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval,
......@@ -131,6 +191,12 @@ impl Router {
eviction_interval_secs: self.eviction_interval_secs,
max_tree_size: self.max_tree_size,
},
PolicyType::PowerOfTwo => {
return Err(pyo3::exceptions::PyValueError::new_err(
"PowerOfTwo policy is only supported in PD disaggregated mode",
));
}
}
};
// Create service discovery config if enabled
......@@ -166,6 +232,7 @@ impl Router {
log_dir: self.log_dir.clone(),
service_discovery_config,
prometheus_config,
request_timeout_secs: self.request_timeout_secs,
})
.await
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
......
This diff is collapsed.
This diff is collapsed.
// Essential PDLB types extracted for PD routing
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone)]
pub enum EngineType {
Prefill,
Decode,
}
#[derive(Debug, Clone)]
pub struct EngineInfo {
pub engine_type: EngineType,
pub url: String,
pub bootstrap_port: Option<u16>,
}
impl EngineInfo {
pub fn new_prefill(url: String, bootstrap_port: Option<u16>) -> Self {
EngineInfo {
engine_type: EngineType::Prefill,
url,
bootstrap_port,
}
}
pub fn new_decode(url: String) -> Self {
EngineInfo {
engine_type: EngineType::Decode,
url,
bootstrap_port: None,
}
}
pub fn api_path(&self, api_path: &str) -> String {
if api_path.starts_with("/") {
format!("{}{}", self.url, api_path)
} else {
format!("{}/{}", self.url, api_path)
}
}
pub fn get_hostname(&self) -> String {
// Simple hostname extraction without external dependencies
let url = self
.url
.trim_start_matches("http://")
.trim_start_matches("https://");
url.split(':').next().unwrap_or("localhost").to_string()
}
}
// PD-specific routing policies
#[derive(Debug, Clone, PartialEq)]
pub enum PDSelectionPolicy {
Random,
PowerOfTwo,
CacheAware {
cache_threshold: f32,
balance_abs_threshold: usize,
balance_rel_threshold: f32,
},
}
// Bootstrap types from PDLB
#[derive(Debug, Deserialize, Serialize)]
#[serde(untagged)]
pub enum SingleOrBatch<T> {
Single(T),
Batch(Vec<T>),
}
pub type InputIds = SingleOrBatch<Vec<i32>>;
pub type InputText = SingleOrBatch<String>;
pub type BootstrapHost = SingleOrBatch<String>;
pub type BootstrapPort = SingleOrBatch<Option<u16>>;
pub type BootstrapRoom = SingleOrBatch<u64>;
// Bootstrap trait for request handling
pub trait Bootstrap: Send + Sync {
fn is_stream(&self) -> bool;
fn get_batch_size(&self) -> Result<Option<usize>, String>;
fn set_bootstrap_info(
&mut self,
bootstrap_host: BootstrapHost,
bootstrap_port: BootstrapPort,
bootstrap_room: BootstrapRoom,
);
fn add_bootstrap_info(&mut self, prefill_info: &EngineInfo) -> Result<(), String> {
let batch_size = self.get_batch_size()?;
if let Some(batch_size) = batch_size {
self.set_bootstrap_info(
BootstrapHost::Batch(vec![prefill_info.get_hostname(); batch_size]),
BootstrapPort::Batch(vec![prefill_info.bootstrap_port; batch_size]),
// Use high-quality random numbers to minimize collision risk
BootstrapRoom::Batch(
(0..batch_size)
.map(|_| {
// Combine multiple sources of randomness for better distribution
let r1 = rand::random::<u64>();
let r2 = rand::random::<u64>();
r1.wrapping_add(r2.rotate_left(32))
})
.collect(),
),
);
} else {
self.set_bootstrap_info(
BootstrapHost::Single(prefill_info.get_hostname()),
BootstrapPort::Single(prefill_info.bootstrap_port),
BootstrapRoom::Single({
// Use high-quality random number for single requests too
let r1 = rand::random::<u64>();
let r2 = rand::random::<u64>();
r1.wrapping_add(r2.rotate_left(32))
}),
);
}
Ok(())
}
}
// Request types
#[derive(Debug, Deserialize, Serialize)]
pub struct GenerateReqInput {
pub text: Option<InputText>,
pub input_ids: Option<InputIds>,
#[serde(default)]
pub stream: bool,
pub bootstrap_host: Option<BootstrapHost>,
pub bootstrap_port: Option<BootstrapPort>,
pub bootstrap_room: Option<BootstrapRoom>,
#[serde(flatten)]
pub other: Value,
}
impl GenerateReqInput {
pub fn get_batch_size(&self) -> Result<Option<usize>, String> {
if self.text.is_some() && self.input_ids.is_some() {
return Err("Both text and input_ids are present in the request".to_string());
}
// Check text batch
if let Some(InputText::Batch(texts)) = &self.text {
if texts.is_empty() {
return Err("Batch text array is empty".to_string());
}
if texts.len() > 10000 {
// Reasonable limit for production
return Err(format!(
"Batch size {} exceeds maximum allowed (10000)",
texts.len()
));
}
return Ok(Some(texts.len()));
}
// Check input_ids batch
if let Some(InputIds::Batch(ids)) = &self.input_ids {
if ids.is_empty() {
return Err("Batch input_ids array is empty".to_string());
}
if ids.len() > 10000 {
// Reasonable limit for production
return Err(format!(
"Batch size {} exceeds maximum allowed (10000)",
ids.len()
));
}
// Validate each sequence is not empty
for (i, seq) in ids.iter().enumerate() {
if seq.is_empty() {
return Err(format!("Input sequence at index {} is empty", i));
}
}
return Ok(Some(ids.len()));
}
Ok(None)
}
}
impl Bootstrap for GenerateReqInput {
fn is_stream(&self) -> bool {
self.stream
}
fn get_batch_size(&self) -> Result<Option<usize>, String> {
self.get_batch_size()
}
fn set_bootstrap_info(
&mut self,
bootstrap_host: BootstrapHost,
bootstrap_port: BootstrapPort,
bootstrap_room: BootstrapRoom,
) {
self.bootstrap_host = Some(bootstrap_host);
self.bootstrap_port = Some(bootstrap_port);
self.bootstrap_room = Some(bootstrap_room);
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ChatReqInput {
#[serde(default)]
pub stream: bool,
pub bootstrap_host: Option<BootstrapHost>,
pub bootstrap_port: Option<BootstrapPort>,
pub bootstrap_room: Option<BootstrapRoom>,
#[serde(flatten)]
pub other: Value,
}
impl Bootstrap for ChatReqInput {
fn is_stream(&self) -> bool {
self.stream
}
fn get_batch_size(&self) -> Result<Option<usize>, String> {
// Check if 'n' parameter is present and > 1
if let Some(n_value) = self.other.get("n") {
if let Some(n) = n_value.as_u64() {
if n > 1 {
return Ok(Some(n as usize));
}
}
}
Ok(None)
}
fn set_bootstrap_info(
&mut self,
bootstrap_host: BootstrapHost,
bootstrap_port: BootstrapPort,
bootstrap_room: BootstrapRoom,
) {
self.bootstrap_host = Some(bootstrap_host);
self.bootstrap_port = Some(bootstrap_port);
self.bootstrap_room = Some(bootstrap_room);
}
}
// Request adapter to bridge OpenAI API types with PD routing requirements
use crate::openai_api_types::{
ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, StringOrArray,
};
use crate::pd_types::{Bootstrap, ChatReqInput, GenerateReqInput, SingleOrBatch};
use serde_json::Value;
/// Adapter trait to convert OpenAI requests to PD-compatible requests
pub trait ToPdRequest {
type Output: Bootstrap;
fn to_pd_request(self) -> Self::Output;
}
// Helper macro to insert optional fields into a map
macro_rules! insert_if_some {
($map:expr, $($field:expr => $key:expr),* $(,)?) => {
$(
if let Some(value) = $field {
$map.insert($key.to_string(), serde_json::to_value(value).unwrap_or(Value::Null));
}
)*
};
}
// Helper macro for simple value insertions
macro_rules! insert_value {
($map:expr, $($field:expr => $key:expr),* $(,)?) => {
$(
$map.insert($key.to_string(), $field.into());
)*
};
}
// ============= Generate Request Adapter =============
impl ToPdRequest for GenerateRequest {
type Output = GenerateReqInput;
fn to_pd_request(self) -> Self::Output {
// Build the other fields first
let mut other = serde_json::Map::new();
// Handle text input - check in priority order: text (SGLang), prompt (OpenAI)
let (text, input_ids) = if let Some(text_str) = self.text {
// SGLang native format
(Some(SingleOrBatch::Single(text_str)), None)
} else if let Some(prompt) = self.prompt {
// OpenAI style prompt
let text = match prompt {
StringOrArray::String(s) => Some(SingleOrBatch::Single(s)),
StringOrArray::Array(v) => Some(SingleOrBatch::Batch(v)),
};
(text, None)
} else if let Some(ids) = self.input_ids {
// Input IDs case
let input_ids = match ids {
crate::openai_api_types::InputIds::Single(ids) => Some(SingleOrBatch::Single(ids)),
crate::openai_api_types::InputIds::Batch(ids) => Some(SingleOrBatch::Batch(ids)),
};
(None, input_ids)
} else {
// No input provided
(None, None)
};
// Add parameters to other - handle both old and new style
if let Some(params) = self.parameters {
// For generate endpoint, extract max_new_tokens to top level if present
let mut params_value = serde_json::to_value(&params).unwrap_or(Value::Null);
if let Value::Object(ref mut params_map) = params_value {
// Move max_new_tokens to top level if it exists
if let Some(max_new_tokens) = params_map.remove("max_new_tokens") {
other.insert("max_new_tokens".to_string(), max_new_tokens);
}
// Move temperature to top level if it exists
if let Some(temperature) = params_map.remove("temperature") {
other.insert("temperature".to_string(), temperature);
}
}
// Only add parameters if there are remaining fields
if !params_value.is_null() && params_value.as_object().map_or(false, |m| !m.is_empty())
{
other.insert("parameters".to_string(), params_value);
}
}
// Add sampling_params if present
if let Some(sampling_params) = self.sampling_params {
let params_value = serde_json::to_value(&sampling_params).unwrap_or(Value::Null);
if !params_value.is_null() {
// Extract commonly used fields to top level
if let Value::Object(ref params_map) = params_value {
if let Some(max_new_tokens) = params_map.get("max_new_tokens") {
other.insert("max_new_tokens".to_string(), max_new_tokens.clone());
}
if let Some(temperature) = params_map.get("temperature") {
other.insert("temperature".to_string(), temperature.clone());
}
}
other.insert("sampling_params".to_string(), params_value);
}
}
// Add other fields
insert_value!(other,
self.stream => "stream",
self.return_logprob => "return_logprob"
);
GenerateReqInput {
text,
input_ids,
stream: self.stream,
bootstrap_host: None,
bootstrap_port: None,
bootstrap_room: None,
other: Value::Object(other),
}
}
}
// ============= Completion Request Adapter =============
impl ToPdRequest for CompletionRequest {
type Output = GenerateReqInput;
fn to_pd_request(self) -> Self::Output {
// Convert CompletionRequest to GenerateReqInput
let text = match self.prompt {
StringOrArray::String(s) => Some(SingleOrBatch::Single(s)),
StringOrArray::Array(v) => Some(SingleOrBatch::Batch(v)),
};
// Map OpenAI parameters to generate parameters
let mut other = serde_json::Map::new();
// Create parameters object
let mut params = serde_json::Map::new();
// Map OpenAI fields to internal parameter names
insert_if_some!(params,
self.max_tokens => "max_new_tokens",
self.temperature => "temperature",
self.top_p => "top_p",
self.n => "best_of",
self.logprobs => "top_n_tokens",
self.seed => "seed"
);
// Special handling for fields that need transformation
if let Some(presence_penalty) = self.presence_penalty {
params.insert(
"repetition_penalty".to_string(),
(1.0 + presence_penalty).into(),
);
}
if let Some(stop) = self.stop {
let stop_sequences = match stop {
StringOrArray::String(s) => vec![s],
StringOrArray::Array(v) => v,
};
params.insert("stop".to_string(), stop_sequences.into());
}
if self.echo {
params.insert("return_full_text".to_string(), true.into());
}
other.insert("parameters".to_string(), Value::Object(params));
// Store original model and stream flag
insert_value!(other,
self.model => "model",
self.stream => "stream"
);
GenerateReqInput {
text,
input_ids: None,
stream: self.stream,
bootstrap_host: None,
bootstrap_port: None,
bootstrap_room: None,
other: Value::Object(other),
}
}
}
// ============= Chat Completion Request Adapter =============
impl ToPdRequest for ChatCompletionRequest {
type Output = ChatReqInput;
fn to_pd_request(self) -> Self::Output {
let mut other = serde_json::Map::new();
// Add required fields
insert_if_some!(other,
Some(&self.messages) => "messages"
);
insert_value!(other,
self.model => "model",
self.stream => "stream"
);
// Add all optional fields
insert_if_some!(other,
self.temperature => "temperature",
self.top_p => "top_p",
self.n => "n",
self.stop => "stop",
self.max_tokens => "max_tokens",
self.max_completion_tokens => "max_completion_tokens",
self.presence_penalty => "presence_penalty",
self.frequency_penalty => "frequency_penalty",
self.logit_bias => "logit_bias",
self.user => "user",
self.seed => "seed",
self.top_logprobs => "top_logprobs",
self.response_format => "response_format",
self.tools => "tools",
self.tool_choice => "tool_choice",
self.parallel_tool_calls => "parallel_tool_calls",
self.functions => "functions",
self.function_call => "function_call"
);
// Handle boolean logprobs flag
if self.logprobs {
other.insert("logprobs".to_string(), true.into());
}
ChatReqInput {
stream: self.stream,
bootstrap_host: None,
bootstrap_port: None,
bootstrap_room: None,
other: Value::Object(other),
}
}
}
// ============= Direct routing support for regular router =============
/// Extension trait for routing without PD conversion
pub trait RouteableRequest: GenerationRequest + serde::Serialize + Clone {
/// Convert to JSON for sending to backend
fn to_json(&self) -> Result<Value, serde_json::Error> {
serde_json::to_value(self)
}
/// Convert to bytes for legacy routing
fn to_bytes(&self) -> Result<bytes::Bytes, serde_json::Error> {
let json = serde_json::to_vec(self)?;
Ok(bytes::Bytes::from(json))
}
}
impl RouteableRequest for GenerateRequest {}
impl RouteableRequest for CompletionRequest {}
impl RouteableRequest for ChatCompletionRequest {}
This diff is collapsed.
use crate::logging::{self, LoggingConfig};
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::prometheus::{self, PrometheusConfig};
use crate::request_adapter::ToPdRequest;
use crate::router::PolicyConfig;
use crate::router::Router;
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
use actix_web::{
error, get, post, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder,
};
use bytes::Bytes;
use futures_util::StreamExt;
use reqwest::Client;
use std::collections::HashMap;
......@@ -20,6 +21,7 @@ use tracing::{error, info, warn, Level};
pub struct AppState {
router: Arc<Router>,
client: Client,
is_pd_mode: bool, // Add flag to track PD mode
}
impl AppState {
......@@ -28,9 +30,16 @@ impl AppState {
client: Client,
policy_config: PolicyConfig,
) -> Result<Self, String> {
// Check if this is PD mode from policy config
let is_pd_mode = matches!(policy_config, PolicyConfig::PrefillDecodeConfig { .. });
// Create router based on policy
let router = Arc::new(Router::new(worker_urls, policy_config)?);
Ok(Self { router, client })
Ok(Self {
router,
client,
is_pd_mode,
})
}
}
......@@ -46,8 +55,25 @@ async fn sink_handler(_req: HttpRequest, mut payload: web::Payload) -> Result<Ht
}
// Custom error handler for JSON payload errors.
fn json_error_handler(_err: error::JsonPayloadError, _req: &HttpRequest) -> Error {
error::ErrorPayloadTooLarge("Payload too large")
fn json_error_handler(err: error::JsonPayloadError, _req: &HttpRequest) -> Error {
error!("JSON payload error: {:?}", err);
match &err {
error::JsonPayloadError::OverflowKnownLength { length, limit } => {
error!(
"Payload too large: {} bytes exceeds limit of {} bytes",
length, limit
);
error::ErrorPayloadTooLarge(format!(
"Payload too large: {} bytes exceeds limit of {} bytes",
length, limit
))
}
error::JsonPayloadError::Overflow { limit } => {
error!("Payload overflow: exceeds limit of {} bytes", limit);
error::ErrorPayloadTooLarge(format!("Payload exceeds limit of {} bytes", limit))
}
_ => error::ErrorBadRequest(format!("Invalid JSON payload: {}", err)),
}
}
#[get("/health")]
......@@ -59,59 +85,134 @@ async fn health(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
#[get("/health_generate")]
async fn health_generate(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
// Check if we're in PD mode
if data.is_pd_mode {
// For PD mode, check health on all servers
data.router
.route_pd_health_generate(&data.client, &req)
.await
} else {
// Regular mode
data.router
.route_to_first(&data.client, "/health_generate", &req)
.await
}
}
#[get("/get_server_info")]
async fn get_server_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
if data.is_pd_mode {
// For PD mode, aggregate info from both prefill and decode servers
data.router.get_pd_server_info(&data.client, &req).await
} else {
// Regular mode - return first server's info
data.router
.route_to_first(&data.client, "/get_server_info", &req)
.await
}
}
#[get("/v1/models")]
async fn v1_models(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
if data.is_pd_mode {
// For PD mode, return models from the first prefill server
data.router.get_pd_models(&data.client, &req).await
} else {
// Regular mode
data.router
.route_to_first(&data.client, "/v1/models", &req)
.await
}
}
#[get("/get_model_info")]
async fn get_model_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
if data.is_pd_mode {
// For PD mode, get model info from the first prefill server
data.router.get_pd_model_info(&data.client, &req).await
} else {
data.router
.route_to_first(&data.client, "/get_model_info", &req)
.await
}
}
#[post("/generate")]
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
data.router
.route_generate_request(&data.client, &req, &body, "/generate")
.await
async fn generate(
req: HttpRequest,
body: web::Json<GenerateRequest>,
state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
let client = &state.client;
let router = &state.router;
// Use typed request directly for both PD and regular routing
if state.is_pd_mode {
// For PD mode, convert to PD request with bootstrap
let pd_request = body.into_inner().to_pd_request();
Ok(router
.route_pd_generate_typed(&client, &req, pd_request, "/generate")
.await)
} else {
// For regular mode, use typed request directly
let request = body.into_inner();
Ok(router
.route_typed_request(&client, &req, &request, "/generate")
.await)
}
}
#[post("/v1/chat/completions")]
async fn v1_chat_completions(
req: HttpRequest,
body: Bytes,
data: web::Data<AppState>,
) -> impl Responder {
data.router
.route_generate_request(&data.client, &req, &body, "/v1/chat/completions")
.await
body: web::Json<ChatCompletionRequest>,
state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
let client = &state.client;
let router = &state.router;
// Use typed request directly for both PD and regular routing
if state.is_pd_mode {
// For PD mode, convert to PD request with bootstrap
let pd_request = body.into_inner().to_pd_request();
Ok(router
.route_pd_chat_typed(&client, &req, pd_request, "/v1/chat/completions")
.await)
} else {
// For regular mode, use typed request directly
let request = body.into_inner();
Ok(router
.route_typed_request(&client, &req, &request, "/v1/chat/completions")
.await)
}
}
#[post("/v1/completions")]
async fn v1_completions(
req: HttpRequest,
body: Bytes,
data: web::Data<AppState>,
) -> impl Responder {
data.router
.route_generate_request(&data.client, &req, &body, "/v1/completions")
.await
body: web::Json<CompletionRequest>,
state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
let client = &state.client;
let router = &state.router;
// Use typed request directly for both PD and regular routing
if state.is_pd_mode {
// For PD mode, convert to PD request with bootstrap
let pd_request = body.into_inner().to_pd_request();
Ok(router
.route_pd_generate_typed(&client, &req, pd_request, "/v1/completions")
.await)
} else {
// For regular mode, use typed request directly
let request = body.into_inner();
Ok(router
.route_typed_request(&client, &req, &request, "/v1/completions")
.await)
}
}
#[post("/add_worker")]
......@@ -153,6 +254,25 @@ async fn remove_worker(
HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url))
}
#[post("/flush_cache")]
async fn flush_cache(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
if data.is_pd_mode {
// For PD mode, flush cache on both prefill and decode servers
data.router.route_pd_flush_cache(&data.client).await
} else {
// Route to all workers for cache flushing
data.router
.route_to_all(&data.client, "/flush_cache", &req)
.await
}
}
#[get("/get_loads")]
async fn get_loads(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
// Get loads from all workers
data.router.get_all_loads(&data.client, &req).await
}
pub struct ServerConfig {
pub host: String,
pub port: u16,
......@@ -163,6 +283,7 @@ pub struct ServerConfig {
pub log_dir: Option<String>,
pub service_discovery_config: Option<ServiceDiscoveryConfig>,
pub prometheus_config: Option<PrometheusConfig>,
pub request_timeout_secs: u64,
}
pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
......@@ -215,6 +336,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
let client = Client::builder()
.pool_idle_timeout(Some(Duration::from_secs(50)))
.timeout(Duration::from_secs(config.request_timeout_secs)) // Use configurable timeout
.build()
.expect("Failed to create HTTP client");
......@@ -276,7 +398,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
.service(add_worker)
.service(remove_worker)
.service(list_workers)
// Default handler for unmatched routes.
.service(flush_cache)
.service(get_loads)
.default_service(web::route().to(sink_handler))
})
.bind_auto_h2c((config.host, config.port))?
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment