Unverified Commit 6d6e24bc authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] Add builder pattern for RouterConfig with zero duplication (#12030)

parent 2c057fbf
This diff is collapsed.
pub mod builder;
pub mod types;
pub mod validation;
pub use builder::*;
pub use types::*;
pub use validation::*;
......
......@@ -609,17 +609,14 @@ mod tests {
#[test]
fn test_router_config_serialization() {
let config = RouterConfig {
mode: RoutingMode::Regular {
worker_urls: vec!["http://worker1".to_string()],
},
policy: PolicyConfig::Random,
host: "0.0.0.0".to_string(),
port: 8080,
log_dir: Some("/var/log".to_string()),
log_level: Some("debug".to_string()),
..Default::default()
};
let config = RouterConfig::builder()
.regular_mode(vec!["http://worker1".to_string()])
.random_policy()
.host("0.0.0.0")
.port(8080)
.log_dir("/var/log")
.log_level("debug")
.build_unchecked();
let json = serde_json::to_string(&config).unwrap();
let deserialized: RouterConfig = serde_json::from_str(&json).unwrap();
......@@ -866,23 +863,14 @@ mod tests {
#[test]
fn test_mode_type() {
let config = RouterConfig {
mode: RoutingMode::Regular {
worker_urls: vec![],
},
..Default::default()
};
let config = RouterConfig::builder()
.regular_mode(vec![])
.build_unchecked();
assert_eq!(config.mode_type(), "regular");
let config = RouterConfig {
mode: RoutingMode::PrefillDecode {
prefill_urls: vec![],
decode_urls: vec![],
prefill_policy: None,
decode_policy: None,
},
..Default::default()
};
let config = RouterConfig::builder()
.prefill_decode_mode(vec![], vec![])
.build_unchecked();
assert_eq!(config.mode_type(), "prefill_decode");
}
......@@ -891,22 +879,15 @@ mod tests {
let config = RouterConfig::default();
assert!(!config.has_service_discovery());
let config = RouterConfig {
discovery: Some(DiscoveryConfig {
let config = RouterConfig::builder()
.discovery_config(DiscoveryConfig {
enabled: false,
..Default::default()
}),
..Default::default()
};
})
.build_unchecked();
assert!(!config.has_service_discovery());
let config = RouterConfig {
discovery: Some(DiscoveryConfig {
enabled: true,
..Default::default()
}),
..Default::default()
};
let config = RouterConfig::builder().enable_discovery().build_unchecked();
assert!(config.has_service_discovery());
}
......@@ -915,10 +896,9 @@ mod tests {
let config = RouterConfig::default();
assert!(!config.has_metrics());
let config = RouterConfig {
metrics: Some(MetricsConfig::default()),
..Default::default()
};
let config = RouterConfig::builder()
.metrics_config(MetricsConfig::default())
.build_unchecked();
assert!(config.has_metrics());
}
......@@ -926,16 +906,11 @@ mod tests {
fn test_large_worker_lists() {
let large_urls: Vec<String> = (0..1000).map(|i| format!("http://worker{}", i)).collect();
let mode = RoutingMode::Regular {
worker_urls: large_urls.clone(),
};
let config = RouterConfig::builder()
.regular_mode(large_urls.clone())
.build_unchecked();
assert_eq!(mode.worker_count(), 1000);
let config = RouterConfig {
mode,
..Default::default()
};
assert_eq!(config.mode.worker_count(), 1000);
let json = serde_json::to_string(&config).unwrap();
let deserialized: RouterConfig = serde_json::from_str(&json).unwrap();
......@@ -950,13 +925,13 @@ mod tests {
#[test]
fn test_unicode_in_config() {
let config = RouterConfig {
mode: RoutingMode::Regular {
worker_urls: vec!["http://работник1".to_string(), "http://工作者2".to_string()],
},
log_dir: Some("/日志/目录".to_string()),
..Default::default()
};
let config = RouterConfig::builder()
.regular_mode(vec![
"http://работник1".to_string(),
"http://工作者2".to_string(),
])
.log_dir("/日志/目录")
.build_unchecked();
let json = serde_json::to_string(&config).unwrap();
let deserialized: RouterConfig = serde_json::from_str(&json).unwrap();
......@@ -974,12 +949,11 @@ mod tests {
#[test]
fn test_empty_string_fields() {
let config = RouterConfig {
host: "".to_string(),
log_dir: Some("".to_string()),
log_level: Some("".to_string()),
..Default::default()
};
let config = RouterConfig::builder()
.host("")
.log_dir("")
.log_level("")
.build_unchecked();
assert_eq!(config.host, "");
assert_eq!(config.log_dir, Some("".to_string()));
......@@ -988,63 +962,34 @@ mod tests {
#[test]
fn test_full_pd_mode_config() {
let config = RouterConfig {
mode: RoutingMode::PrefillDecode {
prefill_urls: vec![
let config = RouterConfig::builder()
.prefill_decode_mode(
vec![
("http://prefill1:8000".to_string(), Some(8001)),
("http://prefill2:8000".to_string(), None),
],
decode_urls: vec![
vec![
"http://decode1:8000".to_string(),
"http://decode2:8000".to_string(),
],
prefill_policy: None,
decode_policy: None,
},
policy: PolicyConfig::PowerOfTwo {
load_check_interval_secs: 30,
},
host: "0.0.0.0".to_string(),
port: 3000,
max_payload_size: 1048576,
request_timeout_secs: 120,
worker_startup_timeout_secs: 60,
worker_startup_check_interval_secs: 5,
dp_aware: false,
api_key: None,
discovery: Some(DiscoveryConfig {
)
.power_of_two_policy(30)
.host("0.0.0.0")
.port(3000)
.max_payload_size(1048576)
.request_timeout_secs(120)
.worker_startup_timeout_secs(60)
.worker_startup_check_interval_secs(5)
.discovery_config(DiscoveryConfig {
enabled: true,
namespace: Some("sglang".to_string()),
..Default::default()
}),
metrics: Some(MetricsConfig {
port: 9090,
host: "0.0.0.0".to_string(),
}),
log_dir: Some("/var/log/sglang".to_string()),
log_level: Some("info".to_string()),
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
queue_size: 100,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
chat_template: None,
history_backend: default_history_backend(),
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: TokenizerCacheConfig::default(),
};
})
.enable_metrics("0.0.0.0", 9090)
.log_dir("/var/log/sglang")
.log_level("info")
.max_concurrent_requests(64)
.build_unchecked();
assert!(config.mode.is_pd_mode());
assert_eq!(config.mode.worker_count(), 4);
......@@ -1058,62 +1003,31 @@ mod tests {
let mut selector = HashMap::new();
selector.insert("app".to_string(), "sglang".to_string());
let config = RouterConfig {
mode: RoutingMode::Regular {
worker_urls: vec![
"http://worker1:8000".to_string(),
"http://worker2:8000".to_string(),
"http://worker3:8000".to_string(),
],
},
policy: PolicyConfig::CacheAware {
cache_threshold: 0.9,
balance_abs_threshold: 5,
balance_rel_threshold: 1.2,
eviction_interval_secs: 600,
max_tree_size: 10000,
},
host: "0.0.0.0".to_string(),
port: 3001,
max_payload_size: 536870912,
request_timeout_secs: 300,
worker_startup_timeout_secs: 180,
worker_startup_check_interval_secs: 15,
dp_aware: false,
api_key: None,
discovery: Some(DiscoveryConfig {
let config = RouterConfig::builder()
.regular_mode(vec![
"http://worker1:8000".to_string(),
"http://worker2:8000".to_string(),
"http://worker3:8000".to_string(),
])
.cache_aware_policy(0.9, 5, 1.2, 600, 10000)
.host("0.0.0.0")
.port(3001)
.max_payload_size(536870912)
.request_timeout_secs(300)
.worker_startup_timeout_secs(180)
.worker_startup_check_interval_secs(15)
.discovery_config(DiscoveryConfig {
enabled: true,
namespace: None,
port: 8080,
check_interval_secs: 45,
selector,
..Default::default()
}),
metrics: Some(MetricsConfig::default()),
log_dir: None,
log_level: Some("debug".to_string()),
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
queue_size: 100,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
chat_template: None,
history_backend: default_history_backend(),
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: TokenizerCacheConfig::default(),
};
})
.metrics_config(MetricsConfig::default())
.log_level("debug")
.max_concurrent_requests(64)
.build_unchecked();
assert!(!config.mode.is_pd_mode());
assert_eq!(config.mode.worker_count(), 3);
......@@ -1128,20 +1042,16 @@ mod tests {
selectors.insert("env".to_string(), "prod".to_string());
selectors.insert("version".to_string(), "v1".to_string());
let config = RouterConfig {
mode: RoutingMode::Regular {
worker_urls: vec!["http://worker1".to_string()],
},
policy: PolicyConfig::RoundRobin,
host: "::1".to_string(), // IPv6
port: 8888,
max_payload_size: 1024 * 1024 * 512, // 512MB
request_timeout_secs: 900,
worker_startup_timeout_secs: 600,
worker_startup_check_interval_secs: 20,
dp_aware: false,
api_key: None,
discovery: Some(DiscoveryConfig {
let config = RouterConfig::builder()
.regular_mode(vec!["http://worker1".to_string()])
.round_robin_policy()
.host("::1") // IPv6
.port(8888)
.max_payload_size(1024 * 1024 * 512) // 512MB
.request_timeout_secs(900)
.worker_startup_timeout_secs(600)
.worker_startup_check_interval_secs(20)
.discovery_config(DiscoveryConfig {
enabled: true,
namespace: Some("production".to_string()),
port: 8443,
......@@ -1150,35 +1060,12 @@ mod tests {
prefill_selector: selectors.clone(),
decode_selector: selectors,
bootstrap_port_annotation: "mycompany.io/bootstrap".to_string(),
}),
metrics: Some(MetricsConfig {
port: 9999,
host: "::".to_string(), // IPv6 any
}),
log_dir: Some("/opt/logs/sglang".to_string()),
log_level: Some("trace".to_string()),
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
queue_size: 100,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
chat_template: None,
history_backend: default_history_backend(),
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: TokenizerCacheConfig::default(),
};
})
.enable_metrics("::", 9999) // IPv6 any
.log_dir("/opt/logs/sglang")
.log_level("trace")
.max_concurrent_requests(64)
.build_unchecked();
assert!(config.has_service_discovery());
assert!(config.has_metrics());
......
......@@ -302,65 +302,66 @@ impl Router {
None
};
Ok(config::RouterConfig {
mode,
policy,
host: self.host.clone(),
port: self.port,
connection_mode: self.connection_mode.clone(),
max_payload_size: self.max_payload_size,
request_timeout_secs: self.request_timeout_secs,
worker_startup_timeout_secs: self.worker_startup_timeout_secs,
worker_startup_check_interval_secs: self.worker_startup_check_interval,
dp_aware: self.dp_aware,
api_key: self.api_key.clone(),
discovery,
metrics,
log_dir: self.log_dir.clone(),
log_level: self.log_level.clone(),
request_id_headers: self.request_id_headers.clone(),
max_concurrent_requests: self.max_concurrent_requests,
queue_size: self.queue_size,
queue_timeout_secs: self.queue_timeout_secs,
rate_limit_tokens_per_second: self.rate_limit_tokens_per_second,
cors_allowed_origins: self.cors_allowed_origins.clone(),
retry: config::RetryConfig {
let builder = config::RouterConfig::builder()
.mode(mode)
.policy(policy)
.host(&self.host)
.port(self.port)
.connection_mode(self.connection_mode.clone())
.max_payload_size(self.max_payload_size)
.request_timeout_secs(self.request_timeout_secs)
.worker_startup_timeout_secs(self.worker_startup_timeout_secs)
.worker_startup_check_interval_secs(self.worker_startup_check_interval)
.max_concurrent_requests(self.max_concurrent_requests)
.queue_size(self.queue_size)
.queue_timeout_secs(self.queue_timeout_secs)
.cors_allowed_origins(self.cors_allowed_origins.clone())
.retry_config(config::RetryConfig {
max_retries: self.retry_max_retries,
initial_backoff_ms: self.retry_initial_backoff_ms,
max_backoff_ms: self.retry_max_backoff_ms,
backoff_multiplier: self.retry_backoff_multiplier,
jitter_factor: self.retry_jitter_factor,
},
circuit_breaker: config::CircuitBreakerConfig {
})
.circuit_breaker_config(config::CircuitBreakerConfig {
failure_threshold: self.cb_failure_threshold,
success_threshold: self.cb_success_threshold,
timeout_duration_secs: self.cb_timeout_duration_secs,
window_duration_secs: self.cb_window_duration_secs,
},
disable_retries: self.disable_retries,
disable_circuit_breaker: self.disable_circuit_breaker,
health_check: config::HealthCheckConfig {
})
.health_check_config(config::HealthCheckConfig {
failure_threshold: self.health_failure_threshold,
success_threshold: self.health_success_threshold,
timeout_secs: self.health_check_timeout_secs,
check_interval_secs: self.health_check_interval_secs,
endpoint: self.health_check_endpoint.clone(),
},
enable_igw: self.enable_igw,
model_path: self.model_path.clone(),
tokenizer_path: self.tokenizer_path.clone(),
chat_template: self.chat_template.clone(),
history_backend,
oracle,
reasoning_parser: self.reasoning_parser.clone(),
tool_call_parser: self.tool_call_parser.clone(),
tokenizer_cache: config::TokenizerCacheConfig {
})
.tokenizer_cache(config::TokenizerCacheConfig {
enable_l0: self.tokenizer_cache_enable_l0,
l0_max_entries: self.tokenizer_cache_l0_max_entries,
enable_l1: self.tokenizer_cache_enable_l1,
l1_max_memory: self.tokenizer_cache_l1_max_memory,
},
})
})
.history_backend(history_backend)
.maybe_api_key(self.api_key.as_ref())
.maybe_discovery(discovery)
.maybe_metrics(metrics)
.maybe_log_dir(self.log_dir.as_ref())
.maybe_log_level(self.log_level.as_ref())
.maybe_request_id_headers(self.request_id_headers.clone())
.maybe_rate_limit_tokens_per_second(self.rate_limit_tokens_per_second)
.maybe_model_path(self.model_path.as_ref())
.maybe_tokenizer_path(self.tokenizer_path.as_ref())
.maybe_chat_template(self.chat_template.as_ref())
.maybe_oracle(oracle)
.maybe_reasoning_parser(self.reasoning_parser.as_ref())
.maybe_tool_call_parser(self.tool_call_parser.as_ref())
.dp_aware(self.dp_aware)
.retries(!self.disable_retries)
.circuit_breaker(!self.disable_circuit_breaker)
.igw(self.enable_igw);
builder.build()
}
}
......
......@@ -538,69 +538,68 @@ impl CliArgs {
None
};
Ok(RouterConfig {
mode,
policy,
connection_mode,
host: self.host.clone(),
port: self.port,
max_payload_size: self.max_payload_size,
request_timeout_secs: self.request_timeout_secs,
worker_startup_timeout_secs: self.worker_startup_timeout_secs,
worker_startup_check_interval_secs: self.worker_startup_check_interval,
dp_aware: self.dp_aware,
api_key: self.api_key.clone(),
discovery,
metrics,
log_dir: self.log_dir.clone(),
log_level: Some(self.log_level.clone()),
request_id_headers: if self.request_id_headers.is_empty() {
None
} else {
Some(self.request_id_headers.clone())
},
max_concurrent_requests: self.max_concurrent_requests,
queue_size: self.queue_size,
queue_timeout_secs: self.queue_timeout_secs,
cors_allowed_origins: self.cors_allowed_origins.clone(),
retry: RetryConfig {
let builder = RouterConfig::builder()
.mode(mode)
.policy(policy)
.connection_mode(connection_mode)
.host(&self.host)
.port(self.port)
.max_payload_size(self.max_payload_size)
.request_timeout_secs(self.request_timeout_secs)
.worker_startup_timeout_secs(self.worker_startup_timeout_secs)
.worker_startup_check_interval_secs(self.worker_startup_check_interval)
.max_concurrent_requests(self.max_concurrent_requests)
.queue_size(self.queue_size)
.queue_timeout_secs(self.queue_timeout_secs)
.cors_allowed_origins(self.cors_allowed_origins.clone())
.retry_config(RetryConfig {
max_retries: self.retry_max_retries,
initial_backoff_ms: self.retry_initial_backoff_ms,
max_backoff_ms: self.retry_max_backoff_ms,
backoff_multiplier: self.retry_backoff_multiplier,
jitter_factor: self.retry_jitter_factor,
},
circuit_breaker: CircuitBreakerConfig {
})
.circuit_breaker_config(CircuitBreakerConfig {
failure_threshold: self.cb_failure_threshold,
success_threshold: self.cb_success_threshold,
timeout_duration_secs: self.cb_timeout_duration_secs,
window_duration_secs: self.cb_window_duration_secs,
},
disable_retries: self.disable_retries,
disable_circuit_breaker: self.disable_circuit_breaker,
health_check: HealthCheckConfig {
})
.health_check_config(HealthCheckConfig {
failure_threshold: self.health_failure_threshold,
success_threshold: self.health_success_threshold,
timeout_secs: self.health_check_timeout_secs,
check_interval_secs: self.health_check_interval_secs,
endpoint: self.health_check_endpoint.clone(),
},
enable_igw: self.enable_igw,
rate_limit_tokens_per_second: self.rate_limit_tokens_per_second,
model_path: self.model_path.clone(),
tokenizer_path: self.tokenizer_path.clone(),
chat_template: self.chat_template.clone(),
history_backend,
oracle,
reasoning_parser: self.reasoning_parser.clone(),
tool_call_parser: self.tool_call_parser.clone(),
tokenizer_cache: TokenizerCacheConfig {
})
.tokenizer_cache(TokenizerCacheConfig {
enable_l0: self.tokenizer_cache_enable_l0,
l0_max_entries: self.tokenizer_cache_l0_max_entries,
enable_l1: self.tokenizer_cache_enable_l1,
l1_max_memory: self.tokenizer_cache_l1_max_memory,
},
})
})
.history_backend(history_backend)
.log_level(&self.log_level)
.maybe_api_key(self.api_key.as_ref())
.maybe_discovery(discovery)
.maybe_metrics(metrics)
.maybe_log_dir(self.log_dir.as_ref())
.maybe_request_id_headers(
(!self.request_id_headers.is_empty()).then(|| self.request_id_headers.clone()),
)
.maybe_rate_limit_tokens_per_second(self.rate_limit_tokens_per_second)
.maybe_model_path(self.model_path.as_ref())
.maybe_tokenizer_path(self.tokenizer_path.as_ref())
.maybe_chat_template(self.chat_template.as_ref())
.maybe_oracle(oracle)
.maybe_reasoning_parser(self.reasoning_parser.as_ref())
.maybe_tool_call_parser(self.tool_call_parser.as_ref())
.dp_aware(self.dp_aware)
.retries(!self.disable_retries)
.circuit_breaker(!self.disable_circuit_breaker)
.igw(self.enable_igw);
builder.build()
}
fn to_server_config(&self, router_config: RouterConfig) -> ServerConfig {
......
......@@ -565,10 +565,9 @@ mod tests {
async fn create_test_app_context() -> Arc<AppContext> {
use crate::{config::RouterConfig, middleware::TokenBucket};
let router_config = RouterConfig {
worker_startup_timeout_secs: 1,
..Default::default()
};
let router_config = RouterConfig::builder()
.worker_startup_timeout_secs(1)
.build_unchecked();
// Note: Using uninitialized queue for tests to avoid spawning background workers
// Jobs submitted during tests will queue but not be processed
......
......@@ -11,8 +11,8 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType
use reqwest::Client;
use serde_json::json;
use sglang_router_rs::{
config::{CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode},
core::{ConnectionMode, Job},
config::{RouterConfig, RoutingMode},
core::Job,
routers::{RouterFactory, RouterTrait},
server::AppContext,
};
......@@ -30,45 +30,18 @@ struct TestContext {
impl TestContext {
async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self {
// Create default router config
let config = RouterConfig {
chat_template: None,
mode: RoutingMode::Regular {
worker_urls: vec![],
},
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 3002,
max_payload_size: 256 * 1024 * 1024,
request_timeout_secs: 600,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
discovery: None,
dp_aware: false,
api_key: None,
metrics: None,
log_dir: None,
log_level: None,
request_id_headers: None,
max_concurrent_requests: 64,
queue_size: 0,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let config = RouterConfig::builder()
.regular_mode(vec![])
.random_policy()
.host("127.0.0.1")
.port(3002)
.max_payload_size(256 * 1024 * 1024)
.request_timeout_secs(600)
.worker_startup_timeout_secs(1)
.worker_startup_check_interval_secs(1)
.max_concurrent_requests(64)
.queue_timeout_secs(60)
.build_unchecked();
Self::new_with_config(config, worker_configs).await
}
......@@ -1182,45 +1155,18 @@ mod error_tests {
#[tokio::test]
async fn test_payload_too_large() {
// Create context with small payload limit
let config = RouterConfig {
chat_template: None,
mode: RoutingMode::Regular {
worker_urls: vec![],
},
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 3010,
max_payload_size: 1024, // 1KB limit
request_timeout_secs: 600,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: None,
request_id_headers: None,
max_concurrent_requests: 64,
queue_size: 0,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let config = RouterConfig::builder()
.regular_mode(vec![])
.random_policy()
.host("127.0.0.1")
.port(3010)
.max_payload_size(1024) // 1KB limit
.request_timeout_secs(600)
.worker_startup_timeout_secs(1)
.worker_startup_check_interval_secs(1)
.max_concurrent_requests(64)
.queue_timeout_secs(60)
.build_unchecked();
let ctx = TestContext::new_with_config(
config,
......@@ -1509,48 +1455,18 @@ mod pd_mode_tests {
.and_then(|p| p.trim_end_matches('/').parse::<u16>().ok())
.unwrap_or(9000);
let config = RouterConfig {
chat_template: None,
mode: RoutingMode::PrefillDecode {
prefill_urls: vec![(prefill_url, Some(prefill_port))],
decode_urls: vec![decode_url],
prefill_policy: None,
decode_policy: None,
},
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 3011,
max_payload_size: 256 * 1024 * 1024,
request_timeout_secs: 600,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
discovery: None,
metrics: None,
log_dir: None,
dp_aware: false,
api_key: None,
log_level: None,
request_id_headers: None,
max_concurrent_requests: 64,
queue_size: 0,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let config = RouterConfig::builder()
.prefill_decode_mode(vec![(prefill_url, Some(prefill_port))], vec![decode_url])
.random_policy()
.host("127.0.0.1")
.port(3011)
.max_payload_size(256 * 1024 * 1024)
.request_timeout_secs(600)
.worker_startup_timeout_secs(1)
.worker_startup_check_interval_secs(1)
.max_concurrent_requests(64)
.queue_timeout_secs(60)
.build_unchecked();
// Create app context
let app_context = common::create_test_context(config);
......@@ -1676,45 +1592,19 @@ mod request_id_tests {
#[tokio::test]
async fn test_request_id_with_custom_headers() {
// Create config with custom request ID headers
let config = RouterConfig {
chat_template: None,
mode: RoutingMode::Regular {
worker_urls: vec![],
},
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 3002,
max_payload_size: 256 * 1024 * 1024,
request_timeout_secs: 600,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
discovery: None,
metrics: None,
dp_aware: false,
api_key: None,
log_dir: None,
log_level: None,
request_id_headers: Some(vec!["custom-id".to_string(), "trace-id".to_string()]),
max_concurrent_requests: 64,
queue_size: 0,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let config = RouterConfig::builder()
.regular_mode(vec![])
.random_policy()
.host("127.0.0.1")
.port(3002)
.max_payload_size(256 * 1024 * 1024)
.request_timeout_secs(600)
.worker_startup_timeout_secs(1)
.worker_startup_check_interval_secs(1)
.request_id_headers(vec!["custom-id".to_string(), "trace-id".to_string()])
.max_concurrent_requests(64)
.queue_timeout_secs(60)
.build_unchecked();
let ctx = TestContext::new_with_config(
config,
......
......@@ -19,16 +19,12 @@ struct TestContext {
impl TestContext {
async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self {
let mut config = RouterConfig {
chat_template: None,
mode: RoutingMode::Regular {
worker_urls: vec![],
},
port: 3003,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
..Default::default()
};
let mut config = RouterConfig::builder()
.regular_mode(vec![])
.port(3003)
.worker_startup_timeout_secs(1)
.worker_startup_check_interval_secs(1)
.build_unchecked();
let mut workers = Vec::new();
let mut worker_urls = Vec::new();
......
This diff is collapsed.
......@@ -20,16 +20,12 @@ struct TestContext {
impl TestContext {
async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self {
let mut config = RouterConfig {
chat_template: None,
mode: RoutingMode::Regular {
worker_urls: vec![],
},
port: 3004,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
..Default::default()
};
let mut config = RouterConfig::builder()
.regular_mode(vec![])
.port(3004)
.worker_startup_timeout_secs(1)
.worker_startup_check_interval_secs(1)
.build_unchecked();
let mut workers = Vec::new();
let mut worker_urls = Vec::new();
......
......@@ -902,17 +902,10 @@ async fn test_openai_router_models_auth_forwarding() {
#[test]
fn oracle_config_validation_requires_config_when_enabled() {
let config = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI {
worker_urls: vec!["https://api.openai.com".to_string()],
},
history_backend: HistoryBackend::Oracle,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
..Default::default()
};
let config = RouterConfig::builder()
.openai_mode(vec!["https://api.openai.com".to_string()])
.history_backend(HistoryBackend::Oracle)
.build_unchecked();
let err =
ConfigValidator::validate(&config).expect_err("config should fail without oracle details");
......@@ -927,13 +920,9 @@ fn oracle_config_validation_requires_config_when_enabled() {
#[test]
fn oracle_config_validation_accepts_dsn_only() {
let config = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI {
worker_urls: vec!["https://api.openai.com".to_string()],
},
history_backend: HistoryBackend::Oracle,
oracle: Some(OracleConfig {
let config = RouterConfig::builder()
.openai_mode(vec!["https://api.openai.com".to_string()])
.oracle_history(OracleConfig {
wallet_path: None,
connect_descriptor: "tcps://db.example.com:1522/service".to_string(),
username: "scott".to_string(),
......@@ -941,22 +930,17 @@ fn oracle_config_validation_accepts_dsn_only() {
pool_min: 1,
pool_max: 4,
pool_timeout_secs: 30,
}),
..Default::default()
};
})
.build_unchecked();
ConfigValidator::validate(&config).expect("dsn-based config should validate");
}
#[test]
fn oracle_config_validation_accepts_wallet_alias() {
let config = RouterConfig {
chat_template: None,
mode: RoutingMode::OpenAI {
worker_urls: vec!["https://api.openai.com".to_string()],
},
history_backend: HistoryBackend::Oracle,
oracle: Some(OracleConfig {
let config = RouterConfig::builder()
.openai_mode(vec!["https://api.openai.com".to_string()])
.oracle_history(OracleConfig {
wallet_path: Some("/etc/sglang/oracle-wallet".to_string()),
connect_descriptor: "db_low".to_string(),
username: "app_user".to_string(),
......@@ -964,9 +948,8 @@ fn oracle_config_validation_accepts_wallet_alias() {
pool_min: 1,
pool_max: 8,
pool_timeout_secs: 45,
}),
..Default::default()
};
})
.build_unchecked();
ConfigValidator::validate(&config).expect("wallet-based config should validate");
}
......@@ -2,8 +2,8 @@
mod test_pd_routing {
use serde_json::json;
use sglang_router_rs::{
config::{CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode},
core::{BasicWorkerBuilder, ConnectionMode, Worker, WorkerType},
config::{PolicyConfig, RouterConfig, RoutingMode},
core::{BasicWorkerBuilder, Worker, WorkerType},
routers::{http::pd_types::PDSelectionPolicy, RouterFactory},
};
......@@ -162,42 +162,24 @@ mod test_pd_routing {
];
for (mode, policy) in test_cases {
let config = RouterConfig {
chat_template: None,
mode,
policy,
host: "127.0.0.1".to_string(),
port: 3001,
max_payload_size: 1024 * 1024,
request_timeout_secs: 60,
worker_startup_timeout_secs: 10,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: None,
request_id_headers: None,
max_concurrent_requests: 64,
queue_size: 0,
queue_timeout_secs: 60,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
rate_limit_tokens_per_second: None,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
let config = match mode {
RoutingMode::PrefillDecode {
prefill_urls,
decode_urls,
..
} => RouterConfig::builder()
.prefill_decode_mode(prefill_urls, decode_urls)
.policy(policy)
.host("127.0.0.1")
.port(3001)
.max_payload_size(1024 * 1024)
.request_timeout_secs(60)
.worker_startup_timeout_secs(10)
.worker_startup_check_interval_secs(1)
.max_concurrent_requests(64)
.queue_timeout_secs(60)
.build_unchecked(),
_ => panic!("Expected PrefillDecode mode"),
};
let app_context = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment