"...source/git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "b1df760035a43aee39e88fa3d410c024ef7d6a48"
Unverified Commit 773d89da authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router] cleaned up all the redundant comments in the config module (#12147)

parent 03e7d949
...@@ -18,12 +18,11 @@ pub struct RouterConfigBuilder { ...@@ -18,12 +18,11 @@ pub struct RouterConfigBuilder {
} }
impl RouterConfigBuilder { impl RouterConfigBuilder {
/// Create a new builder with default values
pub fn new() -> Self { pub fn new() -> Self {
Self::default() Self::default()
} }
/// Create a builder from an existing configuration (takes ownership) /// Takes ownership
pub fn from_config(config: RouterConfig) -> Self { pub fn from_config(config: RouterConfig) -> Self {
Self { Self {
config, config,
...@@ -34,20 +33,17 @@ impl RouterConfigBuilder { ...@@ -34,20 +33,17 @@ impl RouterConfigBuilder {
} }
} }
/// Create a builder from a reference to an existing configuration
pub fn from_config_ref(config: &RouterConfig) -> Self { pub fn from_config_ref(config: &RouterConfig) -> Self {
Self::from_config(config.clone()) Self::from_config(config.clone())
} }
// ==================== Routing Mode Setters ==================== // ==================== Routing Mode ====================
/// Set regular routing mode with worker URLs
pub fn regular_mode(mut self, worker_urls: Vec<String>) -> Self { pub fn regular_mode(mut self, worker_urls: Vec<String>) -> Self {
self.config.mode = RoutingMode::Regular { worker_urls }; self.config.mode = RoutingMode::Regular { worker_urls };
self self
} }
/// Set prefill-decode routing mode
pub fn prefill_decode_mode( pub fn prefill_decode_mode(
mut self, mut self,
prefill_urls: Vec<(String, Option<u16>)>, prefill_urls: Vec<(String, Option<u16>)>,
...@@ -62,7 +58,7 @@ impl RouterConfigBuilder { ...@@ -62,7 +58,7 @@ impl RouterConfigBuilder {
self self
} }
/// Set prefill-decode mode with separate policies /// With separate policies
pub fn prefill_decode_mode_with_policies( pub fn prefill_decode_mode_with_policies(
mut self, mut self,
prefill_urls: Vec<(String, Option<u16>)>, prefill_urls: Vec<(String, Option<u16>)>,
...@@ -79,39 +75,33 @@ impl RouterConfigBuilder { ...@@ -79,39 +75,33 @@ impl RouterConfigBuilder {
self self
} }
/// Set OpenAI routing mode
pub fn openai_mode(mut self, worker_urls: Vec<String>) -> Self { pub fn openai_mode(mut self, worker_urls: Vec<String>) -> Self {
self.config.mode = RoutingMode::OpenAI { worker_urls }; self.config.mode = RoutingMode::OpenAI { worker_urls };
self self
} }
/// Set the routing mode directly
pub fn mode(mut self, mode: RoutingMode) -> Self { pub fn mode(mut self, mode: RoutingMode) -> Self {
self.config.mode = mode; self.config.mode = mode;
self self
} }
// ==================== Policy Setters ==================== // ==================== Policy ====================
/// Set the routing policy
pub fn policy(mut self, policy: PolicyConfig) -> Self { pub fn policy(mut self, policy: PolicyConfig) -> Self {
self.config.policy = policy; self.config.policy = policy;
self self
} }
/// Set random policy
pub fn random_policy(mut self) -> Self { pub fn random_policy(mut self) -> Self {
self.config.policy = PolicyConfig::Random; self.config.policy = PolicyConfig::Random;
self self
} }
/// Set round-robin policy
pub fn round_robin_policy(mut self) -> Self { pub fn round_robin_policy(mut self) -> Self {
self.config.policy = PolicyConfig::RoundRobin; self.config.policy = PolicyConfig::RoundRobin;
self self
} }
/// Set cache-aware policy with parameters
pub fn cache_aware_policy( pub fn cache_aware_policy(
mut self, mut self,
cache_threshold: f32, cache_threshold: f32,
...@@ -130,7 +120,6 @@ impl RouterConfigBuilder { ...@@ -130,7 +120,6 @@ impl RouterConfigBuilder {
self self
} }
/// Set power-of-two policy
pub fn power_of_two_policy(mut self, load_check_interval_secs: u64) -> Self { pub fn power_of_two_policy(mut self, load_check_interval_secs: u64) -> Self {
self.config.policy = PolicyConfig::PowerOfTwo { self.config.policy = PolicyConfig::PowerOfTwo {
load_check_interval_secs, load_check_interval_secs,
...@@ -138,65 +127,55 @@ impl RouterConfigBuilder { ...@@ -138,65 +127,55 @@ impl RouterConfigBuilder {
self self
} }
// ==================== Connection Settings ==================== // ==================== Connection ====================
/// Set connection mode
pub fn connection_mode(mut self, mode: ConnectionMode) -> Self { pub fn connection_mode(mut self, mode: ConnectionMode) -> Self {
self.config.connection_mode = mode; self.config.connection_mode = mode;
self self
} }
/// Set HTTP connection mode
pub fn http_connection(mut self) -> Self { pub fn http_connection(mut self) -> Self {
self.config.connection_mode = ConnectionMode::Http; self.config.connection_mode = ConnectionMode::Http;
self self
} }
/// Set gRPC connection mode with optional port
pub fn grpc_connection(mut self, port: Option<u16>) -> Self { pub fn grpc_connection(mut self, port: Option<u16>) -> Self {
self.config.connection_mode = ConnectionMode::Grpc { port }; self.config.connection_mode = ConnectionMode::Grpc { port };
self self
} }
/// Set gRPC connection mode without specifying a port
pub fn grpc_connection_default(mut self) -> Self { pub fn grpc_connection_default(mut self) -> Self {
self.config.connection_mode = ConnectionMode::Grpc { port: None }; self.config.connection_mode = ConnectionMode::Grpc { port: None };
self self
} }
/// Set host address
pub fn host<S: Into<String>>(mut self, host: S) -> Self { pub fn host<S: Into<String>>(mut self, host: S) -> Self {
self.config.host = host.into(); self.config.host = host.into();
self self
} }
/// Set port number
pub fn port(mut self, port: u16) -> Self { pub fn port(mut self, port: u16) -> Self {
self.config.port = port; self.config.port = port;
self self
} }
// ==================== Request Settings ==================== // ==================== Request ====================
/// Set maximum payload size in bytes
pub fn max_payload_size(mut self, size: usize) -> Self { pub fn max_payload_size(mut self, size: usize) -> Self {
self.config.max_payload_size = size; self.config.max_payload_size = size;
self self
} }
/// Set request timeout in seconds
pub fn request_timeout_secs(mut self, timeout: u64) -> Self { pub fn request_timeout_secs(mut self, timeout: u64) -> Self {
self.config.request_timeout_secs = timeout; self.config.request_timeout_secs = timeout;
self self
} }
/// Set worker startup timeout in seconds
pub fn worker_startup_timeout_secs(mut self, timeout: u64) -> Self { pub fn worker_startup_timeout_secs(mut self, timeout: u64) -> Self {
self.config.worker_startup_timeout_secs = timeout; self.config.worker_startup_timeout_secs = timeout;
self self
} }
/// Set worker startup check interval in seconds
pub fn worker_startup_check_interval_secs(mut self, interval: u64) -> Self { pub fn worker_startup_check_interval_secs(mut self, interval: u64) -> Self {
self.config.worker_startup_check_interval_secs = interval; self.config.worker_startup_check_interval_secs = interval;
self self
...@@ -204,31 +183,26 @@ impl RouterConfigBuilder { ...@@ -204,31 +183,26 @@ impl RouterConfigBuilder {
// ==================== Rate Limiting ==================== // ==================== Rate Limiting ====================
/// Set maximum concurrent requests
pub fn max_concurrent_requests(mut self, max: i32) -> Self { pub fn max_concurrent_requests(mut self, max: i32) -> Self {
self.config.max_concurrent_requests = max; self.config.max_concurrent_requests = max;
self self
} }
/// Disable rate limiting
pub fn disable_rate_limiting(mut self) -> Self { pub fn disable_rate_limiting(mut self) -> Self {
self.config.max_concurrent_requests = -1; self.config.max_concurrent_requests = -1;
self self
} }
/// Set queue size for pending requests
pub fn queue_size(mut self, size: usize) -> Self { pub fn queue_size(mut self, size: usize) -> Self {
self.config.queue_size = size; self.config.queue_size = size;
self self
} }
/// Set queue timeout in seconds
pub fn queue_timeout_secs(mut self, timeout: u64) -> Self { pub fn queue_timeout_secs(mut self, timeout: u64) -> Self {
self.config.queue_timeout_secs = timeout; self.config.queue_timeout_secs = timeout;
self self
} }
/// Set rate limit tokens per second
pub fn rate_limit_tokens_per_second(mut self, tokens: i32) -> Self { pub fn rate_limit_tokens_per_second(mut self, tokens: i32) -> Self {
self.config.rate_limit_tokens_per_second = Some(tokens); self.config.rate_limit_tokens_per_second = Some(tokens);
self self
...@@ -236,81 +210,70 @@ impl RouterConfigBuilder { ...@@ -236,81 +210,70 @@ impl RouterConfigBuilder {
// ==================== Security & CORS ==================== // ==================== Security & CORS ====================
/// Set API key for worker authorization
pub fn api_key<S: Into<String>>(mut self, key: S) -> Self { pub fn api_key<S: Into<String>>(mut self, key: S) -> Self {
self.config.api_key = Some(key.into()); self.config.api_key = Some(key.into());
self self
} }
/// Set CORS allowed origins
pub fn cors_allowed_origins(mut self, origins: Vec<String>) -> Self { pub fn cors_allowed_origins(mut self, origins: Vec<String>) -> Self {
self.config.cors_allowed_origins = origins; self.config.cors_allowed_origins = origins;
self self
} }
/// Add a single CORS origin
pub fn add_cors_origin<S: Into<String>>(mut self, origin: S) -> Self { pub fn add_cors_origin<S: Into<String>>(mut self, origin: S) -> Self {
self.config.cors_allowed_origins.push(origin.into()); self.config.cors_allowed_origins.push(origin.into());
self self
} }
// ==================== Retry Configuration ==================== // ==================== Retry ====================
/// Set retry configuration
pub fn retry_config(mut self, retry: RetryConfig) -> Self { pub fn retry_config(mut self, retry: RetryConfig) -> Self {
self.config.retry = retry; self.config.retry = retry;
self self
} }
/// Disable retries
pub fn disable_retries(mut self) -> Self { pub fn disable_retries(mut self) -> Self {
self.config.disable_retries = true; self.config.disable_retries = true;
self self
} }
/// Enable retries
pub fn enable_retries(mut self) -> Self { pub fn enable_retries(mut self) -> Self {
self.config.disable_retries = false; self.config.disable_retries = false;
self self
} }
// ==================== Circuit Breaker Configuration ==================== // ==================== Circuit Breaker ====================
/// Set circuit breaker configuration
pub fn circuit_breaker_config(mut self, circuit_breaker: CircuitBreakerConfig) -> Self { pub fn circuit_breaker_config(mut self, circuit_breaker: CircuitBreakerConfig) -> Self {
self.config.circuit_breaker = circuit_breaker; self.config.circuit_breaker = circuit_breaker;
self self
} }
/// Disable circuit breaker
pub fn disable_circuit_breaker(mut self) -> Self { pub fn disable_circuit_breaker(mut self) -> Self {
self.config.disable_circuit_breaker = true; self.config.disable_circuit_breaker = true;
self self
} }
/// Enable circuit breaker
pub fn enable_circuit_breaker(mut self) -> Self { pub fn enable_circuit_breaker(mut self) -> Self {
self.config.disable_circuit_breaker = false; self.config.disable_circuit_breaker = false;
self self
} }
// ==================== Health Check Configuration ==================== // ==================== Health Check ====================
/// Set health check configuration
pub fn health_check_config(mut self, health_check: HealthCheckConfig) -> Self { pub fn health_check_config(mut self, health_check: HealthCheckConfig) -> Self {
self.config.health_check = health_check; self.config.health_check = health_check;
self self
} }
// ==================== Discovery Configuration ==================== // ==================== Discovery ====================
/// Set service discovery configuration
pub fn discovery_config(mut self, discovery: DiscoveryConfig) -> Self { pub fn discovery_config(mut self, discovery: DiscoveryConfig) -> Self {
self.config.discovery = Some(discovery); self.config.discovery = Some(discovery);
self self
} }
/// Enable service discovery with default settings /// With default settings
pub fn enable_discovery(mut self) -> Self { pub fn enable_discovery(mut self) -> Self {
self.config.discovery = Some(DiscoveryConfig { self.config.discovery = Some(DiscoveryConfig {
enabled: true, enabled: true,
...@@ -319,15 +282,13 @@ impl RouterConfigBuilder { ...@@ -319,15 +282,13 @@ impl RouterConfigBuilder {
self self
} }
// ==================== Metrics Configuration ==================== // ==================== Metrics ====================
/// Set metrics configuration
pub fn metrics_config(mut self, metrics: MetricsConfig) -> Self { pub fn metrics_config(mut self, metrics: MetricsConfig) -> Self {
self.config.metrics = Some(metrics); self.config.metrics = Some(metrics);
self self
} }
/// Enable metrics with host and port
pub fn enable_metrics<S: Into<String>>(mut self, host: S, port: u16) -> Self { pub fn enable_metrics<S: Into<String>>(mut self, host: S, port: u16) -> Self {
self.config.metrics = Some(MetricsConfig { self.config.metrics = Some(MetricsConfig {
host: host.into(), host: host.into(),
...@@ -336,115 +297,100 @@ impl RouterConfigBuilder { ...@@ -336,115 +297,100 @@ impl RouterConfigBuilder {
self self
} }
// ==================== Logging Configuration ==================== // ==================== Logging ====================
/// Set log directory
pub fn log_dir<S: Into<String>>(mut self, dir: S) -> Self { pub fn log_dir<S: Into<String>>(mut self, dir: S) -> Self {
self.config.log_dir = Some(dir.into()); self.config.log_dir = Some(dir.into());
self self
} }
/// Set log level
pub fn log_level<S: Into<String>>(mut self, level: S) -> Self { pub fn log_level<S: Into<String>>(mut self, level: S) -> Self {
self.config.log_level = Some(level.into()); self.config.log_level = Some(level.into());
self self
} }
/// Set custom request ID headers
pub fn request_id_headers(mut self, headers: Vec<String>) -> Self { pub fn request_id_headers(mut self, headers: Vec<String>) -> Self {
self.config.request_id_headers = Some(headers); self.config.request_id_headers = Some(headers);
self self
} }
// ==================== IGW Mode Configuration ==================== // ==================== IGW Mode ====================
/// Enable Inference Gateway mode
pub fn enable_igw(mut self) -> Self { pub fn enable_igw(mut self) -> Self {
self.config.enable_igw = true; self.config.enable_igw = true;
self self
} }
/// Disable Inference Gateway mode (use proxy mode) /// Use proxy mode
pub fn disable_igw(mut self) -> Self { pub fn disable_igw(mut self) -> Self {
self.config.enable_igw = false; self.config.enable_igw = false;
self self
} }
/// Set model path for tokenizer
pub fn model_path<S: Into<String>>(mut self, path: S) -> Self { pub fn model_path<S: Into<String>>(mut self, path: S) -> Self {
self.config.model_path = Some(path.into()); self.config.model_path = Some(path.into());
self self
} }
/// Set tokenizer path (overrides model_path tokenizer) /// Overrides model_path tokenizer
pub fn tokenizer_path<S: Into<String>>(mut self, path: S) -> Self { pub fn tokenizer_path<S: Into<String>>(mut self, path: S) -> Self {
self.config.tokenizer_path = Some(path.into()); self.config.tokenizer_path = Some(path.into());
self self
} }
/// Set chat template path
pub fn chat_template<S: Into<String>>(mut self, path: S) -> Self { pub fn chat_template<S: Into<String>>(mut self, path: S) -> Self {
self.config.chat_template = Some(path.into()); self.config.chat_template = Some(path.into());
self self
} }
// ==================== History Backend Configuration ==================== // ==================== History Backend ====================
/// Set history backend
pub fn history_backend(mut self, backend: HistoryBackend) -> Self { pub fn history_backend(mut self, backend: HistoryBackend) -> Self {
self.config.history_backend = backend; self.config.history_backend = backend;
self self
} }
/// Use memory history backend
pub fn memory_history(mut self) -> Self { pub fn memory_history(mut self) -> Self {
self.config.history_backend = HistoryBackend::Memory; self.config.history_backend = HistoryBackend::Memory;
self self
} }
/// Disable history storage
pub fn no_history(mut self) -> Self { pub fn no_history(mut self) -> Self {
self.config.history_backend = HistoryBackend::None; self.config.history_backend = HistoryBackend::None;
self self
} }
/// Use Oracle history backend
pub fn oracle_history(mut self, oracle_config: OracleConfig) -> Self { pub fn oracle_history(mut self, oracle_config: OracleConfig) -> Self {
self.config.history_backend = HistoryBackend::Oracle; self.config.history_backend = HistoryBackend::Oracle;
self.config.oracle = Some(oracle_config); self.config.oracle = Some(oracle_config);
self self
} }
// ==================== Parsers Configuration ==================== // ==================== Parsers ====================
/// Set reasoning parser
pub fn reasoning_parser<S: Into<String>>(mut self, parser: S) -> Self { pub fn reasoning_parser<S: Into<String>>(mut self, parser: S) -> Self {
self.config.reasoning_parser = Some(parser.into()); self.config.reasoning_parser = Some(parser.into());
self self
} }
/// Set tool call parser
pub fn tool_call_parser<S: Into<String>>(mut self, parser: S) -> Self { pub fn tool_call_parser<S: Into<String>>(mut self, parser: S) -> Self {
self.config.tool_call_parser = Some(parser.into()); self.config.tool_call_parser = Some(parser.into());
self self
} }
// ==================== Tokenizer Cache Configuration ==================== // ==================== Tokenizer Cache ====================
/// Set tokenizer cache configuration
pub fn tokenizer_cache(mut self, cache: TokenizerCacheConfig) -> Self { pub fn tokenizer_cache(mut self, cache: TokenizerCacheConfig) -> Self {
self.config.tokenizer_cache = cache; self.config.tokenizer_cache = cache;
self self
} }
/// Enable L0 cache with entry limit
pub fn enable_l0_cache(mut self, max_entries: usize) -> Self { pub fn enable_l0_cache(mut self, max_entries: usize) -> Self {
self.config.tokenizer_cache.enable_l0 = true; self.config.tokenizer_cache.enable_l0 = true;
self.config.tokenizer_cache.l0_max_entries = max_entries; self.config.tokenizer_cache.l0_max_entries = max_entries;
self self
} }
/// Enable L1 cache with memory limit
pub fn enable_l1_cache(mut self, max_memory: usize) -> Self { pub fn enable_l1_cache(mut self, max_memory: usize) -> Self {
self.config.tokenizer_cache.enable_l1 = true; self.config.tokenizer_cache.enable_l1 = true;
self.config.tokenizer_cache.l1_max_memory = max_memory; self.config.tokenizer_cache.l1_max_memory = max_memory;
...@@ -453,51 +399,44 @@ impl RouterConfigBuilder { ...@@ -453,51 +399,44 @@ impl RouterConfigBuilder {
// ==================== Data Parallelism ==================== // ==================== Data Parallelism ====================
/// Enable data parallelism aware scheduling
pub fn enable_dp_aware(mut self) -> Self { pub fn enable_dp_aware(mut self) -> Self {
self.config.dp_aware = true; self.config.dp_aware = true;
self self
} }
/// Disable data parallelism aware scheduling
pub fn disable_dp_aware(mut self) -> Self { pub fn disable_dp_aware(mut self) -> Self {
self.config.dp_aware = false; self.config.dp_aware = false;
self self
} }
// ==================== Conditional Boolean Setters ==================== // ==================== Boolean Setters ====================
// These methods accept bool parameters to conditionally set flags, // Accept bool parameters to conditionally set flags without if statements
// eliminating the need for if statements in calling code
/// Set dp_aware flag conditionally
pub fn dp_aware(mut self, enable: bool) -> Self { pub fn dp_aware(mut self, enable: bool) -> Self {
self.config.dp_aware = enable; self.config.dp_aware = enable;
self self
} }
/// Enable or disable retries (inverse of disable_retries field) /// Inverse of disable_retries field
pub fn retries(mut self, enable: bool) -> Self { pub fn retries(mut self, enable: bool) -> Self {
self.config.disable_retries = !enable; self.config.disable_retries = !enable;
self self
} }
/// Enable or disable circuit breaker (inverse of disable_circuit_breaker field) /// Inverse of disable_circuit_breaker field
pub fn circuit_breaker(mut self, enable: bool) -> Self { pub fn circuit_breaker(mut self, enable: bool) -> Self {
self.config.disable_circuit_breaker = !enable; self.config.disable_circuit_breaker = !enable;
self self
} }
/// Set enable_igw flag conditionally
pub fn igw(mut self, enable: bool) -> Self { pub fn igw(mut self, enable: bool) -> Self {
self.config.enable_igw = enable; self.config.enable_igw = enable;
self self
} }
// ==================== Option-Aware Setters ==================== // ==================== Option Setters ====================
// These methods accept Option<T> and only set if Some, making it easier // Accept Option<T> and only set if Some
// to conditionally set values without if-let chains
/// Set API key if Some
pub fn maybe_api_key(mut self, key: Option<impl Into<String>>) -> Self { pub fn maybe_api_key(mut self, key: Option<impl Into<String>>) -> Self {
if let Some(k) = key { if let Some(k) = key {
self.config.api_key = Some(k.into()); self.config.api_key = Some(k.into());
...@@ -505,61 +444,51 @@ impl RouterConfigBuilder { ...@@ -505,61 +444,51 @@ impl RouterConfigBuilder {
self self
} }
/// Set discovery config if Some
pub fn maybe_discovery(mut self, discovery: Option<DiscoveryConfig>) -> Self { pub fn maybe_discovery(mut self, discovery: Option<DiscoveryConfig>) -> Self {
self.config.discovery = discovery; self.config.discovery = discovery;
self self
} }
/// Set metrics config if Some
pub fn maybe_metrics(mut self, metrics: Option<MetricsConfig>) -> Self { pub fn maybe_metrics(mut self, metrics: Option<MetricsConfig>) -> Self {
self.config.metrics = metrics; self.config.metrics = metrics;
self self
} }
/// Set log directory if Some
pub fn maybe_log_dir(mut self, dir: Option<impl Into<String>>) -> Self { pub fn maybe_log_dir(mut self, dir: Option<impl Into<String>>) -> Self {
self.config.log_dir = dir.map(|d| d.into()); self.config.log_dir = dir.map(|d| d.into());
self self
} }
/// Set log level if Some
pub fn maybe_log_level(mut self, level: Option<impl Into<String>>) -> Self { pub fn maybe_log_level(mut self, level: Option<impl Into<String>>) -> Self {
self.config.log_level = level.map(|l| l.into()); self.config.log_level = level.map(|l| l.into());
self self
} }
/// Set request ID headers if Some
pub fn maybe_request_id_headers(mut self, headers: Option<Vec<String>>) -> Self { pub fn maybe_request_id_headers(mut self, headers: Option<Vec<String>>) -> Self {
self.config.request_id_headers = headers; self.config.request_id_headers = headers;
self self
} }
/// Set rate limit tokens per second if Some
pub fn maybe_rate_limit_tokens_per_second(mut self, tokens: Option<i32>) -> Self { pub fn maybe_rate_limit_tokens_per_second(mut self, tokens: Option<i32>) -> Self {
self.config.rate_limit_tokens_per_second = tokens; self.config.rate_limit_tokens_per_second = tokens;
self self
} }
/// Set model path if Some
pub fn maybe_model_path(mut self, path: Option<impl Into<String>>) -> Self { pub fn maybe_model_path(mut self, path: Option<impl Into<String>>) -> Self {
self.config.model_path = path.map(|p| p.into()); self.config.model_path = path.map(|p| p.into());
self self
} }
/// Set tokenizer path if Some
pub fn maybe_tokenizer_path(mut self, path: Option<impl Into<String>>) -> Self { pub fn maybe_tokenizer_path(mut self, path: Option<impl Into<String>>) -> Self {
self.config.tokenizer_path = path.map(|p| p.into()); self.config.tokenizer_path = path.map(|p| p.into());
self self
} }
/// Set chat template if Some
pub fn maybe_chat_template(mut self, template: Option<impl Into<String>>) -> Self { pub fn maybe_chat_template(mut self, template: Option<impl Into<String>>) -> Self {
self.config.chat_template = template.map(|t| t.into()); self.config.chat_template = template.map(|t| t.into());
self self
} }
/// Set oracle config if Some
pub fn maybe_oracle(mut self, oracle: Option<OracleConfig>) -> Self { pub fn maybe_oracle(mut self, oracle: Option<OracleConfig>) -> Self {
if let Some(cfg) = oracle { if let Some(cfg) = oracle {
self.config.history_backend = HistoryBackend::Oracle; self.config.history_backend = HistoryBackend::Oracle;
...@@ -568,23 +497,19 @@ impl RouterConfigBuilder { ...@@ -568,23 +497,19 @@ impl RouterConfigBuilder {
self self
} }
/// Set reasoning parser if Some
pub fn maybe_reasoning_parser(mut self, parser: Option<impl Into<String>>) -> Self { pub fn maybe_reasoning_parser(mut self, parser: Option<impl Into<String>>) -> Self {
self.config.reasoning_parser = parser.map(|p| p.into()); self.config.reasoning_parser = parser.map(|p| p.into());
self self
} }
/// Set tool call parser if Some
pub fn maybe_tool_call_parser(mut self, parser: Option<impl Into<String>>) -> Self { pub fn maybe_tool_call_parser(mut self, parser: Option<impl Into<String>>) -> Self {
self.config.tool_call_parser = parser.map(|p| p.into()); self.config.tool_call_parser = parser.map(|p| p.into());
self self
} }
// ==================== mTLS Configuration ==================== // ==================== mTLS ====================
/// Set client certificate and key paths for mTLS authentication /// Both paths must be provided together. Files read during build()
/// Both paths must be provided together
/// Files will be read during build()
pub fn client_cert_and_key<S1: Into<String>, S2: Into<String>>( pub fn client_cert_and_key<S1: Into<String>, S2: Into<String>>(
mut self, mut self,
cert_path: S1, cert_path: S1,
...@@ -595,8 +520,7 @@ impl RouterConfigBuilder { ...@@ -595,8 +520,7 @@ impl RouterConfigBuilder {
self self
} }
/// Set client certificate and key paths for mTLS if both paths are provided /// Files read during build()
/// Files will be read during build()
pub fn maybe_client_cert_and_key( pub fn maybe_client_cert_and_key(
mut self, mut self,
cert_path: Option<impl Into<String>>, cert_path: Option<impl Into<String>>,
...@@ -607,49 +531,43 @@ impl RouterConfigBuilder { ...@@ -607,49 +531,43 @@ impl RouterConfigBuilder {
self self
} }
/// Add a CA certificate path for verifying worker TLS certificates /// File read during build()
/// File will be read during build()
pub fn add_ca_certificate<S: Into<String>>(mut self, ca_cert_path: S) -> Self { pub fn add_ca_certificate<S: Into<String>>(mut self, ca_cert_path: S) -> Self {
self.ca_cert_paths.push(ca_cert_path.into()); self.ca_cert_paths.push(ca_cert_path.into());
self self
} }
/// Add multiple CA certificate paths for verifying worker TLS certificates /// Files read during build()
/// Files will be read during build()
pub fn add_ca_certificates<S: Into<String>>(mut self, ca_cert_paths: Vec<S>) -> Self { pub fn add_ca_certificates<S: Into<String>>(mut self, ca_cert_paths: Vec<S>) -> Self {
self.ca_cert_paths self.ca_cert_paths
.extend(ca_cert_paths.into_iter().map(|p| p.into())); .extend(ca_cert_paths.into_iter().map(|p| p.into()));
self self
} }
// ==================== MCP Configuration ==================== // ==================== MCP ====================
/// Set MCP server configuration file path /// Config file loaded during build()
/// The config file will be loaded during build()
pub fn mcp_config_path<S: Into<String>>(mut self, path: S) -> Self { pub fn mcp_config_path<S: Into<String>>(mut self, path: S) -> Self {
self.mcp_config_path = Some(path.into()); self.mcp_config_path = Some(path.into());
self self
} }
/// Set MCP server configuration file path if Some /// Config file loaded during build()
pub fn maybe_mcp_config_path(mut self, path: Option<impl Into<String>>) -> Self { pub fn maybe_mcp_config_path(mut self, path: Option<impl Into<String>>) -> Self {
self.mcp_config_path = path.map(|p| p.into()); self.mcp_config_path = path.map(|p| p.into());
self self
} }
// ==================== Builder Methods ==================== // ==================== Build ====================
/// Build the RouterConfig, validating if requested
pub fn build(self) -> ConfigResult<RouterConfig> { pub fn build(self) -> ConfigResult<RouterConfig> {
self.build_with_validation(true) self.build_with_validation(true)
} }
/// Build the RouterConfig without validation
pub fn build_unchecked(self) -> RouterConfig { pub fn build_unchecked(self) -> RouterConfig {
self.into() self.into()
} }
/// Build with optional validation
pub fn build_with_validation(mut self, validate: bool) -> ConfigResult<RouterConfig> { pub fn build_with_validation(mut self, validate: bool) -> ConfigResult<RouterConfig> {
// Read mTLS certificates from paths if provided // Read mTLS certificates from paths if provided
self = self.read_mtls_certificates()?; self = self.read_mtls_certificates()?;
......
...@@ -6,7 +6,6 @@ pub use builder::*; ...@@ -6,7 +6,6 @@ pub use builder::*;
pub use types::*; pub use types::*;
pub use validation::*; pub use validation::*;
/// Configuration errors
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum ConfigError { pub enum ConfigError {
#[error("Validation failed: {reason}")] #[error("Validation failed: {reason}")]
...@@ -26,5 +25,4 @@ pub enum ConfigError { ...@@ -26,5 +25,4 @@ pub enum ConfigError {
MissingRequired { field: String }, MissingRequired { field: String },
} }
/// Result type for configuration operations
pub type ConfigResult<T> = Result<T, ConfigError>; pub type ConfigResult<T> = Result<T, ConfigError>;
...@@ -8,93 +8,64 @@ use crate::core::ConnectionMode; ...@@ -8,93 +8,64 @@ use crate::core::ConnectionMode;
/// Main router configuration /// Main router configuration
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouterConfig { pub struct RouterConfig {
/// Routing mode configuration
pub mode: RoutingMode, pub mode: RoutingMode,
/// Worker connection mode
#[serde(default)] #[serde(default)]
pub connection_mode: ConnectionMode, pub connection_mode: ConnectionMode,
/// Policy configuration
pub policy: PolicyConfig, pub policy: PolicyConfig,
/// Server host address
pub host: String, pub host: String,
/// Server port
pub port: u16, pub port: u16,
/// Maximum payload size in bytes
pub max_payload_size: usize, pub max_payload_size: usize,
/// Request timeout in seconds
pub request_timeout_secs: u64, pub request_timeout_secs: u64,
/// Worker startup timeout in seconds
pub worker_startup_timeout_secs: u64, pub worker_startup_timeout_secs: u64,
/// Worker health check interval in seconds
pub worker_startup_check_interval_secs: u64, pub worker_startup_check_interval_secs: u64,
/// Enable data parallelism aware schedule
pub dp_aware: bool, pub dp_aware: bool,
/// The api key used for the authorization with the worker
pub api_key: Option<String>, pub api_key: Option<String>,
/// Service discovery configuration (optional)
pub discovery: Option<DiscoveryConfig>, pub discovery: Option<DiscoveryConfig>,
/// Metrics configuration (optional)
pub metrics: Option<MetricsConfig>, pub metrics: Option<MetricsConfig>,
/// Log directory (None = stdout only)
pub log_dir: Option<String>, pub log_dir: Option<String>,
/// Log level (None = info)
pub log_level: Option<String>, pub log_level: Option<String>,
/// Custom request ID headers to check (defaults to common headers)
pub request_id_headers: Option<Vec<String>>, pub request_id_headers: Option<Vec<String>>,
/// Maximum concurrent requests allowed (for rate limiting). Set to -1 to disable rate limiting. /// Set to -1 to disable rate limiting
pub max_concurrent_requests: i32, pub max_concurrent_requests: i32,
/// Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately)
pub queue_size: usize, pub queue_size: usize,
/// Maximum time (in seconds) a request can wait in queue before timing out
pub queue_timeout_secs: u64, pub queue_timeout_secs: u64,
/// Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests /// If not set, defaults to max_concurrent_requests
pub rate_limit_tokens_per_second: Option<i32>, pub rate_limit_tokens_per_second: Option<i32>,
/// CORS allowed origins
pub cors_allowed_origins: Vec<String>, pub cors_allowed_origins: Vec<String>,
/// Retry configuration
pub retry: RetryConfig, pub retry: RetryConfig,
/// Circuit breaker configuration
pub circuit_breaker: CircuitBreakerConfig, pub circuit_breaker: CircuitBreakerConfig,
/// Disable retries (overrides retry.max_retries to 1 when true) /// When true, overrides retry.max_retries to 1
#[serde(default)] #[serde(default)]
pub disable_retries: bool, pub disable_retries: bool,
/// Disable circuit breaker (overrides circuit_breaker.failure_threshold to u32::MAX when true) /// When true, overrides circuit_breaker.failure_threshold to u32::MAX
#[serde(default)] #[serde(default)]
pub disable_circuit_breaker: bool, pub disable_circuit_breaker: bool,
/// Health check configuration
pub health_check: HealthCheckConfig, pub health_check: HealthCheckConfig,
/// Enable Inference Gateway mode (false = proxy mode, true = IGW mode)
#[serde(default)] #[serde(default)]
pub enable_igw: bool, pub enable_igw: bool,
/// Model path for loading tokenizer (can be a HuggingFace model ID or local path) /// Can be a HuggingFace model ID or local path
pub model_path: Option<String>, pub model_path: Option<String>,
/// Explicit tokenizer path (overrides model_path tokenizer if provided) /// Overrides model_path tokenizer if provided
pub tokenizer_path: Option<String>, pub tokenizer_path: Option<String>,
/// Chat template path (optional)
pub chat_template: Option<String>, pub chat_template: Option<String>,
/// History backend configuration (memory or none, default: memory)
#[serde(default = "default_history_backend")] #[serde(default = "default_history_backend")]
pub history_backend: HistoryBackend, pub history_backend: HistoryBackend,
/// Oracle history backend configuration (required when `history_backend` = "oracle") /// Required when history_backend = "oracle"
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub oracle: Option<OracleConfig>, pub oracle: Option<OracleConfig>,
/// Parser for reasoning models (e.g., deepseek-r1, qwen3) /// For reasoning models (e.g., deepseek-r1, qwen3)
pub reasoning_parser: Option<String>, pub reasoning_parser: Option<String>,
/// Parser for handling tool-call interactions /// For tool-call interactions
pub tool_call_parser: Option<String>, pub tool_call_parser: Option<String>,
/// Tokenizer cache configuration
#[serde(default)] #[serde(default)]
pub tokenizer_cache: TokenizerCacheConfig, pub tokenizer_cache: TokenizerCacheConfig,
/// mTLS client identity (combined certificate + key in PEM format) /// Combined certificate + key in PEM format, loaded from client_cert_path and client_key_path during config creation
/// This is loaded from client_cert_path and client_key_path during config creation
#[serde(skip)] #[serde(skip)]
pub client_identity: Option<Vec<u8>>, pub client_identity: Option<Vec<u8>>,
/// CA certificates for verifying worker TLS certificates (PEM format) /// PEM format, loaded from ca_cert_paths during config creation
/// Loaded from ca_cert_paths during config creation
#[serde(default)] #[serde(default)]
pub ca_certificates: Vec<Vec<u8>>, pub ca_certificates: Vec<Vec<u8>>,
/// MCP server configuration (loaded from mcp_config_path during config creation) /// Loaded from mcp_config_path during config creation
/// This is loaded from the config file path and stored here for runtime use
#[serde(skip)] #[serde(skip)]
pub mcp_config: Option<crate::mcp::McpConfig>, pub mcp_config: Option<crate::mcp::McpConfig>,
} }
...@@ -102,16 +73,14 @@ pub struct RouterConfig { ...@@ -102,16 +73,14 @@ pub struct RouterConfig {
/// Tokenizer cache configuration /// Tokenizer cache configuration
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct TokenizerCacheConfig { pub struct TokenizerCacheConfig {
/// Enable L0 cache (whole-string exact match) /// Whole-string exact match cache
#[serde(default = "default_enable_l0")] #[serde(default = "default_enable_l0")]
pub enable_l0: bool, pub enable_l0: bool,
/// Maximum number of entries in L0 cache
#[serde(default = "default_l0_max_entries")] #[serde(default = "default_l0_max_entries")]
pub l0_max_entries: usize, pub l0_max_entries: usize,
/// Enable L1 cache (prefix matching at fixed boundaries) /// Prefix matching at fixed boundaries
#[serde(default = "default_enable_l1")] #[serde(default = "default_enable_l1")]
pub enable_l1: bool, pub enable_l1: bool,
/// Maximum memory for L1 cache in bytes
#[serde(default = "default_l1_max_memory")] #[serde(default = "default_l1_max_memory")]
pub l1_max_memory: usize, pub l1_max_memory: usize,
} }
...@@ -151,33 +120,25 @@ fn default_history_backend() -> HistoryBackend { ...@@ -151,33 +120,25 @@ fn default_history_backend() -> HistoryBackend {
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum HistoryBackend { pub enum HistoryBackend {
/// In-memory storage (default)
Memory, Memory,
/// No history storage
None, None,
/// Oracle ATP-backed storage
Oracle, Oracle,
} }
/// Oracle history backend configuration /// Oracle history backend configuration
#[derive(Clone, Serialize, Deserialize, PartialEq)] #[derive(Clone, Serialize, Deserialize, PartialEq)]
pub struct OracleConfig { pub struct OracleConfig {
/// Directory containing the ATP wallet or TLS config files (optional) /// ATP wallet or TLS config files directory
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub wallet_path: Option<String>, pub wallet_path: Option<String>,
/// Connection descriptor / DSN (e.g. `tcps://host:port/service`) /// DSN (e.g. `tcps://host:port/service`)
pub connect_descriptor: String, pub connect_descriptor: String,
/// Database username
pub username: String, pub username: String,
/// Database password
pub password: String, pub password: String,
/// Minimum number of pooled connections to keep ready
#[serde(default = "default_pool_min")] #[serde(default = "default_pool_min")]
pub pool_min: usize, pub pool_min: usize,
/// Maximum number of pooled connections
#[serde(default = "default_pool_max")] #[serde(default = "default_pool_max")]
pub pool_max: usize, pub pool_max: usize,
/// Maximum time to wait for a connection from the pool (seconds)
#[serde(default = "default_pool_timeout_secs")] #[serde(default = "default_pool_timeout_secs")]
pub pool_timeout_secs: u64, pub pool_timeout_secs: u64,
} }
...@@ -226,28 +187,19 @@ impl std::fmt::Debug for OracleConfig { ...@@ -226,28 +187,19 @@ impl std::fmt::Debug for OracleConfig {
#[serde(tag = "type")] #[serde(tag = "type")]
pub enum RoutingMode { pub enum RoutingMode {
#[serde(rename = "regular")] #[serde(rename = "regular")]
Regular { Regular { worker_urls: Vec<String> },
/// List of worker URLs
worker_urls: Vec<String>,
},
#[serde(rename = "prefill_decode")] #[serde(rename = "prefill_decode")]
PrefillDecode { PrefillDecode {
/// Prefill worker URLs with optional bootstrap ports /// With optional bootstrap ports
prefill_urls: Vec<(String, Option<u16>)>, prefill_urls: Vec<(String, Option<u16>)>,
/// Decode worker URLs
decode_urls: Vec<String>, decode_urls: Vec<String>,
/// Optional separate policy for prefill workers
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
prefill_policy: Option<PolicyConfig>, prefill_policy: Option<PolicyConfig>,
/// Optional separate policy for decode workers
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
decode_policy: Option<PolicyConfig>, decode_policy: Option<PolicyConfig>,
}, },
#[serde(rename = "openai")] #[serde(rename = "openai")]
OpenAI { OpenAI { worker_urls: Vec<String> },
/// OpenAI-compatible API base(s), provided via worker URLs
worker_urls: Vec<String>,
},
} }
impl RoutingMode { impl RoutingMode {
...@@ -302,23 +254,15 @@ pub enum PolicyConfig { ...@@ -302,23 +254,15 @@ pub enum PolicyConfig {
#[serde(rename = "cache_aware")] #[serde(rename = "cache_aware")]
CacheAware { CacheAware {
/// Minimum prefix match ratio to use cache-based routing
cache_threshold: f32, cache_threshold: f32,
/// Absolute load difference threshold for load balancing
balance_abs_threshold: usize, balance_abs_threshold: usize,
/// Relative load ratio threshold for load balancing
balance_rel_threshold: f32, balance_rel_threshold: f32,
/// Interval between cache eviction cycles (seconds)
eviction_interval_secs: u64, eviction_interval_secs: u64,
/// Maximum cache tree size per tenant
max_tree_size: usize, max_tree_size: usize,
}, },
#[serde(rename = "power_of_two")] #[serde(rename = "power_of_two")]
PowerOfTwo { PowerOfTwo { load_check_interval_secs: u64 },
/// Interval for load monitoring (seconds)
load_check_interval_secs: u64,
},
} }
impl PolicyConfig { impl PolicyConfig {
...@@ -335,21 +279,17 @@ impl PolicyConfig { ...@@ -335,21 +279,17 @@ impl PolicyConfig {
/// Service discovery configuration /// Service discovery configuration
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiscoveryConfig { pub struct DiscoveryConfig {
/// Enable service discovery
pub enabled: bool, pub enabled: bool,
/// Kubernetes namespace (None = all namespaces) /// None = all namespaces
pub namespace: Option<String>, pub namespace: Option<String>,
/// Service discovery port
pub port: u16, pub port: u16,
/// Check interval for service discovery
pub check_interval_secs: u64, pub check_interval_secs: u64,
/// Regular mode selector /// Regular mode
pub selector: HashMap<String, String>, pub selector: HashMap<String, String>,
/// PD mode prefill selector /// PD mode prefill
pub prefill_selector: HashMap<String, String>, pub prefill_selector: HashMap<String, String>,
/// PD mode decode selector /// PD mode decode
pub decode_selector: HashMap<String, String>, pub decode_selector: HashMap<String, String>,
/// Bootstrap port annotation key
pub bootstrap_port_annotation: String, pub bootstrap_port_annotation: String,
} }
...@@ -371,16 +311,11 @@ impl Default for DiscoveryConfig { ...@@ -371,16 +311,11 @@ impl Default for DiscoveryConfig {
/// Retry configuration for request handling /// Retry configuration for request handling
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryConfig { pub struct RetryConfig {
/// Maximum number of retry attempts
pub max_retries: u32, pub max_retries: u32,
/// Initial backoff delay in milliseconds
pub initial_backoff_ms: u64, pub initial_backoff_ms: u64,
/// Maximum backoff delay in milliseconds
pub max_backoff_ms: u64, pub max_backoff_ms: u64,
/// Backoff multiplier for exponential backoff
pub backoff_multiplier: f32, pub backoff_multiplier: f32,
/// Jitter factor applied to backoff (0.0 - 1.0) /// D' = D * (1 + U[-j, +j]) where j is jitter factor
/// Effective delay D' = D * (1 + U[-j, +j])
#[serde(default = "default_retry_jitter_factor")] #[serde(default = "default_retry_jitter_factor")]
pub jitter_factor: f32, pub jitter_factor: f32,
} }
...@@ -404,15 +339,10 @@ fn default_retry_jitter_factor() -> f32 { ...@@ -404,15 +339,10 @@ fn default_retry_jitter_factor() -> f32 {
/// Health check configuration for worker monitoring /// Health check configuration for worker monitoring
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HealthCheckConfig { pub struct HealthCheckConfig {
/// Number of consecutive failures before marking unhealthy
pub failure_threshold: u32, pub failure_threshold: u32,
/// Number of consecutive successes before marking healthy
pub success_threshold: u32, pub success_threshold: u32,
/// Timeout for health check requests in seconds
pub timeout_secs: u64, pub timeout_secs: u64,
/// Interval between health checks in seconds
pub check_interval_secs: u64, pub check_interval_secs: u64,
/// Health check endpoint path
pub endpoint: String, pub endpoint: String,
} }
...@@ -431,13 +361,9 @@ impl Default for HealthCheckConfig { ...@@ -431,13 +361,9 @@ impl Default for HealthCheckConfig {
/// Circuit breaker configuration for worker reliability /// Circuit breaker configuration for worker reliability
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CircuitBreakerConfig { pub struct CircuitBreakerConfig {
/// Number of consecutive failures before opening circuit
pub failure_threshold: u32, pub failure_threshold: u32,
/// Number of consecutive successes before closing circuit
pub success_threshold: u32, pub success_threshold: u32,
/// Time before attempting to recover from open state (in seconds)
pub timeout_duration_secs: u64, pub timeout_duration_secs: u64,
/// Window duration for failure tracking (in seconds)
pub window_duration_secs: u64, pub window_duration_secs: u64,
} }
...@@ -455,9 +381,7 @@ impl Default for CircuitBreakerConfig { ...@@ -455,9 +381,7 @@ impl Default for CircuitBreakerConfig {
/// Metrics configuration /// Metrics configuration
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetricsConfig { pub struct MetricsConfig {
/// Prometheus metrics port
pub port: u16, pub port: u16,
/// Prometheus metrics host
pub host: String, pub host: String,
} }
......
...@@ -5,9 +5,7 @@ use crate::core::ConnectionMode; ...@@ -5,9 +5,7 @@ use crate::core::ConnectionMode;
pub struct ConfigValidator; pub struct ConfigValidator;
impl ConfigValidator { impl ConfigValidator {
/// Validate a complete router configuration
pub fn validate(config: &RouterConfig) -> ConfigResult<()> { pub fn validate(config: &RouterConfig) -> ConfigResult<()> {
// Check if service discovery is enabled
let has_service_discovery = config.discovery.as_ref().is_some_and(|d| d.enabled); let has_service_discovery = config.discovery.as_ref().is_some_and(|d| d.enabled);
Self::validate_mode(&config.mode, has_service_discovery)?; Self::validate_mode(&config.mode, has_service_discovery)?;
...@@ -24,55 +22,46 @@ impl ConfigValidator { ...@@ -24,55 +22,46 @@ impl ConfigValidator {
Self::validate_compatibility(config)?; Self::validate_compatibility(config)?;
// Validate effective retry/CB configs (respect disable flags)
let retry_cfg = config.effective_retry_config(); let retry_cfg = config.effective_retry_config();
let cb_cfg = config.effective_circuit_breaker_config(); let cb_cfg = config.effective_circuit_breaker_config();
Self::validate_retry(&retry_cfg)?; Self::validate_retry(&retry_cfg)?;
Self::validate_circuit_breaker(&cb_cfg)?; Self::validate_circuit_breaker(&cb_cfg)?;
// Validate Oracle configuration if enabled
if config.history_backend == HistoryBackend::Oracle { if config.history_backend == HistoryBackend::Oracle {
if config.oracle.is_none() { if config.oracle.is_none() {
return Err(ConfigError::MissingRequired { return Err(ConfigError::MissingRequired {
field: "oracle".to_string(), field: "oracle".to_string(),
}); });
} }
// Validate Oracle configuration details
if let Some(oracle) = &config.oracle { if let Some(oracle) = &config.oracle {
Self::validate_oracle(oracle)?; Self::validate_oracle(oracle)?;
} }
} }
// Validate tokenizer cache configuration
Self::validate_tokenizer_cache(&config.tokenizer_cache)?; Self::validate_tokenizer_cache(&config.tokenizer_cache)?;
Ok(()) Ok(())
} }
/// Validate Oracle configuration
fn validate_oracle(oracle: &OracleConfig) -> ConfigResult<()> { fn validate_oracle(oracle: &OracleConfig) -> ConfigResult<()> {
// Validate username is not empty
if oracle.username.is_empty() { if oracle.username.is_empty() {
return Err(ConfigError::MissingRequired { return Err(ConfigError::MissingRequired {
field: "oracle.username".to_string(), field: "oracle.username".to_string(),
}); });
} }
// Validate password is not empty
if oracle.password.is_empty() { if oracle.password.is_empty() {
return Err(ConfigError::MissingRequired { return Err(ConfigError::MissingRequired {
field: "oracle.password".to_string(), field: "oracle.password".to_string(),
}); });
} }
// Validate connect_descriptor is not empty
if oracle.connect_descriptor.is_empty() { if oracle.connect_descriptor.is_empty() {
return Err(ConfigError::MissingRequired { return Err(ConfigError::MissingRequired {
field: "oracle_dsn or oracle_tns_alias".to_string(), field: "oracle_dsn or oracle_tns_alias".to_string(),
}); });
} }
// Validate pool_min is at least 1
if oracle.pool_min < 1 { if oracle.pool_min < 1 {
return Err(ConfigError::InvalidValue { return Err(ConfigError::InvalidValue {
field: "oracle.pool_min".to_string(), field: "oracle.pool_min".to_string(),
...@@ -81,7 +70,6 @@ impl ConfigValidator { ...@@ -81,7 +70,6 @@ impl ConfigValidator {
}); });
} }
// Validate pool_max is greater than or equal to pool_min
if oracle.pool_max < oracle.pool_min { if oracle.pool_max < oracle.pool_min {
return Err(ConfigError::InvalidValue { return Err(ConfigError::InvalidValue {
field: "oracle.pool_max".to_string(), field: "oracle.pool_max".to_string(),
...@@ -90,7 +78,6 @@ impl ConfigValidator { ...@@ -90,7 +78,6 @@ impl ConfigValidator {
}); });
} }
// Validate pool_timeout_secs is greater than 0
if oracle.pool_timeout_secs == 0 { if oracle.pool_timeout_secs == 0 {
return Err(ConfigError::InvalidValue { return Err(ConfigError::InvalidValue {
field: "oracle.pool_timeout_secs".to_string(), field: "oracle.pool_timeout_secs".to_string(),
...@@ -102,17 +89,13 @@ impl ConfigValidator { ...@@ -102,17 +89,13 @@ impl ConfigValidator {
Ok(()) Ok(())
} }
/// Validate routing mode configuration
fn validate_mode(mode: &RoutingMode, has_service_discovery: bool) -> ConfigResult<()> { fn validate_mode(mode: &RoutingMode, has_service_discovery: bool) -> ConfigResult<()> {
match mode { match mode {
RoutingMode::Regular { worker_urls } => { RoutingMode::Regular { worker_urls } => {
// Validate URLs if any are provided
if !worker_urls.is_empty() { if !worker_urls.is_empty() {
Self::validate_urls(worker_urls)?; Self::validate_urls(worker_urls)?;
} }
// Note: We allow empty worker URLs even without service discovery // Allow empty URLs without service discovery to match legacy behavior
// to let the router start and fail at runtime when routing requests.
// This matches legacy behavior and test expectations.
} }
RoutingMode::PrefillDecode { RoutingMode::PrefillDecode {
prefill_urls, prefill_urls,
...@@ -120,7 +103,6 @@ impl ConfigValidator { ...@@ -120,7 +103,6 @@ impl ConfigValidator {
prefill_policy, prefill_policy,
decode_policy, decode_policy,
} => { } => {
// Only require URLs if service discovery is disabled
if !has_service_discovery { if !has_service_discovery {
if prefill_urls.is_empty() { if prefill_urls.is_empty() {
return Err(ConfigError::ValidationFailed { return Err(ConfigError::ValidationFailed {
...@@ -134,7 +116,6 @@ impl ConfigValidator { ...@@ -134,7 +116,6 @@ impl ConfigValidator {
} }
} }
// Validate URLs if any are provided
if !prefill_urls.is_empty() { if !prefill_urls.is_empty() {
let prefill_url_strings: Vec<String> = let prefill_url_strings: Vec<String> =
prefill_urls.iter().map(|(url, _)| url.clone()).collect(); prefill_urls.iter().map(|(url, _)| url.clone()).collect();
...@@ -144,7 +125,6 @@ impl ConfigValidator { ...@@ -144,7 +125,6 @@ impl ConfigValidator {
Self::validate_urls(decode_urls)?; Self::validate_urls(decode_urls)?;
} }
// Validate bootstrap ports
for (_url, port) in prefill_urls { for (_url, port) in prefill_urls {
if let Some(port) = port { if let Some(port) = port {
if *port == 0 { if *port == 0 {
...@@ -157,7 +137,6 @@ impl ConfigValidator { ...@@ -157,7 +137,6 @@ impl ConfigValidator {
} }
} }
// Validate optional prefill and decode policies
if let Some(p_policy) = prefill_policy { if let Some(p_policy) = prefill_policy {
Self::validate_policy(p_policy)?; Self::validate_policy(p_policy)?;
} }
...@@ -166,25 +145,20 @@ impl ConfigValidator { ...@@ -166,25 +145,20 @@ impl ConfigValidator {
} }
} }
RoutingMode::OpenAI { worker_urls } => { RoutingMode::OpenAI { worker_urls } => {
// Require at least one worker URL for OpenAI router
if worker_urls.is_empty() { if worker_urls.is_empty() {
return Err(ConfigError::ValidationFailed { return Err(ConfigError::ValidationFailed {
reason: "OpenAI mode requires at least one --worker-urls entry".to_string(), reason: "OpenAI mode requires at least one --worker-urls entry".to_string(),
}); });
} }
// Validate URLs
Self::validate_urls(worker_urls)?; Self::validate_urls(worker_urls)?;
} }
} }
Ok(()) Ok(())
} }
/// Validate policy configuration
fn validate_policy(policy: &PolicyConfig) -> ConfigResult<()> { fn validate_policy(policy: &PolicyConfig) -> ConfigResult<()> {
match policy { match policy {
PolicyConfig::Random | PolicyConfig::RoundRobin => { PolicyConfig::Random | PolicyConfig::RoundRobin => {}
// No specific validation needed
}
PolicyConfig::CacheAware { PolicyConfig::CacheAware {
cache_threshold, cache_threshold,
balance_abs_threshold: _, balance_abs_threshold: _,
...@@ -239,7 +213,6 @@ impl ConfigValidator { ...@@ -239,7 +213,6 @@ impl ConfigValidator {
Ok(()) Ok(())
} }
/// Validate server configuration
fn validate_server_settings(config: &RouterConfig) -> ConfigResult<()> { fn validate_server_settings(config: &RouterConfig) -> ConfigResult<()> {
if config.port == 0 { if config.port == 0 {
return Err(ConfigError::InvalidValue { return Err(ConfigError::InvalidValue {
...@@ -302,10 +275,9 @@ impl ConfigValidator { ...@@ -302,10 +275,9 @@ impl ConfigValidator {
Ok(()) Ok(())
} }
/// Validate service discovery configuration
fn validate_discovery(discovery: &DiscoveryConfig, mode: &RoutingMode) -> ConfigResult<()> { fn validate_discovery(discovery: &DiscoveryConfig, mode: &RoutingMode) -> ConfigResult<()> {
if !discovery.enabled { if !discovery.enabled {
return Ok(()); // No validation needed if disabled return Ok(());
} }
if discovery.port == 0 { if discovery.port == 0 {
...@@ -324,7 +296,6 @@ impl ConfigValidator { ...@@ -324,7 +296,6 @@ impl ConfigValidator {
}); });
} }
// Validate selectors based on mode
match mode { match mode {
RoutingMode::Regular { .. } => { RoutingMode::Regular { .. } => {
if discovery.selector.is_empty() { if discovery.selector.is_empty() {
...@@ -342,7 +313,6 @@ impl ConfigValidator { ...@@ -342,7 +313,6 @@ impl ConfigValidator {
} }
} }
RoutingMode::OpenAI { .. } => { RoutingMode::OpenAI { .. } => {
// OpenAI mode doesn't use service discovery
return Err(ConfigError::ValidationFailed { return Err(ConfigError::ValidationFailed {
reason: "OpenAI mode does not support service discovery".to_string(), reason: "OpenAI mode does not support service discovery".to_string(),
}); });
...@@ -352,7 +322,6 @@ impl ConfigValidator { ...@@ -352,7 +322,6 @@ impl ConfigValidator {
Ok(()) Ok(())
} }
/// Validate metrics configuration
fn validate_metrics(metrics: &MetricsConfig) -> ConfigResult<()> { fn validate_metrics(metrics: &MetricsConfig) -> ConfigResult<()> {
if metrics.port == 0 { if metrics.port == 0 {
return Err(ConfigError::InvalidValue { return Err(ConfigError::InvalidValue {
...@@ -373,7 +342,6 @@ impl ConfigValidator { ...@@ -373,7 +342,6 @@ impl ConfigValidator {
Ok(()) Ok(())
} }
/// Validate retry configuration
fn validate_retry(retry: &RetryConfig) -> ConfigResult<()> { fn validate_retry(retry: &RetryConfig) -> ConfigResult<()> {
if retry.max_retries < 1 { if retry.max_retries < 1 {
return Err(ConfigError::InvalidValue { return Err(ConfigError::InvalidValue {
...@@ -413,7 +381,6 @@ impl ConfigValidator { ...@@ -413,7 +381,6 @@ impl ConfigValidator {
Ok(()) Ok(())
} }
/// Validate circuit breaker configuration
fn validate_circuit_breaker(cb: &CircuitBreakerConfig) -> ConfigResult<()> { fn validate_circuit_breaker(cb: &CircuitBreakerConfig) -> ConfigResult<()> {
if cb.failure_threshold < 1 { if cb.failure_threshold < 1 {
return Err(ConfigError::InvalidValue { return Err(ConfigError::InvalidValue {
...@@ -446,9 +413,7 @@ impl ConfigValidator { ...@@ -446,9 +413,7 @@ impl ConfigValidator {
Ok(()) Ok(())
} }
/// Validate tokenizer cache configuration
fn validate_tokenizer_cache(cache: &TokenizerCacheConfig) -> ConfigResult<()> { fn validate_tokenizer_cache(cache: &TokenizerCacheConfig) -> ConfigResult<()> {
// Validate L0 max entries when L0 is enabled
if cache.enable_l0 && cache.l0_max_entries == 0 { if cache.enable_l0 && cache.l0_max_entries == 0 {
return Err(ConfigError::InvalidValue { return Err(ConfigError::InvalidValue {
field: "tokenizer_cache.l0_max_entries".to_string(), field: "tokenizer_cache.l0_max_entries".to_string(),
...@@ -457,7 +422,6 @@ impl ConfigValidator { ...@@ -457,7 +422,6 @@ impl ConfigValidator {
}); });
} }
// Validate L1 max memory when L1 is enabled
if cache.enable_l1 && cache.l1_max_memory == 0 { if cache.enable_l1 && cache.l1_max_memory == 0 {
return Err(ConfigError::InvalidValue { return Err(ConfigError::InvalidValue {
field: "tokenizer_cache.l1_max_memory".to_string(), field: "tokenizer_cache.l1_max_memory".to_string(),
...@@ -469,9 +433,7 @@ impl ConfigValidator { ...@@ -469,9 +433,7 @@ impl ConfigValidator {
Ok(()) Ok(())
} }
/// Validate mTLS certificate configuration
fn validate_mtls(config: &RouterConfig) -> ConfigResult<()> { fn validate_mtls(config: &RouterConfig) -> ConfigResult<()> {
// Validate that if we have client_identity, it's not empty
if let Some(identity) = &config.client_identity { if let Some(identity) = &config.client_identity {
if identity.is_empty() { if identity.is_empty() {
return Err(ConfigError::ValidationFailed { return Err(ConfigError::ValidationFailed {
...@@ -480,7 +442,6 @@ impl ConfigValidator { ...@@ -480,7 +442,6 @@ impl ConfigValidator {
} }
} }
// Validate CA certificates are not empty
for (idx, ca_cert) in config.ca_certificates.iter().enumerate() { for (idx, ca_cert) in config.ca_certificates.iter().enumerate() {
if ca_cert.is_empty() { if ca_cert.is_empty() {
return Err(ConfigError::ValidationFailed { return Err(ConfigError::ValidationFailed {
...@@ -492,14 +453,11 @@ impl ConfigValidator { ...@@ -492,14 +453,11 @@ impl ConfigValidator {
Ok(()) Ok(())
} }
/// Validate compatibility between different configuration sections
fn validate_compatibility(config: &RouterConfig) -> ConfigResult<()> { fn validate_compatibility(config: &RouterConfig) -> ConfigResult<()> {
// IGW mode is independent - skip other compatibility checks when enabled
if config.enable_igw { if config.enable_igw {
return Ok(()); return Ok(());
} }
// Validate gRPC connection mode requires tokenizer configuration
if matches!(config.connection_mode, ConnectionMode::Grpc { .. }) if matches!(config.connection_mode, ConnectionMode::Grpc { .. })
&& config.tokenizer_path.is_none() && config.tokenizer_path.is_none()
&& config.model_path.is_none() && config.model_path.is_none()
...@@ -509,18 +467,11 @@ impl ConfigValidator { ...@@ -509,18 +467,11 @@ impl ConfigValidator {
}); });
} }
// Validate mTLS configuration
Self::validate_mtls(config)?; Self::validate_mtls(config)?;
// All policies are now supported for both router types thanks to the unified trait design
// No mode/policy restrictions needed anymore
// Check if service discovery is enabled for worker count validation
let has_service_discovery = config.discovery.as_ref().is_some_and(|d| d.enabled); let has_service_discovery = config.discovery.as_ref().is_some_and(|d| d.enabled);
// Only validate worker counts if service discovery is disabled
if !has_service_discovery { if !has_service_discovery {
// Check if power-of-two policy makes sense with insufficient workers
if let PolicyConfig::PowerOfTwo { .. } = &config.policy { if let PolicyConfig::PowerOfTwo { .. } = &config.policy {
let worker_count = config.mode.worker_count(); let worker_count = config.mode.worker_count();
if worker_count < 2 { if worker_count < 2 {
...@@ -530,7 +481,6 @@ impl ConfigValidator { ...@@ -530,7 +481,6 @@ impl ConfigValidator {
} }
} }
// For PD mode, validate that policies have sufficient workers
if let RoutingMode::PrefillDecode { if let RoutingMode::PrefillDecode {
prefill_urls, prefill_urls,
decode_urls, decode_urls,
...@@ -538,7 +488,6 @@ impl ConfigValidator { ...@@ -538,7 +488,6 @@ impl ConfigValidator {
decode_policy, decode_policy,
} = &config.mode } = &config.mode
{ {
// Check power-of-two for prefill
if let Some(PolicyConfig::PowerOfTwo { .. }) = prefill_policy { if let Some(PolicyConfig::PowerOfTwo { .. }) = prefill_policy {
if prefill_urls.len() < 2 { if prefill_urls.len() < 2 {
return Err(ConfigError::IncompatibleConfig { return Err(ConfigError::IncompatibleConfig {
...@@ -547,7 +496,6 @@ impl ConfigValidator { ...@@ -547,7 +496,6 @@ impl ConfigValidator {
} }
} }
// Check power-of-two for decode
if let Some(PolicyConfig::PowerOfTwo { .. }) = decode_policy { if let Some(PolicyConfig::PowerOfTwo { .. }) = decode_policy {
if decode_urls.len() < 2 { if decode_urls.len() < 2 {
return Err(ConfigError::IncompatibleConfig { return Err(ConfigError::IncompatibleConfig {
...@@ -560,8 +508,6 @@ impl ConfigValidator { ...@@ -560,8 +508,6 @@ impl ConfigValidator {
} }
} }
// Service discovery is conflict with dp_aware routing for now
// since it's not fully supported yet
if has_service_discovery && config.dp_aware { if has_service_discovery && config.dp_aware {
return Err(ConfigError::IncompatibleConfig { return Err(ConfigError::IncompatibleConfig {
reason: "DP-aware routing is not compatible with service discovery".to_string(), reason: "DP-aware routing is not compatible with service discovery".to_string(),
...@@ -571,7 +517,6 @@ impl ConfigValidator { ...@@ -571,7 +517,6 @@ impl ConfigValidator {
Ok(()) Ok(())
} }
/// Validate URL format
fn validate_urls(urls: &[String]) -> ConfigResult<()> { fn validate_urls(urls: &[String]) -> ConfigResult<()> {
for url in urls { for url in urls {
if url.is_empty() { if url.is_empty() {
...@@ -593,10 +538,8 @@ impl ConfigValidator { ...@@ -593,10 +538,8 @@ impl ConfigValidator {
}); });
} }
// Basic URL validation
match ::url::Url::parse(url) { match ::url::Url::parse(url) {
Ok(parsed) => { Ok(parsed) => {
// Additional validation
if parsed.host_str().is_none() { if parsed.host_str().is_none() {
return Err(ConfigError::InvalidValue { return Err(ConfigError::InvalidValue {
field: "worker_url".to_string(), field: "worker_url".to_string(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment