use pyo3::prelude::*; pub mod config; pub mod logging; use std::collections::HashMap; pub mod core; pub mod data_connector; #[cfg(feature = "grpc-client")] pub mod grpc_client; pub mod mcp; pub mod metrics; pub mod middleware; pub mod policies; pub mod protocols; pub mod reasoning_parser; pub mod routers; pub mod server; pub mod service_discovery; pub mod tokenizer; pub mod tool_parser; pub mod tree; use crate::metrics::PrometheusConfig; #[pyclass(eq)] #[derive(Clone, PartialEq, Debug)] pub enum PolicyType { Random, RoundRobin, CacheAware, PowerOfTwo, } #[pyclass(eq)] #[derive(Clone, PartialEq, Debug)] pub enum BackendType { Sglang, Openai, } #[pyclass(eq)] #[derive(Clone, PartialEq, Debug)] pub enum HistoryBackendType { Memory, None, Oracle, } #[pyclass] #[derive(Clone, PartialEq)] pub struct PyOracleConfig { #[pyo3(get, set)] pub wallet_path: Option, #[pyo3(get, set)] pub connect_descriptor: Option, #[pyo3(get, set)] pub username: Option, #[pyo3(get, set)] pub password: Option, #[pyo3(get, set)] pub pool_min: usize, #[pyo3(get, set)] pub pool_max: usize, #[pyo3(get, set)] pub pool_timeout_secs: u64, } impl std::fmt::Debug for PyOracleConfig { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("PyOracleConfig") .field("wallet_path", &self.wallet_path) .field("connect_descriptor", &"") .field("username", &self.username) .field("password", &"") .field("pool_min", &self.pool_min) .field("pool_max", &self.pool_max) .field("pool_timeout_secs", &self.pool_timeout_secs) .finish() } } #[pymethods] impl PyOracleConfig { #[new] #[pyo3(signature = ( password = None, username = None, connect_descriptor = None, wallet_path = None, pool_min = 1, pool_max = 16, pool_timeout_secs = 30, ))] fn new( password: Option, username: Option, connect_descriptor: Option, wallet_path: Option, pool_min: usize, pool_max: usize, pool_timeout_secs: u64, ) -> PyResult { if pool_min == 0 { return Err(pyo3::exceptions::PyValueError::new_err( "pool_min must be at least 1", )); } if pool_max < pool_min { return Err(pyo3::exceptions::PyValueError::new_err( "pool_max must be >= pool_min", )); } Ok(PyOracleConfig { wallet_path, connect_descriptor, username, password, pool_min, pool_max, pool_timeout_secs, }) } } impl PyOracleConfig { fn to_config_oracle(&self) -> config::OracleConfig { // Simple conversion - validation happens later in validate_oracle() config::OracleConfig { wallet_path: self.wallet_path.clone(), connect_descriptor: self.connect_descriptor.clone().unwrap_or_default(), username: self.username.clone().unwrap_or_default(), password: self.password.clone().unwrap_or_default(), pool_min: self.pool_min, pool_max: self.pool_max, pool_timeout_secs: self.pool_timeout_secs, } } } #[pyclass] #[derive(Debug, Clone, PartialEq)] struct Router { host: String, port: u16, worker_urls: Vec, policy: PolicyType, worker_startup_timeout_secs: u64, worker_startup_check_interval: u64, cache_threshold: f32, balance_abs_threshold: usize, balance_rel_threshold: f32, eviction_interval_secs: u64, max_tree_size: usize, max_payload_size: usize, dp_aware: bool, api_key: Option, log_dir: Option, log_level: Option, service_discovery: bool, selector: HashMap, service_discovery_port: u16, service_discovery_namespace: Option, prefill_selector: HashMap, decode_selector: HashMap, bootstrap_port_annotation: String, prometheus_port: Option, prometheus_host: Option, request_timeout_secs: u64, request_id_headers: Option>, pd_disaggregation: bool, prefill_urls: Option)>>, decode_urls: Option>, prefill_policy: Option, decode_policy: Option, max_concurrent_requests: i32, cors_allowed_origins: Vec, retry_max_retries: u32, retry_initial_backoff_ms: u64, retry_max_backoff_ms: u64, retry_backoff_multiplier: f32, retry_jitter_factor: f32, disable_retries: bool, cb_failure_threshold: u32, cb_success_threshold: u32, cb_timeout_duration_secs: u64, cb_window_duration_secs: u64, disable_circuit_breaker: bool, health_failure_threshold: u32, health_success_threshold: u32, health_check_timeout_secs: u64, health_check_interval_secs: u64, health_check_endpoint: String, enable_igw: bool, queue_size: usize, queue_timeout_secs: u64, rate_limit_tokens_per_second: Option, connection_mode: core::ConnectionMode, model_path: Option, tokenizer_path: Option, chat_template: Option, tokenizer_cache_enable_l0: bool, tokenizer_cache_l0_max_entries: usize, tokenizer_cache_enable_l1: bool, tokenizer_cache_l1_max_memory: usize, reasoning_parser: Option, tool_call_parser: Option, backend: BackendType, history_backend: HistoryBackendType, oracle_config: Option, } impl Router { /// Determine connection mode from worker URLs fn determine_connection_mode(worker_urls: &[String]) -> core::ConnectionMode { for url in worker_urls { if url.starts_with("grpc://") || url.starts_with("grpcs://") { return core::ConnectionMode::Grpc { port: None }; } } core::ConnectionMode::Http } pub fn to_router_config(&self) -> config::ConfigResult { use config::{ DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode, }; let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig { match policy { PolicyType::Random => ConfigPolicyConfig::Random, PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin, PolicyType::CacheAware => ConfigPolicyConfig::CacheAware { cache_threshold: self.cache_threshold, balance_abs_threshold: self.balance_abs_threshold, balance_rel_threshold: self.balance_rel_threshold, eviction_interval_secs: self.eviction_interval_secs, max_tree_size: self.max_tree_size, }, PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo { load_check_interval_secs: 5, }, } }; let mode = if self.enable_igw { RoutingMode::Regular { worker_urls: vec![], } } else if matches!(self.backend, BackendType::Openai) { RoutingMode::OpenAI { worker_urls: self.worker_urls.clone(), } } else if self.pd_disaggregation { RoutingMode::PrefillDecode { prefill_urls: self.prefill_urls.clone().unwrap_or_default(), decode_urls: self.decode_urls.clone().unwrap_or_default(), prefill_policy: self.prefill_policy.as_ref().map(convert_policy), decode_policy: self.decode_policy.as_ref().map(convert_policy), } } else { RoutingMode::Regular { worker_urls: self.worker_urls.clone(), } }; let policy = convert_policy(&self.policy); let discovery = if self.service_discovery { Some(DiscoveryConfig { enabled: true, namespace: self.service_discovery_namespace.clone(), port: self.service_discovery_port, check_interval_secs: 60, selector: self.selector.clone(), prefill_selector: self.prefill_selector.clone(), decode_selector: self.decode_selector.clone(), bootstrap_port_annotation: self.bootstrap_port_annotation.clone(), }) } else { None }; let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) { (Some(port), Some(host)) => Some(MetricsConfig { port, host: host.clone(), }), _ => None, }; let history_backend = match self.history_backend { HistoryBackendType::Memory => config::HistoryBackend::Memory, HistoryBackendType::None => config::HistoryBackend::None, HistoryBackendType::Oracle => config::HistoryBackend::Oracle, }; let oracle = if matches!(self.history_backend, HistoryBackendType::Oracle) { self.oracle_config .as_ref() .map(|cfg| cfg.to_config_oracle()) } else { None }; let builder = config::RouterConfig::builder() .mode(mode) .policy(policy) .host(&self.host) .port(self.port) .connection_mode(self.connection_mode.clone()) .max_payload_size(self.max_payload_size) .request_timeout_secs(self.request_timeout_secs) .worker_startup_timeout_secs(self.worker_startup_timeout_secs) .worker_startup_check_interval_secs(self.worker_startup_check_interval) .max_concurrent_requests(self.max_concurrent_requests) .queue_size(self.queue_size) .queue_timeout_secs(self.queue_timeout_secs) .cors_allowed_origins(self.cors_allowed_origins.clone()) .retry_config(config::RetryConfig { max_retries: self.retry_max_retries, initial_backoff_ms: self.retry_initial_backoff_ms, max_backoff_ms: self.retry_max_backoff_ms, backoff_multiplier: self.retry_backoff_multiplier, jitter_factor: self.retry_jitter_factor, }) .circuit_breaker_config(config::CircuitBreakerConfig { failure_threshold: self.cb_failure_threshold, success_threshold: self.cb_success_threshold, timeout_duration_secs: self.cb_timeout_duration_secs, window_duration_secs: self.cb_window_duration_secs, }) .health_check_config(config::HealthCheckConfig { failure_threshold: self.health_failure_threshold, success_threshold: self.health_success_threshold, timeout_secs: self.health_check_timeout_secs, check_interval_secs: self.health_check_interval_secs, endpoint: self.health_check_endpoint.clone(), }) .tokenizer_cache(config::TokenizerCacheConfig { enable_l0: self.tokenizer_cache_enable_l0, l0_max_entries: self.tokenizer_cache_l0_max_entries, enable_l1: self.tokenizer_cache_enable_l1, l1_max_memory: self.tokenizer_cache_l1_max_memory, }) .history_backend(history_backend) .maybe_api_key(self.api_key.as_ref()) .maybe_discovery(discovery) .maybe_metrics(metrics) .maybe_log_dir(self.log_dir.as_ref()) .maybe_log_level(self.log_level.as_ref()) .maybe_request_id_headers(self.request_id_headers.clone()) .maybe_rate_limit_tokens_per_second(self.rate_limit_tokens_per_second) .maybe_model_path(self.model_path.as_ref()) .maybe_tokenizer_path(self.tokenizer_path.as_ref()) .maybe_chat_template(self.chat_template.as_ref()) .maybe_oracle(oracle) .maybe_reasoning_parser(self.reasoning_parser.as_ref()) .maybe_tool_call_parser(self.tool_call_parser.as_ref()) .dp_aware(self.dp_aware) .retries(!self.disable_retries) .circuit_breaker(!self.disable_circuit_breaker) .igw(self.enable_igw); builder.build() } } #[pymethods] impl Router { #[new] #[pyo3(signature = ( worker_urls, policy = PolicyType::RoundRobin, host = String::from("0.0.0.0"), port = 3001, worker_startup_timeout_secs = 600, worker_startup_check_interval = 30, cache_threshold = 0.3, balance_abs_threshold = 64, balance_rel_threshold = 1.5, eviction_interval_secs = 120, max_tree_size = 2usize.pow(26), max_payload_size = 512 * 1024 * 1024, dp_aware = false, api_key = None, log_dir = None, log_level = None, service_discovery = false, selector = HashMap::new(), service_discovery_port = 80, service_discovery_namespace = None, prefill_selector = HashMap::new(), decode_selector = HashMap::new(), bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"), prometheus_port = None, prometheus_host = None, request_timeout_secs = 1800, request_id_headers = None, pd_disaggregation = false, prefill_urls = None, decode_urls = None, prefill_policy = None, decode_policy = None, max_concurrent_requests = -1, cors_allowed_origins = vec![], retry_max_retries = 5, retry_initial_backoff_ms = 50, retry_max_backoff_ms = 30_000, retry_backoff_multiplier = 1.5, retry_jitter_factor = 0.2, disable_retries = false, cb_failure_threshold = 10, cb_success_threshold = 3, cb_timeout_duration_secs = 60, cb_window_duration_secs = 120, disable_circuit_breaker = false, health_failure_threshold = 3, health_success_threshold = 2, health_check_timeout_secs = 5, health_check_interval_secs = 60, health_check_endpoint = String::from("/health"), enable_igw = false, queue_size = 100, queue_timeout_secs = 60, rate_limit_tokens_per_second = None, model_path = None, tokenizer_path = None, chat_template = None, tokenizer_cache_enable_l0 = false, tokenizer_cache_l0_max_entries = 10000, tokenizer_cache_enable_l1 = false, tokenizer_cache_l1_max_memory = 52428800, reasoning_parser = None, tool_call_parser = None, backend = BackendType::Sglang, history_backend = HistoryBackendType::Memory, oracle_config = None, ))] #[allow(clippy::too_many_arguments)] fn new( worker_urls: Vec, policy: PolicyType, host: String, port: u16, worker_startup_timeout_secs: u64, worker_startup_check_interval: u64, cache_threshold: f32, balance_abs_threshold: usize, balance_rel_threshold: f32, eviction_interval_secs: u64, max_tree_size: usize, max_payload_size: usize, dp_aware: bool, api_key: Option, log_dir: Option, log_level: Option, service_discovery: bool, selector: HashMap, service_discovery_port: u16, service_discovery_namespace: Option, prefill_selector: HashMap, decode_selector: HashMap, bootstrap_port_annotation: String, prometheus_port: Option, prometheus_host: Option, request_timeout_secs: u64, request_id_headers: Option>, pd_disaggregation: bool, prefill_urls: Option)>>, decode_urls: Option>, prefill_policy: Option, decode_policy: Option, max_concurrent_requests: i32, cors_allowed_origins: Vec, retry_max_retries: u32, retry_initial_backoff_ms: u64, retry_max_backoff_ms: u64, retry_backoff_multiplier: f32, retry_jitter_factor: f32, disable_retries: bool, cb_failure_threshold: u32, cb_success_threshold: u32, cb_timeout_duration_secs: u64, cb_window_duration_secs: u64, disable_circuit_breaker: bool, health_failure_threshold: u32, health_success_threshold: u32, health_check_timeout_secs: u64, health_check_interval_secs: u64, health_check_endpoint: String, enable_igw: bool, queue_size: usize, queue_timeout_secs: u64, rate_limit_tokens_per_second: Option, model_path: Option, tokenizer_path: Option, chat_template: Option, tokenizer_cache_enable_l0: bool, tokenizer_cache_l0_max_entries: usize, tokenizer_cache_enable_l1: bool, tokenizer_cache_l1_max_memory: usize, reasoning_parser: Option, tool_call_parser: Option, backend: BackendType, history_backend: HistoryBackendType, oracle_config: Option, ) -> PyResult { let mut all_urls = worker_urls.clone(); if let Some(ref prefill_urls) = prefill_urls { for (url, _) in prefill_urls { all_urls.push(url.clone()); } } if let Some(ref decode_urls) = decode_urls { all_urls.extend(decode_urls.clone()); } let connection_mode = Self::determine_connection_mode(&all_urls); Ok(Router { host, port, worker_urls, policy, worker_startup_timeout_secs, worker_startup_check_interval, cache_threshold, balance_abs_threshold, balance_rel_threshold, eviction_interval_secs, max_tree_size, max_payload_size, dp_aware, api_key, log_dir, log_level, service_discovery, selector, service_discovery_port, service_discovery_namespace, prefill_selector, decode_selector, bootstrap_port_annotation, prometheus_port, prometheus_host, request_timeout_secs, request_id_headers, pd_disaggregation, prefill_urls, decode_urls, prefill_policy, decode_policy, max_concurrent_requests, cors_allowed_origins, retry_max_retries, retry_initial_backoff_ms, retry_max_backoff_ms, retry_backoff_multiplier, retry_jitter_factor, disable_retries, cb_failure_threshold, cb_success_threshold, cb_timeout_duration_secs, cb_window_duration_secs, disable_circuit_breaker, health_failure_threshold, health_success_threshold, health_check_timeout_secs, health_check_interval_secs, health_check_endpoint, enable_igw, queue_size, queue_timeout_secs, rate_limit_tokens_per_second, connection_mode, model_path, tokenizer_path, chat_template, tokenizer_cache_enable_l0, tokenizer_cache_l0_max_entries, tokenizer_cache_enable_l1, tokenizer_cache_l1_max_memory, reasoning_parser, tool_call_parser, backend, history_backend, oracle_config, }) } fn start(&self) -> PyResult<()> { let router_config = self.to_router_config().map_err(|e| { pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e)) })?; router_config.validate().map_err(|e| { pyo3::exceptions::PyValueError::new_err(format!( "Configuration validation failed: {}", e )) })?; let service_discovery_config = if self.service_discovery { Some(service_discovery::ServiceDiscoveryConfig { enabled: true, selector: self.selector.clone(), check_interval: std::time::Duration::from_secs(60), port: self.service_discovery_port, namespace: self.service_discovery_namespace.clone(), pd_mode: self.pd_disaggregation, prefill_selector: self.prefill_selector.clone(), decode_selector: self.decode_selector.clone(), bootstrap_port_annotation: self.bootstrap_port_annotation.clone(), }) } else { None }; let prometheus_config = Some(PrometheusConfig { port: self.prometheus_port.unwrap_or(29000), host: self .prometheus_host .clone() .unwrap_or_else(|| "127.0.0.1".to_string()), }); let runtime = tokio::runtime::Runtime::new() .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; runtime.block_on(async move { server::startup(server::ServerConfig { host: self.host.clone(), port: self.port, router_config, max_payload_size: self.max_payload_size, log_dir: self.log_dir.clone(), log_level: self.log_level.clone(), service_discovery_config, prometheus_config, request_timeout_secs: self.request_timeout_secs, request_id_headers: self.request_id_headers.clone(), }) .await .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string())) }) } } #[pymodule] fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; Ok(()) }