Unverified Commit a7fe6e10 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] remove old/oudated/useless comments (#10967)

parent be059b83
...@@ -67,58 +67,47 @@ struct Router { ...@@ -67,58 +67,47 @@ struct Router {
decode_policy: Option<PolicyType>, decode_policy: Option<PolicyType>,
max_concurrent_requests: usize, max_concurrent_requests: usize,
cors_allowed_origins: Vec<String>, cors_allowed_origins: Vec<String>,
// Retry configuration
retry_max_retries: u32, retry_max_retries: u32,
retry_initial_backoff_ms: u64, retry_initial_backoff_ms: u64,
retry_max_backoff_ms: u64, retry_max_backoff_ms: u64,
retry_backoff_multiplier: f32, retry_backoff_multiplier: f32,
retry_jitter_factor: f32, retry_jitter_factor: f32,
disable_retries: bool, disable_retries: bool,
// Circuit breaker configuration
cb_failure_threshold: u32, cb_failure_threshold: u32,
cb_success_threshold: u32, cb_success_threshold: u32,
cb_timeout_duration_secs: u64, cb_timeout_duration_secs: u64,
cb_window_duration_secs: u64, cb_window_duration_secs: u64,
disable_circuit_breaker: bool, disable_circuit_breaker: bool,
// Health check configuration
health_failure_threshold: u32, health_failure_threshold: u32,
health_success_threshold: u32, health_success_threshold: u32,
health_check_timeout_secs: u64, health_check_timeout_secs: u64,
health_check_interval_secs: u64, health_check_interval_secs: u64,
health_check_endpoint: String, health_check_endpoint: String,
// IGW (Inference Gateway) configuration
enable_igw: bool, enable_igw: bool,
queue_size: usize, queue_size: usize,
queue_timeout_secs: u64, queue_timeout_secs: u64,
rate_limit_tokens_per_second: Option<usize>, rate_limit_tokens_per_second: Option<usize>,
// Connection mode (determined from worker URLs)
connection_mode: config::ConnectionMode, connection_mode: config::ConnectionMode,
// Model path for tokenizer
model_path: Option<String>, model_path: Option<String>,
// Explicit tokenizer path
tokenizer_path: Option<String>, tokenizer_path: Option<String>,
} }
impl Router { impl Router {
/// Determine connection mode from worker URLs /// Determine connection mode from worker URLs
fn determine_connection_mode(worker_urls: &[String]) -> config::ConnectionMode { fn determine_connection_mode(worker_urls: &[String]) -> config::ConnectionMode {
// Only consider it gRPC if explicitly specified with grpc:// or grpcs:// scheme
for url in worker_urls { for url in worker_urls {
if url.starts_with("grpc://") || url.starts_with("grpcs://") { if url.starts_with("grpc://") || url.starts_with("grpcs://") {
return config::ConnectionMode::Grpc; return config::ConnectionMode::Grpc;
} }
} }
// Default to HTTP for all other cases (including http://, https://, or no scheme)
config::ConnectionMode::Http config::ConnectionMode::Http
} }
/// Convert PyO3 Router to RouterConfig
pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> { pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
use config::{ use config::{
DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode, DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
}; };
// Convert policy helper function
let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig { let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
match policy { match policy {
PolicyType::Random => ConfigPolicyConfig::Random, PolicyType::Random => ConfigPolicyConfig::Random,
...@@ -131,14 +120,12 @@ impl Router { ...@@ -131,14 +120,12 @@ impl Router {
max_tree_size: self.max_tree_size, max_tree_size: self.max_tree_size,
}, },
PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo { PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
load_check_interval_secs: 5, // Default value load_check_interval_secs: 5,
}, },
} }
}; };
// Determine routing mode
let mode = if self.enable_igw { let mode = if self.enable_igw {
// IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
RoutingMode::Regular { RoutingMode::Regular {
worker_urls: vec![], worker_urls: vec![],
} }
...@@ -155,10 +142,8 @@ impl Router { ...@@ -155,10 +142,8 @@ impl Router {
} }
}; };
// Convert main policy
let policy = convert_policy(&self.policy); let policy = convert_policy(&self.policy);
// Service discovery configuration
let discovery = if self.service_discovery { let discovery = if self.service_discovery {
Some(DiscoveryConfig { Some(DiscoveryConfig {
enabled: true, enabled: true,
...@@ -174,7 +159,6 @@ impl Router { ...@@ -174,7 +159,6 @@ impl Router {
None None
}; };
// Metrics configuration
let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) { let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
(Some(port), Some(host)) => Some(MetricsConfig { (Some(port), Some(host)) => Some(MetricsConfig {
port, port,
...@@ -251,7 +235,7 @@ impl Router { ...@@ -251,7 +235,7 @@ impl Router {
balance_rel_threshold = 1.5, balance_rel_threshold = 1.5,
eviction_interval_secs = 120, eviction_interval_secs = 120,
max_tree_size = 2usize.pow(26), max_tree_size = 2usize.pow(26),
max_payload_size = 512 * 1024 * 1024, // 512MB default for large batches max_payload_size = 512 * 1024 * 1024,
dp_aware = false, dp_aware = false,
api_key = None, api_key = None,
log_dir = None, log_dir = None,
...@@ -265,40 +249,35 @@ impl Router { ...@@ -265,40 +249,35 @@ impl Router {
bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"), bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
prometheus_port = None, prometheus_port = None,
prometheus_host = None, prometheus_host = None,
request_timeout_secs = 1800, // Add configurable request timeout request_timeout_secs = 1800,
request_id_headers = None, // Custom request ID headers request_id_headers = None,
pd_disaggregation = false, // New flag for PD mode pd_disaggregation = false,
prefill_urls = None, prefill_urls = None,
decode_urls = None, decode_urls = None,
prefill_policy = None, prefill_policy = None,
decode_policy = None, decode_policy = None,
max_concurrent_requests = 256, max_concurrent_requests = 256,
cors_allowed_origins = vec![], cors_allowed_origins = vec![],
// Retry defaults
retry_max_retries = 5, retry_max_retries = 5,
retry_initial_backoff_ms = 50, retry_initial_backoff_ms = 50,
retry_max_backoff_ms = 30_000, retry_max_backoff_ms = 30_000,
retry_backoff_multiplier = 1.5, retry_backoff_multiplier = 1.5,
retry_jitter_factor = 0.2, retry_jitter_factor = 0.2,
disable_retries = false, disable_retries = false,
// Circuit breaker defaults
cb_failure_threshold = 10, cb_failure_threshold = 10,
cb_success_threshold = 3, cb_success_threshold = 3,
cb_timeout_duration_secs = 60, cb_timeout_duration_secs = 60,
cb_window_duration_secs = 120, cb_window_duration_secs = 120,
disable_circuit_breaker = false, disable_circuit_breaker = false,
// Health check defaults
health_failure_threshold = 3, health_failure_threshold = 3,
health_success_threshold = 2, health_success_threshold = 2,
health_check_timeout_secs = 5, health_check_timeout_secs = 5,
health_check_interval_secs = 60, health_check_interval_secs = 60,
health_check_endpoint = String::from("/health"), health_check_endpoint = String::from("/health"),
// IGW defaults
enable_igw = false, enable_igw = false,
queue_size = 100, queue_size = 100,
queue_timeout_secs = 60, queue_timeout_secs = 60,
rate_limit_tokens_per_second = None, rate_limit_tokens_per_second = None,
// Tokenizer defaults
model_path = None, model_path = None,
tokenizer_path = None, tokenizer_path = None,
))] ))]
...@@ -361,17 +340,14 @@ impl Router { ...@@ -361,17 +340,14 @@ impl Router {
model_path: Option<String>, model_path: Option<String>,
tokenizer_path: Option<String>, tokenizer_path: Option<String>,
) -> PyResult<Self> { ) -> PyResult<Self> {
// Determine connection mode from worker URLs
let mut all_urls = worker_urls.clone(); let mut all_urls = worker_urls.clone();
// Add prefill URLs if in PD mode
if let Some(ref prefill_urls) = prefill_urls { if let Some(ref prefill_urls) = prefill_urls {
for (url, _) in prefill_urls { for (url, _) in prefill_urls {
all_urls.push(url.clone()); all_urls.push(url.clone());
} }
} }
// Add decode URLs if in PD mode
if let Some(ref decode_urls) = decode_urls { if let Some(ref decode_urls) = decode_urls {
all_urls.extend(decode_urls.clone()); all_urls.extend(decode_urls.clone());
} }
...@@ -440,12 +416,10 @@ impl Router { ...@@ -440,12 +416,10 @@ impl Router {
} }
fn start(&self) -> PyResult<()> { fn start(&self) -> PyResult<()> {
// Convert to RouterConfig and validate
let router_config = self.to_router_config().map_err(|e| { let router_config = self.to_router_config().map_err(|e| {
pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e)) pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
})?; })?;
// Validate the configuration
router_config.validate().map_err(|e| { router_config.validate().map_err(|e| {
pyo3::exceptions::PyValueError::new_err(format!( pyo3::exceptions::PyValueError::new_err(format!(
"Configuration validation failed: {}", "Configuration validation failed: {}",
...@@ -453,7 +427,6 @@ impl Router { ...@@ -453,7 +427,6 @@ impl Router {
)) ))
})?; })?;
// Create service discovery config if enabled
let service_discovery_config = if self.service_discovery { let service_discovery_config = if self.service_discovery {
Some(service_discovery::ServiceDiscoveryConfig { Some(service_discovery::ServiceDiscoveryConfig {
enabled: true, enabled: true,
...@@ -470,7 +443,6 @@ impl Router { ...@@ -470,7 +443,6 @@ impl Router {
None None
}; };
// Create Prometheus config if enabled
let prometheus_config = Some(PrometheusConfig { let prometheus_config = Some(PrometheusConfig {
port: self.prometheus_port.unwrap_or(29000), port: self.prometheus_port.unwrap_or(29000),
host: self host: self
...@@ -479,11 +451,9 @@ impl Router { ...@@ -479,11 +451,9 @@ impl Router {
.unwrap_or_else(|| "127.0.0.1".to_string()), .unwrap_or_else(|| "127.0.0.1".to_string()),
}); });
// Use tokio runtime instead of actix-web System for better compatibility
let runtime = tokio::runtime::Runtime::new() let runtime = tokio::runtime::Runtime::new()
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
// Block on the async startup function
runtime.block_on(async move { runtime.block_on(async move {
server::startup(server::ServerConfig { server::startup(server::ServerConfig {
host: self.host.clone(), host: self.host.clone(),
......
...@@ -8,20 +8,13 @@ use tracing_subscriber::layer::SubscriberExt; ...@@ -8,20 +8,13 @@ use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer}; use tracing_subscriber::{EnvFilter, Layer};
/// Configuration for the logging system
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct LoggingConfig { pub struct LoggingConfig {
/// Log level for the application (default: INFO)
pub level: Level, pub level: Level,
/// Whether to use json format for logs (default: false)
pub json_format: bool, pub json_format: bool,
/// Path to store log files. If None, logs will only go to stdout/stderr
pub log_dir: Option<String>, pub log_dir: Option<String>,
/// Whether to colorize logs when output is a terminal (default: true)
pub colorize: bool, pub colorize: bool,
/// Log file name to use if log_dir is specified (default: "sgl-router")
pub log_file_name: String, pub log_file_name: String,
/// Custom log targets to filter (default: "sglang_router_rs")
pub log_targets: Option<Vec<String>>, pub log_targets: Option<Vec<String>>,
} }
...@@ -38,30 +31,14 @@ impl Default for LoggingConfig { ...@@ -38,30 +31,14 @@ impl Default for LoggingConfig {
} }
} }
/// Guard that keeps the file appender worker thread alive
///
/// This must be kept in scope for the duration of the program
/// to ensure logs are properly written to files
#[allow(dead_code)] #[allow(dead_code)]
pub struct LogGuard { pub struct LogGuard {
_file_guard: Option<WorkerGuard>, _file_guard: Option<WorkerGuard>,
} }
/// Initialize the logging system with the given configuration
///
/// # Arguments
/// * `config` - Configuration for the logging system
///
/// # Returns
/// A LogGuard that must be kept alive for the duration of the program
///
/// # Panics
/// Will not panic, as initialization errors are handled gracefully
pub fn init_logging(config: LoggingConfig) -> LogGuard { pub fn init_logging(config: LoggingConfig) -> LogGuard {
// Forward logs to tracing - ignore errors to allow for multiple initialization
let _ = LogTracer::init(); let _ = LogTracer::init();
// Convert log level to filter string
let level_filter = match config.level { let level_filter = match config.level {
Level::TRACE => "trace", Level::TRACE => "trace",
Level::DEBUG => "debug", Level::DEBUG => "debug",
...@@ -70,9 +47,7 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard { ...@@ -70,9 +47,7 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
Level::ERROR => "error", Level::ERROR => "error",
}; };
// Create env filter
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| { let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| {
// Format: <target>=<level>,<target2>=<level2>,...
let filter_string = if let Some(targets) = &config.log_targets { let filter_string = if let Some(targets) = &config.log_targets {
targets targets
.iter() .iter()
...@@ -92,13 +67,10 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard { ...@@ -92,13 +67,10 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
EnvFilter::new(filter_string) EnvFilter::new(filter_string)
}); });
// Setup stdout/stderr layer
let mut layers = Vec::new(); let mut layers = Vec::new();
// Standard timestamp format: YYYY-MM-DD HH:MM:SS
let time_format = "%Y-%m-%d %H:%M:%S".to_string(); let time_format = "%Y-%m-%d %H:%M:%S".to_string();
// Configure the console stdout layer
let stdout_layer = tracing_subscriber::fmt::layer() let stdout_layer = tracing_subscriber::fmt::layer()
.with_ansi(config.colorize) .with_ansi(config.colorize)
.with_file(true) .with_file(true)
...@@ -113,14 +85,12 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard { ...@@ -113,14 +85,12 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
layers.push(stdout_layer); layers.push(stdout_layer);
// Create a file appender if log_dir is specified
let mut file_guard = None; let mut file_guard = None;
if let Some(log_dir) = &config.log_dir { if let Some(log_dir) = &config.log_dir {
let file_name = config.log_file_name.clone(); let file_name = config.log_file_name.clone();
let log_dir = PathBuf::from(log_dir); let log_dir = PathBuf::from(log_dir);
// Create log directory if it doesn't exist
if !log_dir.exists() { if !log_dir.exists() {
if let Err(e) = std::fs::create_dir_all(&log_dir) { if let Err(e) = std::fs::create_dir_all(&log_dir) {
eprintln!("Failed to create log directory: {}", e); eprintln!("Failed to create log directory: {}", e);
...@@ -134,7 +104,7 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard { ...@@ -134,7 +104,7 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
file_guard = Some(guard); file_guard = Some(guard);
let file_layer = tracing_subscriber::fmt::layer() let file_layer = tracing_subscriber::fmt::layer()
.with_ansi(false) // Never use ANSI colors in log files .with_ansi(false)
.with_file(true) .with_file(true)
.with_line_number(true) .with_line_number(true)
.with_timer(ChronoUtc::new(time_format)) .with_timer(ChronoUtc::new(time_format))
...@@ -149,14 +119,11 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard { ...@@ -149,14 +119,11 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
layers.push(file_layer); layers.push(file_layer);
} }
// Initialize the subscriber with all layers
// Use try_init to handle errors gracefully in case another subscriber is already set
let _ = tracing_subscriber::registry() let _ = tracing_subscriber::registry()
.with(env_filter) .with(env_filter)
.with(layers) .with(layers)
.try_init(); .try_init();
// Return the guard to keep the file appender worker thread alive
LogGuard { LogGuard {
_file_guard: file_guard, _file_guard: file_guard,
} }
......
...@@ -9,7 +9,6 @@ use sglang_router_rs::server::{self, ServerConfig}; ...@@ -9,7 +9,6 @@ use sglang_router_rs::server::{self, ServerConfig};
use sglang_router_rs::service_discovery::ServiceDiscoveryConfig; use sglang_router_rs::service_discovery::ServiceDiscoveryConfig;
use std::collections::HashMap; use std::collections::HashMap;
// Helper function to parse prefill arguments from command line
fn parse_prefill_args() -> Vec<(String, Option<u16>)> { fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
let args: Vec<String> = std::env::args().collect(); let args: Vec<String> = std::env::args().collect();
let mut prefill_entries = Vec::new(); let mut prefill_entries = Vec::new();
...@@ -19,12 +18,11 @@ fn parse_prefill_args() -> Vec<(String, Option<u16>)> { ...@@ -19,12 +18,11 @@ fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
if args[i] == "--prefill" && i + 1 < args.len() { if args[i] == "--prefill" && i + 1 < args.len() {
let url = args[i + 1].clone(); let url = args[i + 1].clone();
let bootstrap_port = if i + 2 < args.len() && !args[i + 2].starts_with("--") { let bootstrap_port = if i + 2 < args.len() && !args[i + 2].starts_with("--") {
// Check if next arg is a port number
if let Ok(port) = args[i + 2].parse::<u16>() { if let Ok(port) = args[i + 2].parse::<u16>() {
i += 1; // Skip the port argument i += 1;
Some(port) Some(port)
} else if args[i + 2].to_lowercase() == "none" { } else if args[i + 2].to_lowercase() == "none" {
i += 1; // Skip the "none" argument i += 1;
None None
} else { } else {
None None
...@@ -33,7 +31,7 @@ fn parse_prefill_args() -> Vec<(String, Option<u16>)> { ...@@ -33,7 +31,7 @@ fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
None None
}; };
prefill_entries.push((url, bootstrap_port)); prefill_entries.push((url, bootstrap_port));
i += 2; // Skip --prefill and URL i += 2;
} else { } else {
i += 1; i += 1;
} }
...@@ -101,252 +99,186 @@ Examples: ...@@ -101,252 +99,186 @@ Examples:
"#)] "#)]
struct CliArgs { struct CliArgs {
/// Host address to bind the router server
#[arg(long, default_value = "127.0.0.1")] #[arg(long, default_value = "127.0.0.1")]
host: String, host: String,
/// Port number to bind the router server
#[arg(long, default_value_t = 30000)] #[arg(long, default_value_t = 30000)]
port: u16, port: u16,
/// List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)
#[arg(long, num_args = 0..)] #[arg(long, num_args = 0..)]
worker_urls: Vec<String>, worker_urls: Vec<String>,
/// Load balancing policy to use
#[arg(long, default_value = "cache_aware", value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])] #[arg(long, default_value = "cache_aware", value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
policy: String, policy: String,
/// Enable PD (Prefill-Decode) disaggregated mode
#[arg(long, default_value_t = false)] #[arg(long, default_value_t = false)]
pd_disaggregation: bool, pd_disaggregation: bool,
/// Decode server URL (can be specified multiple times)
#[arg(long, action = ArgAction::Append)] #[arg(long, action = ArgAction::Append)]
decode: Vec<String>, decode: Vec<String>,
/// Specific policy for prefill nodes in PD mode
#[arg(long, value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])] #[arg(long, value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
prefill_policy: Option<String>, prefill_policy: Option<String>,
/// Specific policy for decode nodes in PD mode
#[arg(long, value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])] #[arg(long, value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
decode_policy: Option<String>, decode_policy: Option<String>,
/// Timeout in seconds for worker startup
#[arg(long, default_value_t = 600)] #[arg(long, default_value_t = 600)]
worker_startup_timeout_secs: u64, worker_startup_timeout_secs: u64,
/// Interval in seconds between checks for worker startup
#[arg(long, default_value_t = 30)] #[arg(long, default_value_t = 30)]
worker_startup_check_interval: u64, worker_startup_check_interval: u64,
/// Cache threshold (0.0-1.0) for cache-aware routing
#[arg(long, default_value_t = 0.3)] #[arg(long, default_value_t = 0.3)]
cache_threshold: f32, cache_threshold: f32,
/// Absolute threshold for load balancing
#[arg(long, default_value_t = 64)] #[arg(long, default_value_t = 64)]
balance_abs_threshold: usize, balance_abs_threshold: usize,
/// Relative threshold for load balancing
#[arg(long, default_value_t = 1.5)] #[arg(long, default_value_t = 1.5)]
balance_rel_threshold: f32, balance_rel_threshold: f32,
/// Interval in seconds between cache eviction operations
#[arg(long, default_value_t = 120)] #[arg(long, default_value_t = 120)]
eviction_interval: u64, eviction_interval: u64,
/// Maximum size of the approximation tree for cache-aware routing #[arg(long, default_value_t = 67108864)]
#[arg(long, default_value_t = 67108864)] // 2^26
max_tree_size: usize, max_tree_size: usize,
/// Maximum payload size in bytes #[arg(long, default_value_t = 536870912)]
#[arg(long, default_value_t = 536870912)] // 512MB
max_payload_size: usize, max_payload_size: usize,
/// Enable data parallelism aware schedule
#[arg(long, default_value_t = false)] #[arg(long, default_value_t = false)]
dp_aware: bool, dp_aware: bool,
/// API key for worker authorization
#[arg(long)] #[arg(long)]
api_key: Option<String>, api_key: Option<String>,
/// Backend to route requests to (sglang, vllm, trtllm, openai, anthropic)
#[arg(long, value_enum, default_value_t = Backend::Sglang, alias = "runtime")] #[arg(long, value_enum, default_value_t = Backend::Sglang, alias = "runtime")]
backend: Backend, backend: Backend,
/// Directory to store log files
#[arg(long)] #[arg(long)]
log_dir: Option<String>, log_dir: Option<String>,
/// Set the logging level
#[arg(long, default_value = "info", value_parser = ["debug", "info", "warn", "error"])] #[arg(long, default_value = "info", value_parser = ["debug", "info", "warn", "error"])]
log_level: String, log_level: String,
/// Enable Kubernetes service discovery
#[arg(long, default_value_t = false)] #[arg(long, default_value_t = false)]
service_discovery: bool, service_discovery: bool,
/// Label selector for Kubernetes service discovery (format: key1=value1 key2=value2)
#[arg(long, num_args = 0..)] #[arg(long, num_args = 0..)]
selector: Vec<String>, selector: Vec<String>,
/// Port to use for discovered worker pods
#[arg(long, default_value_t = 80)] #[arg(long, default_value_t = 80)]
service_discovery_port: u16, service_discovery_port: u16,
/// Kubernetes namespace to watch for pods
#[arg(long)] #[arg(long)]
service_discovery_namespace: Option<String>, service_discovery_namespace: Option<String>,
/// Label selector for prefill server pods in PD mode
#[arg(long, num_args = 0..)] #[arg(long, num_args = 0..)]
prefill_selector: Vec<String>, prefill_selector: Vec<String>,
/// Label selector for decode server pods in PD mode
#[arg(long, num_args = 0..)] #[arg(long, num_args = 0..)]
decode_selector: Vec<String>, decode_selector: Vec<String>,
/// Port to expose Prometheus metrics
#[arg(long, default_value_t = 29000)] #[arg(long, default_value_t = 29000)]
prometheus_port: u16, prometheus_port: u16,
/// Host address to bind the Prometheus metrics server
#[arg(long, default_value = "127.0.0.1")] #[arg(long, default_value = "127.0.0.1")]
prometheus_host: String, prometheus_host: String,
/// Custom HTTP headers to check for request IDs
#[arg(long, num_args = 0..)] #[arg(long, num_args = 0..)]
request_id_headers: Vec<String>, request_id_headers: Vec<String>,
/// Request timeout in seconds
#[arg(long, default_value_t = 1800)] #[arg(long, default_value_t = 1800)]
request_timeout_secs: u64, request_timeout_secs: u64,
/// Maximum number of concurrent requests allowed
#[arg(long, default_value_t = 256)] #[arg(long, default_value_t = 256)]
max_concurrent_requests: usize, max_concurrent_requests: usize,
/// CORS allowed origins
#[arg(long, num_args = 0..)] #[arg(long, num_args = 0..)]
cors_allowed_origins: Vec<String>, cors_allowed_origins: Vec<String>,
// Retry configuration
/// Maximum number of retries
#[arg(long, default_value_t = 5)] #[arg(long, default_value_t = 5)]
retry_max_retries: u32, retry_max_retries: u32,
/// Initial backoff in milliseconds for retries
#[arg(long, default_value_t = 50)] #[arg(long, default_value_t = 50)]
retry_initial_backoff_ms: u64, retry_initial_backoff_ms: u64,
/// Maximum backoff in milliseconds for retries
#[arg(long, default_value_t = 30000)] #[arg(long, default_value_t = 30000)]
retry_max_backoff_ms: u64, retry_max_backoff_ms: u64,
/// Backoff multiplier for exponential backoff
#[arg(long, default_value_t = 1.5)] #[arg(long, default_value_t = 1.5)]
retry_backoff_multiplier: f32, retry_backoff_multiplier: f32,
/// Jitter factor for retry backoff
#[arg(long, default_value_t = 0.2)] #[arg(long, default_value_t = 0.2)]
retry_jitter_factor: f32, retry_jitter_factor: f32,
/// Disable retries
#[arg(long, default_value_t = false)] #[arg(long, default_value_t = false)]
disable_retries: bool, disable_retries: bool,
// Circuit breaker configuration
/// Number of failures before circuit breaker opens
#[arg(long, default_value_t = 10)] #[arg(long, default_value_t = 10)]
cb_failure_threshold: u32, cb_failure_threshold: u32,
/// Number of successes before circuit breaker closes
#[arg(long, default_value_t = 3)] #[arg(long, default_value_t = 3)]
cb_success_threshold: u32, cb_success_threshold: u32,
/// Timeout duration in seconds for circuit breaker
#[arg(long, default_value_t = 60)] #[arg(long, default_value_t = 60)]
cb_timeout_duration_secs: u64, cb_timeout_duration_secs: u64,
/// Window duration in seconds for circuit breaker
#[arg(long, default_value_t = 120)] #[arg(long, default_value_t = 120)]
cb_window_duration_secs: u64, cb_window_duration_secs: u64,
/// Disable circuit breaker
#[arg(long, default_value_t = false)] #[arg(long, default_value_t = false)]
disable_circuit_breaker: bool, disable_circuit_breaker: bool,
// Health check configuration
/// Number of consecutive health check failures before marking worker unhealthy
#[arg(long, default_value_t = 3)] #[arg(long, default_value_t = 3)]
health_failure_threshold: u32, health_failure_threshold: u32,
/// Number of consecutive health check successes before marking worker healthy
#[arg(long, default_value_t = 2)] #[arg(long, default_value_t = 2)]
health_success_threshold: u32, health_success_threshold: u32,
/// Timeout in seconds for health check requests
#[arg(long, default_value_t = 5)] #[arg(long, default_value_t = 5)]
health_check_timeout_secs: u64, health_check_timeout_secs: u64,
/// Interval in seconds between runtime health checks
#[arg(long, default_value_t = 60)] #[arg(long, default_value_t = 60)]
health_check_interval_secs: u64, health_check_interval_secs: u64,
/// Health check endpoint path
#[arg(long, default_value = "/health")] #[arg(long, default_value = "/health")]
health_check_endpoint: String, health_check_endpoint: String,
// IGW (Inference Gateway) configuration
/// Enable Inference Gateway mode
#[arg(long, default_value_t = false)] #[arg(long, default_value_t = false)]
enable_igw: bool, enable_igw: bool,
// Tokenizer configuration
/// Model path for loading tokenizer (HuggingFace model ID or local path)
#[arg(long)] #[arg(long)]
model_path: Option<String>, model_path: Option<String>,
/// Explicit tokenizer path (overrides model_path tokenizer if provided)
#[arg(long)] #[arg(long)]
tokenizer_path: Option<String>, tokenizer_path: Option<String>,
/// History backend configuration (memory, none, or oracle)
#[arg(long, default_value = "memory", value_parser = ["memory", "none", "oracle"])] #[arg(long, default_value = "memory", value_parser = ["memory", "none", "oracle"])]
history_backend: String, history_backend: String,
/// Directory containing the Oracle ATP wallet/config files (optional)
#[arg(long, env = "ATP_WALLET_PATH")] #[arg(long, env = "ATP_WALLET_PATH")]
oracle_wallet_path: Option<String>, oracle_wallet_path: Option<String>,
/// Wallet TNS alias to use (e.g. `<db_name>_low`)
#[arg(long, env = "ATP_TNS_ALIAS")] #[arg(long, env = "ATP_TNS_ALIAS")]
oracle_tns_alias: Option<String>, oracle_tns_alias: Option<String>,
/// Oracle connection descriptor / DSN (e.g. `tcps://host:port/service_name`)
#[arg(long, env = "ATP_DSN")] #[arg(long, env = "ATP_DSN")]
oracle_dsn: Option<String>, oracle_dsn: Option<String>,
/// Oracle ATP username
#[arg(long, env = "ATP_USER")] #[arg(long, env = "ATP_USER")]
oracle_user: Option<String>, oracle_user: Option<String>,
/// Oracle ATP password
#[arg(long, env = "ATP_PASSWORD")] #[arg(long, env = "ATP_PASSWORD")]
oracle_password: Option<String>, oracle_password: Option<String>,
/// Minimum number of pooled ATP connections (defaults to 1 when omitted)
#[arg(long, env = "ATP_POOL_MIN")] #[arg(long, env = "ATP_POOL_MIN")]
oracle_pool_min: Option<usize>, oracle_pool_min: Option<usize>,
/// Maximum number of pooled ATP connections (defaults to 16 when omitted)
#[arg(long, env = "ATP_POOL_MAX")] #[arg(long, env = "ATP_POOL_MAX")]
oracle_pool_max: Option<usize>, oracle_pool_max: Option<usize>,
/// Connection acquisition timeout in seconds (defaults to 30 when omitted)
#[arg(long, env = "ATP_POOL_TIMEOUT_SECS")] #[arg(long, env = "ATP_POOL_TIMEOUT_SECS")]
oracle_pool_timeout_secs: Option<u64>, oracle_pool_timeout_secs: Option<u64>,
} }
...@@ -357,19 +289,15 @@ enum OracleConnectSource { ...@@ -357,19 +289,15 @@ enum OracleConnectSource {
} }
impl CliArgs { impl CliArgs {
/// Determine connection mode from worker URLs
fn determine_connection_mode(worker_urls: &[String]) -> ConnectionMode { fn determine_connection_mode(worker_urls: &[String]) -> ConnectionMode {
// Only consider it gRPC if explicitly specified with grpc:// or grpcs:// scheme
for url in worker_urls { for url in worker_urls {
if url.starts_with("grpc://") || url.starts_with("grpcs://") { if url.starts_with("grpc://") || url.starts_with("grpcs://") {
return ConnectionMode::Grpc; return ConnectionMode::Grpc;
} }
} }
// Default to HTTP for all other cases (including http://, https://, or no scheme)
ConnectionMode::Http ConnectionMode::Http
} }
/// Parse selector strings into HashMap
fn parse_selector(selector_list: &[String]) -> HashMap<String, String> { fn parse_selector(selector_list: &[String]) -> HashMap<String, String> {
let mut map = HashMap::new(); let mut map = HashMap::new();
for item in selector_list { for item in selector_list {
...@@ -382,7 +310,6 @@ impl CliArgs { ...@@ -382,7 +310,6 @@ impl CliArgs {
map map
} }
/// Convert policy string to PolicyConfig
fn parse_policy(&self, policy_str: &str) -> PolicyConfig { fn parse_policy(&self, policy_str: &str) -> PolicyConfig {
match policy_str { match policy_str {
"random" => PolicyConfig::Random, "random" => PolicyConfig::Random,
...@@ -395,9 +322,9 @@ impl CliArgs { ...@@ -395,9 +322,9 @@ impl CliArgs {
max_tree_size: self.max_tree_size, max_tree_size: self.max_tree_size,
}, },
"power_of_two" => PolicyConfig::PowerOfTwo { "power_of_two" => PolicyConfig::PowerOfTwo {
load_check_interval_secs: 5, // Default value load_check_interval_secs: 5,
}, },
_ => PolicyConfig::RoundRobin, // Fallback _ => PolicyConfig::RoundRobin,
} }
} }
...@@ -482,26 +409,21 @@ impl CliArgs { ...@@ -482,26 +409,21 @@ impl CliArgs {
}) })
} }
/// Convert CLI arguments to RouterConfig
fn to_router_config( fn to_router_config(
&self, &self,
prefill_urls: Vec<(String, Option<u16>)>, prefill_urls: Vec<(String, Option<u16>)>,
) -> ConfigResult<RouterConfig> { ) -> ConfigResult<RouterConfig> {
// Determine routing mode
let mode = if self.enable_igw { let mode = if self.enable_igw {
// IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
RoutingMode::Regular { RoutingMode::Regular {
worker_urls: vec![], worker_urls: vec![],
} }
} else if matches!(self.backend, Backend::Openai) { } else if matches!(self.backend, Backend::Openai) {
// OpenAI backend mode - use worker_urls as base(s)
RoutingMode::OpenAI { RoutingMode::OpenAI {
worker_urls: self.worker_urls.clone(), worker_urls: self.worker_urls.clone(),
} }
} else if self.pd_disaggregation { } else if self.pd_disaggregation {
let decode_urls = self.decode.clone(); let decode_urls = self.decode.clone();
// Validate PD configuration if not using service discovery
if !self.service_discovery && (prefill_urls.is_empty() || decode_urls.is_empty()) { if !self.service_discovery && (prefill_urls.is_empty() || decode_urls.is_empty()) {
return Err(ConfigError::ValidationFailed { return Err(ConfigError::ValidationFailed {
reason: "PD disaggregation mode requires --prefill and --decode URLs when not using service discovery".to_string(), reason: "PD disaggregation mode requires --prefill and --decode URLs when not using service discovery".to_string(),
...@@ -515,7 +437,6 @@ impl CliArgs { ...@@ -515,7 +437,6 @@ impl CliArgs {
decode_policy: self.decode_policy.as_ref().map(|p| self.parse_policy(p)), decode_policy: self.decode_policy.as_ref().map(|p| self.parse_policy(p)),
} }
} else { } else {
// Regular mode
if !self.service_discovery && self.worker_urls.is_empty() { if !self.service_discovery && self.worker_urls.is_empty() {
return Err(ConfigError::ValidationFailed { return Err(ConfigError::ValidationFailed {
reason: "Regular mode requires --worker-urls when not using service discovery" reason: "Regular mode requires --worker-urls when not using service discovery"
...@@ -527,10 +448,8 @@ impl CliArgs { ...@@ -527,10 +448,8 @@ impl CliArgs {
} }
}; };
// Main policy
let policy = self.parse_policy(&self.policy); let policy = self.parse_policy(&self.policy);
// Service discovery configuration
let discovery = if self.service_discovery { let discovery = if self.service_discovery {
Some(DiscoveryConfig { Some(DiscoveryConfig {
enabled: true, enabled: true,
...@@ -546,13 +465,11 @@ impl CliArgs { ...@@ -546,13 +465,11 @@ impl CliArgs {
None None
}; };
// Metrics configuration
let metrics = Some(MetricsConfig { let metrics = Some(MetricsConfig {
port: self.prometheus_port, port: self.prometheus_port,
host: self.prometheus_host.clone(), host: self.prometheus_host.clone(),
}); });
// Determine connection mode from all worker URLs
let mut all_urls = Vec::new(); let mut all_urls = Vec::new();
match &mode { match &mode {
RoutingMode::Regular { worker_urls } => { RoutingMode::Regular { worker_urls } => {
...@@ -568,9 +485,7 @@ impl CliArgs { ...@@ -568,9 +485,7 @@ impl CliArgs {
} }
all_urls.extend(decode_urls.clone()); all_urls.extend(decode_urls.clone());
} }
RoutingMode::OpenAI { .. } => { RoutingMode::OpenAI { .. } => {}
// For connection-mode detection, skip URLs; OpenAI forces HTTP below.
}
} }
let connection_mode = match &mode { let connection_mode = match &mode {
RoutingMode::OpenAI { .. } => ConnectionMode::Http, RoutingMode::OpenAI { .. } => ConnectionMode::Http,
...@@ -589,7 +504,6 @@ impl CliArgs { ...@@ -589,7 +504,6 @@ impl CliArgs {
None None
}; };
// Build RouterConfig
Ok(RouterConfig { Ok(RouterConfig {
mode, mode,
policy, policy,
...@@ -612,8 +526,8 @@ impl CliArgs { ...@@ -612,8 +526,8 @@ impl CliArgs {
Some(self.request_id_headers.clone()) Some(self.request_id_headers.clone())
}, },
max_concurrent_requests: self.max_concurrent_requests, max_concurrent_requests: self.max_concurrent_requests,
queue_size: 100, // Default queue size queue_size: 100,
queue_timeout_secs: 60, // Default timeout queue_timeout_secs: 60,
cors_allowed_origins: self.cors_allowed_origins.clone(), cors_allowed_origins: self.cors_allowed_origins.clone(),
retry: RetryConfig { retry: RetryConfig {
max_retries: self.retry_max_retries, max_retries: self.retry_max_retries,
...@@ -646,9 +560,7 @@ impl CliArgs { ...@@ -646,9 +560,7 @@ impl CliArgs {
}) })
} }
/// Create ServerConfig from CLI args and RouterConfig
fn to_server_config(&self, router_config: RouterConfig) -> ServerConfig { fn to_server_config(&self, router_config: RouterConfig) -> ServerConfig {
// Create service discovery config if enabled
let service_discovery_config = if self.service_discovery { let service_discovery_config = if self.service_discovery {
Some(ServiceDiscoveryConfig { Some(ServiceDiscoveryConfig {
enabled: true, enabled: true,
...@@ -665,7 +577,6 @@ impl CliArgs { ...@@ -665,7 +577,6 @@ impl CliArgs {
None None
}; };
// Create Prometheus config
let prometheus_config = Some(PrometheusConfig { let prometheus_config = Some(PrometheusConfig {
port: self.prometheus_port, port: self.prometheus_port,
host: self.prometheus_host.clone(), host: self.prometheus_host.clone(),
...@@ -691,19 +602,15 @@ impl CliArgs { ...@@ -691,19 +602,15 @@ impl CliArgs {
} }
fn main() -> Result<(), Box<dyn std::error::Error>> { fn main() -> Result<(), Box<dyn std::error::Error>> {
// Parse prefill arguments manually before clap parsing
let prefill_urls = parse_prefill_args(); let prefill_urls = parse_prefill_args();
// Filter out prefill arguments and their values before passing to clap
let mut filtered_args: Vec<String> = Vec::new(); let mut filtered_args: Vec<String> = Vec::new();
let raw_args: Vec<String> = std::env::args().collect(); let raw_args: Vec<String> = std::env::args().collect();
let mut i = 0; let mut i = 0;
while i < raw_args.len() { while i < raw_args.len() {
if raw_args[i] == "--prefill" && i + 1 < raw_args.len() { if raw_args[i] == "--prefill" && i + 1 < raw_args.len() {
// Skip --prefill and its URL
i += 2; i += 2;
// Also skip bootstrap port if present
if i < raw_args.len() if i < raw_args.len()
&& !raw_args[i].starts_with("--") && !raw_args[i].starts_with("--")
&& (raw_args[i].parse::<u16>().is_ok() || raw_args[i].to_lowercase() == "none") && (raw_args[i].parse::<u16>().is_ok() || raw_args[i].to_lowercase() == "none")
...@@ -716,10 +623,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> { ...@@ -716,10 +623,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
} }
} }
// Parse CLI arguments with clap using filtered args
let cli_args = CliArgs::parse_from(filtered_args); let cli_args = CliArgs::parse_from(filtered_args);
// Print startup info
println!("SGLang Router starting..."); println!("SGLang Router starting...");
println!("Host: {}:{}", cli_args.host, cli_args.port); println!("Host: {}:{}", cli_args.host, cli_args.port);
let mode_str = if cli_args.enable_igw { let mode_str = if cli_args.enable_igw {
...@@ -733,7 +638,6 @@ fn main() -> Result<(), Box<dyn std::error::Error>> { ...@@ -733,7 +638,6 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
}; };
println!("Mode: {}", mode_str); println!("Mode: {}", mode_str);
// Warn for runtimes that are parsed but not yet implemented
match cli_args.backend { match cli_args.backend {
Backend::Vllm | Backend::Trtllm | Backend::Anthropic => { Backend::Vllm | Backend::Trtllm | Backend::Anthropic => {
println!( println!(
...@@ -754,19 +658,10 @@ Provide --worker-urls or PD flags as usual.", ...@@ -754,19 +658,10 @@ Provide --worker-urls or PD flags as usual.",
} }
} }
// Convert to RouterConfig
let router_config = cli_args.to_router_config(prefill_urls)?; let router_config = cli_args.to_router_config(prefill_urls)?;
// Validate configuration
router_config.validate()?; router_config.validate()?;
// Create ServerConfig
let server_config = cli_args.to_server_config(router_config); let server_config = cli_args.to_server_config(router_config);
// Create a new runtime for the server (like Python binding does)
let runtime = tokio::runtime::Runtime::new()?; let runtime = tokio::runtime::Runtime::new()?;
// Block on the async startup function
runtime.block_on(async move { server::startup(server_config).await })?; runtime.block_on(async move { server::startup(server_config).await })?;
Ok(()) Ok(())
......
...@@ -19,7 +19,6 @@ impl Default for PrometheusConfig { ...@@ -19,7 +19,6 @@ impl Default for PrometheusConfig {
} }
pub fn init_metrics() { pub fn init_metrics() {
// Request metrics
describe_counter!( describe_counter!(
"sgl_router_requests_total", "sgl_router_requests_total",
"Total number of requests by route and method" "Total number of requests by route and method"
...@@ -45,7 +44,6 @@ pub fn init_metrics() { ...@@ -45,7 +44,6 @@ pub fn init_metrics() {
"Total number of requests that exhausted retries by route" "Total number of requests that exhausted retries by route"
); );
// Circuit breaker metrics
describe_gauge!( describe_gauge!(
"sgl_router_cb_state", "sgl_router_cb_state",
"Circuit breaker state per worker (0=closed, 1=open, 2=half_open)" "Circuit breaker state per worker (0=closed, 1=open, 2=half_open)"
...@@ -59,7 +57,6 @@ pub fn init_metrics() { ...@@ -59,7 +57,6 @@ pub fn init_metrics() {
"Total number of circuit breaker outcomes by worker and outcome type (success/failure)" "Total number of circuit breaker outcomes by worker and outcome type (success/failure)"
); );
// Worker metrics
describe_gauge!( describe_gauge!(
"sgl_router_active_workers", "sgl_router_active_workers",
"Number of currently active workers" "Number of currently active workers"
...@@ -74,7 +71,6 @@ pub fn init_metrics() { ...@@ -74,7 +71,6 @@ pub fn init_metrics() {
"Total requests processed by each worker" "Total requests processed by each worker"
); );
// Policy metrics
describe_counter!( describe_counter!(
"sgl_router_policy_decisions_total", "sgl_router_policy_decisions_total",
"Total routing policy decisions by policy and worker" "Total routing policy decisions by policy and worker"
...@@ -92,7 +88,6 @@ pub fn init_metrics() { ...@@ -92,7 +88,6 @@ pub fn init_metrics() {
describe_gauge!("sgl_router_max_load", "Maximum worker load"); describe_gauge!("sgl_router_max_load", "Maximum worker load");
describe_gauge!("sgl_router_min_load", "Minimum worker load"); describe_gauge!("sgl_router_min_load", "Minimum worker load");
// PD-specific metrics
describe_counter!("sgl_router_pd_requests_total", "Total PD requests by route"); describe_counter!("sgl_router_pd_requests_total", "Total PD requests by route");
describe_counter!( describe_counter!(
"sgl_router_pd_prefill_requests_total", "sgl_router_pd_prefill_requests_total",
...@@ -123,7 +118,6 @@ pub fn init_metrics() { ...@@ -123,7 +118,6 @@ pub fn init_metrics() {
"PD request duration by route" "PD request duration by route"
); );
// Service discovery metrics
describe_counter!( describe_counter!(
"sgl_router_discovery_updates_total", "sgl_router_discovery_updates_total",
"Total service discovery update events" "Total service discovery update events"
...@@ -137,13 +131,11 @@ pub fn init_metrics() { ...@@ -137,13 +131,11 @@ pub fn init_metrics() {
"Number of workers removed in last discovery update" "Number of workers removed in last discovery update"
); );
// Generate request specific metrics
describe_histogram!( describe_histogram!(
"sgl_router_generate_duration_seconds", "sgl_router_generate_duration_seconds",
"Generate request duration" "Generate request duration"
); );
// Embedding request specific metrics
describe_counter!("sgl_router_embeddings_total", "Total embedding requests"); describe_counter!("sgl_router_embeddings_total", "Total embedding requests");
describe_histogram!( describe_histogram!(
"sgl_router_embeddings_duration_seconds", "sgl_router_embeddings_duration_seconds",
...@@ -155,13 +147,11 @@ pub fn init_metrics() { ...@@ -155,13 +147,11 @@ pub fn init_metrics() {
); );
describe_gauge!("sgl_router_embeddings_queue_size", "Embedding queue size"); describe_gauge!("sgl_router_embeddings_queue_size", "Embedding queue size");
// Running requests gauge for cache-aware policy
describe_gauge!( describe_gauge!(
"sgl_router_running_requests", "sgl_router_running_requests",
"Number of running requests per worker" "Number of running requests per worker"
); );
// Tokenizer metrics
describe_histogram!( describe_histogram!(
"sgl_tokenizer_encode_duration_seconds", "sgl_tokenizer_encode_duration_seconds",
"Time to encode text to tokens" "Time to encode text to tokens"
...@@ -207,7 +197,6 @@ pub fn init_metrics() { ...@@ -207,7 +197,6 @@ pub fn init_metrics() {
"Vocabulary size of the loaded tokenizer" "Vocabulary size of the loaded tokenizer"
); );
// Stop sequence detection metrics
describe_counter!( describe_counter!(
"sgl_tokenizer_stop_sequences_detected_total", "sgl_tokenizer_stop_sequences_detected_total",
"Total stop sequences detected by type" "Total stop sequences detected by type"
...@@ -221,7 +210,6 @@ pub fn init_metrics() { ...@@ -221,7 +210,6 @@ pub fn init_metrics() {
"Time to check for stop sequences per token" "Time to check for stop sequences per token"
); );
// Streaming decode metrics
describe_counter!( describe_counter!(
"sgl_tokenizer_stream_tokens_total", "sgl_tokenizer_stream_tokens_total",
"Total tokens processed in streaming decode" "Total tokens processed in streaming decode"
...@@ -235,7 +223,6 @@ pub fn init_metrics() { ...@@ -235,7 +223,6 @@ pub fn init_metrics() {
"Time per streaming decode step" "Time per streaming decode step"
); );
// Factory metrics
describe_counter!( describe_counter!(
"sgl_tokenizer_factory_loads_total", "sgl_tokenizer_factory_loads_total",
"Total tokenizer loads by file type" "Total tokenizer loads by file type"
...@@ -251,7 +238,6 @@ pub fn init_metrics() { ...@@ -251,7 +238,6 @@ pub fn init_metrics() {
} }
pub fn start_prometheus(config: PrometheusConfig) { pub fn start_prometheus(config: PrometheusConfig) {
// Initialize metric descriptions
init_metrics(); init_metrics();
let duration_matcher = Matcher::Suffix(String::from("duration_seconds")); let duration_matcher = Matcher::Suffix(String::from("duration_seconds"));
...@@ -280,7 +266,6 @@ pub struct RouterMetrics; ...@@ -280,7 +266,6 @@ pub struct RouterMetrics;
pub struct TokenizerMetrics; pub struct TokenizerMetrics;
impl RouterMetrics { impl RouterMetrics {
// Request metrics
pub fn record_request(route: &str) { pub fn record_request(route: &str) {
counter!("sgl_router_requests_total", counter!("sgl_router_requests_total",
"route" => route.to_string() "route" => route.to_string()
...@@ -324,7 +309,6 @@ impl RouterMetrics { ...@@ -324,7 +309,6 @@ impl RouterMetrics {
.increment(1); .increment(1);
} }
// Worker metrics
pub fn set_active_workers(count: usize) { pub fn set_active_workers(count: usize) {
gauge!("sgl_router_active_workers").set(count as f64); gauge!("sgl_router_active_workers").set(count as f64);
} }
...@@ -350,7 +334,6 @@ impl RouterMetrics { ...@@ -350,7 +334,6 @@ impl RouterMetrics {
.increment(1); .increment(1);
} }
// Policy metrics
pub fn record_policy_decision(policy: &str, worker: &str) { pub fn record_policy_decision(policy: &str, worker: &str) {
counter!("sgl_router_policy_decisions_total", counter!("sgl_router_policy_decisions_total",
"policy" => policy.to_string(), "policy" => policy.to_string(),
...@@ -383,7 +366,6 @@ impl RouterMetrics { ...@@ -383,7 +366,6 @@ impl RouterMetrics {
gauge!("sgl_router_min_load").set(min_load as f64); gauge!("sgl_router_min_load").set(min_load as f64);
} }
// PD-specific metrics
pub fn record_pd_request(route: &str) { pub fn record_pd_request(route: &str) {
counter!("sgl_router_pd_requests_total", counter!("sgl_router_pd_requests_total",
"route" => route.to_string() "route" => route.to_string()
...@@ -440,19 +422,16 @@ impl RouterMetrics { ...@@ -440,19 +422,16 @@ impl RouterMetrics {
.increment(1); .increment(1);
} }
// Service discovery metrics
pub fn record_discovery_update(added: usize, removed: usize) { pub fn record_discovery_update(added: usize, removed: usize) {
counter!("sgl_router_discovery_updates_total").increment(1); counter!("sgl_router_discovery_updates_total").increment(1);
gauge!("sgl_router_discovery_workers_added").set(added as f64); gauge!("sgl_router_discovery_workers_added").set(added as f64);
gauge!("sgl_router_discovery_workers_removed").set(removed as f64); gauge!("sgl_router_discovery_workers_removed").set(removed as f64);
} }
// Generate request metrics
pub fn record_generate_duration(duration: Duration) { pub fn record_generate_duration(duration: Duration) {
histogram!("sgl_router_generate_duration_seconds").record(duration.as_secs_f64()); histogram!("sgl_router_generate_duration_seconds").record(duration.as_secs_f64());
} }
// Embeddings metrics
pub fn record_embeddings_request() { pub fn record_embeddings_request() {
counter!("sgl_router_embeddings_total").increment(1); counter!("sgl_router_embeddings_total").increment(1);
} }
...@@ -473,7 +452,6 @@ impl RouterMetrics { ...@@ -473,7 +452,6 @@ impl RouterMetrics {
gauge!("sgl_router_embeddings_queue_size").set(size as f64); gauge!("sgl_router_embeddings_queue_size").set(size as f64);
} }
// Running requests for cache-aware policy
pub fn set_running_requests(worker: &str, count: usize) { pub fn set_running_requests(worker: &str, count: usize) {
gauge!("sgl_router_running_requests", gauge!("sgl_router_running_requests",
"worker" => worker.to_string() "worker" => worker.to_string()
...@@ -481,7 +459,6 @@ impl RouterMetrics { ...@@ -481,7 +459,6 @@ impl RouterMetrics {
.set(count as f64); .set(count as f64);
} }
// Circuit breaker metrics
pub fn set_cb_state(worker: &str, state_code: u8) { pub fn set_cb_state(worker: &str, state_code: u8) {
gauge!("sgl_router_cb_state", gauge!("sgl_router_cb_state",
"worker" => worker.to_string() "worker" => worker.to_string()
...@@ -508,7 +485,6 @@ impl RouterMetrics { ...@@ -508,7 +485,6 @@ impl RouterMetrics {
} }
impl TokenizerMetrics { impl TokenizerMetrics {
// Encoding metrics
pub fn record_encode_request(tokenizer_type: &str) { pub fn record_encode_request(tokenizer_type: &str) {
counter!("sgl_tokenizer_encode_requests_total", counter!("sgl_tokenizer_encode_requests_total",
"tokenizer_type" => tokenizer_type.to_string() "tokenizer_type" => tokenizer_type.to_string()
...@@ -535,7 +511,6 @@ impl TokenizerMetrics { ...@@ -535,7 +511,6 @@ impl TokenizerMetrics {
histogram!("sgl_tokenizer_chars_per_encode").record(char_count as f64); histogram!("sgl_tokenizer_chars_per_encode").record(char_count as f64);
} }
// Decoding metrics
pub fn record_decode_request(tokenizer_type: &str) { pub fn record_decode_request(tokenizer_type: &str) {
counter!("sgl_tokenizer_decode_requests_total", counter!("sgl_tokenizer_decode_requests_total",
"tokenizer_type" => tokenizer_type.to_string() "tokenizer_type" => tokenizer_type.to_string()
...@@ -558,7 +533,6 @@ impl TokenizerMetrics { ...@@ -558,7 +533,6 @@ impl TokenizerMetrics {
histogram!("sgl_tokenizer_tokens_per_decode").record(token_count as f64); histogram!("sgl_tokenizer_tokens_per_decode").record(token_count as f64);
} }
// Batch encoding metrics
pub fn record_encode_batch_duration(duration: Duration, batch_size: usize) { pub fn record_encode_batch_duration(duration: Duration, batch_size: usize) {
histogram!("sgl_tokenizer_encode_batch_duration_seconds", histogram!("sgl_tokenizer_encode_batch_duration_seconds",
"batch_size" => batch_size.to_string() "batch_size" => batch_size.to_string()
...@@ -566,7 +540,6 @@ impl TokenizerMetrics { ...@@ -566,7 +540,6 @@ impl TokenizerMetrics {
.record(duration.as_secs_f64()); .record(duration.as_secs_f64());
} }
// Stop sequence detection metrics
pub fn record_stop_sequence_detected(stop_type: &str) { pub fn record_stop_sequence_detected(stop_type: &str) {
counter!("sgl_tokenizer_stop_sequences_detected_total", counter!("sgl_tokenizer_stop_sequences_detected_total",
"type" => stop_type.to_string() "type" => stop_type.to_string()
...@@ -582,7 +555,6 @@ impl TokenizerMetrics { ...@@ -582,7 +555,6 @@ impl TokenizerMetrics {
histogram!("sgl_tokenizer_stop_detection_duration_seconds").record(duration.as_secs_f64()); histogram!("sgl_tokenizer_stop_detection_duration_seconds").record(duration.as_secs_f64());
} }
// Streaming decode metrics
pub fn record_stream_token() { pub fn record_stream_token() {
counter!("sgl_tokenizer_stream_tokens_total").increment(1); counter!("sgl_tokenizer_stream_tokens_total").increment(1);
} }
...@@ -595,7 +567,6 @@ impl TokenizerMetrics { ...@@ -595,7 +567,6 @@ impl TokenizerMetrics {
histogram!("sgl_tokenizer_stream_step_duration_seconds").record(duration.as_secs_f64()); histogram!("sgl_tokenizer_stream_step_duration_seconds").record(duration.as_secs_f64());
} }
// Factory metrics
pub fn record_factory_load(file_type: &str) { pub fn record_factory_load(file_type: &str) {
counter!("sgl_tokenizer_factory_loads_total", counter!("sgl_tokenizer_factory_loads_total",
"file_type" => file_type.to_string() "file_type" => file_type.to_string()
...@@ -614,7 +585,6 @@ impl TokenizerMetrics { ...@@ -614,7 +585,6 @@ impl TokenizerMetrics {
histogram!("sgl_tokenizer_factory_load_duration_seconds").record(duration.as_secs_f64()); histogram!("sgl_tokenizer_factory_load_duration_seconds").record(duration.as_secs_f64());
} }
// Vocabulary metrics
pub fn set_vocab_size(tokenizer_type: &str, size: usize) { pub fn set_vocab_size(tokenizer_type: &str, size: usize) {
gauge!("sgl_tokenizer_vocab_size", gauge!("sgl_tokenizer_vocab_size",
"tokenizer_type" => tokenizer_type.to_string() "tokenizer_type" => tokenizer_type.to_string()
...@@ -705,7 +675,6 @@ mod tests { ...@@ -705,7 +675,6 @@ mod tests {
.parse() .parse()
.unwrap_or(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))); .unwrap_or(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
// Should fall back to 0.0.0.0
assert_eq!(ip_addr, IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))); assert_eq!(ip_addr, IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
} }
} }
...@@ -780,7 +749,6 @@ mod tests { ...@@ -780,7 +749,6 @@ mod tests {
fn test_duration_suffix_matcher() { fn test_duration_suffix_matcher() {
let matcher = Matcher::Suffix(String::from("duration_seconds")); let matcher = Matcher::Suffix(String::from("duration_seconds"));
// Test matching behavior
let _matching_metrics = [ let _matching_metrics = [
"request_duration_seconds", "request_duration_seconds",
"response_duration_seconds", "response_duration_seconds",
...@@ -789,8 +757,6 @@ mod tests { ...@@ -789,8 +757,6 @@ mod tests {
let _non_matching_metrics = ["duration_total", "duration_seconds_total", "other_metric"]; let _non_matching_metrics = ["duration_total", "duration_seconds_total", "other_metric"];
// Note: We can't directly test Matcher matching without the internals,
// but we can verify the matcher is created correctly
match matcher { match matcher {
Matcher::Suffix(suffix) => assert_eq!(suffix, "duration_seconds"), Matcher::Suffix(suffix) => assert_eq!(suffix, "duration_seconds"),
_ => panic!("Expected Suffix matcher"), _ => panic!("Expected Suffix matcher"),
...@@ -801,7 +767,6 @@ mod tests { ...@@ -801,7 +767,6 @@ mod tests {
#[test] #[test]
fn test_prometheus_builder_configuration() { fn test_prometheus_builder_configuration() {
// This test verifies the builder configuration without actually starting Prometheus
let _config = PrometheusConfig::default(); let _config = PrometheusConfig::default();
let duration_matcher = Matcher::Suffix(String::from("duration_seconds")); let duration_matcher = Matcher::Suffix(String::from("duration_seconds"));
...@@ -810,10 +775,8 @@ mod tests { ...@@ -810,10 +775,8 @@ mod tests {
60.0, 90.0, 120.0, 180.0, 240.0, 60.0, 90.0, 120.0, 180.0, 240.0,
]; ];
// Verify bucket configuration
assert_eq!(duration_bucket.len(), 20); assert_eq!(duration_bucket.len(), 20);
// Verify matcher is suffix type
match duration_matcher { match duration_matcher {
Matcher::Suffix(s) => assert_eq!(s, "duration_seconds"), Matcher::Suffix(s) => assert_eq!(s, "duration_seconds"),
_ => panic!("Expected Suffix matcher"), _ => panic!("Expected Suffix matcher"),
...@@ -832,14 +795,12 @@ mod tests { ...@@ -832,14 +795,12 @@ mod tests {
#[test] #[test]
fn test_custom_buckets_for_different_metrics() { fn test_custom_buckets_for_different_metrics() {
// Test that we can create different bucket configurations
let request_buckets = [0.001, 0.01, 0.1, 1.0, 10.0]; let request_buckets = [0.001, 0.01, 0.1, 1.0, 10.0];
let generate_buckets = [0.1, 0.5, 1.0, 5.0, 30.0, 60.0]; let generate_buckets = [0.1, 0.5, 1.0, 5.0, 30.0, 60.0];
assert_eq!(request_buckets.len(), 5); assert_eq!(request_buckets.len(), 5);
assert_eq!(generate_buckets.len(), 6); assert_eq!(generate_buckets.len(), 6);
// Verify each set is sorted
for i in 1..request_buckets.len() { for i in 1..request_buckets.len() {
assert!(request_buckets[i] > request_buckets[i - 1]); assert!(request_buckets[i] > request_buckets[i - 1]);
} }
...@@ -853,7 +814,6 @@ mod tests { ...@@ -853,7 +814,6 @@ mod tests {
#[test] #[test]
fn test_metrics_static_methods() { fn test_metrics_static_methods() {
// Test that all static methods can be called without panic
RouterMetrics::record_request("/generate"); RouterMetrics::record_request("/generate");
RouterMetrics::record_request_duration("/generate", Duration::from_millis(100)); RouterMetrics::record_request_duration("/generate", Duration::from_millis(100));
RouterMetrics::record_request_error("/generate", "timeout"); RouterMetrics::record_request_error("/generate", "timeout");
...@@ -887,41 +847,32 @@ mod tests { ...@@ -887,41 +847,32 @@ mod tests {
#[test] #[test]
fn test_tokenizer_metrics_static_methods() { fn test_tokenizer_metrics_static_methods() {
// Test that all tokenizer metric methods can be called without panic
// Encoding metrics
TokenizerMetrics::record_encode_request("huggingface"); TokenizerMetrics::record_encode_request("huggingface");
TokenizerMetrics::record_encode_duration(Duration::from_millis(10)); TokenizerMetrics::record_encode_duration(Duration::from_millis(10));
TokenizerMetrics::record_encode_error("invalid_input"); TokenizerMetrics::record_encode_error("invalid_input");
TokenizerMetrics::record_tokens_per_encode(100); TokenizerMetrics::record_tokens_per_encode(100);
TokenizerMetrics::record_chars_per_encode(500); TokenizerMetrics::record_chars_per_encode(500);
// Decoding metrics
TokenizerMetrics::record_decode_request("huggingface"); TokenizerMetrics::record_decode_request("huggingface");
TokenizerMetrics::record_decode_duration(Duration::from_millis(5)); TokenizerMetrics::record_decode_duration(Duration::from_millis(5));
TokenizerMetrics::record_decode_error("invalid_tokens"); TokenizerMetrics::record_decode_error("invalid_tokens");
TokenizerMetrics::record_tokens_per_decode(50); TokenizerMetrics::record_tokens_per_decode(50);
// Batch encoding
TokenizerMetrics::record_encode_batch_duration(Duration::from_millis(100), 10); TokenizerMetrics::record_encode_batch_duration(Duration::from_millis(100), 10);
// Stop sequence detection
TokenizerMetrics::record_stop_sequence_detected("token"); TokenizerMetrics::record_stop_sequence_detected("token");
TokenizerMetrics::record_stop_sequence_detected("string"); TokenizerMetrics::record_stop_sequence_detected("string");
TokenizerMetrics::record_partial_match(); TokenizerMetrics::record_partial_match();
TokenizerMetrics::record_stop_detection_duration(Duration::from_micros(100)); TokenizerMetrics::record_stop_detection_duration(Duration::from_micros(100));
// Streaming decode
TokenizerMetrics::record_stream_token(); TokenizerMetrics::record_stream_token();
TokenizerMetrics::record_incomplete_utf8(); TokenizerMetrics::record_incomplete_utf8();
TokenizerMetrics::record_stream_step_duration(Duration::from_micros(50)); TokenizerMetrics::record_stream_step_duration(Duration::from_micros(50));
// Factory metrics
TokenizerMetrics::record_factory_load("json"); TokenizerMetrics::record_factory_load("json");
TokenizerMetrics::record_factory_error("unsupported_format"); TokenizerMetrics::record_factory_error("unsupported_format");
TokenizerMetrics::record_factory_load_duration(Duration::from_millis(200)); TokenizerMetrics::record_factory_load_duration(Duration::from_millis(200));
// Vocabulary metrics
TokenizerMetrics::set_vocab_size("huggingface", 50000); TokenizerMetrics::set_vocab_size("huggingface", 50000);
} }
...@@ -929,17 +880,14 @@ mod tests { ...@@ -929,17 +880,14 @@ mod tests {
#[test] #[test]
fn test_port_already_in_use() { fn test_port_already_in_use() {
// Skip this test if we can't bind to the port let port = 29123;
let port = 29123; // Use a different port to avoid conflicts
if let Ok(_listener) = TcpListener::bind(("127.0.0.1", port)) { if let Ok(_listener) = TcpListener::bind(("127.0.0.1", port)) {
// Port is available, we can test
let config = PrometheusConfig { let config = PrometheusConfig {
port, port,
host: "127.0.0.1".to_string(), host: "127.0.0.1".to_string(),
}; };
// Just verify config is created correctly
assert_eq!(config.port, port); assert_eq!(config.port, port);
} }
} }
...@@ -948,8 +896,6 @@ mod tests { ...@@ -948,8 +896,6 @@ mod tests {
#[test] #[test]
fn test_metrics_endpoint_accessibility() { fn test_metrics_endpoint_accessibility() {
// This would be an integration test in practice
// Here we just verify the configuration
let config = PrometheusConfig { let config = PrometheusConfig {
port: 29000, port: 29000,
host: "127.0.0.1".to_string(), host: "127.0.0.1".to_string(),
...@@ -963,7 +909,6 @@ mod tests { ...@@ -963,7 +909,6 @@ mod tests {
#[test] #[test]
fn test_concurrent_metric_updates() { fn test_concurrent_metric_updates() {
// Test that metric updates can be called concurrently
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::thread; use std::thread;
...@@ -984,11 +929,9 @@ mod tests { ...@@ -984,11 +929,9 @@ mod tests {
handles.push(handle); handles.push(handle);
} }
// Let threads run briefly
thread::sleep(Duration::from_millis(10)); thread::sleep(Duration::from_millis(10));
done.store(true, Ordering::Relaxed); done.store(true, Ordering::Relaxed);
// Wait for all threads
for handle in handles { for handle in handles {
handle.join().unwrap(); handle.join().unwrap();
} }
...@@ -998,7 +941,6 @@ mod tests { ...@@ -998,7 +941,6 @@ mod tests {
#[test] #[test]
fn test_empty_string_metrics() { fn test_empty_string_metrics() {
// Test that empty strings don't cause issues
RouterMetrics::record_request(""); RouterMetrics::record_request("");
RouterMetrics::set_worker_health("", true); RouterMetrics::set_worker_health("", true);
RouterMetrics::record_policy_decision("", ""); RouterMetrics::record_policy_decision("", "");
...@@ -1030,7 +972,6 @@ mod tests { ...@@ -1030,7 +972,6 @@ mod tests {
#[test] #[test]
fn test_extreme_metric_values() { fn test_extreme_metric_values() {
// Test extreme values
RouterMetrics::set_active_workers(0); RouterMetrics::set_active_workers(0);
RouterMetrics::set_active_workers(usize::MAX); RouterMetrics::set_active_workers(usize::MAX);
...@@ -1038,7 +979,6 @@ mod tests { ...@@ -1038,7 +979,6 @@ mod tests {
RouterMetrics::set_worker_load("worker", usize::MAX); RouterMetrics::set_worker_load("worker", usize::MAX);
RouterMetrics::record_request_duration("route", Duration::from_nanos(1)); RouterMetrics::record_request_duration("route", Duration::from_nanos(1));
// 24 hours
RouterMetrics::record_request_duration("route", Duration::from_secs(86400)); RouterMetrics::record_request_duration("route", Duration::from_secs(86400));
} }
} }
...@@ -19,7 +19,6 @@ use tokio::task; ...@@ -19,7 +19,6 @@ use tokio::task;
use tokio::time; use tokio::time;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
/// Represents the service discovery configuration
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ServiceDiscoveryConfig { pub struct ServiceDiscoveryConfig {
pub enabled: bool, pub enabled: bool,
...@@ -41,8 +40,8 @@ impl Default for ServiceDiscoveryConfig { ...@@ -41,8 +40,8 @@ impl Default for ServiceDiscoveryConfig {
enabled: false, enabled: false,
selector: HashMap::new(), selector: HashMap::new(),
check_interval: Duration::from_secs(60), check_interval: Duration::from_secs(60),
port: 8000, // Standard port for modern services port: 8000,
namespace: None, // None means watch all namespaces namespace: None,
pd_mode: false, pd_mode: false,
prefill_selector: HashMap::new(), prefill_selector: HashMap::new(),
decode_selector: HashMap::new(), decode_selector: HashMap::new(),
...@@ -51,7 +50,6 @@ impl Default for ServiceDiscoveryConfig { ...@@ -51,7 +50,6 @@ impl Default for ServiceDiscoveryConfig {
} }
} }
/// Pod type for PD mode service discovery
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum PodType { pub enum PodType {
Prefill, Prefill,
...@@ -59,7 +57,6 @@ pub enum PodType { ...@@ -59,7 +57,6 @@ pub enum PodType {
Regular, Regular,
} }
/// Represents a Kubernetes pod's information used for worker management
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct PodInfo { pub struct PodInfo {
pub name: String, pub name: String,
...@@ -71,7 +68,6 @@ pub struct PodInfo { ...@@ -71,7 +68,6 @@ pub struct PodInfo {
} }
impl PodInfo { impl PodInfo {
/// Check if a pod matches any of the given selectors
fn matches_selector(pod: &Pod, selector: &HashMap<String, String>) -> bool { fn matches_selector(pod: &Pod, selector: &HashMap<String, String>) -> bool {
if selector.is_empty() { if selector.is_empty() {
return false; return false;
...@@ -83,19 +79,15 @@ impl PodInfo { ...@@ -83,19 +79,15 @@ impl PodInfo {
.is_some_and(|labels| selector.iter().all(|(k, v)| labels.get(k) == Some(v))) .is_some_and(|labels| selector.iter().all(|(k, v)| labels.get(k) == Some(v)))
} }
/// Check if a pod should be included in service discovery
pub fn should_include(pod: &Pod, config: &ServiceDiscoveryConfig) -> bool { pub fn should_include(pod: &Pod, config: &ServiceDiscoveryConfig) -> bool {
if config.pd_mode { if config.pd_mode {
// In PD mode, at least one selector must be non-empty
if config.prefill_selector.is_empty() && config.decode_selector.is_empty() { if config.prefill_selector.is_empty() && config.decode_selector.is_empty() {
warn!("PD mode enabled but both prefill_selector and decode_selector are empty"); warn!("PD mode enabled but both prefill_selector and decode_selector are empty");
return false; return false;
} }
// In PD mode, pod must match either prefill or decode selector
Self::matches_selector(pod, &config.prefill_selector) Self::matches_selector(pod, &config.prefill_selector)
|| Self::matches_selector(pod, &config.decode_selector) || Self::matches_selector(pod, &config.decode_selector)
} else { } else {
// In regular mode, pod must match the general selector
if config.selector.is_empty() { if config.selector.is_empty() {
warn!("Regular mode enabled but selector is empty"); warn!("Regular mode enabled but selector is empty");
return false; return false;
...@@ -104,7 +96,6 @@ impl PodInfo { ...@@ -104,7 +96,6 @@ impl PodInfo {
} }
} }
/// Unified PodInfo creation with optional PD configuration
pub fn from_pod(pod: &Pod, config: Option<&ServiceDiscoveryConfig>) -> Option<Self> { pub fn from_pod(pod: &Pod, config: Option<&ServiceDiscoveryConfig>) -> Option<Self> {
let name = pod.metadata.name.clone()?; let name = pod.metadata.name.clone()?;
let status = pod.status.clone()?; let status = pod.status.clone()?;
...@@ -120,10 +111,8 @@ impl PodInfo { ...@@ -120,10 +111,8 @@ impl PodInfo {
let pod_status = status.phase.unwrap_or_else(|| "Unknown".to_string()); let pod_status = status.phase.unwrap_or_else(|| "Unknown".to_string());
// Determine pod type based on labels if config is provided and in PD mode
let pod_type = if let Some(config) = config { let pod_type = if let Some(config) = config {
if config.pd_mode { if config.pd_mode {
// Use simplified helper methods for cleaner logic
if Self::matches_selector(pod, &config.prefill_selector) { if Self::matches_selector(pod, &config.prefill_selector) {
Some(PodType::Prefill) Some(PodType::Prefill)
} else if Self::matches_selector(pod, &config.decode_selector) { } else if Self::matches_selector(pod, &config.decode_selector) {
...@@ -135,11 +124,9 @@ impl PodInfo { ...@@ -135,11 +124,9 @@ impl PodInfo {
Some(PodType::Regular) Some(PodType::Regular)
} }
} else { } else {
// No config provided, default to None (for backwards compatibility)
None None
}; };
// Extract bootstrap port from annotations for prefill pods
let bootstrap_port = if matches!(pod_type, Some(PodType::Prefill)) { let bootstrap_port = if matches!(pod_type, Some(PodType::Prefill)) {
if let Some(config) = config { if let Some(config) = config {
pod.metadata pod.metadata
...@@ -164,12 +151,10 @@ impl PodInfo { ...@@ -164,12 +151,10 @@ impl PodInfo {
}) })
} }
/// Returns true if the pod is in a state where it can accept traffic
pub fn is_healthy(&self) -> bool { pub fn is_healthy(&self) -> bool {
self.is_ready && self.status == "Running" self.is_ready && self.status == "Running"
} }
/// Generates a worker URL for this pod
pub fn worker_url(&self, port: u16) -> String { pub fn worker_url(&self, port: u16) -> String {
format!("http://{}:{}", self.ip, port) format!("http://{}:{}", self.ip, port)
} }
...@@ -179,9 +164,7 @@ pub async fn start_service_discovery( ...@@ -179,9 +164,7 @@ pub async fn start_service_discovery(
config: ServiceDiscoveryConfig, config: ServiceDiscoveryConfig,
app_context: Arc<AppContext>, app_context: Arc<AppContext>,
) -> Result<task::JoinHandle<()>, kube::Error> { ) -> Result<task::JoinHandle<()>, kube::Error> {
// Don't initialize anything if service discovery is disabled
if !config.enabled { if !config.enabled {
// Return a generic error when service discovery is disabled
return Err(kube::Error::Api(kube::error::ErrorResponse { return Err(kube::Error::Api(kube::error::ErrorResponse {
status: "Disabled".to_string(), status: "Disabled".to_string(),
message: "Service discovery is disabled".to_string(), message: "Service discovery is disabled".to_string(),
...@@ -192,7 +175,6 @@ pub async fn start_service_discovery( ...@@ -192,7 +175,6 @@ pub async fn start_service_discovery(
let _ = rustls::crypto::ring::default_provider().install_default(); let _ = rustls::crypto::ring::default_provider().install_default();
// Initialize Kubernetes client
let client = Client::try_default().await?; let client = Client::try_default().await?;
// Log the appropriate selectors based on mode // Log the appropriate selectors based on mode
...@@ -229,12 +211,9 @@ pub async fn start_service_discovery( ...@@ -229,12 +211,9 @@ pub async fn start_service_discovery(
); );
} }
// Create the task that will run in the background
let handle = task::spawn(async move { let handle = task::spawn(async move {
// We'll track pods we've already added to avoid duplicates
let tracked_pods = Arc::new(Mutex::new(HashSet::new())); let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
// Create a watcher for pods
let pods: Api<Pod> = if let Some(namespace) = &config.namespace { let pods: Api<Pod> = if let Some(namespace) = &config.namespace {
Api::namespaced(client, namespace) Api::namespaced(client, namespace)
} else { } else {
...@@ -243,23 +222,19 @@ pub async fn start_service_discovery( ...@@ -243,23 +222,19 @@ pub async fn start_service_discovery(
debug!("K8s service discovery initialized"); debug!("K8s service discovery initialized");
// Create Arcs for configuration data
let config_arc = Arc::new(config.clone()); let config_arc = Arc::new(config.clone());
let port = config.port; let port = config.port;
let mut retry_delay = Duration::from_secs(1); let mut retry_delay = Duration::from_secs(1);
const MAX_RETRY_DELAY: Duration = Duration::from_secs(300); // 5 minutes max const MAX_RETRY_DELAY: Duration = Duration::from_secs(300);
loop { loop {
// Create a watcher with the proper parameters according to the kube-rs API
let watcher_config = Config::default(); let watcher_config = Config::default();
let watcher_stream = watcher(pods.clone(), watcher_config).applied_objects(); let watcher_stream = watcher(pods.clone(), watcher_config).applied_objects();
// Clone Arcs for the closures
let config_clone = Arc::clone(&config_arc); let config_clone = Arc::clone(&config_arc);
let tracked_pods_clone = Arc::clone(&tracked_pods); let tracked_pods_clone = Arc::clone(&tracked_pods);
// Simplified label selector filter using helper method
let filtered_stream = watcher_stream.filter_map(move |obj_res| { let filtered_stream = watcher_stream.filter_map(move |obj_res| {
let config_inner = Arc::clone(&config_clone); let config_inner = Arc::clone(&config_clone);
...@@ -277,7 +252,6 @@ pub async fn start_service_discovery( ...@@ -277,7 +252,6 @@ pub async fn start_service_discovery(
} }
}); });
// Clone again for the next closure
let tracked_pods_clone2 = Arc::clone(&tracked_pods_clone); let tracked_pods_clone2 = Arc::clone(&tracked_pods_clone);
let app_context_clone = Arc::clone(&app_context); let app_context_clone = Arc::clone(&app_context);
let config_clone2 = Arc::clone(&config_arc); let config_clone2 = Arc::clone(&config_arc);
...@@ -317,7 +291,6 @@ pub async fn start_service_discovery( ...@@ -317,7 +291,6 @@ pub async fn start_service_discovery(
.await .await
{ {
Ok(_) => { Ok(_) => {
// Reset retry delay on success
retry_delay = Duration::from_secs(1); retry_delay = Duration::from_secs(1);
} }
Err(err) => { Err(err) => {
...@@ -328,12 +301,10 @@ pub async fn start_service_discovery( ...@@ -328,12 +301,10 @@ pub async fn start_service_discovery(
); );
time::sleep(retry_delay).await; time::sleep(retry_delay).await;
// Exponential backoff with jitter
retry_delay = std::cmp::min(retry_delay * 2, MAX_RETRY_DELAY); retry_delay = std::cmp::min(retry_delay * 2, MAX_RETRY_DELAY);
} }
} }
// If the watcher exits for some reason, wait a bit before restarting
warn!( warn!(
"Kubernetes watcher exited, restarting in {} seconds", "Kubernetes watcher exited, restarting in {} seconds",
config_arc.check_interval.as_secs() config_arc.check_interval.as_secs()
...@@ -354,9 +325,7 @@ async fn handle_pod_event( ...@@ -354,9 +325,7 @@ async fn handle_pod_event(
) { ) {
let worker_url = pod_info.worker_url(port); let worker_url = pod_info.worker_url(port);
// If pod is healthy, try to add it (with atomic check-and-insert)
if pod_info.is_healthy() { if pod_info.is_healthy() {
// Atomic check-and-insert to prevent race conditions
let should_add = { let should_add = {
let mut tracker = match tracked_pods.lock() { let mut tracker = match tracked_pods.lock() {
Ok(tracker) => tracker, Ok(tracker) => tracker,
...@@ -367,9 +336,8 @@ async fn handle_pod_event( ...@@ -367,9 +336,8 @@ async fn handle_pod_event(
}; };
if tracker.contains(pod_info) { if tracker.contains(pod_info) {
false // Already tracked false
} else { } else {
// Reserve the spot to prevent other threads from adding the same pod
tracker.insert(pod_info.clone()); tracker.insert(pod_info.clone());
true true
} }
...@@ -381,7 +349,6 @@ async fn handle_pod_event( ...@@ -381,7 +349,6 @@ async fn handle_pod_event(
pod_info.name, pod_info.pod_type, worker_url pod_info.name, pod_info.pod_type, worker_url
); );
// Build worker config based on pod type and routing mode
let worker_type = if pd_mode { let worker_type = if pd_mode {
match &pod_info.pod_type { match &pod_info.pod_type {
Some(PodType::Prefill) => Some("prefill".to_string()), Some(PodType::Prefill) => Some("prefill".to_string()),
...@@ -392,7 +359,6 @@ async fn handle_pod_event( ...@@ -392,7 +359,6 @@ async fn handle_pod_event(
None None
}; };
// Only set bootstrap_port for prefill workers in PD mode
let bootstrap_port = if pd_mode { let bootstrap_port = if pd_mode {
match &pod_info.pod_type { match &pod_info.pod_type {
Some(PodType::Prefill) => pod_info.bootstrap_port, Some(PodType::Prefill) => pod_info.bootstrap_port,
...@@ -425,7 +391,6 @@ async fn handle_pod_event( ...@@ -425,7 +391,6 @@ async fn handle_pod_event(
} }
Err(e) => { Err(e) => {
error!("Failed to add worker {} to router: {}", worker_url, e); error!("Failed to add worker {} to router: {}", worker_url, e);
// Remove from tracking since addition failed
if let Ok(mut tracker) = tracked_pods.lock() { if let Ok(mut tracker) = tracked_pods.lock() {
tracker.remove(pod_info); tracker.remove(pod_info);
} }
...@@ -464,8 +429,6 @@ async fn handle_pod_deletion( ...@@ -464,8 +429,6 @@ async fn handle_pod_deletion(
error!("Failed to remove worker {}: {}", worker_url, e); error!("Failed to remove worker {}: {}", worker_url, e);
} }
} else { } else {
// This case might occur if a pod is deleted before it was ever marked healthy and added.
// Or if the event is duplicated. No action needed on the router if it wasn't tracked (and thus not added).
debug!( debug!(
"Pod deletion event for untracked/already removed pod: {} (type: {:?}). Worker URL: {}", "Pod deletion event for untracked/already removed pod: {} (type: {:?}). Worker URL: {}",
pod_info.name, pod_info.pod_type, worker_url pod_info.name, pod_info.pod_type, worker_url
...@@ -480,7 +443,6 @@ mod tests { ...@@ -480,7 +443,6 @@ mod tests {
use k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta; use k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta;
use k8s_openapi::apimachinery::pkg::apis::meta::v1::Time; use k8s_openapi::apimachinery::pkg::apis::meta::v1::Time;
// Helper function to create a Pod for testing PodInfo::from_pod
fn create_k8s_pod( fn create_k8s_pod(
name: Option<&str>, name: Option<&str>,
ip: Option<&str>, ip: Option<&str>,
...@@ -523,7 +485,6 @@ mod tests { ...@@ -523,7 +485,6 @@ mod tests {
pod pod
} }
// Helper function to create a Pod with PD-specific labels and annotations
fn create_pd_k8s_pod(name: &str, ip: &str, pod_type: &str, bootstrap_port: Option<u16>) -> Pod { fn create_pd_k8s_pod(name: &str, ip: &str, pod_type: &str, bootstrap_port: Option<u16>) -> Pod {
let mut labels = std::collections::BTreeMap::new(); let mut labels = std::collections::BTreeMap::new();
labels.insert("app".to_string(), "sglang".to_string()); labels.insert("app".to_string(), "sglang".to_string());
...@@ -559,18 +520,15 @@ mod tests { ...@@ -559,18 +520,15 @@ mod tests {
} }
} }
// Helper to create an AppContext instance for testing event handlers
async fn create_test_app_context() -> Arc<AppContext> { async fn create_test_app_context() -> Arc<AppContext> {
use crate::config::RouterConfig; use crate::config::RouterConfig;
use crate::middleware::TokenBucket; use crate::middleware::TokenBucket;
// Create a minimal RouterConfig for testing with very short timeout
let router_config = RouterConfig { let router_config = RouterConfig {
worker_startup_timeout_secs: 1, worker_startup_timeout_secs: 1,
..Default::default() ..Default::default()
}; // Very short timeout for tests };
// Create AppContext with minimal components
Arc::new(AppContext { Arc::new(AppContext {
client: reqwest::Client::new(), client: reqwest::Client::new(),
router_config: router_config.clone(), router_config: router_config.clone(),
...@@ -579,16 +537,15 @@ mod tests { ...@@ -579,16 +537,15 @@ mod tests {
policy_registry: Arc::new(crate::policies::PolicyRegistry::new( policy_registry: Arc::new(crate::policies::PolicyRegistry::new(
router_config.policy.clone(), router_config.policy.clone(),
)), )),
tokenizer: None, // HTTP mode doesn't need tokenizer tokenizer: None,
reasoning_parser_factory: None, // HTTP mode doesn't need reasoning parser reasoning_parser_factory: None,
tool_parser_registry: None, // HTTP mode doesn't need tool parser tool_parser_registry: None,
router_manager: None, // Test doesn't need router manager router_manager: None,
response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()), response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()),
load_monitor: None, load_monitor: None,
}) })
} }
// Helper to create a PD config for testing
fn create_pd_config() -> ServiceDiscoveryConfig { fn create_pd_config() -> ServiceDiscoveryConfig {
let mut prefill_selector = HashMap::new(); let mut prefill_selector = HashMap::new();
prefill_selector.insert("app".to_string(), "sglang".to_string()); prefill_selector.insert("app".to_string(), "sglang".to_string());
...@@ -615,19 +572,15 @@ mod tests { ...@@ -615,19 +572,15 @@ mod tests {
fn test_pod_info_should_include() { fn test_pod_info_should_include() {
let config = create_pd_config(); let config = create_pd_config();
// Test prefill pod should be included
let prefill_pod = create_pd_k8s_pod("prefill-pod", "10.0.0.1", "prefill", Some(8081)); let prefill_pod = create_pd_k8s_pod("prefill-pod", "10.0.0.1", "prefill", Some(8081));
assert!(PodInfo::should_include(&prefill_pod, &config)); assert!(PodInfo::should_include(&prefill_pod, &config));
// Test decode pod should be included
let decode_pod = create_pd_k8s_pod("decode-pod", "10.0.0.2", "decode", None); let decode_pod = create_pd_k8s_pod("decode-pod", "10.0.0.2", "decode", None);
assert!(PodInfo::should_include(&decode_pod, &config)); assert!(PodInfo::should_include(&decode_pod, &config));
// Test unmatched pod should not be included
let unmatched_pod = create_pd_k8s_pod("other-pod", "10.0.0.3", "other", None); let unmatched_pod = create_pd_k8s_pod("other-pod", "10.0.0.3", "other", None);
assert!(!PodInfo::should_include(&unmatched_pod, &config)); assert!(!PodInfo::should_include(&unmatched_pod, &config));
// Test regular mode
let mut regular_config = ServiceDiscoveryConfig::default(); let mut regular_config = ServiceDiscoveryConfig::default();
regular_config regular_config
.selector .selector
...@@ -654,7 +607,6 @@ mod tests { ...@@ -654,7 +607,6 @@ mod tests {
#[test] #[test]
fn test_pod_type_enum() { fn test_pod_type_enum() {
// Test that PodType enum has expected variants
let prefill = PodType::Prefill; let prefill = PodType::Prefill;
let decode = PodType::Decode; let decode = PodType::Decode;
let regular = PodType::Regular; let regular = PodType::Regular;
...@@ -714,7 +666,7 @@ mod tests { ...@@ -714,7 +666,7 @@ mod tests {
fn test_pod_info_from_pod_with_pd_config_regular_mode() { fn test_pod_info_from_pod_with_pd_config_regular_mode() {
let k8s_pod = create_pd_k8s_pod("regular-pod", "10.0.0.3", "worker", None); let k8s_pod = create_pd_k8s_pod("regular-pod", "10.0.0.3", "worker", None);
let mut config = create_pd_config(); let mut config = create_pd_config();
config.pd_mode = false; // Set to regular mode config.pd_mode = false;
let pod_info = PodInfo::from_pod(&k8s_pod, Some(&config)).unwrap(); let pod_info = PodInfo::from_pod(&k8s_pod, Some(&config)).unwrap();
assert_eq!(pod_info.name, "regular-pod"); assert_eq!(pod_info.name, "regular-pod");
...@@ -742,7 +694,6 @@ mod tests { ...@@ -742,7 +694,6 @@ mod tests {
#[test] #[test]
fn test_pod_info_from_pod_with_pd_config_invalid_bootstrap_port() { fn test_pod_info_from_pod_with_pd_config_invalid_bootstrap_port() {
let mut pod = create_pd_k8s_pod("prefill-pod", "10.0.0.1", "prefill", None); let mut pod = create_pd_k8s_pod("prefill-pod", "10.0.0.1", "prefill", None);
// Add invalid bootstrap port annotation
pod.metadata.annotations.as_mut().unwrap().insert( pod.metadata.annotations.as_mut().unwrap().insert(
"sglang.ai/bootstrap-port".to_string(), "sglang.ai/bootstrap-port".to_string(),
"invalid".to_string(), "invalid".to_string(),
...@@ -751,7 +702,7 @@ mod tests { ...@@ -751,7 +702,7 @@ mod tests {
let pod_info = PodInfo::from_pod(&pod, Some(&config)).unwrap(); let pod_info = PodInfo::from_pod(&pod, Some(&config)).unwrap();
assert_eq!(pod_info.pod_type, Some(PodType::Prefill)); assert_eq!(pod_info.pod_type, Some(PodType::Prefill));
assert!(pod_info.bootstrap_port.is_none()); // Should be None for invalid port assert!(pod_info.bootstrap_port.is_none());
} }
#[test] #[test]
...@@ -1077,7 +1028,6 @@ mod tests { ...@@ -1077,7 +1028,6 @@ mod tests {
) )
.await; .await;
// Pod should not be tracked since add_worker_from_url will fail for non-running server
assert!(!tracked_pods.lock().unwrap().contains(&pod_info)); assert!(!tracked_pods.lock().unwrap().contains(&pod_info));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment