use crate::config::types::{CircuitBreakerConfig as ConfigCircuitBreakerConfig, RetryConfig}; use crate::core::{CircuitBreakerConfig, HealthChecker, RetryExecutor, Worker, WorkerFactory}; use crate::metrics::RouterMetrics; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::policies::LoadBalancingPolicy; use crate::routers::{RouterTrait, WorkerManagement}; use axum::{ body::Body, extract::Request, http::{header::CONTENT_LENGTH, header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode}, response::{IntoResponse, Response}, Json, }; use futures_util::StreamExt; use reqwest::Client; use std::collections::HashMap; use std::sync::{Arc, RwLock}; use std::thread; use std::time::{Duration, Instant}; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, error, info, warn}; pub fn copy_request_headers(req: &Request) -> Vec<(String, String)> { req.headers() .iter() .filter_map(|(name, value)| { value .to_str() .ok() .map(|v| (name.to_string(), v.to_string())) }) .collect() } /// Regular router that uses injected load balancing policies #[derive(Debug)] pub struct Router { workers: Arc>>>, policy: Arc, client: Client, timeout_secs: u64, interval_secs: u64, dp_aware: bool, api_key: Option, retry_config: RetryConfig, circuit_breaker_config: CircuitBreakerConfig, _worker_loads: Arc>>, _load_monitor_handle: Option>>, _health_checker: Option, } impl Router { /// Create a new router with injected policy and client pub fn new( worker_urls: Vec, policy: Arc, client: Client, timeout_secs: u64, interval_secs: u64, dp_aware: bool, api_key: Option, retry_config: RetryConfig, circuit_breaker_config: ConfigCircuitBreakerConfig, ) -> Result { // Update active workers gauge RouterMetrics::set_active_workers(worker_urls.len()); // Wait for workers to be healthy (skip if empty - for service discovery mode) if !worker_urls.is_empty() { Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?; } let worker_urls = if dp_aware { // worker address now in the format of "http://host:port@dp_rank" Self::get_dp_aware_workers(&worker_urls, &api_key) .map_err(|e| format!("Failed to get dp-aware workers: {}", e))? } else { worker_urls }; // Convert config CircuitBreakerConfig to core CircuitBreakerConfig let core_cb_config = CircuitBreakerConfig { failure_threshold: circuit_breaker_config.failure_threshold, success_threshold: circuit_breaker_config.success_threshold, timeout_duration: std::time::Duration::from_secs( circuit_breaker_config.timeout_duration_secs, ), window_duration: std::time::Duration::from_secs( circuit_breaker_config.window_duration_secs, ), }; // Create Worker trait objects from URLs let workers: Vec> = worker_urls .iter() .map(|url| { WorkerFactory::create_regular_with_config(url.clone(), core_cb_config.clone()) }) .collect(); // Initialize policy with workers if needed (e.g., for cache-aware) if let Some(cache_aware) = policy .as_any() .downcast_ref::() { cache_aware.init_workers(&workers); } let workers = Arc::new(RwLock::new(workers)); let health_checker = crate::core::start_health_checker(Arc::clone(&workers), interval_secs); // Setup load monitoring for PowerOfTwo policy let (tx, rx) = tokio::sync::watch::channel(HashMap::new()); let worker_loads = Arc::new(rx); let load_monitor_handle = if policy.name() == "power_of_two" { let monitor_urls = worker_urls.clone(); let monitor_interval = interval_secs; let policy_clone = Arc::clone(&policy); let client_clone = client.clone(); Some(Arc::new(tokio::spawn(async move { Self::monitor_worker_loads( monitor_urls, tx, monitor_interval, policy_clone, client_clone, ) .await; }))) } else { None }; Ok(Router { workers, policy, client, timeout_secs, interval_secs, dp_aware, api_key, retry_config, circuit_breaker_config: core_cb_config, _worker_loads: worker_loads, _load_monitor_handle: load_monitor_handle, _health_checker: Some(health_checker), }) } /// Get the current list of worker URLs pub fn get_worker_urls(&self) -> Vec { self.workers .read() .unwrap() .iter() .map(|w| w.url().to_string()) .collect() } pub fn wait_for_healthy_workers( worker_urls: &[String], timeout_secs: u64, interval_secs: u64, ) -> Result<(), String> { if worker_urls.is_empty() { return Err( "Timeout waiting for workers to become healthy: no workers provided".to_string(), ); } let start_time = std::time::Instant::now(); let sync_client = reqwest::blocking::Client::builder() .timeout(Duration::from_secs(timeout_secs)) .build() .map_err(|e| format!("Failed to create HTTP client: {}", e))?; loop { if start_time.elapsed() > Duration::from_secs(timeout_secs) { error!( "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", timeout_secs, worker_urls ); return Err(format!( "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", timeout_secs, worker_urls )); } let mut all_healthy = true; let mut unhealthy_workers = Vec::new(); for url in worker_urls { match sync_client.get(&format!("{}/health", url)).send() { Ok(res) => { if !res.status().is_success() { all_healthy = false; unhealthy_workers.push((url, format!("status: {}", res.status()))); } } Err(_) => { all_healthy = false; unhealthy_workers.push((url, "not ready".to_string())); } } } if all_healthy { info!("All {} workers are healthy", worker_urls.len()); return Ok(()); } else { debug!( "Waiting for {} workers to become healthy ({} unhealthy)", worker_urls.len(), unhealthy_workers.len() ); thread::sleep(Duration::from_secs(interval_secs)); } } } fn get_worker_dp_size(worker_url: &str, api_key: &Option) -> Result { let sync_client = reqwest::blocking::Client::new(); let mut req_builder = sync_client.get(&format!("{}/get_server_info", worker_url)); if let Some(key) = api_key { req_builder = req_builder.bearer_auth(key); } match req_builder.send() { Ok(res) => { if res.status().is_success() { let server_info = res .text() .map_err(|e| format!("failed to read text from response: {}", e))?; let server_info: serde_json::Value = serde_json::from_str(&server_info) .map_err(|e| format!("failed to decode JSON: {}", e))?; let dp_size = server_info .get("dp_size") .and_then(|v| v.as_u64()) .ok_or_else(|| String::from("dp_size not found or not an u64"))?; Ok(if dp_size > usize::MAX as u64 { return Err(format!("dp_size is too large: {}", dp_size)); } else { dp_size as usize }) } else { Err(format!("unexpected status code: {}", res.status())) } } Err(e) => Err(format!("error response: {}", e)), } } // Given a list of workers, return a list of workers with dp_rank as suffix fn get_dp_aware_workers( worker_urls: &[String], api_key: &Option, ) -> Result, String> { let mut dp_aware_workers: Vec = Vec::new(); for url in worker_urls { match Self::get_worker_dp_size(url, api_key) { Ok(dp_size) => { for i in 0..dp_size { dp_aware_workers.push(format!("{}@{}", url, i)); } } Err(e) => return Err(format!("Failed to get DP size for {}: {}", url, e)), } } Ok(dp_aware_workers) } fn select_first_worker(&self) -> Result { let workers_guard = self.workers.read().unwrap(); if workers_guard.is_empty() { Err("No workers are available".to_string()) } else { Ok(workers_guard[0].url().to_string()) } } pub async fn send_health_check(&self, worker_url: &str) -> Response { let health_url = if self.dp_aware { // Need to extract the URL from "http://host:port@dp_rank" match Self::extract_dp_rank(worker_url) { Ok((worker_url_prefix, _dp_rank)) => worker_url_prefix, Err(e) => { error!("Failed to extract dp_rank for health check: {}", e); return ( StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to extract dp_rank: {}", e), ) .into_response(); } } } else { worker_url }; let request_builder = self.client.get(format!("{}/health", health_url)); let response = match request_builder.send().await { Ok(res) => { let status = StatusCode::from_u16(res.status().as_u16()) .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); match res.bytes().await { Ok(body) => (status, body).into_response(), Err(e) => { error!( worker_url = %health_url, error = %e, "Failed to read health response body" ); ( StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to read response body: {}", e), ) .into_response() } } } Err(e) => { error!( worker_url = %health_url, error = %e, "Failed to send health request to worker" ); ( StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to send request to worker {}: {}", health_url, e), ) .into_response() } }; // Don't record metrics for health checks response } // Helper method to proxy GET requests to the first available worker async fn proxy_get_request(&self, req: Request, endpoint: &str) -> Response { let headers = copy_request_headers(&req); match self.select_first_worker() { Ok(worker_url) => { let mut request_builder = self.client.get(format!("{}/{}", worker_url, endpoint)); for (name, value) in headers { let name_lc = name.to_lowercase(); if name_lc != "content-type" && name_lc != "content-length" { request_builder = request_builder.header(name, value); } } match request_builder.send().await { Ok(res) => { let status = StatusCode::from_u16(res.status().as_u16()) .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); match res.bytes().await { Ok(body) => (status, body).into_response(), Err(e) => ( StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to read response: {}", e), ) .into_response(), } } Err(e) => ( StatusCode::INTERNAL_SERVER_ERROR, format!("Request failed: {}", e), ) .into_response(), } } Err(e) => (StatusCode::SERVICE_UNAVAILABLE, e).into_response(), } } // New method to route typed requests directly /// Select worker considering circuit breaker state fn select_worker_with_circuit_breaker(&self, text: Option<&str>) -> Option> { let workers = self.workers.read().ok()?; let available: Vec> = workers .iter() .filter(|w| w.is_available()) .map(|w| w.clone_worker()) .collect(); if available.is_empty() { return None; } let idx = self.policy.select_worker(&available, text)?; Some(available[idx].clone_worker()) } fn is_retryable_status(status: StatusCode) -> bool { matches!( status, StatusCode::REQUEST_TIMEOUT | StatusCode::TOO_MANY_REQUESTS | StatusCode::INTERNAL_SERVER_ERROR | StatusCode::BAD_GATEWAY | StatusCode::SERVICE_UNAVAILABLE | StatusCode::GATEWAY_TIMEOUT ) } pub async fn route_typed_request< T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone, >( &self, headers: Option<&HeaderMap>, typed_req: &T, route: &str, ) -> Response { let start = Instant::now(); let is_stream = typed_req.is_stream(); let text = typed_req.extract_text_for_routing(); let response = RetryExecutor::execute_response_with_retry( &self.retry_config, // operation per attempt |_: u32| async { let worker = match self.select_worker_with_circuit_breaker(Some(&text)) { Some(w) => w, None => { RouterMetrics::record_request_error(route, "no_available_workers"); return ( StatusCode::SERVICE_UNAVAILABLE, "No available workers (all circuits open or unhealthy)", ) .into_response(); } }; // Optional load tracking for cache-aware policy let load_incremented = if self.policy.name() == "cache_aware" { worker.increment_load(); RouterMetrics::set_running_requests(worker.url(), worker.load()); true } else { false }; let response = self .send_typed_request( headers, typed_req, route, worker.url(), is_stream, load_incremented, ) .await; worker.record_outcome(response.status().is_success()); response }, // should_retry predicate |res, _attempt| Self::is_retryable_status(res.status()), // on_backoff hook |delay, attempt| { RouterMetrics::record_retry(route); RouterMetrics::record_retry_backoff_duration(delay, attempt); }, // on_exhausted hook || RouterMetrics::record_retries_exhausted(route), ) .await; if response.status().is_success() { let duration = start.elapsed(); RouterMetrics::record_request(route); RouterMetrics::record_generate_duration(duration); } else if !Self::is_retryable_status(response.status()) { RouterMetrics::record_request_error(route, "non_retryable_error"); } response } // TODO (rui): Better accommodate to the Worker abstraction fn extract_dp_rank(worker_url: &str) -> Result<(&str, usize), String> { let parts: Vec<&str> = worker_url.split('@').collect(); if parts.len() != 2 { return Err(format!("invalid worker_url format: {}", worker_url)); } // Parse the second part (dp_rank) into an integer match parts[1].parse::() { Ok(dp_rank) => Ok((parts[0], dp_rank)), Err(_) => Err(format!( "failed to parse dp_rank from worker_url: {}", worker_url )), } } // Send typed request directly without conversion async fn send_typed_request( &self, headers: Option<&HeaderMap>, typed_req: &T, route: &str, worker_url: &str, is_stream: bool, load_incremented: bool, // Whether load was incremented for this request ) -> Response { let mut request_builder = if self.dp_aware { let (worker_url_prefix, dp_rank) = match Self::extract_dp_rank(worker_url) { Ok(tup) => tup, Err(e) => { error!("Failed to extract dp_rank: {}", e); return ( StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to extract dp_rank: {}", e), ) .into_response(); } }; // Parse the request body let mut json_val = match serde_json::to_value(typed_req) { Ok(j) => j, Err(e) => { return ( StatusCode::BAD_REQUEST, format!("Convert into serde_json::Value failed: {}", e), ) .into_response(); } }; // Insert the data_parallel_rank field if let Some(map) = json_val.as_object_mut() { map.insert( String::from("data_parallel_rank"), serde_json::json!(dp_rank), ); debug!( "Modified request body: {}", serde_json::to_string(&json_val).unwrap_or(String::from("ERR")) ); } else { return ( StatusCode::BAD_REQUEST, "Failed to insert the data_parallel_rank field into the request body", ) .into_response(); } self.client .post(format!("{}{}", worker_url_prefix, route)) .json(&json_val) } else { self.client .post(format!("{}{}", worker_url, route)) .json(typed_req) // Use json() directly with typed request }; // Copy all headers from original request if provided if let Some(headers) = headers { for (name, value) in headers { // Skip Content-Type and Content-Length as .json() sets them if *name != CONTENT_TYPE && *name != CONTENT_LENGTH { request_builder = request_builder.header(name, value); } } } let res = match request_builder.send().await { Ok(res) => res, Err(e) => { error!( "Failed to send typed request worker_url={} route={} error={}", worker_url, route, e ); // Decrement load on error if it was incremented if load_incremented { if let Ok(workers_guard) = self.workers.read() { if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) { worker.decrement_load(); RouterMetrics::set_running_requests(&worker_url, worker.load()); } } } return ( StatusCode::INTERNAL_SERVER_ERROR, format!("Request failed: {}", e), ) .into_response(); } }; let status = StatusCode::from_u16(res.status().as_u16()) .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); if !is_stream { // For non-streaming requests, get response first let response = match res.bytes().await { Ok(body) => (status, body).into_response(), Err(e) => { let error_msg = format!("Failed to get response body: {}", e); (StatusCode::INTERNAL_SERVER_ERROR, error_msg).into_response() } }; // Decrement load counter for non-streaming requests if it was incremented if load_incremented && !is_stream { if let Ok(workers_guard) = self.workers.read() { if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) { worker.decrement_load(); RouterMetrics::set_running_requests(&worker_url, worker.load()); } } } response } else if load_incremented { // For streaming with load tracking, we need to manually decrement when done let workers = Arc::clone(&self.workers); let worker_url = worker_url.to_string(); let stream = res.bytes_stream(); let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); // Spawn task to forward stream and detect completion tokio::spawn(async move { let mut stream = stream; let mut decremented = false; while let Some(chunk) = stream.next().await { match chunk { Ok(bytes) => { // Check for stream end marker if bytes .as_ref() .windows(12) .any(|window| window == b"data: [DONE]") { if let Ok(workers_guard) = workers.read() { if let Some(worker) = workers_guard.iter().find(|w| w.url() == &worker_url) { worker.decrement_load(); RouterMetrics::set_running_requests( &worker_url, worker.load(), ); decremented = true; } } } if tx.send(Ok(bytes)).is_err() { break; } } Err(e) => { let _ = tx.send(Err(format!("Stream error: {}", e))); break; } } } if !decremented { if let Ok(workers_guard) = workers.read() { if let Some(worker) = workers_guard.iter().find(|w| w.url() == &worker_url) { worker.decrement_load(); RouterMetrics::set_running_requests(&worker_url, worker.load()); } } } }); let stream = UnboundedReceiverStream::new(rx); let body = Body::from_stream(stream); let mut response = Response::new(body); *response.status_mut() = status; response .headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); response } else { // For requests without load tracking, just stream let stream = res.bytes_stream(); let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); // Spawn task to forward stream tokio::spawn(async move { let mut stream = stream; while let Some(chunk) = stream.next().await { match chunk { Ok(bytes) => { if tx.send(Ok(bytes)).is_err() { break; } } Err(e) => { let _ = tx.send(Err(format!("Stream error: {}", e))); break; } } } }); let stream = UnboundedReceiverStream::new(rx); let body = Body::from_stream(stream); let mut response = Response::new(body); *response.status_mut() = status; response .headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); response } } pub async fn add_worker(&self, worker_url: &str) -> Result { let start_time = std::time::Instant::now(); let client = reqwest::Client::builder() .timeout(Duration::from_secs(self.timeout_secs)) .build() .map_err(|e| format!("Failed to create HTTP client: {}", e))?; loop { if start_time.elapsed() > Duration::from_secs(self.timeout_secs) { error!( "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", self.timeout_secs, worker_url ); return Err(format!( "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", self.timeout_secs, worker_url )); } match client.get(&format!("{}/health", worker_url)).send().await { Ok(res) => { if res.status().is_success() { let mut workers_guard = self.workers.write().unwrap(); if self.dp_aware { // Need to contact the worker to extract the dp_size, // and add them as multiple workers let url_vec = vec![String::from(worker_url)]; let dp_url_vec = Self::get_dp_aware_workers(&url_vec, &self.api_key) .map_err(|e| format!("Failed to get dp-aware workers: {}", e))?; let mut worker_added: bool = false; for dp_url in &dp_url_vec { if workers_guard.iter().any(|w| w.url() == dp_url) { warn!("Worker {} already exists", dp_url); continue; } info!("Added worker: {}", dp_url); let new_worker = WorkerFactory::create_regular_with_config( dp_url.to_string(), self.circuit_breaker_config.clone(), ); workers_guard.push(new_worker); worker_added = true; } if !worker_added { return Err(format!("No worker added for {}", worker_url)); } } else { if workers_guard.iter().any(|w| w.url() == worker_url) { return Err(format!("Worker {} already exists", worker_url)); } info!("Added worker: {}", worker_url); let new_worker = WorkerFactory::create_regular_with_config( worker_url.to_string(), self.circuit_breaker_config.clone(), ); workers_guard.push(new_worker); } RouterMetrics::set_active_workers(workers_guard.len()); // If cache aware policy, initialize the worker in the tree if let Some(cache_aware) = self.policy .as_any() .downcast_ref::() { // Get updated workers after adding drop(workers_guard); let workers_guard = self.workers.read().unwrap(); cache_aware.init_workers(&workers_guard); } return Ok(format!("Successfully added worker: {}", worker_url)); } else { debug!( "Worker {} health check pending - status: {}", worker_url, res.status() ); // if the url does not have http or https prefix, warn users if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") { warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url); } tokio::time::sleep(Duration::from_secs(self.interval_secs)).await; continue; } } Err(e) => { debug!("Worker {} health check pending - error: {}", worker_url, e); // if the url does not have http or https prefix, warn users if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") { warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url); } tokio::time::sleep(Duration::from_secs(self.interval_secs)).await; continue; } } } } pub fn remove_worker(&self, worker_url: &str) { if self.dp_aware { // remove dp-aware workers in a prefix-matching fashion // without contacting the remote worker let mut candidate_workers: Vec = Vec::new(); let mut removed_workers: Vec = Vec::new(); let worker_url_prefix = format!("{}@", worker_url); { // find the candidate workers to be removed let workers_guard = self.workers.read().unwrap(); for w in workers_guard.iter() { if w.url().starts_with(&worker_url_prefix) { candidate_workers.push(w.url().to_string()); } } } { // do the removing on the worker_urls let mut workers_guard = self.workers.write().unwrap(); for dp_url in candidate_workers.iter() { if let Some(index) = workers_guard.iter().position(|w| w.url() == dp_url) { workers_guard.remove(index); info!("Removed worker: {}", dp_url); removed_workers.push(dp_url.to_string()); } else { warn!("Worker {} not found, skipping removal", dp_url); continue; } } RouterMetrics::set_active_workers(workers_guard.len()); } // If cache aware policy, remove the workers from the tree if let Some(cache_aware) = self .policy .as_any() .downcast_ref::() { for dp_url in removed_workers.iter() { cache_aware.remove_worker(dp_url); info!("Removed worker from tree: {}", dp_url); } } } else { let mut workers_guard = self.workers.write().unwrap(); if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) { workers_guard.remove(index); info!("Removed worker: {}", worker_url); RouterMetrics::set_active_workers(workers_guard.len()); } else { warn!("Worker {} not found, skipping removal", worker_url); return; } // If cache aware policy, remove the workers from the tree if let Some(cache_aware) = self .policy .as_any() .downcast_ref::() { cache_aware.remove_worker(worker_url); info!("Removed worker from tree: {}", worker_url); } } } async fn get_worker_load(&self, worker_url: &str) -> Option { let worker_url = if self.dp_aware { // Need to extract the URL from "http://host:port@dp_rank" let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) { Ok(tup) => tup, Err(e) => { error!("Failed to extract dp_rank: {}", e); return None; } }; worker_url_prefix } else { worker_url }; match self .client .get(&format!("{}/get_load", worker_url)) .send() .await { Ok(res) if res.status().is_success() => match res.bytes().await { Ok(bytes) => match serde_json::from_slice::(&bytes) { Ok(data) => data .get("load") .and_then(|v| v.as_i64()) .map(|v| v as isize), Err(e) => { debug!("Failed to parse load response from {}: {}", worker_url, e); None } }, Err(e) => { debug!("Failed to read load response from {}: {}", worker_url, e); None } }, Ok(res) => { debug!( "Worker {} returned non-success status: {}", worker_url, res.status() ); None } Err(e) => { debug!("Failed to get load from {}: {}", worker_url, e); None } } } // Background task to monitor worker loads async fn monitor_worker_loads( worker_urls: Vec, tx: tokio::sync::watch::Sender>, interval_secs: u64, policy: Arc, client: Client, ) { let mut interval = tokio::time::interval(Duration::from_secs(interval_secs)); loop { interval.tick().await; let mut loads = HashMap::new(); for url in &worker_urls { if let Some(load) = Self::get_worker_load_static(&client, url).await { loads.insert(url.clone(), load); } } if !loads.is_empty() { // Update policy with new loads policy.update_loads(&loads); // Send to watchers if let Err(e) = tx.send(loads) { error!("Failed to send load update: {}", e); } } } } // Static version of get_worker_load for use in monitoring task async fn get_worker_load_static(client: &reqwest::Client, worker_url: &str) -> Option { let worker_url = if worker_url.contains("@") { // Need to extract the URL from "http://host:port@dp_rank" let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) { Ok(tup) => tup, Err(e) => { debug!("Failed to extract dp_rank: {}", e); return None; } }; worker_url_prefix } else { worker_url }; match client.get(&format!("{}/get_load", worker_url)).send().await { Ok(res) if res.status().is_success() => match res.bytes().await { Ok(bytes) => match serde_json::from_slice::(&bytes) { Ok(data) => data .get("load") .and_then(|v| v.as_i64()) .map(|v| v as isize), Err(e) => { debug!("Failed to parse load response from {}: {}", worker_url, e); None } }, Err(e) => { debug!("Failed to read load response from {}: {}", worker_url, e); None } }, Ok(res) => { debug!( "Worker {} returned non-success status: {}", worker_url, res.status() ); None } Err(e) => { debug!("Failed to get load from {}: {}", worker_url, e); None } } } } use async_trait::async_trait; #[async_trait] impl WorkerManagement for Router { async fn add_worker(&self, worker_url: &str) -> Result { Router::add_worker(self, worker_url).await } fn remove_worker(&self, worker_url: &str) { Router::remove_worker(self, worker_url) } fn get_worker_urls(&self) -> Vec { Router::get_worker_urls(self) } } #[async_trait] impl RouterTrait for Router { fn as_any(&self) -> &dyn std::any::Any { self } async fn health(&self, _req: Request) -> Response { let workers = self.workers.read().unwrap(); let unhealthy_servers: Vec<_> = workers .iter() .filter(|w| !w.is_healthy()) .map(|w| w.url().to_string()) .collect(); if unhealthy_servers.is_empty() { (StatusCode::OK, "All servers healthy").into_response() } else { ( StatusCode::SERVICE_UNAVAILABLE, format!("Unhealthy servers: {:?}", unhealthy_servers), ) .into_response() } } async fn health_generate(&self, req: Request) -> Response { self.proxy_get_request(req, "health_generate").await } async fn get_server_info(&self, req: Request) -> Response { self.proxy_get_request(req, "get_server_info").await } async fn get_models(&self, req: Request) -> Response { self.proxy_get_request(req, "v1/models").await } async fn get_model_info(&self, req: Request) -> Response { self.proxy_get_request(req, "get_model_info").await } async fn route_generate( &self, headers: Option<&HeaderMap>, body: &GenerateRequest, ) -> Response { self.route_typed_request(headers, body, "/generate").await } async fn route_chat( &self, headers: Option<&HeaderMap>, body: &ChatCompletionRequest, ) -> Response { self.route_typed_request(headers, body, "/v1/chat/completions") .await } async fn route_completion( &self, headers: Option<&HeaderMap>, body: &CompletionRequest, ) -> Response { self.route_typed_request(headers, body, "/v1/completions") .await } async fn flush_cache(&self) -> Response { // Get all worker URLs let worker_urls = self.get_worker_urls(); // Send requests to all workers concurrently without headers let mut tasks = Vec::new(); for worker_url in &worker_urls { let worker_url = if self.dp_aware { // Need to extract the URL from "http://host:port@dp_rank" let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) { Ok(tup) => tup, Err(e) => { error!("Failed to extract dp_rank: {}", e); return ( StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to extract dp_rank: {}", e), ) .into_response(); } }; worker_url_prefix } else { worker_url }; let request_builder = self.client.post(format!("{}/flush_cache", worker_url)); tasks.push(request_builder.send()); } // Wait for all responses let results = futures_util::future::join_all(tasks).await; // Check if all succeeded let all_success = results.iter().all(|r| { r.as_ref() .map(|res| res.status().is_success()) .unwrap_or(false) }); if all_success { (StatusCode::OK, "Cache flushed on all servers").into_response() } else { ( StatusCode::INTERNAL_SERVER_ERROR, "Cache flush failed on one or more servers", ) .into_response() } } async fn get_worker_loads(&self) -> Response { let urls = self.get_worker_urls(); let mut loads = Vec::new(); // Get loads from all workers for url in &urls { let load = self.get_worker_load(url).await.unwrap_or(-1); loads.push(serde_json::json!({ "worker": url, "load": load })); } Json(serde_json::json!({ "workers": loads })) .into_response() } fn router_type(&self) -> &'static str { "regular" } fn readiness(&self) -> Response { // Regular router is ready if it has at least one healthy worker let healthy_count = self .workers .read() .unwrap() .iter() .filter(|w| w.is_healthy()) .count(); if healthy_count > 0 { Json(serde_json::json!({ "status": "ready", "healthy_workers": healthy_count, "total_workers": self.workers.read().unwrap().len() })) .into_response() } else { ( StatusCode::SERVICE_UNAVAILABLE, Json(serde_json::json!({ "status": "not_ready", "reason": "no healthy workers available", "total_workers": self.workers.read().unwrap().len() })), ) .into_response() } } } #[cfg(test)] mod tests { use super::*; use crate::policies::RandomPolicy; use std::collections::HashMap; fn create_test_regular_router() -> Router { let workers = vec![ WorkerFactory::create_regular("http://worker1:8080".to_string()), WorkerFactory::create_regular("http://worker2:8080".to_string()), ]; let (_, rx) = tokio::sync::watch::channel(HashMap::new()); Router { workers: Arc::new(RwLock::new(workers)), policy: Arc::new(RandomPolicy::new()), timeout_secs: 5, interval_secs: 1, dp_aware: false, api_key: None, client: Client::new(), retry_config: RetryConfig::default(), circuit_breaker_config: CircuitBreakerConfig::default(), _worker_loads: Arc::new(rx), _load_monitor_handle: None, _health_checker: None, } } #[test] fn test_router_get_worker_urls_regular() { let router = create_test_regular_router(); let urls = router.get_worker_urls(); assert_eq!(urls.len(), 2); assert!(urls.contains(&"http://worker1:8080".to_string())); assert!(urls.contains(&"http://worker2:8080".to_string())); } #[test] fn test_select_first_worker_regular() { let router = create_test_regular_router(); let result = router.select_first_worker(); assert!(result.is_ok()); assert_eq!(result.unwrap(), "http://worker1:8080"); } #[test] fn test_wait_for_healthy_workers_empty_list() { // Empty list will timeout as there are no workers to check let result = Router::wait_for_healthy_workers(&[], 1, 1); assert!(result.is_err()); assert!(result.unwrap_err().contains("Timeout")); } #[test] fn test_wait_for_healthy_workers_invalid_urls() { // This test will timeout quickly since the URLs are invalid let result = Router::wait_for_healthy_workers(&["http://nonexistent:8080".to_string()], 1, 1); assert!(result.is_err()); assert!(result.unwrap_err().contains("Timeout")); } }