Unverified Commit a7fe6e10 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] remove old/oudated/useless comments (#10967)

parent be059b83
......@@ -67,58 +67,47 @@ struct Router {
decode_policy: Option<PolicyType>,
max_concurrent_requests: usize,
cors_allowed_origins: Vec<String>,
// Retry configuration
retry_max_retries: u32,
retry_initial_backoff_ms: u64,
retry_max_backoff_ms: u64,
retry_backoff_multiplier: f32,
retry_jitter_factor: f32,
disable_retries: bool,
// Circuit breaker configuration
cb_failure_threshold: u32,
cb_success_threshold: u32,
cb_timeout_duration_secs: u64,
cb_window_duration_secs: u64,
disable_circuit_breaker: bool,
// Health check configuration
health_failure_threshold: u32,
health_success_threshold: u32,
health_check_timeout_secs: u64,
health_check_interval_secs: u64,
health_check_endpoint: String,
// IGW (Inference Gateway) configuration
enable_igw: bool,
queue_size: usize,
queue_timeout_secs: u64,
rate_limit_tokens_per_second: Option<usize>,
// Connection mode (determined from worker URLs)
connection_mode: config::ConnectionMode,
// Model path for tokenizer
model_path: Option<String>,
// Explicit tokenizer path
tokenizer_path: Option<String>,
}
impl Router {
/// Determine connection mode from worker URLs
fn determine_connection_mode(worker_urls: &[String]) -> config::ConnectionMode {
// Only consider it gRPC if explicitly specified with grpc:// or grpcs:// scheme
for url in worker_urls {
if url.starts_with("grpc://") || url.starts_with("grpcs://") {
return config::ConnectionMode::Grpc;
}
}
// Default to HTTP for all other cases (including http://, https://, or no scheme)
config::ConnectionMode::Http
}
/// Convert PyO3 Router to RouterConfig
pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
use config::{
DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
};
// Convert policy helper function
let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
match policy {
PolicyType::Random => ConfigPolicyConfig::Random,
......@@ -131,14 +120,12 @@ impl Router {
max_tree_size: self.max_tree_size,
},
PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
load_check_interval_secs: 5, // Default value
load_check_interval_secs: 5,
},
}
};
// Determine routing mode
let mode = if self.enable_igw {
// IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
RoutingMode::Regular {
worker_urls: vec![],
}
......@@ -155,10 +142,8 @@ impl Router {
}
};
// Convert main policy
let policy = convert_policy(&self.policy);
// Service discovery configuration
let discovery = if self.service_discovery {
Some(DiscoveryConfig {
enabled: true,
......@@ -174,7 +159,6 @@ impl Router {
None
};
// Metrics configuration
let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
(Some(port), Some(host)) => Some(MetricsConfig {
port,
......@@ -251,7 +235,7 @@ impl Router {
balance_rel_threshold = 1.5,
eviction_interval_secs = 120,
max_tree_size = 2usize.pow(26),
max_payload_size = 512 * 1024 * 1024, // 512MB default for large batches
max_payload_size = 512 * 1024 * 1024,
dp_aware = false,
api_key = None,
log_dir = None,
......@@ -265,40 +249,35 @@ impl Router {
bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
prometheus_port = None,
prometheus_host = None,
request_timeout_secs = 1800, // Add configurable request timeout
request_id_headers = None, // Custom request ID headers
pd_disaggregation = false, // New flag for PD mode
request_timeout_secs = 1800,
request_id_headers = None,
pd_disaggregation = false,
prefill_urls = None,
decode_urls = None,
prefill_policy = None,
decode_policy = None,
max_concurrent_requests = 256,
cors_allowed_origins = vec![],
// Retry defaults
retry_max_retries = 5,
retry_initial_backoff_ms = 50,
retry_max_backoff_ms = 30_000,
retry_backoff_multiplier = 1.5,
retry_jitter_factor = 0.2,
disable_retries = false,
// Circuit breaker defaults
cb_failure_threshold = 10,
cb_success_threshold = 3,
cb_timeout_duration_secs = 60,
cb_window_duration_secs = 120,
disable_circuit_breaker = false,
// Health check defaults
health_failure_threshold = 3,
health_success_threshold = 2,
health_check_timeout_secs = 5,
health_check_interval_secs = 60,
health_check_endpoint = String::from("/health"),
// IGW defaults
enable_igw = false,
queue_size = 100,
queue_timeout_secs = 60,
rate_limit_tokens_per_second = None,
// Tokenizer defaults
model_path = None,
tokenizer_path = None,
))]
......@@ -361,17 +340,14 @@ impl Router {
model_path: Option<String>,
tokenizer_path: Option<String>,
) -> PyResult<Self> {
// Determine connection mode from worker URLs
let mut all_urls = worker_urls.clone();
// Add prefill URLs if in PD mode
if let Some(ref prefill_urls) = prefill_urls {
for (url, _) in prefill_urls {
all_urls.push(url.clone());
}
}
// Add decode URLs if in PD mode
if let Some(ref decode_urls) = decode_urls {
all_urls.extend(decode_urls.clone());
}
......@@ -440,12 +416,10 @@ impl Router {
}
fn start(&self) -> PyResult<()> {
// Convert to RouterConfig and validate
let router_config = self.to_router_config().map_err(|e| {
pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
})?;
// Validate the configuration
router_config.validate().map_err(|e| {
pyo3::exceptions::PyValueError::new_err(format!(
"Configuration validation failed: {}",
......@@ -453,7 +427,6 @@ impl Router {
))
})?;
// Create service discovery config if enabled
let service_discovery_config = if self.service_discovery {
Some(service_discovery::ServiceDiscoveryConfig {
enabled: true,
......@@ -470,7 +443,6 @@ impl Router {
None
};
// Create Prometheus config if enabled
let prometheus_config = Some(PrometheusConfig {
port: self.prometheus_port.unwrap_or(29000),
host: self
......@@ -479,11 +451,9 @@ impl Router {
.unwrap_or_else(|| "127.0.0.1".to_string()),
});
// Use tokio runtime instead of actix-web System for better compatibility
let runtime = tokio::runtime::Runtime::new()
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
// Block on the async startup function
runtime.block_on(async move {
server::startup(server::ServerConfig {
host: self.host.clone(),
......
......@@ -8,20 +8,13 @@ use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer};
/// Configuration for the logging system
#[derive(Debug, Clone)]
pub struct LoggingConfig {
/// Log level for the application (default: INFO)
pub level: Level,
/// Whether to use json format for logs (default: false)
pub json_format: bool,
/// Path to store log files. If None, logs will only go to stdout/stderr
pub log_dir: Option<String>,
/// Whether to colorize logs when output is a terminal (default: true)
pub colorize: bool,
/// Log file name to use if log_dir is specified (default: "sgl-router")
pub log_file_name: String,
/// Custom log targets to filter (default: "sglang_router_rs")
pub log_targets: Option<Vec<String>>,
}
......@@ -38,30 +31,14 @@ impl Default for LoggingConfig {
}
}
/// Guard that keeps the file appender worker thread alive
///
/// This must be kept in scope for the duration of the program
/// to ensure logs are properly written to files
#[allow(dead_code)]
pub struct LogGuard {
_file_guard: Option<WorkerGuard>,
}
/// Initialize the logging system with the given configuration
///
/// # Arguments
/// * `config` - Configuration for the logging system
///
/// # Returns
/// A LogGuard that must be kept alive for the duration of the program
///
/// # Panics
/// Will not panic, as initialization errors are handled gracefully
pub fn init_logging(config: LoggingConfig) -> LogGuard {
// Forward logs to tracing - ignore errors to allow for multiple initialization
let _ = LogTracer::init();
// Convert log level to filter string
let level_filter = match config.level {
Level::TRACE => "trace",
Level::DEBUG => "debug",
......@@ -70,9 +47,7 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
Level::ERROR => "error",
};
// Create env filter
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| {
// Format: <target>=<level>,<target2>=<level2>,...
let filter_string = if let Some(targets) = &config.log_targets {
targets
.iter()
......@@ -92,13 +67,10 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
EnvFilter::new(filter_string)
});
// Setup stdout/stderr layer
let mut layers = Vec::new();
// Standard timestamp format: YYYY-MM-DD HH:MM:SS
let time_format = "%Y-%m-%d %H:%M:%S".to_string();
// Configure the console stdout layer
let stdout_layer = tracing_subscriber::fmt::layer()
.with_ansi(config.colorize)
.with_file(true)
......@@ -113,14 +85,12 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
layers.push(stdout_layer);
// Create a file appender if log_dir is specified
let mut file_guard = None;
if let Some(log_dir) = &config.log_dir {
let file_name = config.log_file_name.clone();
let log_dir = PathBuf::from(log_dir);
// Create log directory if it doesn't exist
if !log_dir.exists() {
if let Err(e) = std::fs::create_dir_all(&log_dir) {
eprintln!("Failed to create log directory: {}", e);
......@@ -134,7 +104,7 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
file_guard = Some(guard);
let file_layer = tracing_subscriber::fmt::layer()
.with_ansi(false) // Never use ANSI colors in log files
.with_ansi(false)
.with_file(true)
.with_line_number(true)
.with_timer(ChronoUtc::new(time_format))
......@@ -149,14 +119,11 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
layers.push(file_layer);
}
// Initialize the subscriber with all layers
// Use try_init to handle errors gracefully in case another subscriber is already set
let _ = tracing_subscriber::registry()
.with(env_filter)
.with(layers)
.try_init();
// Return the guard to keep the file appender worker thread alive
LogGuard {
_file_guard: file_guard,
}
......
......@@ -9,7 +9,6 @@ use sglang_router_rs::server::{self, ServerConfig};
use sglang_router_rs::service_discovery::ServiceDiscoveryConfig;
use std::collections::HashMap;
// Helper function to parse prefill arguments from command line
fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
let args: Vec<String> = std::env::args().collect();
let mut prefill_entries = Vec::new();
......@@ -19,12 +18,11 @@ fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
if args[i] == "--prefill" && i + 1 < args.len() {
let url = args[i + 1].clone();
let bootstrap_port = if i + 2 < args.len() && !args[i + 2].starts_with("--") {
// Check if next arg is a port number
if let Ok(port) = args[i + 2].parse::<u16>() {
i += 1; // Skip the port argument
i += 1;
Some(port)
} else if args[i + 2].to_lowercase() == "none" {
i += 1; // Skip the "none" argument
i += 1;
None
} else {
None
......@@ -33,7 +31,7 @@ fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
None
};
prefill_entries.push((url, bootstrap_port));
i += 2; // Skip --prefill and URL
i += 2;
} else {
i += 1;
}
......@@ -101,252 +99,186 @@ Examples:
"#)]
struct CliArgs {
/// Host address to bind the router server
#[arg(long, default_value = "127.0.0.1")]
host: String,
/// Port number to bind the router server
#[arg(long, default_value_t = 30000)]
port: u16,
/// List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)
#[arg(long, num_args = 0..)]
worker_urls: Vec<String>,
/// Load balancing policy to use
#[arg(long, default_value = "cache_aware", value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
policy: String,
/// Enable PD (Prefill-Decode) disaggregated mode
#[arg(long, default_value_t = false)]
pd_disaggregation: bool,
/// Decode server URL (can be specified multiple times)
#[arg(long, action = ArgAction::Append)]
decode: Vec<String>,
/// Specific policy for prefill nodes in PD mode
#[arg(long, value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
prefill_policy: Option<String>,
/// Specific policy for decode nodes in PD mode
#[arg(long, value_parser = ["random", "round_robin", "cache_aware", "power_of_two"])]
decode_policy: Option<String>,
/// Timeout in seconds for worker startup
#[arg(long, default_value_t = 600)]
worker_startup_timeout_secs: u64,
/// Interval in seconds between checks for worker startup
#[arg(long, default_value_t = 30)]
worker_startup_check_interval: u64,
/// Cache threshold (0.0-1.0) for cache-aware routing
#[arg(long, default_value_t = 0.3)]
cache_threshold: f32,
/// Absolute threshold for load balancing
#[arg(long, default_value_t = 64)]
balance_abs_threshold: usize,
/// Relative threshold for load balancing
#[arg(long, default_value_t = 1.5)]
balance_rel_threshold: f32,
/// Interval in seconds between cache eviction operations
#[arg(long, default_value_t = 120)]
eviction_interval: u64,
/// Maximum size of the approximation tree for cache-aware routing
#[arg(long, default_value_t = 67108864)] // 2^26
#[arg(long, default_value_t = 67108864)]
max_tree_size: usize,
/// Maximum payload size in bytes
#[arg(long, default_value_t = 536870912)] // 512MB
#[arg(long, default_value_t = 536870912)]
max_payload_size: usize,
/// Enable data parallelism aware schedule
#[arg(long, default_value_t = false)]
dp_aware: bool,
/// API key for worker authorization
#[arg(long)]
api_key: Option<String>,
/// Backend to route requests to (sglang, vllm, trtllm, openai, anthropic)
#[arg(long, value_enum, default_value_t = Backend::Sglang, alias = "runtime")]
backend: Backend,
/// Directory to store log files
#[arg(long)]
log_dir: Option<String>,
/// Set the logging level
#[arg(long, default_value = "info", value_parser = ["debug", "info", "warn", "error"])]
log_level: String,
/// Enable Kubernetes service discovery
#[arg(long, default_value_t = false)]
service_discovery: bool,
/// Label selector for Kubernetes service discovery (format: key1=value1 key2=value2)
#[arg(long, num_args = 0..)]
selector: Vec<String>,
/// Port to use for discovered worker pods
#[arg(long, default_value_t = 80)]
service_discovery_port: u16,
/// Kubernetes namespace to watch for pods
#[arg(long)]
service_discovery_namespace: Option<String>,
/// Label selector for prefill server pods in PD mode
#[arg(long, num_args = 0..)]
prefill_selector: Vec<String>,
/// Label selector for decode server pods in PD mode
#[arg(long, num_args = 0..)]
decode_selector: Vec<String>,
/// Port to expose Prometheus metrics
#[arg(long, default_value_t = 29000)]
prometheus_port: u16,
/// Host address to bind the Prometheus metrics server
#[arg(long, default_value = "127.0.0.1")]
prometheus_host: String,
/// Custom HTTP headers to check for request IDs
#[arg(long, num_args = 0..)]
request_id_headers: Vec<String>,
/// Request timeout in seconds
#[arg(long, default_value_t = 1800)]
request_timeout_secs: u64,
/// Maximum number of concurrent requests allowed
#[arg(long, default_value_t = 256)]
max_concurrent_requests: usize,
/// CORS allowed origins
#[arg(long, num_args = 0..)]
cors_allowed_origins: Vec<String>,
// Retry configuration
/// Maximum number of retries
#[arg(long, default_value_t = 5)]
retry_max_retries: u32,
/// Initial backoff in milliseconds for retries
#[arg(long, default_value_t = 50)]
retry_initial_backoff_ms: u64,
/// Maximum backoff in milliseconds for retries
#[arg(long, default_value_t = 30000)]
retry_max_backoff_ms: u64,
/// Backoff multiplier for exponential backoff
#[arg(long, default_value_t = 1.5)]
retry_backoff_multiplier: f32,
/// Jitter factor for retry backoff
#[arg(long, default_value_t = 0.2)]
retry_jitter_factor: f32,
/// Disable retries
#[arg(long, default_value_t = false)]
disable_retries: bool,
// Circuit breaker configuration
/// Number of failures before circuit breaker opens
#[arg(long, default_value_t = 10)]
cb_failure_threshold: u32,
/// Number of successes before circuit breaker closes
#[arg(long, default_value_t = 3)]
cb_success_threshold: u32,
/// Timeout duration in seconds for circuit breaker
#[arg(long, default_value_t = 60)]
cb_timeout_duration_secs: u64,
/// Window duration in seconds for circuit breaker
#[arg(long, default_value_t = 120)]
cb_window_duration_secs: u64,
/// Disable circuit breaker
#[arg(long, default_value_t = false)]
disable_circuit_breaker: bool,
// Health check configuration
/// Number of consecutive health check failures before marking worker unhealthy
#[arg(long, default_value_t = 3)]
health_failure_threshold: u32,
/// Number of consecutive health check successes before marking worker healthy
#[arg(long, default_value_t = 2)]
health_success_threshold: u32,
/// Timeout in seconds for health check requests
#[arg(long, default_value_t = 5)]
health_check_timeout_secs: u64,
/// Interval in seconds between runtime health checks
#[arg(long, default_value_t = 60)]
health_check_interval_secs: u64,
/// Health check endpoint path
#[arg(long, default_value = "/health")]
health_check_endpoint: String,
// IGW (Inference Gateway) configuration
/// Enable Inference Gateway mode
#[arg(long, default_value_t = false)]
enable_igw: bool,
// Tokenizer configuration
/// Model path for loading tokenizer (HuggingFace model ID or local path)
#[arg(long)]
model_path: Option<String>,
/// Explicit tokenizer path (overrides model_path tokenizer if provided)
#[arg(long)]
tokenizer_path: Option<String>,
/// History backend configuration (memory, none, or oracle)
#[arg(long, default_value = "memory", value_parser = ["memory", "none", "oracle"])]
history_backend: String,
/// Directory containing the Oracle ATP wallet/config files (optional)
#[arg(long, env = "ATP_WALLET_PATH")]
oracle_wallet_path: Option<String>,
/// Wallet TNS alias to use (e.g. `<db_name>_low`)
#[arg(long, env = "ATP_TNS_ALIAS")]
oracle_tns_alias: Option<String>,
/// Oracle connection descriptor / DSN (e.g. `tcps://host:port/service_name`)
#[arg(long, env = "ATP_DSN")]
oracle_dsn: Option<String>,
/// Oracle ATP username
#[arg(long, env = "ATP_USER")]
oracle_user: Option<String>,
/// Oracle ATP password
#[arg(long, env = "ATP_PASSWORD")]
oracle_password: Option<String>,
/// Minimum number of pooled ATP connections (defaults to 1 when omitted)
#[arg(long, env = "ATP_POOL_MIN")]
oracle_pool_min: Option<usize>,
/// Maximum number of pooled ATP connections (defaults to 16 when omitted)
#[arg(long, env = "ATP_POOL_MAX")]
oracle_pool_max: Option<usize>,
/// Connection acquisition timeout in seconds (defaults to 30 when omitted)
#[arg(long, env = "ATP_POOL_TIMEOUT_SECS")]
oracle_pool_timeout_secs: Option<u64>,
}
......@@ -357,19 +289,15 @@ enum OracleConnectSource {
}
impl CliArgs {
/// Determine connection mode from worker URLs
fn determine_connection_mode(worker_urls: &[String]) -> ConnectionMode {
// Only consider it gRPC if explicitly specified with grpc:// or grpcs:// scheme
for url in worker_urls {
if url.starts_with("grpc://") || url.starts_with("grpcs://") {
return ConnectionMode::Grpc;
}
}
// Default to HTTP for all other cases (including http://, https://, or no scheme)
ConnectionMode::Http
}
/// Parse selector strings into HashMap
fn parse_selector(selector_list: &[String]) -> HashMap<String, String> {
let mut map = HashMap::new();
for item in selector_list {
......@@ -382,7 +310,6 @@ impl CliArgs {
map
}
/// Convert policy string to PolicyConfig
fn parse_policy(&self, policy_str: &str) -> PolicyConfig {
match policy_str {
"random" => PolicyConfig::Random,
......@@ -395,9 +322,9 @@ impl CliArgs {
max_tree_size: self.max_tree_size,
},
"power_of_two" => PolicyConfig::PowerOfTwo {
load_check_interval_secs: 5, // Default value
load_check_interval_secs: 5,
},
_ => PolicyConfig::RoundRobin, // Fallback
_ => PolicyConfig::RoundRobin,
}
}
......@@ -482,26 +409,21 @@ impl CliArgs {
})
}
/// Convert CLI arguments to RouterConfig
fn to_router_config(
&self,
prefill_urls: Vec<(String, Option<u16>)>,
) -> ConfigResult<RouterConfig> {
// Determine routing mode
let mode = if self.enable_igw {
// IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
RoutingMode::Regular {
worker_urls: vec![],
}
} else if matches!(self.backend, Backend::Openai) {
// OpenAI backend mode - use worker_urls as base(s)
RoutingMode::OpenAI {
worker_urls: self.worker_urls.clone(),
}
} else if self.pd_disaggregation {
let decode_urls = self.decode.clone();
// Validate PD configuration if not using service discovery
if !self.service_discovery && (prefill_urls.is_empty() || decode_urls.is_empty()) {
return Err(ConfigError::ValidationFailed {
reason: "PD disaggregation mode requires --prefill and --decode URLs when not using service discovery".to_string(),
......@@ -515,7 +437,6 @@ impl CliArgs {
decode_policy: self.decode_policy.as_ref().map(|p| self.parse_policy(p)),
}
} else {
// Regular mode
if !self.service_discovery && self.worker_urls.is_empty() {
return Err(ConfigError::ValidationFailed {
reason: "Regular mode requires --worker-urls when not using service discovery"
......@@ -527,10 +448,8 @@ impl CliArgs {
}
};
// Main policy
let policy = self.parse_policy(&self.policy);
// Service discovery configuration
let discovery = if self.service_discovery {
Some(DiscoveryConfig {
enabled: true,
......@@ -546,13 +465,11 @@ impl CliArgs {
None
};
// Metrics configuration
let metrics = Some(MetricsConfig {
port: self.prometheus_port,
host: self.prometheus_host.clone(),
});
// Determine connection mode from all worker URLs
let mut all_urls = Vec::new();
match &mode {
RoutingMode::Regular { worker_urls } => {
......@@ -568,9 +485,7 @@ impl CliArgs {
}
all_urls.extend(decode_urls.clone());
}
RoutingMode::OpenAI { .. } => {
// For connection-mode detection, skip URLs; OpenAI forces HTTP below.
}
RoutingMode::OpenAI { .. } => {}
}
let connection_mode = match &mode {
RoutingMode::OpenAI { .. } => ConnectionMode::Http,
......@@ -589,7 +504,6 @@ impl CliArgs {
None
};
// Build RouterConfig
Ok(RouterConfig {
mode,
policy,
......@@ -612,8 +526,8 @@ impl CliArgs {
Some(self.request_id_headers.clone())
},
max_concurrent_requests: self.max_concurrent_requests,
queue_size: 100, // Default queue size
queue_timeout_secs: 60, // Default timeout
queue_size: 100,
queue_timeout_secs: 60,
cors_allowed_origins: self.cors_allowed_origins.clone(),
retry: RetryConfig {
max_retries: self.retry_max_retries,
......@@ -646,9 +560,7 @@ impl CliArgs {
})
}
/// Create ServerConfig from CLI args and RouterConfig
fn to_server_config(&self, router_config: RouterConfig) -> ServerConfig {
// Create service discovery config if enabled
let service_discovery_config = if self.service_discovery {
Some(ServiceDiscoveryConfig {
enabled: true,
......@@ -665,7 +577,6 @@ impl CliArgs {
None
};
// Create Prometheus config
let prometheus_config = Some(PrometheusConfig {
port: self.prometheus_port,
host: self.prometheus_host.clone(),
......@@ -691,19 +602,15 @@ impl CliArgs {
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
// Parse prefill arguments manually before clap parsing
let prefill_urls = parse_prefill_args();
// Filter out prefill arguments and their values before passing to clap
let mut filtered_args: Vec<String> = Vec::new();
let raw_args: Vec<String> = std::env::args().collect();
let mut i = 0;
while i < raw_args.len() {
if raw_args[i] == "--prefill" && i + 1 < raw_args.len() {
// Skip --prefill and its URL
i += 2;
// Also skip bootstrap port if present
if i < raw_args.len()
&& !raw_args[i].starts_with("--")
&& (raw_args[i].parse::<u16>().is_ok() || raw_args[i].to_lowercase() == "none")
......@@ -716,10 +623,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
}
}
// Parse CLI arguments with clap using filtered args
let cli_args = CliArgs::parse_from(filtered_args);
// Print startup info
println!("SGLang Router starting...");
println!("Host: {}:{}", cli_args.host, cli_args.port);
let mode_str = if cli_args.enable_igw {
......@@ -733,7 +638,6 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
};
println!("Mode: {}", mode_str);
// Warn for runtimes that are parsed but not yet implemented
match cli_args.backend {
Backend::Vllm | Backend::Trtllm | Backend::Anthropic => {
println!(
......@@ -754,19 +658,10 @@ Provide --worker-urls or PD flags as usual.",
}
}
// Convert to RouterConfig
let router_config = cli_args.to_router_config(prefill_urls)?;
// Validate configuration
router_config.validate()?;
// Create ServerConfig
let server_config = cli_args.to_server_config(router_config);
// Create a new runtime for the server (like Python binding does)
let runtime = tokio::runtime::Runtime::new()?;
// Block on the async startup function
runtime.block_on(async move { server::startup(server_config).await })?;
Ok(())
......
......@@ -19,7 +19,6 @@ impl Default for PrometheusConfig {
}
pub fn init_metrics() {
// Request metrics
describe_counter!(
"sgl_router_requests_total",
"Total number of requests by route and method"
......@@ -45,7 +44,6 @@ pub fn init_metrics() {
"Total number of requests that exhausted retries by route"
);
// Circuit breaker metrics
describe_gauge!(
"sgl_router_cb_state",
"Circuit breaker state per worker (0=closed, 1=open, 2=half_open)"
......@@ -59,7 +57,6 @@ pub fn init_metrics() {
"Total number of circuit breaker outcomes by worker and outcome type (success/failure)"
);
// Worker metrics
describe_gauge!(
"sgl_router_active_workers",
"Number of currently active workers"
......@@ -74,7 +71,6 @@ pub fn init_metrics() {
"Total requests processed by each worker"
);
// Policy metrics
describe_counter!(
"sgl_router_policy_decisions_total",
"Total routing policy decisions by policy and worker"
......@@ -92,7 +88,6 @@ pub fn init_metrics() {
describe_gauge!("sgl_router_max_load", "Maximum worker load");
describe_gauge!("sgl_router_min_load", "Minimum worker load");
// PD-specific metrics
describe_counter!("sgl_router_pd_requests_total", "Total PD requests by route");
describe_counter!(
"sgl_router_pd_prefill_requests_total",
......@@ -123,7 +118,6 @@ pub fn init_metrics() {
"PD request duration by route"
);
// Service discovery metrics
describe_counter!(
"sgl_router_discovery_updates_total",
"Total service discovery update events"
......@@ -137,13 +131,11 @@ pub fn init_metrics() {
"Number of workers removed in last discovery update"
);
// Generate request specific metrics
describe_histogram!(
"sgl_router_generate_duration_seconds",
"Generate request duration"
);
// Embedding request specific metrics
describe_counter!("sgl_router_embeddings_total", "Total embedding requests");
describe_histogram!(
"sgl_router_embeddings_duration_seconds",
......@@ -155,13 +147,11 @@ pub fn init_metrics() {
);
describe_gauge!("sgl_router_embeddings_queue_size", "Embedding queue size");
// Running requests gauge for cache-aware policy
describe_gauge!(
"sgl_router_running_requests",
"Number of running requests per worker"
);
// Tokenizer metrics
describe_histogram!(
"sgl_tokenizer_encode_duration_seconds",
"Time to encode text to tokens"
......@@ -207,7 +197,6 @@ pub fn init_metrics() {
"Vocabulary size of the loaded tokenizer"
);
// Stop sequence detection metrics
describe_counter!(
"sgl_tokenizer_stop_sequences_detected_total",
"Total stop sequences detected by type"
......@@ -221,7 +210,6 @@ pub fn init_metrics() {
"Time to check for stop sequences per token"
);
// Streaming decode metrics
describe_counter!(
"sgl_tokenizer_stream_tokens_total",
"Total tokens processed in streaming decode"
......@@ -235,7 +223,6 @@ pub fn init_metrics() {
"Time per streaming decode step"
);
// Factory metrics
describe_counter!(
"sgl_tokenizer_factory_loads_total",
"Total tokenizer loads by file type"
......@@ -251,7 +238,6 @@ pub fn init_metrics() {
}
pub fn start_prometheus(config: PrometheusConfig) {
// Initialize metric descriptions
init_metrics();
let duration_matcher = Matcher::Suffix(String::from("duration_seconds"));
......@@ -280,7 +266,6 @@ pub struct RouterMetrics;
pub struct TokenizerMetrics;
impl RouterMetrics {
// Request metrics
pub fn record_request(route: &str) {
counter!("sgl_router_requests_total",
"route" => route.to_string()
......@@ -324,7 +309,6 @@ impl RouterMetrics {
.increment(1);
}
// Worker metrics
pub fn set_active_workers(count: usize) {
gauge!("sgl_router_active_workers").set(count as f64);
}
......@@ -350,7 +334,6 @@ impl RouterMetrics {
.increment(1);
}
// Policy metrics
pub fn record_policy_decision(policy: &str, worker: &str) {
counter!("sgl_router_policy_decisions_total",
"policy" => policy.to_string(),
......@@ -383,7 +366,6 @@ impl RouterMetrics {
gauge!("sgl_router_min_load").set(min_load as f64);
}
// PD-specific metrics
pub fn record_pd_request(route: &str) {
counter!("sgl_router_pd_requests_total",
"route" => route.to_string()
......@@ -440,19 +422,16 @@ impl RouterMetrics {
.increment(1);
}
// Service discovery metrics
pub fn record_discovery_update(added: usize, removed: usize) {
counter!("sgl_router_discovery_updates_total").increment(1);
gauge!("sgl_router_discovery_workers_added").set(added as f64);
gauge!("sgl_router_discovery_workers_removed").set(removed as f64);
}
// Generate request metrics
pub fn record_generate_duration(duration: Duration) {
histogram!("sgl_router_generate_duration_seconds").record(duration.as_secs_f64());
}
// Embeddings metrics
pub fn record_embeddings_request() {
counter!("sgl_router_embeddings_total").increment(1);
}
......@@ -473,7 +452,6 @@ impl RouterMetrics {
gauge!("sgl_router_embeddings_queue_size").set(size as f64);
}
// Running requests for cache-aware policy
pub fn set_running_requests(worker: &str, count: usize) {
gauge!("sgl_router_running_requests",
"worker" => worker.to_string()
......@@ -481,7 +459,6 @@ impl RouterMetrics {
.set(count as f64);
}
// Circuit breaker metrics
pub fn set_cb_state(worker: &str, state_code: u8) {
gauge!("sgl_router_cb_state",
"worker" => worker.to_string()
......@@ -508,7 +485,6 @@ impl RouterMetrics {
}
impl TokenizerMetrics {
// Encoding metrics
pub fn record_encode_request(tokenizer_type: &str) {
counter!("sgl_tokenizer_encode_requests_total",
"tokenizer_type" => tokenizer_type.to_string()
......@@ -535,7 +511,6 @@ impl TokenizerMetrics {
histogram!("sgl_tokenizer_chars_per_encode").record(char_count as f64);
}
// Decoding metrics
pub fn record_decode_request(tokenizer_type: &str) {
counter!("sgl_tokenizer_decode_requests_total",
"tokenizer_type" => tokenizer_type.to_string()
......@@ -558,7 +533,6 @@ impl TokenizerMetrics {
histogram!("sgl_tokenizer_tokens_per_decode").record(token_count as f64);
}
// Batch encoding metrics
pub fn record_encode_batch_duration(duration: Duration, batch_size: usize) {
histogram!("sgl_tokenizer_encode_batch_duration_seconds",
"batch_size" => batch_size.to_string()
......@@ -566,7 +540,6 @@ impl TokenizerMetrics {
.record(duration.as_secs_f64());
}
// Stop sequence detection metrics
pub fn record_stop_sequence_detected(stop_type: &str) {
counter!("sgl_tokenizer_stop_sequences_detected_total",
"type" => stop_type.to_string()
......@@ -582,7 +555,6 @@ impl TokenizerMetrics {
histogram!("sgl_tokenizer_stop_detection_duration_seconds").record(duration.as_secs_f64());
}
// Streaming decode metrics
pub fn record_stream_token() {
counter!("sgl_tokenizer_stream_tokens_total").increment(1);
}
......@@ -595,7 +567,6 @@ impl TokenizerMetrics {
histogram!("sgl_tokenizer_stream_step_duration_seconds").record(duration.as_secs_f64());
}
// Factory metrics
pub fn record_factory_load(file_type: &str) {
counter!("sgl_tokenizer_factory_loads_total",
"file_type" => file_type.to_string()
......@@ -614,7 +585,6 @@ impl TokenizerMetrics {
histogram!("sgl_tokenizer_factory_load_duration_seconds").record(duration.as_secs_f64());
}
// Vocabulary metrics
pub fn set_vocab_size(tokenizer_type: &str, size: usize) {
gauge!("sgl_tokenizer_vocab_size",
"tokenizer_type" => tokenizer_type.to_string()
......@@ -705,7 +675,6 @@ mod tests {
.parse()
.unwrap_or(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
// Should fall back to 0.0.0.0
assert_eq!(ip_addr, IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
}
}
......@@ -780,7 +749,6 @@ mod tests {
fn test_duration_suffix_matcher() {
let matcher = Matcher::Suffix(String::from("duration_seconds"));
// Test matching behavior
let _matching_metrics = [
"request_duration_seconds",
"response_duration_seconds",
......@@ -789,8 +757,6 @@ mod tests {
let _non_matching_metrics = ["duration_total", "duration_seconds_total", "other_metric"];
// Note: We can't directly test Matcher matching without the internals,
// but we can verify the matcher is created correctly
match matcher {
Matcher::Suffix(suffix) => assert_eq!(suffix, "duration_seconds"),
_ => panic!("Expected Suffix matcher"),
......@@ -801,7 +767,6 @@ mod tests {
#[test]
fn test_prometheus_builder_configuration() {
// This test verifies the builder configuration without actually starting Prometheus
let _config = PrometheusConfig::default();
let duration_matcher = Matcher::Suffix(String::from("duration_seconds"));
......@@ -810,10 +775,8 @@ mod tests {
60.0, 90.0, 120.0, 180.0, 240.0,
];
// Verify bucket configuration
assert_eq!(duration_bucket.len(), 20);
// Verify matcher is suffix type
match duration_matcher {
Matcher::Suffix(s) => assert_eq!(s, "duration_seconds"),
_ => panic!("Expected Suffix matcher"),
......@@ -832,14 +795,12 @@ mod tests {
#[test]
fn test_custom_buckets_for_different_metrics() {
// Test that we can create different bucket configurations
let request_buckets = [0.001, 0.01, 0.1, 1.0, 10.0];
let generate_buckets = [0.1, 0.5, 1.0, 5.0, 30.0, 60.0];
assert_eq!(request_buckets.len(), 5);
assert_eq!(generate_buckets.len(), 6);
// Verify each set is sorted
for i in 1..request_buckets.len() {
assert!(request_buckets[i] > request_buckets[i - 1]);
}
......@@ -853,7 +814,6 @@ mod tests {
#[test]
fn test_metrics_static_methods() {
// Test that all static methods can be called without panic
RouterMetrics::record_request("/generate");
RouterMetrics::record_request_duration("/generate", Duration::from_millis(100));
RouterMetrics::record_request_error("/generate", "timeout");
......@@ -887,41 +847,32 @@ mod tests {
#[test]
fn test_tokenizer_metrics_static_methods() {
// Test that all tokenizer metric methods can be called without panic
// Encoding metrics
TokenizerMetrics::record_encode_request("huggingface");
TokenizerMetrics::record_encode_duration(Duration::from_millis(10));
TokenizerMetrics::record_encode_error("invalid_input");
TokenizerMetrics::record_tokens_per_encode(100);
TokenizerMetrics::record_chars_per_encode(500);
// Decoding metrics
TokenizerMetrics::record_decode_request("huggingface");
TokenizerMetrics::record_decode_duration(Duration::from_millis(5));
TokenizerMetrics::record_decode_error("invalid_tokens");
TokenizerMetrics::record_tokens_per_decode(50);
// Batch encoding
TokenizerMetrics::record_encode_batch_duration(Duration::from_millis(100), 10);
// Stop sequence detection
TokenizerMetrics::record_stop_sequence_detected("token");
TokenizerMetrics::record_stop_sequence_detected("string");
TokenizerMetrics::record_partial_match();
TokenizerMetrics::record_stop_detection_duration(Duration::from_micros(100));
// Streaming decode
TokenizerMetrics::record_stream_token();
TokenizerMetrics::record_incomplete_utf8();
TokenizerMetrics::record_stream_step_duration(Duration::from_micros(50));
// Factory metrics
TokenizerMetrics::record_factory_load("json");
TokenizerMetrics::record_factory_error("unsupported_format");
TokenizerMetrics::record_factory_load_duration(Duration::from_millis(200));
// Vocabulary metrics
TokenizerMetrics::set_vocab_size("huggingface", 50000);
}
......@@ -929,17 +880,14 @@ mod tests {
#[test]
fn test_port_already_in_use() {
// Skip this test if we can't bind to the port
let port = 29123; // Use a different port to avoid conflicts
let port = 29123;
if let Ok(_listener) = TcpListener::bind(("127.0.0.1", port)) {
// Port is available, we can test
let config = PrometheusConfig {
port,
host: "127.0.0.1".to_string(),
};
// Just verify config is created correctly
assert_eq!(config.port, port);
}
}
......@@ -948,8 +896,6 @@ mod tests {
#[test]
fn test_metrics_endpoint_accessibility() {
// This would be an integration test in practice
// Here we just verify the configuration
let config = PrometheusConfig {
port: 29000,
host: "127.0.0.1".to_string(),
......@@ -963,7 +909,6 @@ mod tests {
#[test]
fn test_concurrent_metric_updates() {
// Test that metric updates can be called concurrently
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::thread;
......@@ -984,11 +929,9 @@ mod tests {
handles.push(handle);
}
// Let threads run briefly
thread::sleep(Duration::from_millis(10));
done.store(true, Ordering::Relaxed);
// Wait for all threads
for handle in handles {
handle.join().unwrap();
}
......@@ -998,7 +941,6 @@ mod tests {
#[test]
fn test_empty_string_metrics() {
// Test that empty strings don't cause issues
RouterMetrics::record_request("");
RouterMetrics::set_worker_health("", true);
RouterMetrics::record_policy_decision("", "");
......@@ -1030,7 +972,6 @@ mod tests {
#[test]
fn test_extreme_metric_values() {
// Test extreme values
RouterMetrics::set_active_workers(0);
RouterMetrics::set_active_workers(usize::MAX);
......@@ -1038,7 +979,6 @@ mod tests {
RouterMetrics::set_worker_load("worker", usize::MAX);
RouterMetrics::record_request_duration("route", Duration::from_nanos(1));
// 24 hours
RouterMetrics::record_request_duration("route", Duration::from_secs(86400));
}
}
......@@ -19,7 +19,6 @@ use tokio::task;
use tokio::time;
use tracing::{debug, error, info, warn};
/// Represents the service discovery configuration
#[derive(Debug, Clone)]
pub struct ServiceDiscoveryConfig {
pub enabled: bool,
......@@ -41,8 +40,8 @@ impl Default for ServiceDiscoveryConfig {
enabled: false,
selector: HashMap::new(),
check_interval: Duration::from_secs(60),
port: 8000, // Standard port for modern services
namespace: None, // None means watch all namespaces
port: 8000,
namespace: None,
pd_mode: false,
prefill_selector: HashMap::new(),
decode_selector: HashMap::new(),
......@@ -51,7 +50,6 @@ impl Default for ServiceDiscoveryConfig {
}
}
/// Pod type for PD mode service discovery
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum PodType {
Prefill,
......@@ -59,7 +57,6 @@ pub enum PodType {
Regular,
}
/// Represents a Kubernetes pod's information used for worker management
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct PodInfo {
pub name: String,
......@@ -71,7 +68,6 @@ pub struct PodInfo {
}
impl PodInfo {
/// Check if a pod matches any of the given selectors
fn matches_selector(pod: &Pod, selector: &HashMap<String, String>) -> bool {
if selector.is_empty() {
return false;
......@@ -83,19 +79,15 @@ impl PodInfo {
.is_some_and(|labels| selector.iter().all(|(k, v)| labels.get(k) == Some(v)))
}
/// Check if a pod should be included in service discovery
pub fn should_include(pod: &Pod, config: &ServiceDiscoveryConfig) -> bool {
if config.pd_mode {
// In PD mode, at least one selector must be non-empty
if config.prefill_selector.is_empty() && config.decode_selector.is_empty() {
warn!("PD mode enabled but both prefill_selector and decode_selector are empty");
return false;
}
// In PD mode, pod must match either prefill or decode selector
Self::matches_selector(pod, &config.prefill_selector)
|| Self::matches_selector(pod, &config.decode_selector)
} else {
// In regular mode, pod must match the general selector
if config.selector.is_empty() {
warn!("Regular mode enabled but selector is empty");
return false;
......@@ -104,7 +96,6 @@ impl PodInfo {
}
}
/// Unified PodInfo creation with optional PD configuration
pub fn from_pod(pod: &Pod, config: Option<&ServiceDiscoveryConfig>) -> Option<Self> {
let name = pod.metadata.name.clone()?;
let status = pod.status.clone()?;
......@@ -120,10 +111,8 @@ impl PodInfo {
let pod_status = status.phase.unwrap_or_else(|| "Unknown".to_string());
// Determine pod type based on labels if config is provided and in PD mode
let pod_type = if let Some(config) = config {
if config.pd_mode {
// Use simplified helper methods for cleaner logic
if Self::matches_selector(pod, &config.prefill_selector) {
Some(PodType::Prefill)
} else if Self::matches_selector(pod, &config.decode_selector) {
......@@ -135,11 +124,9 @@ impl PodInfo {
Some(PodType::Regular)
}
} else {
// No config provided, default to None (for backwards compatibility)
None
};
// Extract bootstrap port from annotations for prefill pods
let bootstrap_port = if matches!(pod_type, Some(PodType::Prefill)) {
if let Some(config) = config {
pod.metadata
......@@ -164,12 +151,10 @@ impl PodInfo {
})
}
/// Returns true if the pod is in a state where it can accept traffic
pub fn is_healthy(&self) -> bool {
self.is_ready && self.status == "Running"
}
/// Generates a worker URL for this pod
pub fn worker_url(&self, port: u16) -> String {
format!("http://{}:{}", self.ip, port)
}
......@@ -179,9 +164,7 @@ pub async fn start_service_discovery(
config: ServiceDiscoveryConfig,
app_context: Arc<AppContext>,
) -> Result<task::JoinHandle<()>, kube::Error> {
// Don't initialize anything if service discovery is disabled
if !config.enabled {
// Return a generic error when service discovery is disabled
return Err(kube::Error::Api(kube::error::ErrorResponse {
status: "Disabled".to_string(),
message: "Service discovery is disabled".to_string(),
......@@ -192,7 +175,6 @@ pub async fn start_service_discovery(
let _ = rustls::crypto::ring::default_provider().install_default();
// Initialize Kubernetes client
let client = Client::try_default().await?;
// Log the appropriate selectors based on mode
......@@ -229,12 +211,9 @@ pub async fn start_service_discovery(
);
}
// Create the task that will run in the background
let handle = task::spawn(async move {
// We'll track pods we've already added to avoid duplicates
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
// Create a watcher for pods
let pods: Api<Pod> = if let Some(namespace) = &config.namespace {
Api::namespaced(client, namespace)
} else {
......@@ -243,23 +222,19 @@ pub async fn start_service_discovery(
debug!("K8s service discovery initialized");
// Create Arcs for configuration data
let config_arc = Arc::new(config.clone());
let port = config.port;
let mut retry_delay = Duration::from_secs(1);
const MAX_RETRY_DELAY: Duration = Duration::from_secs(300); // 5 minutes max
const MAX_RETRY_DELAY: Duration = Duration::from_secs(300);
loop {
// Create a watcher with the proper parameters according to the kube-rs API
let watcher_config = Config::default();
let watcher_stream = watcher(pods.clone(), watcher_config).applied_objects();
// Clone Arcs for the closures
let config_clone = Arc::clone(&config_arc);
let tracked_pods_clone = Arc::clone(&tracked_pods);
// Simplified label selector filter using helper method
let filtered_stream = watcher_stream.filter_map(move |obj_res| {
let config_inner = Arc::clone(&config_clone);
......@@ -277,7 +252,6 @@ pub async fn start_service_discovery(
}
});
// Clone again for the next closure
let tracked_pods_clone2 = Arc::clone(&tracked_pods_clone);
let app_context_clone = Arc::clone(&app_context);
let config_clone2 = Arc::clone(&config_arc);
......@@ -317,7 +291,6 @@ pub async fn start_service_discovery(
.await
{
Ok(_) => {
// Reset retry delay on success
retry_delay = Duration::from_secs(1);
}
Err(err) => {
......@@ -328,12 +301,10 @@ pub async fn start_service_discovery(
);
time::sleep(retry_delay).await;
// Exponential backoff with jitter
retry_delay = std::cmp::min(retry_delay * 2, MAX_RETRY_DELAY);
}
}
// If the watcher exits for some reason, wait a bit before restarting
warn!(
"Kubernetes watcher exited, restarting in {} seconds",
config_arc.check_interval.as_secs()
......@@ -354,9 +325,7 @@ async fn handle_pod_event(
) {
let worker_url = pod_info.worker_url(port);
// If pod is healthy, try to add it (with atomic check-and-insert)
if pod_info.is_healthy() {
// Atomic check-and-insert to prevent race conditions
let should_add = {
let mut tracker = match tracked_pods.lock() {
Ok(tracker) => tracker,
......@@ -367,9 +336,8 @@ async fn handle_pod_event(
};
if tracker.contains(pod_info) {
false // Already tracked
false
} else {
// Reserve the spot to prevent other threads from adding the same pod
tracker.insert(pod_info.clone());
true
}
......@@ -381,7 +349,6 @@ async fn handle_pod_event(
pod_info.name, pod_info.pod_type, worker_url
);
// Build worker config based on pod type and routing mode
let worker_type = if pd_mode {
match &pod_info.pod_type {
Some(PodType::Prefill) => Some("prefill".to_string()),
......@@ -392,7 +359,6 @@ async fn handle_pod_event(
None
};
// Only set bootstrap_port for prefill workers in PD mode
let bootstrap_port = if pd_mode {
match &pod_info.pod_type {
Some(PodType::Prefill) => pod_info.bootstrap_port,
......@@ -425,7 +391,6 @@ async fn handle_pod_event(
}
Err(e) => {
error!("Failed to add worker {} to router: {}", worker_url, e);
// Remove from tracking since addition failed
if let Ok(mut tracker) = tracked_pods.lock() {
tracker.remove(pod_info);
}
......@@ -464,8 +429,6 @@ async fn handle_pod_deletion(
error!("Failed to remove worker {}: {}", worker_url, e);
}
} else {
// This case might occur if a pod is deleted before it was ever marked healthy and added.
// Or if the event is duplicated. No action needed on the router if it wasn't tracked (and thus not added).
debug!(
"Pod deletion event for untracked/already removed pod: {} (type: {:?}). Worker URL: {}",
pod_info.name, pod_info.pod_type, worker_url
......@@ -480,7 +443,6 @@ mod tests {
use k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta;
use k8s_openapi::apimachinery::pkg::apis::meta::v1::Time;
// Helper function to create a Pod for testing PodInfo::from_pod
fn create_k8s_pod(
name: Option<&str>,
ip: Option<&str>,
......@@ -523,7 +485,6 @@ mod tests {
pod
}
// Helper function to create a Pod with PD-specific labels and annotations
fn create_pd_k8s_pod(name: &str, ip: &str, pod_type: &str, bootstrap_port: Option<u16>) -> Pod {
let mut labels = std::collections::BTreeMap::new();
labels.insert("app".to_string(), "sglang".to_string());
......@@ -559,18 +520,15 @@ mod tests {
}
}
// Helper to create an AppContext instance for testing event handlers
async fn create_test_app_context() -> Arc<AppContext> {
use crate::config::RouterConfig;
use crate::middleware::TokenBucket;
// Create a minimal RouterConfig for testing with very short timeout
let router_config = RouterConfig {
worker_startup_timeout_secs: 1,
..Default::default()
}; // Very short timeout for tests
};
// Create AppContext with minimal components
Arc::new(AppContext {
client: reqwest::Client::new(),
router_config: router_config.clone(),
......@@ -579,16 +537,15 @@ mod tests {
policy_registry: Arc::new(crate::policies::PolicyRegistry::new(
router_config.policy.clone(),
)),
tokenizer: None, // HTTP mode doesn't need tokenizer
reasoning_parser_factory: None, // HTTP mode doesn't need reasoning parser
tool_parser_registry: None, // HTTP mode doesn't need tool parser
router_manager: None, // Test doesn't need router manager
tokenizer: None,
reasoning_parser_factory: None,
tool_parser_registry: None,
router_manager: None,
response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()),
load_monitor: None,
})
}
// Helper to create a PD config for testing
fn create_pd_config() -> ServiceDiscoveryConfig {
let mut prefill_selector = HashMap::new();
prefill_selector.insert("app".to_string(), "sglang".to_string());
......@@ -615,19 +572,15 @@ mod tests {
fn test_pod_info_should_include() {
let config = create_pd_config();
// Test prefill pod should be included
let prefill_pod = create_pd_k8s_pod("prefill-pod", "10.0.0.1", "prefill", Some(8081));
assert!(PodInfo::should_include(&prefill_pod, &config));
// Test decode pod should be included
let decode_pod = create_pd_k8s_pod("decode-pod", "10.0.0.2", "decode", None);
assert!(PodInfo::should_include(&decode_pod, &config));
// Test unmatched pod should not be included
let unmatched_pod = create_pd_k8s_pod("other-pod", "10.0.0.3", "other", None);
assert!(!PodInfo::should_include(&unmatched_pod, &config));
// Test regular mode
let mut regular_config = ServiceDiscoveryConfig::default();
regular_config
.selector
......@@ -654,7 +607,6 @@ mod tests {
#[test]
fn test_pod_type_enum() {
// Test that PodType enum has expected variants
let prefill = PodType::Prefill;
let decode = PodType::Decode;
let regular = PodType::Regular;
......@@ -714,7 +666,7 @@ mod tests {
fn test_pod_info_from_pod_with_pd_config_regular_mode() {
let k8s_pod = create_pd_k8s_pod("regular-pod", "10.0.0.3", "worker", None);
let mut config = create_pd_config();
config.pd_mode = false; // Set to regular mode
config.pd_mode = false;
let pod_info = PodInfo::from_pod(&k8s_pod, Some(&config)).unwrap();
assert_eq!(pod_info.name, "regular-pod");
......@@ -742,7 +694,6 @@ mod tests {
#[test]
fn test_pod_info_from_pod_with_pd_config_invalid_bootstrap_port() {
let mut pod = create_pd_k8s_pod("prefill-pod", "10.0.0.1", "prefill", None);
// Add invalid bootstrap port annotation
pod.metadata.annotations.as_mut().unwrap().insert(
"sglang.ai/bootstrap-port".to_string(),
"invalid".to_string(),
......@@ -751,7 +702,7 @@ mod tests {
let pod_info = PodInfo::from_pod(&pod, Some(&config)).unwrap();
assert_eq!(pod_info.pod_type, Some(PodType::Prefill));
assert!(pod_info.bootstrap_port.is_none()); // Should be None for invalid port
assert!(pod_info.bootstrap_port.is_none());
}
#[test]
......@@ -1077,7 +1028,6 @@ mod tests {
)
.await;
// Pod should not be tracked since add_worker_from_url will fail for non-running server
assert!(!tracked_pods.lock().unwrap().contains(&pod_info));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment