Unverified Commit 6d6e24bc authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] Add builder pattern for RouterConfig with zero duplication (#12030)

parent 2c057fbf
use super::{
CircuitBreakerConfig, ConfigResult, DiscoveryConfig, HealthCheckConfig, HistoryBackend,
MetricsConfig, OracleConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
TokenizerCacheConfig,
};
use crate::core::ConnectionMode;
/// Builder for RouterConfig that wraps the config itself
/// This eliminates field duplication and stays in sync automatically
#[derive(Debug, Clone, Default)]
pub struct RouterConfigBuilder {
config: RouterConfig,
}
impl RouterConfigBuilder {
/// Create a new builder with default values
pub fn new() -> Self {
Self::default()
}
/// Create a builder from an existing configuration (takes ownership)
pub fn from_config(config: RouterConfig) -> Self {
Self { config }
}
/// Create a builder from a reference to an existing configuration
pub fn from_config_ref(config: &RouterConfig) -> Self {
Self::from_config(config.clone())
}
// ==================== Routing Mode Setters ====================
/// Set regular routing mode with worker URLs
pub fn regular_mode(mut self, worker_urls: Vec<String>) -> Self {
self.config.mode = RoutingMode::Regular { worker_urls };
self
}
/// Set prefill-decode routing mode
pub fn prefill_decode_mode(
mut self,
prefill_urls: Vec<(String, Option<u16>)>,
decode_urls: Vec<String>,
) -> Self {
self.config.mode = RoutingMode::PrefillDecode {
prefill_urls,
decode_urls,
prefill_policy: None,
decode_policy: None,
};
self
}
/// Set prefill-decode mode with separate policies
pub fn prefill_decode_mode_with_policies(
mut self,
prefill_urls: Vec<(String, Option<u16>)>,
decode_urls: Vec<String>,
prefill_policy: Option<PolicyConfig>,
decode_policy: Option<PolicyConfig>,
) -> Self {
self.config.mode = RoutingMode::PrefillDecode {
prefill_urls,
decode_urls,
prefill_policy,
decode_policy,
};
self
}
/// Set OpenAI routing mode
pub fn openai_mode(mut self, worker_urls: Vec<String>) -> Self {
self.config.mode = RoutingMode::OpenAI { worker_urls };
self
}
/// Set the routing mode directly
pub fn mode(mut self, mode: RoutingMode) -> Self {
self.config.mode = mode;
self
}
// ==================== Policy Setters ====================
/// Set the routing policy
pub fn policy(mut self, policy: PolicyConfig) -> Self {
self.config.policy = policy;
self
}
/// Set random policy
pub fn random_policy(mut self) -> Self {
self.config.policy = PolicyConfig::Random;
self
}
/// Set round-robin policy
pub fn round_robin_policy(mut self) -> Self {
self.config.policy = PolicyConfig::RoundRobin;
self
}
/// Set cache-aware policy with parameters
pub fn cache_aware_policy(
mut self,
cache_threshold: f32,
balance_abs_threshold: usize,
balance_rel_threshold: f32,
eviction_interval_secs: u64,
max_tree_size: usize,
) -> Self {
self.config.policy = PolicyConfig::CacheAware {
cache_threshold,
balance_abs_threshold,
balance_rel_threshold,
eviction_interval_secs,
max_tree_size,
};
self
}
/// Set power-of-two policy
pub fn power_of_two_policy(mut self, load_check_interval_secs: u64) -> Self {
self.config.policy = PolicyConfig::PowerOfTwo {
load_check_interval_secs,
};
self
}
// ==================== Connection Settings ====================
/// Set connection mode
pub fn connection_mode(mut self, mode: ConnectionMode) -> Self {
self.config.connection_mode = mode;
self
}
/// Set HTTP connection mode
pub fn http_connection(mut self) -> Self {
self.config.connection_mode = ConnectionMode::Http;
self
}
/// Set gRPC connection mode with optional port
pub fn grpc_connection(mut self, port: Option<u16>) -> Self {
self.config.connection_mode = ConnectionMode::Grpc { port };
self
}
/// Set gRPC connection mode without specifying a port
pub fn grpc_connection_default(mut self) -> Self {
self.config.connection_mode = ConnectionMode::Grpc { port: None };
self
}
/// Set host address
pub fn host<S: Into<String>>(mut self, host: S) -> Self {
self.config.host = host.into();
self
}
/// Set port number
pub fn port(mut self, port: u16) -> Self {
self.config.port = port;
self
}
// ==================== Request Settings ====================
/// Set maximum payload size in bytes
pub fn max_payload_size(mut self, size: usize) -> Self {
self.config.max_payload_size = size;
self
}
/// Set request timeout in seconds
pub fn request_timeout_secs(mut self, timeout: u64) -> Self {
self.config.request_timeout_secs = timeout;
self
}
/// Set worker startup timeout in seconds
pub fn worker_startup_timeout_secs(mut self, timeout: u64) -> Self {
self.config.worker_startup_timeout_secs = timeout;
self
}
/// Set worker startup check interval in seconds
pub fn worker_startup_check_interval_secs(mut self, interval: u64) -> Self {
self.config.worker_startup_check_interval_secs = interval;
self
}
// ==================== Rate Limiting ====================
/// Set maximum concurrent requests
pub fn max_concurrent_requests(mut self, max: i32) -> Self {
self.config.max_concurrent_requests = max;
self
}
/// Disable rate limiting
pub fn disable_rate_limiting(mut self) -> Self {
self.config.max_concurrent_requests = -1;
self
}
/// Set queue size for pending requests
pub fn queue_size(mut self, size: usize) -> Self {
self.config.queue_size = size;
self
}
/// Set queue timeout in seconds
pub fn queue_timeout_secs(mut self, timeout: u64) -> Self {
self.config.queue_timeout_secs = timeout;
self
}
/// Set rate limit tokens per second
pub fn rate_limit_tokens_per_second(mut self, tokens: i32) -> Self {
self.config.rate_limit_tokens_per_second = Some(tokens);
self
}
// ==================== Security & CORS ====================
/// Set API key for worker authorization
pub fn api_key<S: Into<String>>(mut self, key: S) -> Self {
self.config.api_key = Some(key.into());
self
}
/// Set CORS allowed origins
pub fn cors_allowed_origins(mut self, origins: Vec<String>) -> Self {
self.config.cors_allowed_origins = origins;
self
}
/// Add a single CORS origin
pub fn add_cors_origin<S: Into<String>>(mut self, origin: S) -> Self {
self.config.cors_allowed_origins.push(origin.into());
self
}
// ==================== Retry Configuration ====================
/// Set retry configuration
pub fn retry_config(mut self, retry: RetryConfig) -> Self {
self.config.retry = retry;
self
}
/// Disable retries
pub fn disable_retries(mut self) -> Self {
self.config.disable_retries = true;
self
}
/// Enable retries
pub fn enable_retries(mut self) -> Self {
self.config.disable_retries = false;
self
}
// ==================== Circuit Breaker Configuration ====================
/// Set circuit breaker configuration
pub fn circuit_breaker_config(mut self, circuit_breaker: CircuitBreakerConfig) -> Self {
self.config.circuit_breaker = circuit_breaker;
self
}
/// Disable circuit breaker
pub fn disable_circuit_breaker(mut self) -> Self {
self.config.disable_circuit_breaker = true;
self
}
/// Enable circuit breaker
pub fn enable_circuit_breaker(mut self) -> Self {
self.config.disable_circuit_breaker = false;
self
}
// ==================== Health Check Configuration ====================
/// Set health check configuration
pub fn health_check_config(mut self, health_check: HealthCheckConfig) -> Self {
self.config.health_check = health_check;
self
}
// ==================== Discovery Configuration ====================
/// Set service discovery configuration
pub fn discovery_config(mut self, discovery: DiscoveryConfig) -> Self {
self.config.discovery = Some(discovery);
self
}
/// Enable service discovery with default settings
pub fn enable_discovery(mut self) -> Self {
self.config.discovery = Some(DiscoveryConfig {
enabled: true,
..Default::default()
});
self
}
// ==================== Metrics Configuration ====================
/// Set metrics configuration
pub fn metrics_config(mut self, metrics: MetricsConfig) -> Self {
self.config.metrics = Some(metrics);
self
}
/// Enable metrics with host and port
pub fn enable_metrics<S: Into<String>>(mut self, host: S, port: u16) -> Self {
self.config.metrics = Some(MetricsConfig {
host: host.into(),
port,
});
self
}
// ==================== Logging Configuration ====================
/// Set log directory
pub fn log_dir<S: Into<String>>(mut self, dir: S) -> Self {
self.config.log_dir = Some(dir.into());
self
}
/// Set log level
pub fn log_level<S: Into<String>>(mut self, level: S) -> Self {
self.config.log_level = Some(level.into());
self
}
/// Set custom request ID headers
pub fn request_id_headers(mut self, headers: Vec<String>) -> Self {
self.config.request_id_headers = Some(headers);
self
}
// ==================== IGW Mode Configuration ====================
/// Enable Inference Gateway mode
pub fn enable_igw(mut self) -> Self {
self.config.enable_igw = true;
self
}
/// Disable Inference Gateway mode (use proxy mode)
pub fn disable_igw(mut self) -> Self {
self.config.enable_igw = false;
self
}
/// Set model path for tokenizer
pub fn model_path<S: Into<String>>(mut self, path: S) -> Self {
self.config.model_path = Some(path.into());
self
}
/// Set tokenizer path (overrides model_path tokenizer)
pub fn tokenizer_path<S: Into<String>>(mut self, path: S) -> Self {
self.config.tokenizer_path = Some(path.into());
self
}
/// Set chat template path
pub fn chat_template<S: Into<String>>(mut self, path: S) -> Self {
self.config.chat_template = Some(path.into());
self
}
// ==================== History Backend Configuration ====================
/// Set history backend
pub fn history_backend(mut self, backend: HistoryBackend) -> Self {
self.config.history_backend = backend;
self
}
/// Use memory history backend
pub fn memory_history(mut self) -> Self {
self.config.history_backend = HistoryBackend::Memory;
self
}
/// Disable history storage
pub fn no_history(mut self) -> Self {
self.config.history_backend = HistoryBackend::None;
self
}
/// Use Oracle history backend
pub fn oracle_history(mut self, oracle_config: OracleConfig) -> Self {
self.config.history_backend = HistoryBackend::Oracle;
self.config.oracle = Some(oracle_config);
self
}
// ==================== Parsers Configuration ====================
/// Set reasoning parser
pub fn reasoning_parser<S: Into<String>>(mut self, parser: S) -> Self {
self.config.reasoning_parser = Some(parser.into());
self
}
/// Set tool call parser
pub fn tool_call_parser<S: Into<String>>(mut self, parser: S) -> Self {
self.config.tool_call_parser = Some(parser.into());
self
}
// ==================== Tokenizer Cache Configuration ====================
/// Set tokenizer cache configuration
pub fn tokenizer_cache(mut self, cache: TokenizerCacheConfig) -> Self {
self.config.tokenizer_cache = cache;
self
}
/// Enable L0 cache with entry limit
pub fn enable_l0_cache(mut self, max_entries: usize) -> Self {
self.config.tokenizer_cache.enable_l0 = true;
self.config.tokenizer_cache.l0_max_entries = max_entries;
self
}
/// Enable L1 cache with memory limit
pub fn enable_l1_cache(mut self, max_memory: usize) -> Self {
self.config.tokenizer_cache.enable_l1 = true;
self.config.tokenizer_cache.l1_max_memory = max_memory;
self
}
// ==================== Data Parallelism ====================
/// Enable data parallelism aware scheduling
pub fn enable_dp_aware(mut self) -> Self {
self.config.dp_aware = true;
self
}
/// Disable data parallelism aware scheduling
pub fn disable_dp_aware(mut self) -> Self {
self.config.dp_aware = false;
self
}
// ==================== Conditional Boolean Setters ====================
// These methods accept bool parameters to conditionally set flags,
// eliminating the need for if statements in calling code
/// Set dp_aware flag conditionally
pub fn dp_aware(mut self, enable: bool) -> Self {
self.config.dp_aware = enable;
self
}
/// Enable or disable retries (inverse of disable_retries field)
pub fn retries(mut self, enable: bool) -> Self {
self.config.disable_retries = !enable;
self
}
/// Enable or disable circuit breaker (inverse of disable_circuit_breaker field)
pub fn circuit_breaker(mut self, enable: bool) -> Self {
self.config.disable_circuit_breaker = !enable;
self
}
/// Set enable_igw flag conditionally
pub fn igw(mut self, enable: bool) -> Self {
self.config.enable_igw = enable;
self
}
// ==================== Option-Aware Setters ====================
// These methods accept Option<T> and only set if Some, making it easier
// to conditionally set values without if-let chains
/// Set API key if Some
pub fn maybe_api_key(mut self, key: Option<impl Into<String>>) -> Self {
if let Some(k) = key {
self.config.api_key = Some(k.into());
}
self
}
/// Set discovery config if Some
pub fn maybe_discovery(mut self, discovery: Option<DiscoveryConfig>) -> Self {
self.config.discovery = discovery;
self
}
/// Set metrics config if Some
pub fn maybe_metrics(mut self, metrics: Option<MetricsConfig>) -> Self {
self.config.metrics = metrics;
self
}
/// Set log directory if Some
pub fn maybe_log_dir(mut self, dir: Option<impl Into<String>>) -> Self {
self.config.log_dir = dir.map(|d| d.into());
self
}
/// Set log level if Some
pub fn maybe_log_level(mut self, level: Option<impl Into<String>>) -> Self {
self.config.log_level = level.map(|l| l.into());
self
}
/// Set request ID headers if Some
pub fn maybe_request_id_headers(mut self, headers: Option<Vec<String>>) -> Self {
self.config.request_id_headers = headers;
self
}
/// Set rate limit tokens per second if Some
pub fn maybe_rate_limit_tokens_per_second(mut self, tokens: Option<i32>) -> Self {
self.config.rate_limit_tokens_per_second = tokens;
self
}
/// Set model path if Some
pub fn maybe_model_path(mut self, path: Option<impl Into<String>>) -> Self {
self.config.model_path = path.map(|p| p.into());
self
}
/// Set tokenizer path if Some
pub fn maybe_tokenizer_path(mut self, path: Option<impl Into<String>>) -> Self {
self.config.tokenizer_path = path.map(|p| p.into());
self
}
/// Set chat template if Some
pub fn maybe_chat_template(mut self, template: Option<impl Into<String>>) -> Self {
self.config.chat_template = template.map(|t| t.into());
self
}
/// Set oracle config if Some
pub fn maybe_oracle(mut self, oracle: Option<OracleConfig>) -> Self {
if let Some(cfg) = oracle {
self.config.history_backend = HistoryBackend::Oracle;
self.config.oracle = Some(cfg);
}
self
}
/// Set reasoning parser if Some
pub fn maybe_reasoning_parser(mut self, parser: Option<impl Into<String>>) -> Self {
self.config.reasoning_parser = parser.map(|p| p.into());
self
}
/// Set tool call parser if Some
pub fn maybe_tool_call_parser(mut self, parser: Option<impl Into<String>>) -> Self {
self.config.tool_call_parser = parser.map(|p| p.into());
self
}
// ==================== Builder Methods ====================
/// Build the RouterConfig, validating if requested
pub fn build(self) -> ConfigResult<RouterConfig> {
self.build_with_validation(true)
}
/// Build the RouterConfig without validation
pub fn build_unchecked(self) -> RouterConfig {
self.into()
}
/// Build with optional validation
pub fn build_with_validation(self, validate: bool) -> ConfigResult<RouterConfig> {
let config: RouterConfig = self.into();
if validate {
config.validate()?;
}
Ok(config)
}
}
impl From<RouterConfigBuilder> for RouterConfig {
fn from(builder: RouterConfigBuilder) -> Self {
builder.config
}
}
impl RouterConfig {
/// Create a builder for RouterConfig
pub fn builder() -> RouterConfigBuilder {
RouterConfigBuilder::new()
}
/// Create a builder from this configuration
pub fn to_builder(&self) -> RouterConfigBuilder {
RouterConfigBuilder::from_config_ref(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
/// Test that .to_builder() round-trip conversion works correctly
#[test]
fn test_builder_from_existing_config() {
let original = RouterConfigBuilder::new()
.regular_mode(vec!["http://worker1:8000".to_string()])
.port(3000)
.build()
.unwrap();
let modified = original
.to_builder()
.port(4000)
.enable_metrics("0.0.0.0", 29000)
.build()
.unwrap();
assert_eq!(modified.port, 4000);
assert!(modified.metrics.is_some());
}
/// Test complex routing mode helper method
#[test]
fn test_builder_prefill_decode_mode() {
let config = RouterConfigBuilder::new()
.prefill_decode_mode(
vec![("http://prefill:8000".to_string(), Some(8001))],
vec!["http://decode:8000".to_string()],
)
.power_of_two_policy(60)
.build()
.unwrap();
assert!(config.mode.is_pd_mode());
assert_eq!(config.mode.worker_count(), 2);
}
/// Test complex policy helper method with multiple parameters
#[test]
fn test_builder_cache_aware_policy() {
let config = RouterConfigBuilder::new()
.regular_mode(vec!["http://worker1:8000".to_string()])
.cache_aware_policy(0.8, 10, 1.5, 300, 1000)
.build()
.unwrap();
match config.policy {
PolicyConfig::CacheAware {
cache_threshold, ..
} => {
assert!((cache_threshold - 0.8).abs() < 0.0001);
}
_ => panic!("Expected CacheAware policy"),
}
}
}
pub mod builder;
pub mod types;
pub mod validation;
pub use builder::*;
pub use types::*;
pub use validation::*;
......
......@@ -609,17 +609,14 @@ mod tests {
#[test]
fn test_router_config_serialization() {
let config = RouterConfig {
mode: RoutingMode::Regular {
worker_urls: vec!["http://worker1".to_string()],
},
policy: PolicyConfig::Random,
host: "0.0.0.0".to_string(),
port: 8080,
log_dir: Some("/var/log".to_string()),
log_level: Some("debug".to_string()),
..Default::default()
};
let config = RouterConfig::builder()
.regular_mode(vec!["http://worker1".to_string()])
.random_policy()
.host("0.0.0.0")
.port(8080)
.log_dir("/var/log")
.log_level("debug")
.build_unchecked();
let json = serde_json::to_string(&config).unwrap();
let deserialized: RouterConfig = serde_json::from_str(&json).unwrap();
......@@ -866,23 +863,14 @@ mod tests {
#[test]
fn test_mode_type() {
let config = RouterConfig {
mode: RoutingMode::Regular {
worker_urls: vec![],
},
..Default::default()
};
let config = RouterConfig::builder()
.regular_mode(vec![])
.build_unchecked();
assert_eq!(config.mode_type(), "regular");
let config = RouterConfig {
mode: RoutingMode::PrefillDecode {
prefill_urls: vec![],
decode_urls: vec![],
prefill_policy: None,
decode_policy: None,
},
..Default::default()
};
let config = RouterConfig::builder()
.prefill_decode_mode(vec![], vec![])
.build_unchecked();
assert_eq!(config.mode_type(), "prefill_decode");
}
......@@ -891,22 +879,15 @@ mod tests {
let config = RouterConfig::default();
assert!(!config.has_service_discovery());
let config = RouterConfig {
discovery: Some(DiscoveryConfig {
let config = RouterConfig::builder()
.discovery_config(DiscoveryConfig {
enabled: false,
..Default::default()
}),
..Default::default()
};
})
.build_unchecked();
assert!(!config.has_service_discovery());
let config = RouterConfig {
discovery: Some(DiscoveryConfig {
enabled: true,
..Default::default()
}),
..Default::default()
};
let config = RouterConfig::builder().enable_discovery().build_unchecked();
assert!(config.has_service_discovery());
}
......@@ -915,10 +896,9 @@ mod tests {
let config = RouterConfig::default();
assert!(!config.has_metrics());
let config = RouterConfig {
metrics: Some(MetricsConfig::default()),
..Default::default()
};
let config = RouterConfig::builder()
.metrics_config(MetricsConfig::default())
.build_unchecked();
assert!(config.has_metrics());
}
......@@ -926,16 +906,11 @@ mod tests {
fn test_large_worker_lists() {
let large_urls: Vec<String> = (0..1000).map(|i| format!("http://worker{}", i)).collect();
let mode = RoutingMode::Regular {
worker_urls: large_urls.clone(),
};
let config = RouterConfig::builder()
.regular_mode(large_urls.clone())
.build_unchecked();
assert_eq!(mode.worker_count(), 1000);
let config = RouterConfig {
mode,
..Default::default()
};
assert_eq!(config.mode.worker_count(), 1000);
let json = serde_json::to_string(&config).unwrap();
let deserialized: RouterConfig = serde_json::from_str(&json).unwrap();
......@@ -950,13 +925,13 @@ mod tests {
#[test]
fn test_unicode_in_config() {
let config = RouterConfig {
mode: RoutingMode::Regular {
worker_urls: vec!["http://работник1".to_string(), "http://工作者2".to_string()],
},
log_dir: Some("/日志/目录".to_string()),
..Default::default()
};
let config = RouterConfig::builder()
.regular_mode(vec![
"http://работник1".to_string(),
"http://工作者2".to_string(),
])
.log_dir("/日志/目录")
.build_unchecked();
let json = serde_json::to_string(&config).unwrap();
let deserialized: RouterConfig = serde_json::from_str(&json).unwrap();
......@@ -974,12 +949,11 @@ mod tests {
#[test]
fn test_empty_string_fields() {
let config = RouterConfig {
host: "".to_string(),
log_dir: Some("".to_string()),
log_level: Some("".to_string()),
..Default::default()
};
let config = RouterConfig::builder()
.host("")
.log_dir("")
.log_level("")
.build_unchecked();
assert_eq!(config.host, "");
assert_eq!(config.log_dir, Some("".to_string()));
......@@ -988,63 +962,34 @@ mod tests {
#[test]
fn test_full_pd_mode_config() {
let config = RouterConfig {
mode: RoutingMode::PrefillDecode {
prefill_urls: vec![
let config = RouterConfig::builder()
.prefill_decode_mode(
vec![
("http://prefill1:8000".to_string(), Some(8001)),
("http://prefill2:8000".to_string(), None),
],
decode_urls: vec![
vec![
"http://decode1:8000".to_string(),
"http://decode2:8000".to_string(),
],
prefill_policy: None,
decode_policy: None,
},
policy: PolicyConfig::PowerOfTwo {
load_check_interval_secs: 30,
},
host: "0.0.0.0".to_string(),
port: 3000,
max_payload_size: 1048576,
request_timeout_secs: 120,
worker_startup_timeout_secs: 60,
worker_startup_check_interval_secs: 5,
dp_aware: false,
api_key: None,
discovery: Some(DiscoveryConfig {
)
.power_of_two_policy(30)
.host("0.0.0.0")
.port(3000)
.max_payload_size(1048576)
.request_timeout_secs(120)
.worker_startup_timeout_secs(60)
.worker_startup_check_interval_secs(5)
.discovery_config(DiscoveryConfig {
enabled: true,
namespace: Some("sglang".to_string()),
..Default::default()
}),
metrics: Some(MetricsConfig {
port: 9090,
host: "0.0.0.0".to_string(),
}),
log_dir: Some("/var/log/sglang".to_string()),
log_level: Some("info".to_string()),
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
queue_size: 100,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
chat_template: None,
history_backend: default_history_backend(),
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: TokenizerCacheConfig::default(),
};
})
.enable_metrics("0.0.0.0", 9090)
.log_dir("/var/log/sglang")
.log_level("info")
.max_concurrent_requests(64)
.build_unchecked();
assert!(config.mode.is_pd_mode());
assert_eq!(config.mode.worker_count(), 4);
......@@ -1058,62 +1003,31 @@ mod tests {
let mut selector = HashMap::new();
selector.insert("app".to_string(), "sglang".to_string());
let config = RouterConfig {
mode: RoutingMode::Regular {
worker_urls: vec![
"http://worker1:8000".to_string(),
"http://worker2:8000".to_string(),
"http://worker3:8000".to_string(),
],
},
policy: PolicyConfig::CacheAware {
cache_threshold: 0.9,
balance_abs_threshold: 5,
balance_rel_threshold: 1.2,
eviction_interval_secs: 600,
max_tree_size: 10000,
},
host: "0.0.0.0".to_string(),
port: 3001,
max_payload_size: 536870912,
request_timeout_secs: 300,
worker_startup_timeout_secs: 180,
worker_startup_check_interval_secs: 15,
dp_aware: false,
api_key: None,
discovery: Some(DiscoveryConfig {
let config = RouterConfig::builder()
.regular_mode(vec![
"http://worker1:8000".to_string(),
"http://worker2:8000".to_string(),
"http://worker3:8000".to_string(),
])
.cache_aware_policy(0.9, 5, 1.2, 600, 10000)
.host("0.0.0.0")
.port(3001)
.max_payload_size(536870912)
.request_timeout_secs(300)
.worker_startup_timeout_secs(180)
.worker_startup_check_interval_secs(15)
.discovery_config(DiscoveryConfig {
enabled: true,
namespace: None,
port: 8080,
check_interval_secs: 45,
selector,
..Default::default()
}),
metrics: Some(MetricsConfig::default()),
log_dir: None,
log_level: Some("debug".to_string()),
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
queue_size: 100,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
chat_template: None,
history_backend: default_history_backend(),
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: TokenizerCacheConfig::default(),
};
})
.metrics_config(MetricsConfig::default())
.log_level("debug")
.max_concurrent_requests(64)
.build_unchecked();
assert!(!config.mode.is_pd_mode());
assert_eq!(config.mode.worker_count(), 3);
......@@ -1128,20 +1042,16 @@ mod tests {
selectors.insert("env".to_string(), "prod".to_string());
selectors.insert("version".to_string(), "v1".to_string());
let config = RouterConfig {
mode: RoutingMode::Regular {
worker_urls: vec!["http://worker1".to_string()],
},
policy: PolicyConfig::RoundRobin,
host: "::1".to_string(), // IPv6
port: 8888,
max_payload_size: 1024 * 1024 * 512, // 512MB
request_timeout_secs: 900,
worker_startup_timeout_secs: 600,
worker_startup_check_interval_secs: 20,
dp_aware: false,
api_key: None,
discovery: Some(DiscoveryConfig {
let config = RouterConfig::builder()
.regular_mode(vec!["http://worker1".to_string()])
.round_robin_policy()
.host("::1") // IPv6
.port(8888)
.max_payload_size(1024 * 1024 * 512) // 512MB
.request_timeout_secs(900)
.worker_startup_timeout_secs(600)
.worker_startup_check_interval_secs(20)
.discovery_config(DiscoveryConfig {
enabled: true,
namespace: Some("production".to_string()),
port: 8443,
......@@ -1150,35 +1060,12 @@ mod tests {
prefill_selector: selectors.clone(),
decode_selector: selectors,
bootstrap_port_annotation: "mycompany.io/bootstrap".to_string(),
}),
metrics: Some(MetricsConfig {
port: 9999,
host: "::".to_string(), // IPv6 any
}),
log_dir: Some("/opt/logs/sglang".to_string()),
log_level: Some("trace".to_string()),
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
queue_size: 100,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
chat_template: None,
history_backend: default_history_backend(),
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: TokenizerCacheConfig::default(),
};
})
.enable_metrics("::", 9999) // IPv6 any
.log_dir("/opt/logs/sglang")
.log_level("trace")
.max_concurrent_requests(64)
.build_unchecked();
assert!(config.has_service_discovery());
assert!(config.has_metrics());
......
......@@ -302,65 +302,66 @@ impl Router {
None
};
Ok(config::RouterConfig {
mode,
policy,
host: self.host.clone(),
port: self.port,
connection_mode: self.connection_mode.clone(),
max_payload_size: self.max_payload_size,
request_timeout_secs: self.request_timeout_secs,
worker_startup_timeout_secs: self.worker_startup_timeout_secs,
worker_startup_check_interval_secs: self.worker_startup_check_interval,
dp_aware: self.dp_aware,
api_key: self.api_key.clone(),
discovery,
metrics,
log_dir: self.log_dir.clone(),
log_level: self.log_level.clone(),
request_id_headers: self.request_id_headers.clone(),
max_concurrent_requests: self.max_concurrent_requests,
queue_size: self.queue_size,
queue_timeout_secs: self.queue_timeout_secs,
rate_limit_tokens_per_second: self.rate_limit_tokens_per_second,
cors_allowed_origins: self.cors_allowed_origins.clone(),
retry: config::RetryConfig {
let builder = config::RouterConfig::builder()
.mode(mode)
.policy(policy)
.host(&self.host)
.port(self.port)
.connection_mode(self.connection_mode.clone())
.max_payload_size(self.max_payload_size)
.request_timeout_secs(self.request_timeout_secs)
.worker_startup_timeout_secs(self.worker_startup_timeout_secs)
.worker_startup_check_interval_secs(self.worker_startup_check_interval)
.max_concurrent_requests(self.max_concurrent_requests)
.queue_size(self.queue_size)
.queue_timeout_secs(self.queue_timeout_secs)
.cors_allowed_origins(self.cors_allowed_origins.clone())
.retry_config(config::RetryConfig {
max_retries: self.retry_max_retries,
initial_backoff_ms: self.retry_initial_backoff_ms,
max_backoff_ms: self.retry_max_backoff_ms,
backoff_multiplier: self.retry_backoff_multiplier,
jitter_factor: self.retry_jitter_factor,
},
circuit_breaker: config::CircuitBreakerConfig {
})
.circuit_breaker_config(config::CircuitBreakerConfig {
failure_threshold: self.cb_failure_threshold,
success_threshold: self.cb_success_threshold,
timeout_duration_secs: self.cb_timeout_duration_secs,
window_duration_secs: self.cb_window_duration_secs,
},
disable_retries: self.disable_retries,
disable_circuit_breaker: self.disable_circuit_breaker,
health_check: config::HealthCheckConfig {
})
.health_check_config(config::HealthCheckConfig {
failure_threshold: self.health_failure_threshold,
success_threshold: self.health_success_threshold,
timeout_secs: self.health_check_timeout_secs,
check_interval_secs: self.health_check_interval_secs,
endpoint: self.health_check_endpoint.clone(),
},
enable_igw: self.enable_igw,
model_path: self.model_path.clone(),
tokenizer_path: self.tokenizer_path.clone(),
chat_template: self.chat_template.clone(),
history_backend,
oracle,
reasoning_parser: self.reasoning_parser.clone(),
tool_call_parser: self.tool_call_parser.clone(),
tokenizer_cache: config::TokenizerCacheConfig {
})
.tokenizer_cache(config::TokenizerCacheConfig {
enable_l0: self.tokenizer_cache_enable_l0,
l0_max_entries: self.tokenizer_cache_l0_max_entries,
enable_l1: self.tokenizer_cache_enable_l1,
l1_max_memory: self.tokenizer_cache_l1_max_memory,
},
})
})
.history_backend(history_backend)
.maybe_api_key(self.api_key.as_ref())
.maybe_discovery(discovery)
.maybe_metrics(metrics)
.maybe_log_dir(self.log_dir.as_ref())
.maybe_log_level(self.log_level.as_ref())
.maybe_request_id_headers(self.request_id_headers.clone())
.maybe_rate_limit_tokens_per_second(self.rate_limit_tokens_per_second)
.maybe_model_path(self.model_path.as_ref())
.maybe_tokenizer_path(self.tokenizer_path.as_ref())
.maybe_chat_template(self.chat_template.as_ref())
.maybe_oracle(oracle)
.maybe_reasoning_parser(self.reasoning_parser.as_ref())
.maybe_tool_call_parser(self.tool_call_parser.as_ref())
.dp_aware(self.dp_aware)
.retries(!self.disable_retries)
.circuit_breaker(!self.disable_circuit_breaker)
.igw(self.enable_igw);
builder.build()
}
}
......
......@@ -538,69 +538,68 @@ impl CliArgs {
None
};
Ok(RouterConfig {
mode,
policy,
connection_mode,
host: self.host.clone(),
port: self.port,
max_payload_size: self.max_payload_size,
request_timeout_secs: self.request_timeout_secs,
worker_startup_timeout_secs: self.worker_startup_timeout_secs,
worker_startup_check_interval_secs: self.worker_startup_check_interval,
dp_aware: self.dp_aware,
api_key: self.api_key.clone(),
discovery,
metrics,
log_dir: self.log_dir.clone(),
log_level: Some(self.log_level.clone()),
request_id_headers: if self.request_id_headers.is_empty() {
None
} else {
Some(self.request_id_headers.clone())
},
max_concurrent_requests: self.max_concurrent_requests,
queue_size: self.queue_size,
queue_timeout_secs: self.queue_timeout_secs,
cors_allowed_origins: self.cors_allowed_origins.clone(),
retry: RetryConfig {
let builder = RouterConfig::builder()
.mode(mode)
.policy(policy)
.connection_mode(connection_mode)
.host(&self.host)
.port(self.port)
.max_payload_size(self.max_payload_size)
.request_timeout_secs(self.request_timeout_secs)
.worker_startup_timeout_secs(self.worker_startup_timeout_secs)
.worker_startup_check_interval_secs(self.worker_startup_check_interval)
.max_concurrent_requests(self.max_concurrent_requests)
.queue_size(self.queue_size)
.queue_timeout_secs(self.queue_timeout_secs)
.cors_allowed_origins(self.cors_allowed_origins.clone())
.retry_config(RetryConfig {
max_retries: self.retry_max_retries,
initial_backoff_ms: self.retry_initial_backoff_ms,
max_backoff_ms: self.retry_max_backoff_ms,
backoff_multiplier: self.retry_backoff_multiplier,
jitter_factor: self.retry_jitter_factor,
},
circuit_breaker: CircuitBreakerConfig {
})
.circuit_breaker_config(CircuitBreakerConfig {
failure_threshold: self.cb_failure_threshold,
success_threshold: self.cb_success_threshold,
timeout_duration_secs: self.cb_timeout_duration_secs,
window_duration_secs: self.cb_window_duration_secs,
},
disable_retries: self.disable_retries,
disable_circuit_breaker: self.disable_circuit_breaker,
health_check: HealthCheckConfig {
})
.health_check_config(HealthCheckConfig {
failure_threshold: self.health_failure_threshold,
success_threshold: self.health_success_threshold,
timeout_secs: self.health_check_timeout_secs,
check_interval_secs: self.health_check_interval_secs,
endpoint: self.health_check_endpoint.clone(),
},
enable_igw: self.enable_igw,
rate_limit_tokens_per_second: self.rate_limit_tokens_per_second,
model_path: self.model_path.clone(),
tokenizer_path: self.tokenizer_path.clone(),
chat_template: self.chat_template.clone(),
history_backend,
oracle,
reasoning_parser: self.reasoning_parser.clone(),
tool_call_parser: self.tool_call_parser.clone(),
tokenizer_cache: TokenizerCacheConfig {
})
.tokenizer_cache(TokenizerCacheConfig {
enable_l0: self.tokenizer_cache_enable_l0,
l0_max_entries: self.tokenizer_cache_l0_max_entries,
enable_l1: self.tokenizer_cache_enable_l1,
l1_max_memory: self.tokenizer_cache_l1_max_memory,
},
})
})
.history_backend(history_backend)
.log_level(&self.log_level)
.maybe_api_key(self.api_key.as_ref())
.maybe_discovery(discovery)
.maybe_metrics(metrics)
.maybe_log_dir(self.log_dir.as_ref())
.maybe_request_id_headers(
(!self.request_id_headers.is_empty()).then(|| self.request_id_headers.clone()),
)
.maybe_rate_limit_tokens_per_second(self.rate_limit_tokens_per_second)
.maybe_model_path(self.model_path.as_ref())
.maybe_tokenizer_path(self.tokenizer_path.as_ref())
.maybe_chat_template(self.chat_template.as_ref())
.maybe_oracle(oracle)
.maybe_reasoning_parser(self.reasoning_parser.as_ref())
.maybe_tool_call_parser(self.tool_call_parser.as_ref())
.dp_aware(self.dp_aware)
.retries(!self.disable_retries)
.circuit_breaker(!self.disable_circuit_breaker)
.igw(self.enable_igw);
builder.build()
}
fn to_server_config(&self, router_config: RouterConfig) -> ServerConfig {
......
......@@ -565,10 +565,9 @@ mod tests {
async fn create_test_app_context() -> Arc<AppContext> {
use crate::{config::RouterConfig, middleware::TokenBucket};
let router_config = RouterConfig {
worker_startup_timeout_secs: 1,
..Default::default()
};
let router_config = RouterConfig::builder()
.worker_startup_timeout_secs(1)
.build_unchecked();
// Note: Using uninitialized queue for tests to avoid spawning background workers
// Jobs submitted during tests will queue but not be processed
......
......@@ -11,8 +11,8 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType
use reqwest::Client;
use serde_json::json;
use sglang_router_rs::{
config::{CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode},
core::{ConnectionMode, Job},
config::{RouterConfig, RoutingMode},
core::Job,
routers::{RouterFactory, RouterTrait},
server::AppContext,
};
......@@ -30,45 +30,18 @@ struct TestContext {
impl TestContext {
async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self {
// Create default router config
let config = RouterConfig {
chat_template: None,
mode: RoutingMode::Regular {
worker_urls: vec![],
},
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 3002,
max_payload_size: 256 * 1024 * 1024,
request_timeout_secs: 600,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
discovery: None,
dp_aware: false,
api_key: None,
metrics: None,
log_dir: None,
log_level: None,
request_id_headers: None,
max_concurrent_requests: 64,
queue_size: 0,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let config = RouterConfig::builder()
.regular_mode(vec![])
.random_policy()
.host("127.0.0.1")
.port(3002)
.max_payload_size(256 * 1024 * 1024)
.request_timeout_secs(600)
.worker_startup_timeout_secs(1)
.worker_startup_check_interval_secs(1)
.max_concurrent_requests(64)
.queue_timeout_secs(60)
.build_unchecked();
Self::new_with_config(config, worker_configs).await
}
......@@ -1182,45 +1155,18 @@ mod error_tests {
#[tokio::test]
async fn test_payload_too_large() {
// Create context with small payload limit
let config = RouterConfig {
chat_template: None,
mode: RoutingMode::Regular {
worker_urls: vec![],
},
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 3010,
max_payload_size: 1024, // 1KB limit
request_timeout_secs: 600,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: None,
request_id_headers: None,
max_concurrent_requests: 64,
queue_size: 0,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let config = RouterConfig::builder()
.regular_mode(vec![])
.random_policy()
.host("127.0.0.1")
.port(3010)
.max_payload_size(1024) // 1KB limit
.request_timeout_secs(600)
.worker_startup_timeout_secs(1)
.worker_startup_check_interval_secs(1)
.max_concurrent_requests(64)
.queue_timeout_secs(60)
.build_unchecked();
let ctx = TestContext::new_with_config(
config,
......@@ -1509,48 +1455,18 @@ mod pd_mode_tests {
.and_then(|p| p.trim_end_matches('/').parse::<u16>().ok())
.unwrap_or(9000);
let config = RouterConfig {
chat_template: None,
mode: RoutingMode::PrefillDecode {
prefill_urls: vec![(prefill_url, Some(prefill_port))],
decode_urls: vec![decode_url],
prefill_policy: None,
decode_policy: None,
},
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 3011,
max_payload_size: 256 * 1024 * 1024,
request_timeout_secs: 600,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
discovery: None,
metrics: None,
log_dir: None,
dp_aware: false,
api_key: None,
log_level: None,
request_id_headers: None,
max_concurrent_requests: 64,
queue_size: 0,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let config = RouterConfig::builder()
.prefill_decode_mode(vec![(prefill_url, Some(prefill_port))], vec![decode_url])
.random_policy()
.host("127.0.0.1")
.port(3011)
.max_payload_size(256 * 1024 * 1024)
.request_timeout_secs(600)
.worker_startup_timeout_secs(1)
.worker_startup_check_interval_secs(1)
.max_concurrent_requests(64)
.queue_timeout_secs(60)
.build_unchecked();
// Create app context
let app_context = common::create_test_context(config);
......@@ -1676,45 +1592,19 @@ mod request_id_tests {
#[tokio::test]
async fn test_request_id_with_custom_headers() {
// Create config with custom request ID headers
let config = RouterConfig {
chat_template: None,
mode: RoutingMode::Regular {
worker_urls: vec![],
},
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 3002,
max_payload_size: 256 * 1024 * 1024,
request_timeout_secs: 600,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
discovery: None,
metrics: None,
dp_aware: false,
api_key: None,
log_dir: None,
log_level: None,
request_id_headers: Some(vec!["custom-id".to_string(), "trace-id".to_string()]),
max_concurrent_requests: 64,
queue_size: 0,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let config = RouterConfig::builder()
.regular_mode(vec![])
.random_policy()
.host("127.0.0.1")
.port(3002)
.max_payload_size(256 * 1024 * 1024)
.request_timeout_secs(600)
.worker_startup_timeout_secs(1)
.worker_startup_check_interval_secs(1)
.request_id_headers(vec!["custom-id".to_string(), "trace-id".to_string()])
.max_concurrent_requests(64)
.queue_timeout_secs(60)
.build_unchecked();
let ctx = TestContext::new_with_config(
config,
......
......@@ -19,16 +19,12 @@ struct TestContext {
impl TestContext {
async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self {
let mut config = RouterConfig {
chat_template: None,
mode: RoutingMode::Regular {
worker_urls: vec![],
},
port: 3003,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
..Default::default()
};
let mut config = RouterConfig::builder()
.regular_mode(vec![])
.port(3003)
.worker_startup_timeout_secs(1)
.worker_startup_check_interval_secs(1)
.build_unchecked();
let mut workers = Vec::new();
let mut worker_urls = Vec::new();
......
......@@ -14,14 +14,7 @@ use common::{
mock_mcp_server::MockMCPServer,
mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType},
};
use sglang_router_rs::{
config::{
CircuitBreakerConfig, HealthCheckConfig, PolicyConfig, RetryConfig, RouterConfig,
RoutingMode,
},
core::ConnectionMode,
routers::RouterFactory,
};
use sglang_router_rs::{config::RouterConfig, routers::RouterFactory};
#[tokio::test]
async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
......@@ -48,45 +41,19 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
let worker_url = worker.start().await.expect("start worker");
// Build router config (HTTP OpenAI mode)
let router_cfg = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI {
worker_urls: vec![worker_url],
},
connection_mode: ConnectionMode::Http,
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 0,
max_payload_size: 8 * 1024 * 1024,
request_timeout_secs: 60,
worker_startup_timeout_secs: 5,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: Some("warn".to_string()),
request_id_headers: None,
max_concurrent_requests: 32,
queue_size: 0,
queue_timeout_secs: 5,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let router_cfg = RouterConfig::builder()
.openai_mode(vec![worker_url])
.random_policy()
.host("127.0.0.1")
.port(0)
.max_payload_size(8 * 1024 * 1024)
.request_timeout_secs(60)
.worker_startup_timeout_secs(5)
.worker_startup_check_interval_secs(1)
.log_level("warn")
.max_concurrent_requests(32)
.queue_timeout_secs(5)
.build_unchecked();
// Create router and context
let ctx = common::create_test_context(router_cfg);
......@@ -249,45 +216,19 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
#[tokio::test]
async fn test_conversations_crud_basic() {
// Router in OpenAI mode (no actual upstream calls in these tests)
let router_cfg = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI {
worker_urls: vec!["http://localhost".to_string()],
},
connection_mode: ConnectionMode::Http,
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 0,
max_payload_size: 8 * 1024 * 1024,
request_timeout_secs: 60,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: Some("warn".to_string()),
request_id_headers: None,
max_concurrent_requests: 8,
queue_size: 0,
queue_timeout_secs: 5,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let router_cfg = RouterConfig::builder()
.openai_mode(vec!["http://localhost".to_string()])
.random_policy()
.host("127.0.0.1")
.port(0)
.max_payload_size(8 * 1024 * 1024)
.request_timeout_secs(60)
.worker_startup_timeout_secs(1)
.worker_startup_check_interval_secs(1)
.log_level("warn")
.max_concurrent_requests(8)
.queue_timeout_secs(5)
.build_unchecked();
let ctx = common::create_test_context(router_cfg);
let router = RouterFactory::create_router(&ctx).await.expect("router");
......@@ -585,45 +526,19 @@ async fn test_multi_turn_loop_with_mcp() {
let worker_url = worker.start().await.expect("start worker");
// Build router config
let router_cfg = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI {
worker_urls: vec![worker_url],
},
connection_mode: ConnectionMode::Http,
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 0,
max_payload_size: 8 * 1024 * 1024,
request_timeout_secs: 60,
worker_startup_timeout_secs: 5,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: Some("info".to_string()),
request_id_headers: None,
max_concurrent_requests: 32,
queue_size: 0,
queue_timeout_secs: 5,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let router_cfg = RouterConfig::builder()
.openai_mode(vec![worker_url])
.random_policy()
.host("127.0.0.1")
.port(0)
.max_payload_size(8 * 1024 * 1024)
.request_timeout_secs(60)
.worker_startup_timeout_secs(5)
.worker_startup_check_interval_secs(1)
.log_level("info")
.max_concurrent_requests(32)
.queue_timeout_secs(5)
.build_unchecked();
let ctx = common::create_test_context(router_cfg);
let router = RouterFactory::create_router(&ctx).await.expect("router");
......@@ -762,45 +677,19 @@ async fn test_max_tool_calls_limit() {
});
let worker_url = worker.start().await.expect("start worker");
let router_cfg = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI {
worker_urls: vec![worker_url],
},
connection_mode: ConnectionMode::Http,
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 0,
max_payload_size: 8 * 1024 * 1024,
request_timeout_secs: 60,
worker_startup_timeout_secs: 5,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: Some("info".to_string()),
request_id_headers: None,
max_concurrent_requests: 32,
queue_size: 0,
queue_timeout_secs: 5,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let router_cfg = RouterConfig::builder()
.openai_mode(vec![worker_url])
.random_policy()
.host("127.0.0.1")
.port(0)
.max_payload_size(8 * 1024 * 1024)
.request_timeout_secs(60)
.worker_startup_timeout_secs(5)
.worker_startup_check_interval_secs(1)
.log_level("info")
.max_concurrent_requests(32)
.queue_timeout_secs(5)
.build_unchecked();
let ctx = common::create_test_context(router_cfg);
let router = RouterFactory::create_router(&ctx).await.expect("router");
......@@ -905,45 +794,19 @@ async fn setup_streaming_mcp_test() -> (
});
let worker_url = worker.start().await.expect("start worker");
let router_cfg = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI {
worker_urls: vec![worker_url],
},
connection_mode: ConnectionMode::Http,
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 0,
max_payload_size: 8 * 1024 * 1024,
request_timeout_secs: 60,
worker_startup_timeout_secs: 5,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: Some("info".to_string()),
request_id_headers: None,
max_concurrent_requests: 32,
queue_size: 0,
queue_timeout_secs: 5,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let router_cfg = RouterConfig::builder()
.openai_mode(vec![worker_url])
.random_policy()
.host("127.0.0.1")
.port(0)
.max_payload_size(8 * 1024 * 1024)
.request_timeout_secs(60)
.worker_startup_timeout_secs(5)
.worker_startup_check_interval_secs(1)
.log_level("info")
.max_concurrent_requests(32)
.queue_timeout_secs(5)
.build_unchecked();
let ctx = common::create_test_context(router_cfg);
let router = RouterFactory::create_router(&ctx).await.expect("router");
......@@ -1347,45 +1210,19 @@ async fn test_streaming_multi_turn_with_mcp() {
#[tokio::test]
async fn test_conversation_items_create_and_get() {
// Test creating items and getting a specific item
let router_cfg = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI {
worker_urls: vec!["http://localhost".to_string()],
},
connection_mode: ConnectionMode::Http,
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 0,
max_payload_size: 8 * 1024 * 1024,
request_timeout_secs: 60,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: Some("warn".to_string()),
request_id_headers: None,
max_concurrent_requests: 8,
queue_size: 0,
queue_timeout_secs: 5,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let router_cfg = RouterConfig::builder()
.openai_mode(vec!["http://localhost".to_string()])
.random_policy()
.host("127.0.0.1")
.port(0)
.max_payload_size(8 * 1024 * 1024)
.request_timeout_secs(60)
.worker_startup_timeout_secs(1)
.worker_startup_check_interval_secs(1)
.log_level("warn")
.max_concurrent_requests(8)
.queue_timeout_secs(5)
.build_unchecked();
let ctx = common::create_test_context(router_cfg);
let router = RouterFactory::create_router(&ctx).await.expect("router");
......@@ -1449,45 +1286,19 @@ async fn test_conversation_items_create_and_get() {
#[tokio::test]
async fn test_conversation_items_delete() {
// Test deleting an item from a conversation
let router_cfg = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI {
worker_urls: vec!["http://localhost".to_string()],
},
connection_mode: ConnectionMode::Http,
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 0,
max_payload_size: 8 * 1024 * 1024,
request_timeout_secs: 60,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: Some("warn".to_string()),
request_id_headers: None,
max_concurrent_requests: 8,
queue_size: 0,
queue_timeout_secs: 5,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let router_cfg = RouterConfig::builder()
.openai_mode(vec!["http://localhost".to_string()])
.random_policy()
.host("127.0.0.1")
.port(0)
.max_payload_size(8 * 1024 * 1024)
.request_timeout_secs(60)
.worker_startup_timeout_secs(1)
.worker_startup_check_interval_secs(1)
.log_level("warn")
.max_concurrent_requests(8)
.queue_timeout_secs(5)
.build_unchecked();
let ctx = common::create_test_context(router_cfg);
let router = RouterFactory::create_router(&ctx).await.expect("router");
......@@ -1557,45 +1368,19 @@ async fn test_conversation_items_delete() {
#[tokio::test]
async fn test_conversation_items_max_limit() {
// Test that creating > 20 items returns error
let router_cfg = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI {
worker_urls: vec!["http://localhost".to_string()],
},
connection_mode: ConnectionMode::Http,
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 0,
max_payload_size: 8 * 1024 * 1024,
request_timeout_secs: 60,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: Some("warn".to_string()),
request_id_headers: None,
max_concurrent_requests: 8,
queue_size: 0,
queue_timeout_secs: 5,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let router_cfg = RouterConfig::builder()
.openai_mode(vec!["http://localhost".to_string()])
.random_policy()
.host("127.0.0.1")
.port(0)
.max_payload_size(8 * 1024 * 1024)
.request_timeout_secs(60)
.worker_startup_timeout_secs(1)
.worker_startup_check_interval_secs(1)
.log_level("warn")
.max_concurrent_requests(8)
.queue_timeout_secs(5)
.build_unchecked();
let ctx = common::create_test_context(router_cfg);
let router = RouterFactory::create_router(&ctx).await.expect("router");
......@@ -1635,45 +1420,19 @@ async fn test_conversation_items_max_limit() {
#[tokio::test]
async fn test_conversation_items_unsupported_type() {
// Test that unsupported item types return error
let router_cfg = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI {
worker_urls: vec!["http://localhost".to_string()],
},
connection_mode: ConnectionMode::Http,
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 0,
max_payload_size: 8 * 1024 * 1024,
request_timeout_secs: 60,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: Some("warn".to_string()),
request_id_headers: None,
max_concurrent_requests: 8,
queue_size: 0,
queue_timeout_secs: 5,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let router_cfg = RouterConfig::builder()
.openai_mode(vec!["http://localhost".to_string()])
.random_policy()
.host("127.0.0.1")
.port(0)
.max_payload_size(8 * 1024 * 1024)
.request_timeout_secs(60)
.worker_startup_timeout_secs(1)
.worker_startup_check_interval_secs(1)
.log_level("warn")
.max_concurrent_requests(8)
.queue_timeout_secs(5)
.build_unchecked();
let ctx = common::create_test_context(router_cfg);
let router = RouterFactory::create_router(&ctx).await.expect("router");
......@@ -1712,45 +1471,19 @@ async fn test_conversation_items_unsupported_type() {
#[tokio::test]
async fn test_conversation_items_multi_conversation_sharing() {
// Test that items can be shared across conversations via soft delete
let router_cfg = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI {
worker_urls: vec!["http://localhost".to_string()],
},
connection_mode: ConnectionMode::Http,
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 0,
max_payload_size: 8 * 1024 * 1024,
request_timeout_secs: 60,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: Some("warn".to_string()),
request_id_headers: None,
max_concurrent_requests: 8,
queue_size: 0,
queue_timeout_secs: 5,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let router_cfg = RouterConfig::builder()
.openai_mode(vec!["http://localhost".to_string()])
.random_policy()
.host("127.0.0.1")
.port(0)
.max_payload_size(8 * 1024 * 1024)
.request_timeout_secs(60)
.worker_startup_timeout_secs(1)
.worker_startup_check_interval_secs(1)
.log_level("warn")
.max_concurrent_requests(8)
.queue_timeout_secs(5)
.build_unchecked();
let ctx = common::create_test_context(router_cfg);
let router = RouterFactory::create_router(&ctx).await.expect("router");
......
......@@ -20,16 +20,12 @@ struct TestContext {
impl TestContext {
async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self {
let mut config = RouterConfig {
chat_template: None,
mode: RoutingMode::Regular {
worker_urls: vec![],
},
port: 3004,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
..Default::default()
};
let mut config = RouterConfig::builder()
.regular_mode(vec![])
.port(3004)
.worker_startup_timeout_secs(1)
.worker_startup_check_interval_secs(1)
.build_unchecked();
let mut workers = Vec::new();
let mut worker_urls = Vec::new();
......
......@@ -902,17 +902,10 @@ async fn test_openai_router_models_auth_forwarding() {
#[test]
fn oracle_config_validation_requires_config_when_enabled() {
let config = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI {
worker_urls: vec!["https://api.openai.com".to_string()],
},
history_backend: HistoryBackend::Oracle,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
..Default::default()
};
let config = RouterConfig::builder()
.openai_mode(vec!["https://api.openai.com".to_string()])
.history_backend(HistoryBackend::Oracle)
.build_unchecked();
let err =
ConfigValidator::validate(&config).expect_err("config should fail without oracle details");
......@@ -927,13 +920,9 @@ fn oracle_config_validation_requires_config_when_enabled() {
#[test]
fn oracle_config_validation_accepts_dsn_only() {
let config = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI {
worker_urls: vec!["https://api.openai.com".to_string()],
},
history_backend: HistoryBackend::Oracle,
oracle: Some(OracleConfig {
let config = RouterConfig::builder()
.openai_mode(vec!["https://api.openai.com".to_string()])
.oracle_history(OracleConfig {
wallet_path: None,
connect_descriptor: "tcps://db.example.com:1522/service".to_string(),
username: "scott".to_string(),
......@@ -941,22 +930,17 @@ fn oracle_config_validation_accepts_dsn_only() {
pool_min: 1,
pool_max: 4,
pool_timeout_secs: 30,
}),
..Default::default()
};
})
.build_unchecked();
ConfigValidator::validate(&config).expect("dsn-based config should validate");
}
#[test]
fn oracle_config_validation_accepts_wallet_alias() {
let config = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI {
worker_urls: vec!["https://api.openai.com".to_string()],
},
history_backend: HistoryBackend::Oracle,
oracle: Some(OracleConfig {
let config = RouterConfig::builder()
.openai_mode(vec!["https://api.openai.com".to_string()])
.oracle_history(OracleConfig {
wallet_path: Some("/etc/sglang/oracle-wallet".to_string()),
connect_descriptor: "db_low".to_string(),
username: "app_user".to_string(),
......@@ -964,9 +948,8 @@ fn oracle_config_validation_accepts_wallet_alias() {
pool_min: 1,
pool_max: 8,
pool_timeout_secs: 45,
}),
..Default::default()
};
})
.build_unchecked();
ConfigValidator::validate(&config).expect("wallet-based config should validate");
}
......@@ -2,8 +2,8 @@
mod test_pd_routing {
use serde_json::json;
use sglang_router_rs::{
config::{CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode},
core::{BasicWorkerBuilder, ConnectionMode, Worker, WorkerType},
config::{PolicyConfig, RouterConfig, RoutingMode},
core::{BasicWorkerBuilder, Worker, WorkerType},
routers::{http::pd_types::PDSelectionPolicy, RouterFactory},
};
......@@ -162,42 +162,24 @@ mod test_pd_routing {
];
for (mode, policy) in test_cases {
let config = RouterConfig {
chat_template: None,
mode,
policy,
host: "127.0.0.1".to_string(),
port: 3001,
max_payload_size: 1024 * 1024,
request_timeout_secs: 60,
worker_startup_timeout_secs: 10,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: None,
request_id_headers: None,
max_concurrent_requests: 64,
queue_size: 0,
queue_timeout_secs: 60,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
rate_limit_tokens_per_second: None,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
let config = match mode {
RoutingMode::PrefillDecode {
prefill_urls,
decode_urls,
..
} => RouterConfig::builder()
.prefill_decode_mode(prefill_urls, decode_urls)
.policy(policy)
.host("127.0.0.1")
.port(3001)
.max_payload_size(1024 * 1024)
.request_timeout_secs(60)
.worker_startup_timeout_secs(10)
.worker_startup_check_interval_secs(1)
.max_concurrent_requests(64)
.queue_timeout_secs(60)
.build_unchecked(),
_ => panic!("Expected PrefillDecode mode"),
};
let app_context = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment