Unverified Commit 6d6e24bc authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] Add builder pattern for RouterConfig with zero duplication (#12030)

parent 2c057fbf
This diff is collapsed.
pub mod builder;
pub mod types; pub mod types;
pub mod validation; pub mod validation;
pub use builder::*;
pub use types::*; pub use types::*;
pub use validation::*; pub use validation::*;
......
...@@ -609,17 +609,14 @@ mod tests { ...@@ -609,17 +609,14 @@ mod tests {
#[test] #[test]
fn test_router_config_serialization() { fn test_router_config_serialization() {
let config = RouterConfig { let config = RouterConfig::builder()
mode: RoutingMode::Regular { .regular_mode(vec!["http://worker1".to_string()])
worker_urls: vec!["http://worker1".to_string()], .random_policy()
}, .host("0.0.0.0")
policy: PolicyConfig::Random, .port(8080)
host: "0.0.0.0".to_string(), .log_dir("/var/log")
port: 8080, .log_level("debug")
log_dir: Some("/var/log".to_string()), .build_unchecked();
log_level: Some("debug".to_string()),
..Default::default()
};
let json = serde_json::to_string(&config).unwrap(); let json = serde_json::to_string(&config).unwrap();
let deserialized: RouterConfig = serde_json::from_str(&json).unwrap(); let deserialized: RouterConfig = serde_json::from_str(&json).unwrap();
...@@ -866,23 +863,14 @@ mod tests { ...@@ -866,23 +863,14 @@ mod tests {
#[test] #[test]
fn test_mode_type() { fn test_mode_type() {
let config = RouterConfig { let config = RouterConfig::builder()
mode: RoutingMode::Regular { .regular_mode(vec![])
worker_urls: vec![], .build_unchecked();
},
..Default::default()
};
assert_eq!(config.mode_type(), "regular"); assert_eq!(config.mode_type(), "regular");
let config = RouterConfig { let config = RouterConfig::builder()
mode: RoutingMode::PrefillDecode { .prefill_decode_mode(vec![], vec![])
prefill_urls: vec![], .build_unchecked();
decode_urls: vec![],
prefill_policy: None,
decode_policy: None,
},
..Default::default()
};
assert_eq!(config.mode_type(), "prefill_decode"); assert_eq!(config.mode_type(), "prefill_decode");
} }
...@@ -891,22 +879,15 @@ mod tests { ...@@ -891,22 +879,15 @@ mod tests {
let config = RouterConfig::default(); let config = RouterConfig::default();
assert!(!config.has_service_discovery()); assert!(!config.has_service_discovery());
let config = RouterConfig { let config = RouterConfig::builder()
discovery: Some(DiscoveryConfig { .discovery_config(DiscoveryConfig {
enabled: false, enabled: false,
..Default::default() ..Default::default()
}), })
..Default::default() .build_unchecked();
};
assert!(!config.has_service_discovery()); assert!(!config.has_service_discovery());
let config = RouterConfig { let config = RouterConfig::builder().enable_discovery().build_unchecked();
discovery: Some(DiscoveryConfig {
enabled: true,
..Default::default()
}),
..Default::default()
};
assert!(config.has_service_discovery()); assert!(config.has_service_discovery());
} }
...@@ -915,10 +896,9 @@ mod tests { ...@@ -915,10 +896,9 @@ mod tests {
let config = RouterConfig::default(); let config = RouterConfig::default();
assert!(!config.has_metrics()); assert!(!config.has_metrics());
let config = RouterConfig { let config = RouterConfig::builder()
metrics: Some(MetricsConfig::default()), .metrics_config(MetricsConfig::default())
..Default::default() .build_unchecked();
};
assert!(config.has_metrics()); assert!(config.has_metrics());
} }
...@@ -926,16 +906,11 @@ mod tests { ...@@ -926,16 +906,11 @@ mod tests {
fn test_large_worker_lists() { fn test_large_worker_lists() {
let large_urls: Vec<String> = (0..1000).map(|i| format!("http://worker{}", i)).collect(); let large_urls: Vec<String> = (0..1000).map(|i| format!("http://worker{}", i)).collect();
let mode = RoutingMode::Regular { let config = RouterConfig::builder()
worker_urls: large_urls.clone(), .regular_mode(large_urls.clone())
}; .build_unchecked();
assert_eq!(mode.worker_count(), 1000);
let config = RouterConfig { assert_eq!(config.mode.worker_count(), 1000);
mode,
..Default::default()
};
let json = serde_json::to_string(&config).unwrap(); let json = serde_json::to_string(&config).unwrap();
let deserialized: RouterConfig = serde_json::from_str(&json).unwrap(); let deserialized: RouterConfig = serde_json::from_str(&json).unwrap();
...@@ -950,13 +925,13 @@ mod tests { ...@@ -950,13 +925,13 @@ mod tests {
#[test] #[test]
fn test_unicode_in_config() { fn test_unicode_in_config() {
let config = RouterConfig { let config = RouterConfig::builder()
mode: RoutingMode::Regular { .regular_mode(vec![
worker_urls: vec!["http://работник1".to_string(), "http://工作者2".to_string()], "http://работник1".to_string(),
}, "http://工作者2".to_string(),
log_dir: Some("/日志/目录".to_string()), ])
..Default::default() .log_dir("/日志/目录")
}; .build_unchecked();
let json = serde_json::to_string(&config).unwrap(); let json = serde_json::to_string(&config).unwrap();
let deserialized: RouterConfig = serde_json::from_str(&json).unwrap(); let deserialized: RouterConfig = serde_json::from_str(&json).unwrap();
...@@ -974,12 +949,11 @@ mod tests { ...@@ -974,12 +949,11 @@ mod tests {
#[test] #[test]
fn test_empty_string_fields() { fn test_empty_string_fields() {
let config = RouterConfig { let config = RouterConfig::builder()
host: "".to_string(), .host("")
log_dir: Some("".to_string()), .log_dir("")
log_level: Some("".to_string()), .log_level("")
..Default::default() .build_unchecked();
};
assert_eq!(config.host, ""); assert_eq!(config.host, "");
assert_eq!(config.log_dir, Some("".to_string())); assert_eq!(config.log_dir, Some("".to_string()));
...@@ -988,63 +962,34 @@ mod tests { ...@@ -988,63 +962,34 @@ mod tests {
#[test] #[test]
fn test_full_pd_mode_config() { fn test_full_pd_mode_config() {
let config = RouterConfig { let config = RouterConfig::builder()
mode: RoutingMode::PrefillDecode { .prefill_decode_mode(
prefill_urls: vec![ vec![
("http://prefill1:8000".to_string(), Some(8001)), ("http://prefill1:8000".to_string(), Some(8001)),
("http://prefill2:8000".to_string(), None), ("http://prefill2:8000".to_string(), None),
], ],
decode_urls: vec![ vec![
"http://decode1:8000".to_string(), "http://decode1:8000".to_string(),
"http://decode2:8000".to_string(), "http://decode2:8000".to_string(),
], ],
prefill_policy: None, )
decode_policy: None, .power_of_two_policy(30)
}, .host("0.0.0.0")
policy: PolicyConfig::PowerOfTwo { .port(3000)
load_check_interval_secs: 30, .max_payload_size(1048576)
}, .request_timeout_secs(120)
host: "0.0.0.0".to_string(), .worker_startup_timeout_secs(60)
port: 3000, .worker_startup_check_interval_secs(5)
max_payload_size: 1048576, .discovery_config(DiscoveryConfig {
request_timeout_secs: 120,
worker_startup_timeout_secs: 60,
worker_startup_check_interval_secs: 5,
dp_aware: false,
api_key: None,
discovery: Some(DiscoveryConfig {
enabled: true, enabled: true,
namespace: Some("sglang".to_string()), namespace: Some("sglang".to_string()),
..Default::default() ..Default::default()
}), })
metrics: Some(MetricsConfig { .enable_metrics("0.0.0.0", 9090)
port: 9090, .log_dir("/var/log/sglang")
host: "0.0.0.0".to_string(), .log_level("info")
}), .max_concurrent_requests(64)
log_dir: Some("/var/log/sglang".to_string()), .build_unchecked();
log_level: Some("info".to_string()),
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
queue_size: 100,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
chat_template: None,
history_backend: default_history_backend(),
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: TokenizerCacheConfig::default(),
};
assert!(config.mode.is_pd_mode()); assert!(config.mode.is_pd_mode());
assert_eq!(config.mode.worker_count(), 4); assert_eq!(config.mode.worker_count(), 4);
...@@ -1058,62 +1003,31 @@ mod tests { ...@@ -1058,62 +1003,31 @@ mod tests {
let mut selector = HashMap::new(); let mut selector = HashMap::new();
selector.insert("app".to_string(), "sglang".to_string()); selector.insert("app".to_string(), "sglang".to_string());
let config = RouterConfig { let config = RouterConfig::builder()
mode: RoutingMode::Regular { .regular_mode(vec![
worker_urls: vec![
"http://worker1:8000".to_string(), "http://worker1:8000".to_string(),
"http://worker2:8000".to_string(), "http://worker2:8000".to_string(),
"http://worker3:8000".to_string(), "http://worker3:8000".to_string(),
], ])
}, .cache_aware_policy(0.9, 5, 1.2, 600, 10000)
policy: PolicyConfig::CacheAware { .host("0.0.0.0")
cache_threshold: 0.9, .port(3001)
balance_abs_threshold: 5, .max_payload_size(536870912)
balance_rel_threshold: 1.2, .request_timeout_secs(300)
eviction_interval_secs: 600, .worker_startup_timeout_secs(180)
max_tree_size: 10000, .worker_startup_check_interval_secs(15)
}, .discovery_config(DiscoveryConfig {
host: "0.0.0.0".to_string(),
port: 3001,
max_payload_size: 536870912,
request_timeout_secs: 300,
worker_startup_timeout_secs: 180,
worker_startup_check_interval_secs: 15,
dp_aware: false,
api_key: None,
discovery: Some(DiscoveryConfig {
enabled: true, enabled: true,
namespace: None, namespace: None,
port: 8080, port: 8080,
check_interval_secs: 45, check_interval_secs: 45,
selector, selector,
..Default::default() ..Default::default()
}), })
metrics: Some(MetricsConfig::default()), .metrics_config(MetricsConfig::default())
log_dir: None, .log_level("debug")
log_level: Some("debug".to_string()), .max_concurrent_requests(64)
request_id_headers: None, .build_unchecked();
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
queue_size: 100,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
chat_template: None,
history_backend: default_history_backend(),
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: TokenizerCacheConfig::default(),
};
assert!(!config.mode.is_pd_mode()); assert!(!config.mode.is_pd_mode());
assert_eq!(config.mode.worker_count(), 3); assert_eq!(config.mode.worker_count(), 3);
...@@ -1128,20 +1042,16 @@ mod tests { ...@@ -1128,20 +1042,16 @@ mod tests {
selectors.insert("env".to_string(), "prod".to_string()); selectors.insert("env".to_string(), "prod".to_string());
selectors.insert("version".to_string(), "v1".to_string()); selectors.insert("version".to_string(), "v1".to_string());
let config = RouterConfig { let config = RouterConfig::builder()
mode: RoutingMode::Regular { .regular_mode(vec!["http://worker1".to_string()])
worker_urls: vec!["http://worker1".to_string()], .round_robin_policy()
}, .host("::1") // IPv6
policy: PolicyConfig::RoundRobin, .port(8888)
host: "::1".to_string(), // IPv6 .max_payload_size(1024 * 1024 * 512) // 512MB
port: 8888, .request_timeout_secs(900)
max_payload_size: 1024 * 1024 * 512, // 512MB .worker_startup_timeout_secs(600)
request_timeout_secs: 900, .worker_startup_check_interval_secs(20)
worker_startup_timeout_secs: 600, .discovery_config(DiscoveryConfig {
worker_startup_check_interval_secs: 20,
dp_aware: false,
api_key: None,
discovery: Some(DiscoveryConfig {
enabled: true, enabled: true,
namespace: Some("production".to_string()), namespace: Some("production".to_string()),
port: 8443, port: 8443,
...@@ -1150,35 +1060,12 @@ mod tests { ...@@ -1150,35 +1060,12 @@ mod tests {
prefill_selector: selectors.clone(), prefill_selector: selectors.clone(),
decode_selector: selectors, decode_selector: selectors,
bootstrap_port_annotation: "mycompany.io/bootstrap".to_string(), bootstrap_port_annotation: "mycompany.io/bootstrap".to_string(),
}), })
metrics: Some(MetricsConfig { .enable_metrics("::", 9999) // IPv6 any
port: 9999, .log_dir("/opt/logs/sglang")
host: "::".to_string(), // IPv6 any .log_level("trace")
}), .max_concurrent_requests(64)
log_dir: Some("/opt/logs/sglang".to_string()), .build_unchecked();
log_level: Some("trace".to_string()),
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
queue_size: 100,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
chat_template: None,
history_backend: default_history_backend(),
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: TokenizerCacheConfig::default(),
};
assert!(config.has_service_discovery()); assert!(config.has_service_discovery());
assert!(config.has_metrics()); assert!(config.has_metrics());
......
...@@ -302,65 +302,66 @@ impl Router { ...@@ -302,65 +302,66 @@ impl Router {
None None
}; };
Ok(config::RouterConfig { let builder = config::RouterConfig::builder()
mode, .mode(mode)
policy, .policy(policy)
host: self.host.clone(), .host(&self.host)
port: self.port, .port(self.port)
connection_mode: self.connection_mode.clone(), .connection_mode(self.connection_mode.clone())
max_payload_size: self.max_payload_size, .max_payload_size(self.max_payload_size)
request_timeout_secs: self.request_timeout_secs, .request_timeout_secs(self.request_timeout_secs)
worker_startup_timeout_secs: self.worker_startup_timeout_secs, .worker_startup_timeout_secs(self.worker_startup_timeout_secs)
worker_startup_check_interval_secs: self.worker_startup_check_interval, .worker_startup_check_interval_secs(self.worker_startup_check_interval)
dp_aware: self.dp_aware, .max_concurrent_requests(self.max_concurrent_requests)
api_key: self.api_key.clone(), .queue_size(self.queue_size)
discovery, .queue_timeout_secs(self.queue_timeout_secs)
metrics, .cors_allowed_origins(self.cors_allowed_origins.clone())
log_dir: self.log_dir.clone(), .retry_config(config::RetryConfig {
log_level: self.log_level.clone(),
request_id_headers: self.request_id_headers.clone(),
max_concurrent_requests: self.max_concurrent_requests,
queue_size: self.queue_size,
queue_timeout_secs: self.queue_timeout_secs,
rate_limit_tokens_per_second: self.rate_limit_tokens_per_second,
cors_allowed_origins: self.cors_allowed_origins.clone(),
retry: config::RetryConfig {
max_retries: self.retry_max_retries, max_retries: self.retry_max_retries,
initial_backoff_ms: self.retry_initial_backoff_ms, initial_backoff_ms: self.retry_initial_backoff_ms,
max_backoff_ms: self.retry_max_backoff_ms, max_backoff_ms: self.retry_max_backoff_ms,
backoff_multiplier: self.retry_backoff_multiplier, backoff_multiplier: self.retry_backoff_multiplier,
jitter_factor: self.retry_jitter_factor, jitter_factor: self.retry_jitter_factor,
}, })
circuit_breaker: config::CircuitBreakerConfig { .circuit_breaker_config(config::CircuitBreakerConfig {
failure_threshold: self.cb_failure_threshold, failure_threshold: self.cb_failure_threshold,
success_threshold: self.cb_success_threshold, success_threshold: self.cb_success_threshold,
timeout_duration_secs: self.cb_timeout_duration_secs, timeout_duration_secs: self.cb_timeout_duration_secs,
window_duration_secs: self.cb_window_duration_secs, window_duration_secs: self.cb_window_duration_secs,
}, })
disable_retries: self.disable_retries, .health_check_config(config::HealthCheckConfig {
disable_circuit_breaker: self.disable_circuit_breaker,
health_check: config::HealthCheckConfig {
failure_threshold: self.health_failure_threshold, failure_threshold: self.health_failure_threshold,
success_threshold: self.health_success_threshold, success_threshold: self.health_success_threshold,
timeout_secs: self.health_check_timeout_secs, timeout_secs: self.health_check_timeout_secs,
check_interval_secs: self.health_check_interval_secs, check_interval_secs: self.health_check_interval_secs,
endpoint: self.health_check_endpoint.clone(), endpoint: self.health_check_endpoint.clone(),
}, })
enable_igw: self.enable_igw, .tokenizer_cache(config::TokenizerCacheConfig {
model_path: self.model_path.clone(),
tokenizer_path: self.tokenizer_path.clone(),
chat_template: self.chat_template.clone(),
history_backend,
oracle,
reasoning_parser: self.reasoning_parser.clone(),
tool_call_parser: self.tool_call_parser.clone(),
tokenizer_cache: config::TokenizerCacheConfig {
enable_l0: self.tokenizer_cache_enable_l0, enable_l0: self.tokenizer_cache_enable_l0,
l0_max_entries: self.tokenizer_cache_l0_max_entries, l0_max_entries: self.tokenizer_cache_l0_max_entries,
enable_l1: self.tokenizer_cache_enable_l1, enable_l1: self.tokenizer_cache_enable_l1,
l1_max_memory: self.tokenizer_cache_l1_max_memory, l1_max_memory: self.tokenizer_cache_l1_max_memory,
},
}) })
.history_backend(history_backend)
.maybe_api_key(self.api_key.as_ref())
.maybe_discovery(discovery)
.maybe_metrics(metrics)
.maybe_log_dir(self.log_dir.as_ref())
.maybe_log_level(self.log_level.as_ref())
.maybe_request_id_headers(self.request_id_headers.clone())
.maybe_rate_limit_tokens_per_second(self.rate_limit_tokens_per_second)
.maybe_model_path(self.model_path.as_ref())
.maybe_tokenizer_path(self.tokenizer_path.as_ref())
.maybe_chat_template(self.chat_template.as_ref())
.maybe_oracle(oracle)
.maybe_reasoning_parser(self.reasoning_parser.as_ref())
.maybe_tool_call_parser(self.tool_call_parser.as_ref())
.dp_aware(self.dp_aware)
.retries(!self.disable_retries)
.circuit_breaker(!self.disable_circuit_breaker)
.igw(self.enable_igw);
builder.build()
} }
} }
......
...@@ -538,69 +538,68 @@ impl CliArgs { ...@@ -538,69 +538,68 @@ impl CliArgs {
None None
}; };
Ok(RouterConfig { let builder = RouterConfig::builder()
mode, .mode(mode)
policy, .policy(policy)
connection_mode, .connection_mode(connection_mode)
host: self.host.clone(), .host(&self.host)
port: self.port, .port(self.port)
max_payload_size: self.max_payload_size, .max_payload_size(self.max_payload_size)
request_timeout_secs: self.request_timeout_secs, .request_timeout_secs(self.request_timeout_secs)
worker_startup_timeout_secs: self.worker_startup_timeout_secs, .worker_startup_timeout_secs(self.worker_startup_timeout_secs)
worker_startup_check_interval_secs: self.worker_startup_check_interval, .worker_startup_check_interval_secs(self.worker_startup_check_interval)
dp_aware: self.dp_aware, .max_concurrent_requests(self.max_concurrent_requests)
api_key: self.api_key.clone(), .queue_size(self.queue_size)
discovery, .queue_timeout_secs(self.queue_timeout_secs)
metrics, .cors_allowed_origins(self.cors_allowed_origins.clone())
log_dir: self.log_dir.clone(), .retry_config(RetryConfig {
log_level: Some(self.log_level.clone()),
request_id_headers: if self.request_id_headers.is_empty() {
None
} else {
Some(self.request_id_headers.clone())
},
max_concurrent_requests: self.max_concurrent_requests,
queue_size: self.queue_size,
queue_timeout_secs: self.queue_timeout_secs,
cors_allowed_origins: self.cors_allowed_origins.clone(),
retry: RetryConfig {
max_retries: self.retry_max_retries, max_retries: self.retry_max_retries,
initial_backoff_ms: self.retry_initial_backoff_ms, initial_backoff_ms: self.retry_initial_backoff_ms,
max_backoff_ms: self.retry_max_backoff_ms, max_backoff_ms: self.retry_max_backoff_ms,
backoff_multiplier: self.retry_backoff_multiplier, backoff_multiplier: self.retry_backoff_multiplier,
jitter_factor: self.retry_jitter_factor, jitter_factor: self.retry_jitter_factor,
}, })
circuit_breaker: CircuitBreakerConfig { .circuit_breaker_config(CircuitBreakerConfig {
failure_threshold: self.cb_failure_threshold, failure_threshold: self.cb_failure_threshold,
success_threshold: self.cb_success_threshold, success_threshold: self.cb_success_threshold,
timeout_duration_secs: self.cb_timeout_duration_secs, timeout_duration_secs: self.cb_timeout_duration_secs,
window_duration_secs: self.cb_window_duration_secs, window_duration_secs: self.cb_window_duration_secs,
}, })
disable_retries: self.disable_retries, .health_check_config(HealthCheckConfig {
disable_circuit_breaker: self.disable_circuit_breaker,
health_check: HealthCheckConfig {
failure_threshold: self.health_failure_threshold, failure_threshold: self.health_failure_threshold,
success_threshold: self.health_success_threshold, success_threshold: self.health_success_threshold,
timeout_secs: self.health_check_timeout_secs, timeout_secs: self.health_check_timeout_secs,
check_interval_secs: self.health_check_interval_secs, check_interval_secs: self.health_check_interval_secs,
endpoint: self.health_check_endpoint.clone(), endpoint: self.health_check_endpoint.clone(),
}, })
enable_igw: self.enable_igw, .tokenizer_cache(TokenizerCacheConfig {
rate_limit_tokens_per_second: self.rate_limit_tokens_per_second,
model_path: self.model_path.clone(),
tokenizer_path: self.tokenizer_path.clone(),
chat_template: self.chat_template.clone(),
history_backend,
oracle,
reasoning_parser: self.reasoning_parser.clone(),
tool_call_parser: self.tool_call_parser.clone(),
tokenizer_cache: TokenizerCacheConfig {
enable_l0: self.tokenizer_cache_enable_l0, enable_l0: self.tokenizer_cache_enable_l0,
l0_max_entries: self.tokenizer_cache_l0_max_entries, l0_max_entries: self.tokenizer_cache_l0_max_entries,
enable_l1: self.tokenizer_cache_enable_l1, enable_l1: self.tokenizer_cache_enable_l1,
l1_max_memory: self.tokenizer_cache_l1_max_memory, l1_max_memory: self.tokenizer_cache_l1_max_memory,
},
}) })
.history_backend(history_backend)
.log_level(&self.log_level)
.maybe_api_key(self.api_key.as_ref())
.maybe_discovery(discovery)
.maybe_metrics(metrics)
.maybe_log_dir(self.log_dir.as_ref())
.maybe_request_id_headers(
(!self.request_id_headers.is_empty()).then(|| self.request_id_headers.clone()),
)
.maybe_rate_limit_tokens_per_second(self.rate_limit_tokens_per_second)
.maybe_model_path(self.model_path.as_ref())
.maybe_tokenizer_path(self.tokenizer_path.as_ref())
.maybe_chat_template(self.chat_template.as_ref())
.maybe_oracle(oracle)
.maybe_reasoning_parser(self.reasoning_parser.as_ref())
.maybe_tool_call_parser(self.tool_call_parser.as_ref())
.dp_aware(self.dp_aware)
.retries(!self.disable_retries)
.circuit_breaker(!self.disable_circuit_breaker)
.igw(self.enable_igw);
builder.build()
} }
fn to_server_config(&self, router_config: RouterConfig) -> ServerConfig { fn to_server_config(&self, router_config: RouterConfig) -> ServerConfig {
......
...@@ -565,10 +565,9 @@ mod tests { ...@@ -565,10 +565,9 @@ mod tests {
async fn create_test_app_context() -> Arc<AppContext> { async fn create_test_app_context() -> Arc<AppContext> {
use crate::{config::RouterConfig, middleware::TokenBucket}; use crate::{config::RouterConfig, middleware::TokenBucket};
let router_config = RouterConfig { let router_config = RouterConfig::builder()
worker_startup_timeout_secs: 1, .worker_startup_timeout_secs(1)
..Default::default() .build_unchecked();
};
// Note: Using uninitialized queue for tests to avoid spawning background workers // Note: Using uninitialized queue for tests to avoid spawning background workers
// Jobs submitted during tests will queue but not be processed // Jobs submitted during tests will queue but not be processed
......
...@@ -11,8 +11,8 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType ...@@ -11,8 +11,8 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType
use reqwest::Client; use reqwest::Client;
use serde_json::json; use serde_json::json;
use sglang_router_rs::{ use sglang_router_rs::{
config::{CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode}, config::{RouterConfig, RoutingMode},
core::{ConnectionMode, Job}, core::Job,
routers::{RouterFactory, RouterTrait}, routers::{RouterFactory, RouterTrait},
server::AppContext, server::AppContext,
}; };
...@@ -30,45 +30,18 @@ struct TestContext { ...@@ -30,45 +30,18 @@ struct TestContext {
impl TestContext { impl TestContext {
async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self { async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self {
// Create default router config // Create default router config
let config = RouterConfig { let config = RouterConfig::builder()
chat_template: None, .regular_mode(vec![])
mode: RoutingMode::Regular { .random_policy()
worker_urls: vec![], .host("127.0.0.1")
}, .port(3002)
policy: PolicyConfig::Random, .max_payload_size(256 * 1024 * 1024)
host: "127.0.0.1".to_string(), .request_timeout_secs(600)
port: 3002, .worker_startup_timeout_secs(1)
max_payload_size: 256 * 1024 * 1024, .worker_startup_check_interval_secs(1)
request_timeout_secs: 600, .max_concurrent_requests(64)
worker_startup_timeout_secs: 1, .queue_timeout_secs(60)
worker_startup_check_interval_secs: 1, .build_unchecked();
discovery: None,
dp_aware: false,
api_key: None,
metrics: None,
log_dir: None,
log_level: None,
request_id_headers: None,
max_concurrent_requests: 64,
queue_size: 0,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
Self::new_with_config(config, worker_configs).await Self::new_with_config(config, worker_configs).await
} }
...@@ -1182,45 +1155,18 @@ mod error_tests { ...@@ -1182,45 +1155,18 @@ mod error_tests {
#[tokio::test] #[tokio::test]
async fn test_payload_too_large() { async fn test_payload_too_large() {
// Create context with small payload limit // Create context with small payload limit
let config = RouterConfig { let config = RouterConfig::builder()
chat_template: None, .regular_mode(vec![])
mode: RoutingMode::Regular { .random_policy()
worker_urls: vec![], .host("127.0.0.1")
}, .port(3010)
policy: PolicyConfig::Random, .max_payload_size(1024) // 1KB limit
host: "127.0.0.1".to_string(), .request_timeout_secs(600)
port: 3010, .worker_startup_timeout_secs(1)
max_payload_size: 1024, // 1KB limit .worker_startup_check_interval_secs(1)
request_timeout_secs: 600, .max_concurrent_requests(64)
worker_startup_timeout_secs: 1, .queue_timeout_secs(60)
worker_startup_check_interval_secs: 1, .build_unchecked();
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: None,
request_id_headers: None,
max_concurrent_requests: 64,
queue_size: 0,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let ctx = TestContext::new_with_config( let ctx = TestContext::new_with_config(
config, config,
...@@ -1509,48 +1455,18 @@ mod pd_mode_tests { ...@@ -1509,48 +1455,18 @@ mod pd_mode_tests {
.and_then(|p| p.trim_end_matches('/').parse::<u16>().ok()) .and_then(|p| p.trim_end_matches('/').parse::<u16>().ok())
.unwrap_or(9000); .unwrap_or(9000);
let config = RouterConfig { let config = RouterConfig::builder()
chat_template: None, .prefill_decode_mode(vec![(prefill_url, Some(prefill_port))], vec![decode_url])
mode: RoutingMode::PrefillDecode { .random_policy()
prefill_urls: vec![(prefill_url, Some(prefill_port))], .host("127.0.0.1")
decode_urls: vec![decode_url], .port(3011)
prefill_policy: None, .max_payload_size(256 * 1024 * 1024)
decode_policy: None, .request_timeout_secs(600)
}, .worker_startup_timeout_secs(1)
policy: PolicyConfig::Random, .worker_startup_check_interval_secs(1)
host: "127.0.0.1".to_string(), .max_concurrent_requests(64)
port: 3011, .queue_timeout_secs(60)
max_payload_size: 256 * 1024 * 1024, .build_unchecked();
request_timeout_secs: 600,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
discovery: None,
metrics: None,
log_dir: None,
dp_aware: false,
api_key: None,
log_level: None,
request_id_headers: None,
max_concurrent_requests: 64,
queue_size: 0,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
// Create app context // Create app context
let app_context = common::create_test_context(config); let app_context = common::create_test_context(config);
...@@ -1676,45 +1592,19 @@ mod request_id_tests { ...@@ -1676,45 +1592,19 @@ mod request_id_tests {
#[tokio::test] #[tokio::test]
async fn test_request_id_with_custom_headers() { async fn test_request_id_with_custom_headers() {
// Create config with custom request ID headers // Create config with custom request ID headers
let config = RouterConfig { let config = RouterConfig::builder()
chat_template: None, .regular_mode(vec![])
mode: RoutingMode::Regular { .random_policy()
worker_urls: vec![], .host("127.0.0.1")
}, .port(3002)
policy: PolicyConfig::Random, .max_payload_size(256 * 1024 * 1024)
host: "127.0.0.1".to_string(), .request_timeout_secs(600)
port: 3002, .worker_startup_timeout_secs(1)
max_payload_size: 256 * 1024 * 1024, .worker_startup_check_interval_secs(1)
request_timeout_secs: 600, .request_id_headers(vec!["custom-id".to_string(), "trace-id".to_string()])
worker_startup_timeout_secs: 1, .max_concurrent_requests(64)
worker_startup_check_interval_secs: 1, .queue_timeout_secs(60)
discovery: None, .build_unchecked();
metrics: None,
dp_aware: false,
api_key: None,
log_dir: None,
log_level: None,
request_id_headers: Some(vec!["custom-id".to_string(), "trace-id".to_string()]),
max_concurrent_requests: 64,
queue_size: 0,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let ctx = TestContext::new_with_config( let ctx = TestContext::new_with_config(
config, config,
......
...@@ -19,16 +19,12 @@ struct TestContext { ...@@ -19,16 +19,12 @@ struct TestContext {
impl TestContext { impl TestContext {
async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self { async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self {
let mut config = RouterConfig { let mut config = RouterConfig::builder()
chat_template: None, .regular_mode(vec![])
mode: RoutingMode::Regular { .port(3003)
worker_urls: vec![], .worker_startup_timeout_secs(1)
}, .worker_startup_check_interval_secs(1)
port: 3003, .build_unchecked();
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
..Default::default()
};
let mut workers = Vec::new(); let mut workers = Vec::new();
let mut worker_urls = Vec::new(); let mut worker_urls = Vec::new();
......
This diff is collapsed.
...@@ -20,16 +20,12 @@ struct TestContext { ...@@ -20,16 +20,12 @@ struct TestContext {
impl TestContext { impl TestContext {
async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self { async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self {
let mut config = RouterConfig { let mut config = RouterConfig::builder()
chat_template: None, .regular_mode(vec![])
mode: RoutingMode::Regular { .port(3004)
worker_urls: vec![], .worker_startup_timeout_secs(1)
}, .worker_startup_check_interval_secs(1)
port: 3004, .build_unchecked();
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
..Default::default()
};
let mut workers = Vec::new(); let mut workers = Vec::new();
let mut worker_urls = Vec::new(); let mut worker_urls = Vec::new();
......
...@@ -902,17 +902,10 @@ async fn test_openai_router_models_auth_forwarding() { ...@@ -902,17 +902,10 @@ async fn test_openai_router_models_auth_forwarding() {
#[test] #[test]
fn oracle_config_validation_requires_config_when_enabled() { fn oracle_config_validation_requires_config_when_enabled() {
let config = RouterConfig { let config = RouterConfig::builder()
chat_template: None, .openai_mode(vec!["https://api.openai.com".to_string()])
mode: RoutingMode::OpenAI { .history_backend(HistoryBackend::Oracle)
worker_urls: vec!["https://api.openai.com".to_string()], .build_unchecked();
},
history_backend: HistoryBackend::Oracle,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
..Default::default()
};
let err = let err =
ConfigValidator::validate(&config).expect_err("config should fail without oracle details"); ConfigValidator::validate(&config).expect_err("config should fail without oracle details");
...@@ -927,13 +920,9 @@ fn oracle_config_validation_requires_config_when_enabled() { ...@@ -927,13 +920,9 @@ fn oracle_config_validation_requires_config_when_enabled() {
#[test] #[test]
fn oracle_config_validation_accepts_dsn_only() { fn oracle_config_validation_accepts_dsn_only() {
let config = RouterConfig { let config = RouterConfig::builder()
chat_template: None, .openai_mode(vec!["https://api.openai.com".to_string()])
mode: RoutingMode::OpenAI { .oracle_history(OracleConfig {
worker_urls: vec!["https://api.openai.com".to_string()],
},
history_backend: HistoryBackend::Oracle,
oracle: Some(OracleConfig {
wallet_path: None, wallet_path: None,
connect_descriptor: "tcps://db.example.com:1522/service".to_string(), connect_descriptor: "tcps://db.example.com:1522/service".to_string(),
username: "scott".to_string(), username: "scott".to_string(),
...@@ -941,22 +930,17 @@ fn oracle_config_validation_accepts_dsn_only() { ...@@ -941,22 +930,17 @@ fn oracle_config_validation_accepts_dsn_only() {
pool_min: 1, pool_min: 1,
pool_max: 4, pool_max: 4,
pool_timeout_secs: 30, pool_timeout_secs: 30,
}), })
..Default::default() .build_unchecked();
};
ConfigValidator::validate(&config).expect("dsn-based config should validate"); ConfigValidator::validate(&config).expect("dsn-based config should validate");
} }
#[test] #[test]
fn oracle_config_validation_accepts_wallet_alias() { fn oracle_config_validation_accepts_wallet_alias() {
let config = RouterConfig { let config = RouterConfig::builder()
chat_template: None, .openai_mode(vec!["https://api.openai.com".to_string()])
mode: RoutingMode::OpenAI { .oracle_history(OracleConfig {
worker_urls: vec!["https://api.openai.com".to_string()],
},
history_backend: HistoryBackend::Oracle,
oracle: Some(OracleConfig {
wallet_path: Some("/etc/sglang/oracle-wallet".to_string()), wallet_path: Some("/etc/sglang/oracle-wallet".to_string()),
connect_descriptor: "db_low".to_string(), connect_descriptor: "db_low".to_string(),
username: "app_user".to_string(), username: "app_user".to_string(),
...@@ -964,9 +948,8 @@ fn oracle_config_validation_accepts_wallet_alias() { ...@@ -964,9 +948,8 @@ fn oracle_config_validation_accepts_wallet_alias() {
pool_min: 1, pool_min: 1,
pool_max: 8, pool_max: 8,
pool_timeout_secs: 45, pool_timeout_secs: 45,
}), })
..Default::default() .build_unchecked();
};
ConfigValidator::validate(&config).expect("wallet-based config should validate"); ConfigValidator::validate(&config).expect("wallet-based config should validate");
} }
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
mod test_pd_routing { mod test_pd_routing {
use serde_json::json; use serde_json::json;
use sglang_router_rs::{ use sglang_router_rs::{
config::{CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode}, config::{PolicyConfig, RouterConfig, RoutingMode},
core::{BasicWorkerBuilder, ConnectionMode, Worker, WorkerType}, core::{BasicWorkerBuilder, Worker, WorkerType},
routers::{http::pd_types::PDSelectionPolicy, RouterFactory}, routers::{http::pd_types::PDSelectionPolicy, RouterFactory},
}; };
...@@ -162,42 +162,24 @@ mod test_pd_routing { ...@@ -162,42 +162,24 @@ mod test_pd_routing {
]; ];
for (mode, policy) in test_cases { for (mode, policy) in test_cases {
let config = RouterConfig { let config = match mode {
chat_template: None, RoutingMode::PrefillDecode {
mode, prefill_urls,
policy, decode_urls,
host: "127.0.0.1".to_string(), ..
port: 3001, } => RouterConfig::builder()
max_payload_size: 1024 * 1024, .prefill_decode_mode(prefill_urls, decode_urls)
request_timeout_secs: 60, .policy(policy)
worker_startup_timeout_secs: 10, .host("127.0.0.1")
worker_startup_check_interval_secs: 1, .port(3001)
dp_aware: false, .max_payload_size(1024 * 1024)
api_key: None, .request_timeout_secs(60)
discovery: None, .worker_startup_timeout_secs(10)
metrics: None, .worker_startup_check_interval_secs(1)
log_dir: None, .max_concurrent_requests(64)
log_level: None, .queue_timeout_secs(60)
request_id_headers: None, .build_unchecked(),
max_concurrent_requests: 64, _ => panic!("Expected PrefillDecode mode"),
queue_size: 0,
queue_timeout_secs: 60,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
rate_limit_tokens_per_second: None,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
}; };
let app_context = { let app_context = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment