"examples/pytorch/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "4a8cc48981306f80fb725370d1572cf9288b51c4"
Unverified Commit 773d89da authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router] cleaned up all the redundant comments in the config module (#12147)

parent 03e7d949
This diff is collapsed.
...@@ -6,7 +6,6 @@ pub use builder::*; ...@@ -6,7 +6,6 @@ pub use builder::*;
pub use types::*; pub use types::*;
pub use validation::*; pub use validation::*;
/// Configuration errors
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum ConfigError { pub enum ConfigError {
#[error("Validation failed: {reason}")] #[error("Validation failed: {reason}")]
...@@ -26,5 +25,4 @@ pub enum ConfigError { ...@@ -26,5 +25,4 @@ pub enum ConfigError {
MissingRequired { field: String }, MissingRequired { field: String },
} }
/// Result type for configuration operations
pub type ConfigResult<T> = Result<T, ConfigError>; pub type ConfigResult<T> = Result<T, ConfigError>;
...@@ -8,93 +8,64 @@ use crate::core::ConnectionMode; ...@@ -8,93 +8,64 @@ use crate::core::ConnectionMode;
/// Main router configuration /// Main router configuration
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouterConfig { pub struct RouterConfig {
/// Routing mode configuration
pub mode: RoutingMode, pub mode: RoutingMode,
/// Worker connection mode
#[serde(default)] #[serde(default)]
pub connection_mode: ConnectionMode, pub connection_mode: ConnectionMode,
/// Policy configuration
pub policy: PolicyConfig, pub policy: PolicyConfig,
/// Server host address
pub host: String, pub host: String,
/// Server port
pub port: u16, pub port: u16,
/// Maximum payload size in bytes
pub max_payload_size: usize, pub max_payload_size: usize,
/// Request timeout in seconds
pub request_timeout_secs: u64, pub request_timeout_secs: u64,
/// Worker startup timeout in seconds
pub worker_startup_timeout_secs: u64, pub worker_startup_timeout_secs: u64,
/// Worker health check interval in seconds
pub worker_startup_check_interval_secs: u64, pub worker_startup_check_interval_secs: u64,
/// Enable data parallelism aware schedule
pub dp_aware: bool, pub dp_aware: bool,
/// The api key used for the authorization with the worker
pub api_key: Option<String>, pub api_key: Option<String>,
/// Service discovery configuration (optional)
pub discovery: Option<DiscoveryConfig>, pub discovery: Option<DiscoveryConfig>,
/// Metrics configuration (optional)
pub metrics: Option<MetricsConfig>, pub metrics: Option<MetricsConfig>,
/// Log directory (None = stdout only)
pub log_dir: Option<String>, pub log_dir: Option<String>,
/// Log level (None = info)
pub log_level: Option<String>, pub log_level: Option<String>,
/// Custom request ID headers to check (defaults to common headers)
pub request_id_headers: Option<Vec<String>>, pub request_id_headers: Option<Vec<String>>,
/// Maximum concurrent requests allowed (for rate limiting). Set to -1 to disable rate limiting. /// Set to -1 to disable rate limiting
pub max_concurrent_requests: i32, pub max_concurrent_requests: i32,
/// Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately)
pub queue_size: usize, pub queue_size: usize,
/// Maximum time (in seconds) a request can wait in queue before timing out
pub queue_timeout_secs: u64, pub queue_timeout_secs: u64,
/// Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests /// If not set, defaults to max_concurrent_requests
pub rate_limit_tokens_per_second: Option<i32>, pub rate_limit_tokens_per_second: Option<i32>,
/// CORS allowed origins
pub cors_allowed_origins: Vec<String>, pub cors_allowed_origins: Vec<String>,
/// Retry configuration
pub retry: RetryConfig, pub retry: RetryConfig,
/// Circuit breaker configuration
pub circuit_breaker: CircuitBreakerConfig, pub circuit_breaker: CircuitBreakerConfig,
/// Disable retries (overrides retry.max_retries to 1 when true) /// When true, overrides retry.max_retries to 1
#[serde(default)] #[serde(default)]
pub disable_retries: bool, pub disable_retries: bool,
/// Disable circuit breaker (overrides circuit_breaker.failure_threshold to u32::MAX when true) /// When true, overrides circuit_breaker.failure_threshold to u32::MAX
#[serde(default)] #[serde(default)]
pub disable_circuit_breaker: bool, pub disable_circuit_breaker: bool,
/// Health check configuration
pub health_check: HealthCheckConfig, pub health_check: HealthCheckConfig,
/// Enable Inference Gateway mode (false = proxy mode, true = IGW mode)
#[serde(default)] #[serde(default)]
pub enable_igw: bool, pub enable_igw: bool,
/// Model path for loading tokenizer (can be a HuggingFace model ID or local path) /// Can be a HuggingFace model ID or local path
pub model_path: Option<String>, pub model_path: Option<String>,
/// Explicit tokenizer path (overrides model_path tokenizer if provided) /// Overrides model_path tokenizer if provided
pub tokenizer_path: Option<String>, pub tokenizer_path: Option<String>,
/// Chat template path (optional)
pub chat_template: Option<String>, pub chat_template: Option<String>,
/// History backend configuration (memory or none, default: memory)
#[serde(default = "default_history_backend")] #[serde(default = "default_history_backend")]
pub history_backend: HistoryBackend, pub history_backend: HistoryBackend,
/// Oracle history backend configuration (required when `history_backend` = "oracle") /// Required when history_backend = "oracle"
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub oracle: Option<OracleConfig>, pub oracle: Option<OracleConfig>,
/// Parser for reasoning models (e.g., deepseek-r1, qwen3) /// For reasoning models (e.g., deepseek-r1, qwen3)
pub reasoning_parser: Option<String>, pub reasoning_parser: Option<String>,
/// Parser for handling tool-call interactions /// For tool-call interactions
pub tool_call_parser: Option<String>, pub tool_call_parser: Option<String>,
/// Tokenizer cache configuration
#[serde(default)] #[serde(default)]
pub tokenizer_cache: TokenizerCacheConfig, pub tokenizer_cache: TokenizerCacheConfig,
/// mTLS client identity (combined certificate + key in PEM format) /// Combined certificate + key in PEM format, loaded from client_cert_path and client_key_path during config creation
/// This is loaded from client_cert_path and client_key_path during config creation
#[serde(skip)] #[serde(skip)]
pub client_identity: Option<Vec<u8>>, pub client_identity: Option<Vec<u8>>,
/// CA certificates for verifying worker TLS certificates (PEM format) /// PEM format, loaded from ca_cert_paths during config creation
/// Loaded from ca_cert_paths during config creation
#[serde(default)] #[serde(default)]
pub ca_certificates: Vec<Vec<u8>>, pub ca_certificates: Vec<Vec<u8>>,
/// MCP server configuration (loaded from mcp_config_path during config creation) /// Loaded from mcp_config_path during config creation
/// This is loaded from the config file path and stored here for runtime use
#[serde(skip)] #[serde(skip)]
pub mcp_config: Option<crate::mcp::McpConfig>, pub mcp_config: Option<crate::mcp::McpConfig>,
} }
...@@ -102,16 +73,14 @@ pub struct RouterConfig { ...@@ -102,16 +73,14 @@ pub struct RouterConfig {
/// Tokenizer cache configuration /// Tokenizer cache configuration
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct TokenizerCacheConfig { pub struct TokenizerCacheConfig {
/// Enable L0 cache (whole-string exact match) /// Whole-string exact match cache
#[serde(default = "default_enable_l0")] #[serde(default = "default_enable_l0")]
pub enable_l0: bool, pub enable_l0: bool,
/// Maximum number of entries in L0 cache
#[serde(default = "default_l0_max_entries")] #[serde(default = "default_l0_max_entries")]
pub l0_max_entries: usize, pub l0_max_entries: usize,
/// Enable L1 cache (prefix matching at fixed boundaries) /// Prefix matching at fixed boundaries
#[serde(default = "default_enable_l1")] #[serde(default = "default_enable_l1")]
pub enable_l1: bool, pub enable_l1: bool,
/// Maximum memory for L1 cache in bytes
#[serde(default = "default_l1_max_memory")] #[serde(default = "default_l1_max_memory")]
pub l1_max_memory: usize, pub l1_max_memory: usize,
} }
...@@ -151,33 +120,25 @@ fn default_history_backend() -> HistoryBackend { ...@@ -151,33 +120,25 @@ fn default_history_backend() -> HistoryBackend {
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum HistoryBackend { pub enum HistoryBackend {
/// In-memory storage (default)
Memory, Memory,
/// No history storage
None, None,
/// Oracle ATP-backed storage
Oracle, Oracle,
} }
/// Oracle history backend configuration /// Oracle history backend configuration
#[derive(Clone, Serialize, Deserialize, PartialEq)] #[derive(Clone, Serialize, Deserialize, PartialEq)]
pub struct OracleConfig { pub struct OracleConfig {
/// Directory containing the ATP wallet or TLS config files (optional) /// ATP wallet or TLS config files directory
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub wallet_path: Option<String>, pub wallet_path: Option<String>,
/// Connection descriptor / DSN (e.g. `tcps://host:port/service`) /// DSN (e.g. `tcps://host:port/service`)
pub connect_descriptor: String, pub connect_descriptor: String,
/// Database username
pub username: String, pub username: String,
/// Database password
pub password: String, pub password: String,
/// Minimum number of pooled connections to keep ready
#[serde(default = "default_pool_min")] #[serde(default = "default_pool_min")]
pub pool_min: usize, pub pool_min: usize,
/// Maximum number of pooled connections
#[serde(default = "default_pool_max")] #[serde(default = "default_pool_max")]
pub pool_max: usize, pub pool_max: usize,
/// Maximum time to wait for a connection from the pool (seconds)
#[serde(default = "default_pool_timeout_secs")] #[serde(default = "default_pool_timeout_secs")]
pub pool_timeout_secs: u64, pub pool_timeout_secs: u64,
} }
...@@ -226,28 +187,19 @@ impl std::fmt::Debug for OracleConfig { ...@@ -226,28 +187,19 @@ impl std::fmt::Debug for OracleConfig {
#[serde(tag = "type")] #[serde(tag = "type")]
pub enum RoutingMode { pub enum RoutingMode {
#[serde(rename = "regular")] #[serde(rename = "regular")]
Regular { Regular { worker_urls: Vec<String> },
/// List of worker URLs
worker_urls: Vec<String>,
},
#[serde(rename = "prefill_decode")] #[serde(rename = "prefill_decode")]
PrefillDecode { PrefillDecode {
/// Prefill worker URLs with optional bootstrap ports /// With optional bootstrap ports
prefill_urls: Vec<(String, Option<u16>)>, prefill_urls: Vec<(String, Option<u16>)>,
/// Decode worker URLs
decode_urls: Vec<String>, decode_urls: Vec<String>,
/// Optional separate policy for prefill workers
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
prefill_policy: Option<PolicyConfig>, prefill_policy: Option<PolicyConfig>,
/// Optional separate policy for decode workers
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
decode_policy: Option<PolicyConfig>, decode_policy: Option<PolicyConfig>,
}, },
#[serde(rename = "openai")] #[serde(rename = "openai")]
OpenAI { OpenAI { worker_urls: Vec<String> },
/// OpenAI-compatible API base(s), provided via worker URLs
worker_urls: Vec<String>,
},
} }
impl RoutingMode { impl RoutingMode {
...@@ -302,23 +254,15 @@ pub enum PolicyConfig { ...@@ -302,23 +254,15 @@ pub enum PolicyConfig {
#[serde(rename = "cache_aware")] #[serde(rename = "cache_aware")]
CacheAware { CacheAware {
/// Minimum prefix match ratio to use cache-based routing
cache_threshold: f32, cache_threshold: f32,
/// Absolute load difference threshold for load balancing
balance_abs_threshold: usize, balance_abs_threshold: usize,
/// Relative load ratio threshold for load balancing
balance_rel_threshold: f32, balance_rel_threshold: f32,
/// Interval between cache eviction cycles (seconds)
eviction_interval_secs: u64, eviction_interval_secs: u64,
/// Maximum cache tree size per tenant
max_tree_size: usize, max_tree_size: usize,
}, },
#[serde(rename = "power_of_two")] #[serde(rename = "power_of_two")]
PowerOfTwo { PowerOfTwo { load_check_interval_secs: u64 },
/// Interval for load monitoring (seconds)
load_check_interval_secs: u64,
},
} }
impl PolicyConfig { impl PolicyConfig {
...@@ -335,21 +279,17 @@ impl PolicyConfig { ...@@ -335,21 +279,17 @@ impl PolicyConfig {
/// Service discovery configuration /// Service discovery configuration
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiscoveryConfig { pub struct DiscoveryConfig {
/// Enable service discovery
pub enabled: bool, pub enabled: bool,
/// Kubernetes namespace (None = all namespaces) /// None = all namespaces
pub namespace: Option<String>, pub namespace: Option<String>,
/// Service discovery port
pub port: u16, pub port: u16,
/// Check interval for service discovery
pub check_interval_secs: u64, pub check_interval_secs: u64,
/// Regular mode selector /// Regular mode
pub selector: HashMap<String, String>, pub selector: HashMap<String, String>,
/// PD mode prefill selector /// PD mode prefill
pub prefill_selector: HashMap<String, String>, pub prefill_selector: HashMap<String, String>,
/// PD mode decode selector /// PD mode decode
pub decode_selector: HashMap<String, String>, pub decode_selector: HashMap<String, String>,
/// Bootstrap port annotation key
pub bootstrap_port_annotation: String, pub bootstrap_port_annotation: String,
} }
...@@ -371,16 +311,11 @@ impl Default for DiscoveryConfig { ...@@ -371,16 +311,11 @@ impl Default for DiscoveryConfig {
/// Retry configuration for request handling /// Retry configuration for request handling
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryConfig { pub struct RetryConfig {
/// Maximum number of retry attempts
pub max_retries: u32, pub max_retries: u32,
/// Initial backoff delay in milliseconds
pub initial_backoff_ms: u64, pub initial_backoff_ms: u64,
/// Maximum backoff delay in milliseconds
pub max_backoff_ms: u64, pub max_backoff_ms: u64,
/// Backoff multiplier for exponential backoff
pub backoff_multiplier: f32, pub backoff_multiplier: f32,
/// Jitter factor applied to backoff (0.0 - 1.0) /// D' = D * (1 + U[-j, +j]) where j is jitter factor
/// Effective delay D' = D * (1 + U[-j, +j])
#[serde(default = "default_retry_jitter_factor")] #[serde(default = "default_retry_jitter_factor")]
pub jitter_factor: f32, pub jitter_factor: f32,
} }
...@@ -404,15 +339,10 @@ fn default_retry_jitter_factor() -> f32 { ...@@ -404,15 +339,10 @@ fn default_retry_jitter_factor() -> f32 {
/// Health check configuration for worker monitoring /// Health check configuration for worker monitoring
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HealthCheckConfig { pub struct HealthCheckConfig {
/// Number of consecutive failures before marking unhealthy
pub failure_threshold: u32, pub failure_threshold: u32,
/// Number of consecutive successes before marking healthy
pub success_threshold: u32, pub success_threshold: u32,
/// Timeout for health check requests in seconds
pub timeout_secs: u64, pub timeout_secs: u64,
/// Interval between health checks in seconds
pub check_interval_secs: u64, pub check_interval_secs: u64,
/// Health check endpoint path
pub endpoint: String, pub endpoint: String,
} }
...@@ -431,13 +361,9 @@ impl Default for HealthCheckConfig { ...@@ -431,13 +361,9 @@ impl Default for HealthCheckConfig {
/// Circuit breaker configuration for worker reliability /// Circuit breaker configuration for worker reliability
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CircuitBreakerConfig { pub struct CircuitBreakerConfig {
/// Number of consecutive failures before opening circuit
pub failure_threshold: u32, pub failure_threshold: u32,
/// Number of consecutive successes before closing circuit
pub success_threshold: u32, pub success_threshold: u32,
/// Time before attempting to recover from open state (in seconds)
pub timeout_duration_secs: u64, pub timeout_duration_secs: u64,
/// Window duration for failure tracking (in seconds)
pub window_duration_secs: u64, pub window_duration_secs: u64,
} }
...@@ -455,9 +381,7 @@ impl Default for CircuitBreakerConfig { ...@@ -455,9 +381,7 @@ impl Default for CircuitBreakerConfig {
/// Metrics configuration /// Metrics configuration
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetricsConfig { pub struct MetricsConfig {
/// Prometheus metrics port
pub port: u16, pub port: u16,
/// Prometheus metrics host
pub host: String, pub host: String,
} }
......
...@@ -5,9 +5,7 @@ use crate::core::ConnectionMode; ...@@ -5,9 +5,7 @@ use crate::core::ConnectionMode;
pub struct ConfigValidator; pub struct ConfigValidator;
impl ConfigValidator { impl ConfigValidator {
/// Validate a complete router configuration
pub fn validate(config: &RouterConfig) -> ConfigResult<()> { pub fn validate(config: &RouterConfig) -> ConfigResult<()> {
// Check if service discovery is enabled
let has_service_discovery = config.discovery.as_ref().is_some_and(|d| d.enabled); let has_service_discovery = config.discovery.as_ref().is_some_and(|d| d.enabled);
Self::validate_mode(&config.mode, has_service_discovery)?; Self::validate_mode(&config.mode, has_service_discovery)?;
...@@ -24,55 +22,46 @@ impl ConfigValidator { ...@@ -24,55 +22,46 @@ impl ConfigValidator {
Self::validate_compatibility(config)?; Self::validate_compatibility(config)?;
// Validate effective retry/CB configs (respect disable flags)
let retry_cfg = config.effective_retry_config(); let retry_cfg = config.effective_retry_config();
let cb_cfg = config.effective_circuit_breaker_config(); let cb_cfg = config.effective_circuit_breaker_config();
Self::validate_retry(&retry_cfg)?; Self::validate_retry(&retry_cfg)?;
Self::validate_circuit_breaker(&cb_cfg)?; Self::validate_circuit_breaker(&cb_cfg)?;
// Validate Oracle configuration if enabled
if config.history_backend == HistoryBackend::Oracle { if config.history_backend == HistoryBackend::Oracle {
if config.oracle.is_none() { if config.oracle.is_none() {
return Err(ConfigError::MissingRequired { return Err(ConfigError::MissingRequired {
field: "oracle".to_string(), field: "oracle".to_string(),
}); });
} }
// Validate Oracle configuration details
if let Some(oracle) = &config.oracle { if let Some(oracle) = &config.oracle {
Self::validate_oracle(oracle)?; Self::validate_oracle(oracle)?;
} }
} }
// Validate tokenizer cache configuration
Self::validate_tokenizer_cache(&config.tokenizer_cache)?; Self::validate_tokenizer_cache(&config.tokenizer_cache)?;
Ok(()) Ok(())
} }
/// Validate Oracle configuration
fn validate_oracle(oracle: &OracleConfig) -> ConfigResult<()> { fn validate_oracle(oracle: &OracleConfig) -> ConfigResult<()> {
// Validate username is not empty
if oracle.username.is_empty() { if oracle.username.is_empty() {
return Err(ConfigError::MissingRequired { return Err(ConfigError::MissingRequired {
field: "oracle.username".to_string(), field: "oracle.username".to_string(),
}); });
} }
// Validate password is not empty
if oracle.password.is_empty() { if oracle.password.is_empty() {
return Err(ConfigError::MissingRequired { return Err(ConfigError::MissingRequired {
field: "oracle.password".to_string(), field: "oracle.password".to_string(),
}); });
} }
// Validate connect_descriptor is not empty
if oracle.connect_descriptor.is_empty() { if oracle.connect_descriptor.is_empty() {
return Err(ConfigError::MissingRequired { return Err(ConfigError::MissingRequired {
field: "oracle_dsn or oracle_tns_alias".to_string(), field: "oracle_dsn or oracle_tns_alias".to_string(),
}); });
} }
// Validate pool_min is at least 1
if oracle.pool_min < 1 { if oracle.pool_min < 1 {
return Err(ConfigError::InvalidValue { return Err(ConfigError::InvalidValue {
field: "oracle.pool_min".to_string(), field: "oracle.pool_min".to_string(),
...@@ -81,7 +70,6 @@ impl ConfigValidator { ...@@ -81,7 +70,6 @@ impl ConfigValidator {
}); });
} }
// Validate pool_max is greater than or equal to pool_min
if oracle.pool_max < oracle.pool_min { if oracle.pool_max < oracle.pool_min {
return Err(ConfigError::InvalidValue { return Err(ConfigError::InvalidValue {
field: "oracle.pool_max".to_string(), field: "oracle.pool_max".to_string(),
...@@ -90,7 +78,6 @@ impl ConfigValidator { ...@@ -90,7 +78,6 @@ impl ConfigValidator {
}); });
} }
// Validate pool_timeout_secs is greater than 0
if oracle.pool_timeout_secs == 0 { if oracle.pool_timeout_secs == 0 {
return Err(ConfigError::InvalidValue { return Err(ConfigError::InvalidValue {
field: "oracle.pool_timeout_secs".to_string(), field: "oracle.pool_timeout_secs".to_string(),
...@@ -102,17 +89,13 @@ impl ConfigValidator { ...@@ -102,17 +89,13 @@ impl ConfigValidator {
Ok(()) Ok(())
} }
/// Validate routing mode configuration
fn validate_mode(mode: &RoutingMode, has_service_discovery: bool) -> ConfigResult<()> { fn validate_mode(mode: &RoutingMode, has_service_discovery: bool) -> ConfigResult<()> {
match mode { match mode {
RoutingMode::Regular { worker_urls } => { RoutingMode::Regular { worker_urls } => {
// Validate URLs if any are provided
if !worker_urls.is_empty() { if !worker_urls.is_empty() {
Self::validate_urls(worker_urls)?; Self::validate_urls(worker_urls)?;
} }
// Note: We allow empty worker URLs even without service discovery // Allow empty URLs without service discovery to match legacy behavior
// to let the router start and fail at runtime when routing requests.
// This matches legacy behavior and test expectations.
} }
RoutingMode::PrefillDecode { RoutingMode::PrefillDecode {
prefill_urls, prefill_urls,
...@@ -120,7 +103,6 @@ impl ConfigValidator { ...@@ -120,7 +103,6 @@ impl ConfigValidator {
prefill_policy, prefill_policy,
decode_policy, decode_policy,
} => { } => {
// Only require URLs if service discovery is disabled
if !has_service_discovery { if !has_service_discovery {
if prefill_urls.is_empty() { if prefill_urls.is_empty() {
return Err(ConfigError::ValidationFailed { return Err(ConfigError::ValidationFailed {
...@@ -134,7 +116,6 @@ impl ConfigValidator { ...@@ -134,7 +116,6 @@ impl ConfigValidator {
} }
} }
// Validate URLs if any are provided
if !prefill_urls.is_empty() { if !prefill_urls.is_empty() {
let prefill_url_strings: Vec<String> = let prefill_url_strings: Vec<String> =
prefill_urls.iter().map(|(url, _)| url.clone()).collect(); prefill_urls.iter().map(|(url, _)| url.clone()).collect();
...@@ -144,7 +125,6 @@ impl ConfigValidator { ...@@ -144,7 +125,6 @@ impl ConfigValidator {
Self::validate_urls(decode_urls)?; Self::validate_urls(decode_urls)?;
} }
// Validate bootstrap ports
for (_url, port) in prefill_urls { for (_url, port) in prefill_urls {
if let Some(port) = port { if let Some(port) = port {
if *port == 0 { if *port == 0 {
...@@ -157,7 +137,6 @@ impl ConfigValidator { ...@@ -157,7 +137,6 @@ impl ConfigValidator {
} }
} }
// Validate optional prefill and decode policies
if let Some(p_policy) = prefill_policy { if let Some(p_policy) = prefill_policy {
Self::validate_policy(p_policy)?; Self::validate_policy(p_policy)?;
} }
...@@ -166,25 +145,20 @@ impl ConfigValidator { ...@@ -166,25 +145,20 @@ impl ConfigValidator {
} }
} }
RoutingMode::OpenAI { worker_urls } => { RoutingMode::OpenAI { worker_urls } => {
// Require at least one worker URL for OpenAI router
if worker_urls.is_empty() { if worker_urls.is_empty() {
return Err(ConfigError::ValidationFailed { return Err(ConfigError::ValidationFailed {
reason: "OpenAI mode requires at least one --worker-urls entry".to_string(), reason: "OpenAI mode requires at least one --worker-urls entry".to_string(),
}); });
} }
// Validate URLs
Self::validate_urls(worker_urls)?; Self::validate_urls(worker_urls)?;
} }
} }
Ok(()) Ok(())
} }
/// Validate policy configuration
fn validate_policy(policy: &PolicyConfig) -> ConfigResult<()> { fn validate_policy(policy: &PolicyConfig) -> ConfigResult<()> {
match policy { match policy {
PolicyConfig::Random | PolicyConfig::RoundRobin => { PolicyConfig::Random | PolicyConfig::RoundRobin => {}
// No specific validation needed
}
PolicyConfig::CacheAware { PolicyConfig::CacheAware {
cache_threshold, cache_threshold,
balance_abs_threshold: _, balance_abs_threshold: _,
...@@ -239,7 +213,6 @@ impl ConfigValidator { ...@@ -239,7 +213,6 @@ impl ConfigValidator {
Ok(()) Ok(())
} }
/// Validate server configuration
fn validate_server_settings(config: &RouterConfig) -> ConfigResult<()> { fn validate_server_settings(config: &RouterConfig) -> ConfigResult<()> {
if config.port == 0 { if config.port == 0 {
return Err(ConfigError::InvalidValue { return Err(ConfigError::InvalidValue {
...@@ -302,10 +275,9 @@ impl ConfigValidator { ...@@ -302,10 +275,9 @@ impl ConfigValidator {
Ok(()) Ok(())
} }
/// Validate service discovery configuration
fn validate_discovery(discovery: &DiscoveryConfig, mode: &RoutingMode) -> ConfigResult<()> { fn validate_discovery(discovery: &DiscoveryConfig, mode: &RoutingMode) -> ConfigResult<()> {
if !discovery.enabled { if !discovery.enabled {
return Ok(()); // No validation needed if disabled return Ok(());
} }
if discovery.port == 0 { if discovery.port == 0 {
...@@ -324,7 +296,6 @@ impl ConfigValidator { ...@@ -324,7 +296,6 @@ impl ConfigValidator {
}); });
} }
// Validate selectors based on mode
match mode { match mode {
RoutingMode::Regular { .. } => { RoutingMode::Regular { .. } => {
if discovery.selector.is_empty() { if discovery.selector.is_empty() {
...@@ -342,7 +313,6 @@ impl ConfigValidator { ...@@ -342,7 +313,6 @@ impl ConfigValidator {
} }
} }
RoutingMode::OpenAI { .. } => { RoutingMode::OpenAI { .. } => {
// OpenAI mode doesn't use service discovery
return Err(ConfigError::ValidationFailed { return Err(ConfigError::ValidationFailed {
reason: "OpenAI mode does not support service discovery".to_string(), reason: "OpenAI mode does not support service discovery".to_string(),
}); });
...@@ -352,7 +322,6 @@ impl ConfigValidator { ...@@ -352,7 +322,6 @@ impl ConfigValidator {
Ok(()) Ok(())
} }
/// Validate metrics configuration
fn validate_metrics(metrics: &MetricsConfig) -> ConfigResult<()> { fn validate_metrics(metrics: &MetricsConfig) -> ConfigResult<()> {
if metrics.port == 0 { if metrics.port == 0 {
return Err(ConfigError::InvalidValue { return Err(ConfigError::InvalidValue {
...@@ -373,7 +342,6 @@ impl ConfigValidator { ...@@ -373,7 +342,6 @@ impl ConfigValidator {
Ok(()) Ok(())
} }
/// Validate retry configuration
fn validate_retry(retry: &RetryConfig) -> ConfigResult<()> { fn validate_retry(retry: &RetryConfig) -> ConfigResult<()> {
if retry.max_retries < 1 { if retry.max_retries < 1 {
return Err(ConfigError::InvalidValue { return Err(ConfigError::InvalidValue {
...@@ -413,7 +381,6 @@ impl ConfigValidator { ...@@ -413,7 +381,6 @@ impl ConfigValidator {
Ok(()) Ok(())
} }
/// Validate circuit breaker configuration
fn validate_circuit_breaker(cb: &CircuitBreakerConfig) -> ConfigResult<()> { fn validate_circuit_breaker(cb: &CircuitBreakerConfig) -> ConfigResult<()> {
if cb.failure_threshold < 1 { if cb.failure_threshold < 1 {
return Err(ConfigError::InvalidValue { return Err(ConfigError::InvalidValue {
...@@ -446,9 +413,7 @@ impl ConfigValidator { ...@@ -446,9 +413,7 @@ impl ConfigValidator {
Ok(()) Ok(())
} }
/// Validate tokenizer cache configuration
fn validate_tokenizer_cache(cache: &TokenizerCacheConfig) -> ConfigResult<()> { fn validate_tokenizer_cache(cache: &TokenizerCacheConfig) -> ConfigResult<()> {
// Validate L0 max entries when L0 is enabled
if cache.enable_l0 && cache.l0_max_entries == 0 { if cache.enable_l0 && cache.l0_max_entries == 0 {
return Err(ConfigError::InvalidValue { return Err(ConfigError::InvalidValue {
field: "tokenizer_cache.l0_max_entries".to_string(), field: "tokenizer_cache.l0_max_entries".to_string(),
...@@ -457,7 +422,6 @@ impl ConfigValidator { ...@@ -457,7 +422,6 @@ impl ConfigValidator {
}); });
} }
// Validate L1 max memory when L1 is enabled
if cache.enable_l1 && cache.l1_max_memory == 0 { if cache.enable_l1 && cache.l1_max_memory == 0 {
return Err(ConfigError::InvalidValue { return Err(ConfigError::InvalidValue {
field: "tokenizer_cache.l1_max_memory".to_string(), field: "tokenizer_cache.l1_max_memory".to_string(),
...@@ -469,9 +433,7 @@ impl ConfigValidator { ...@@ -469,9 +433,7 @@ impl ConfigValidator {
Ok(()) Ok(())
} }
/// Validate mTLS certificate configuration
fn validate_mtls(config: &RouterConfig) -> ConfigResult<()> { fn validate_mtls(config: &RouterConfig) -> ConfigResult<()> {
// Validate that if we have client_identity, it's not empty
if let Some(identity) = &config.client_identity { if let Some(identity) = &config.client_identity {
if identity.is_empty() { if identity.is_empty() {
return Err(ConfigError::ValidationFailed { return Err(ConfigError::ValidationFailed {
...@@ -480,7 +442,6 @@ impl ConfigValidator { ...@@ -480,7 +442,6 @@ impl ConfigValidator {
} }
} }
// Validate CA certificates are not empty
for (idx, ca_cert) in config.ca_certificates.iter().enumerate() { for (idx, ca_cert) in config.ca_certificates.iter().enumerate() {
if ca_cert.is_empty() { if ca_cert.is_empty() {
return Err(ConfigError::ValidationFailed { return Err(ConfigError::ValidationFailed {
...@@ -492,14 +453,11 @@ impl ConfigValidator { ...@@ -492,14 +453,11 @@ impl ConfigValidator {
Ok(()) Ok(())
} }
/// Validate compatibility between different configuration sections
fn validate_compatibility(config: &RouterConfig) -> ConfigResult<()> { fn validate_compatibility(config: &RouterConfig) -> ConfigResult<()> {
// IGW mode is independent - skip other compatibility checks when enabled
if config.enable_igw { if config.enable_igw {
return Ok(()); return Ok(());
} }
// Validate gRPC connection mode requires tokenizer configuration
if matches!(config.connection_mode, ConnectionMode::Grpc { .. }) if matches!(config.connection_mode, ConnectionMode::Grpc { .. })
&& config.tokenizer_path.is_none() && config.tokenizer_path.is_none()
&& config.model_path.is_none() && config.model_path.is_none()
...@@ -509,18 +467,11 @@ impl ConfigValidator { ...@@ -509,18 +467,11 @@ impl ConfigValidator {
}); });
} }
// Validate mTLS configuration
Self::validate_mtls(config)?; Self::validate_mtls(config)?;
// All policies are now supported for both router types thanks to the unified trait design
// No mode/policy restrictions needed anymore
// Check if service discovery is enabled for worker count validation
let has_service_discovery = config.discovery.as_ref().is_some_and(|d| d.enabled); let has_service_discovery = config.discovery.as_ref().is_some_and(|d| d.enabled);
// Only validate worker counts if service discovery is disabled
if !has_service_discovery { if !has_service_discovery {
// Check if power-of-two policy makes sense with insufficient workers
if let PolicyConfig::PowerOfTwo { .. } = &config.policy { if let PolicyConfig::PowerOfTwo { .. } = &config.policy {
let worker_count = config.mode.worker_count(); let worker_count = config.mode.worker_count();
if worker_count < 2 { if worker_count < 2 {
...@@ -530,7 +481,6 @@ impl ConfigValidator { ...@@ -530,7 +481,6 @@ impl ConfigValidator {
} }
} }
// For PD mode, validate that policies have sufficient workers
if let RoutingMode::PrefillDecode { if let RoutingMode::PrefillDecode {
prefill_urls, prefill_urls,
decode_urls, decode_urls,
...@@ -538,7 +488,6 @@ impl ConfigValidator { ...@@ -538,7 +488,6 @@ impl ConfigValidator {
decode_policy, decode_policy,
} = &config.mode } = &config.mode
{ {
// Check power-of-two for prefill
if let Some(PolicyConfig::PowerOfTwo { .. }) = prefill_policy { if let Some(PolicyConfig::PowerOfTwo { .. }) = prefill_policy {
if prefill_urls.len() < 2 { if prefill_urls.len() < 2 {
return Err(ConfigError::IncompatibleConfig { return Err(ConfigError::IncompatibleConfig {
...@@ -547,7 +496,6 @@ impl ConfigValidator { ...@@ -547,7 +496,6 @@ impl ConfigValidator {
} }
} }
// Check power-of-two for decode
if let Some(PolicyConfig::PowerOfTwo { .. }) = decode_policy { if let Some(PolicyConfig::PowerOfTwo { .. }) = decode_policy {
if decode_urls.len() < 2 { if decode_urls.len() < 2 {
return Err(ConfigError::IncompatibleConfig { return Err(ConfigError::IncompatibleConfig {
...@@ -560,8 +508,6 @@ impl ConfigValidator { ...@@ -560,8 +508,6 @@ impl ConfigValidator {
} }
} }
// Service discovery is conflict with dp_aware routing for now
// since it's not fully supported yet
if has_service_discovery && config.dp_aware { if has_service_discovery && config.dp_aware {
return Err(ConfigError::IncompatibleConfig { return Err(ConfigError::IncompatibleConfig {
reason: "DP-aware routing is not compatible with service discovery".to_string(), reason: "DP-aware routing is not compatible with service discovery".to_string(),
...@@ -571,7 +517,6 @@ impl ConfigValidator { ...@@ -571,7 +517,6 @@ impl ConfigValidator {
Ok(()) Ok(())
} }
/// Validate URL format
fn validate_urls(urls: &[String]) -> ConfigResult<()> { fn validate_urls(urls: &[String]) -> ConfigResult<()> {
for url in urls { for url in urls {
if url.is_empty() { if url.is_empty() {
...@@ -593,10 +538,8 @@ impl ConfigValidator { ...@@ -593,10 +538,8 @@ impl ConfigValidator {
}); });
} }
// Basic URL validation
match ::url::Url::parse(url) { match ::url::Url::parse(url) {
Ok(parsed) => { Ok(parsed) => {
// Additional validation
if parsed.host_str().is_none() { if parsed.host_str().is_none() {
return Err(ConfigError::InvalidValue { return Err(ConfigError::InvalidValue {
field: "worker_url".to_string(), field: "worker_url".to_string(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment