use crate::tree::Tree; use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; use actix_web::{HttpRequest, HttpResponse}; use bytes::Bytes; use futures_util::{StreamExt, TryStreamExt}; use log::{debug, info}; use std::collections::HashMap; use std::fmt::Debug; use std::sync::atomic::AtomicUsize; use std::sync::{Arc, Mutex, RwLock}; use std::thread; use std::time::Duration; #[derive(Debug)] pub enum Router { RoundRobin { worker_urls: Arc>>, current_index: AtomicUsize, }, Random { worker_urls: Arc>>, }, CacheAware { /* Cache-Aware Load Balancing Router This router combines two strategies to optimize both cache utilization and request distribution: 1. Cache-Aware Routing (Approximate Tree) 2. Load Balancing (Shortest Queue with Balance Thresholds) The router dynamically switches between these strategies based on load conditions: - Uses load balancing when the system is imbalanced - Uses cache-aware routing when the system is balanced A system is considered imbalanced if both conditions are met: 1. (max - min) > abs_threshold 2. max > rel_threshold * min Strategy Details: 1. Cache-Aware Routing (Approximate Tree) ------------------------------------------- This strategy maintains an approximate radix tree for each worker based on request history, eliminating the need for direct cache state queries. The tree stores raw text characters instead of token IDs to avoid tokenization overhead. Process: a. For each request, find the worker with the highest prefix match b. If match rate > cache_threshold: Route to the worker with highest match (likely has relevant data cached) c. If match rate ≤ cache_threshold: Route to the worker with smallest tree size (most available cache capacity) d. Background maintenance: Periodically evict least recently used leaf nodes to prevent memory overflow 2. Load Balancing (Shortest Queue) ------------------------------------------- This strategy tracks pending request counts per worker and routes new requests to the least busy worker when the system is detected to be imbalanced. Configuration Parameters: ------------------------ 1. cache_threshold: (float, 0.0 to 1.0) Minimum prefix match ratio to use highest-match routing. Below this threshold, routes to worker with most available cache space. 2. balance_abs_threshold: (integer) Absolute difference threshold for load imbalance detection. System is potentially imbalanced if (max_load - min_load) > abs_threshold 3. balance_rel_threshold: (float) Relative ratio threshold for load imbalance detection. System is potentially imbalanced if max_load > min_load * rel_threshold Used in conjunction with abs_threshold to determine final imbalance state. 4. eviction_interval_secs: (integer) Interval between LRU eviction cycles for the approximate trees. 5. max_tree_size: (integer) Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted during the next eviction cycle. */ worker_urls: Arc>>, tree: Arc>, running_queue: Arc>>, processed_queue: Arc>>, cache_threshold: f32, balance_abs_threshold: usize, balance_rel_threshold: f32, _eviction_thread: Option>, }, } #[derive(Debug)] pub enum PolicyConfig { RandomConfig, RoundRobinConfig, CacheAwareConfig { cache_threshold: f32, balance_abs_threshold: usize, balance_rel_threshold: f32, eviction_interval_secs: u64, max_tree_size: usize, }, } fn get_text_from_request(body: &Bytes, route: &str) -> String { // convert body to json let json = serde_json::from_slice::(body).unwrap(); if route == "generate" { // get the "text" field let text = json.get("text").and_then(|t| t.as_str()).unwrap_or(""); return text.to_string(); } else if route == "v1/chat/completions" { // get the messages field as raw text if let Some(messages) = json.get("messages") { // Convert messages back to a string, preserving all JSON formatting return serde_json::to_string(messages).unwrap_or_default(); } } else if route == "v1/completions" { let prompt = json.get("prompt").and_then(|t| t.as_str()).unwrap_or(""); return prompt.to_string(); } return "".to_string(); } impl Router { pub fn new(worker_urls: Vec, policy_config: PolicyConfig) -> Self { match policy_config { PolicyConfig::RandomConfig => Router::Random { worker_urls: Arc::new(RwLock::new(worker_urls)), }, PolicyConfig::RoundRobinConfig => Router::RoundRobin { worker_urls: Arc::new(RwLock::new(worker_urls)), current_index: std::sync::atomic::AtomicUsize::new(0), }, PolicyConfig::CacheAwareConfig { cache_threshold, balance_abs_threshold, balance_rel_threshold, eviction_interval_secs, max_tree_size, } => { let mut running_queue = HashMap::new(); for url in &worker_urls { running_queue.insert(url.clone(), 0); } let mut processed_queue = HashMap::new(); for url in &worker_urls { processed_queue.insert(url.clone(), 0); } let tree = Arc::new(Mutex::new(Tree::new())); let running_queue = Arc::new(Mutex::new(running_queue)); let processed_queue = Arc::new(Mutex::new(processed_queue)); // Create background eviction thread let tree_clone = Arc::clone(&tree); let processed_queue_clone = Arc::clone(&processed_queue); let running_queue_clone = Arc::clone(&running_queue); let eviction_thread = thread::spawn(move || { loop { // Sleep for the specified interval thread::sleep(Duration::from_secs(eviction_interval_secs)); let locked_tree_clone = tree_clone.lock().unwrap(); // Run eviction locked_tree_clone.evict_tenant_by_size(max_tree_size); // Print the process queue let locked_processed_queue = processed_queue_clone.lock().unwrap(); info!("Processed Queue: {:?}", locked_processed_queue); // Print the running queue let locked_running_queue = running_queue_clone.lock().unwrap(); info!("Running Queue: {:?}", locked_running_queue); } }); for url in &worker_urls { tree.lock().unwrap().insert(&"".to_string(), url); } Router::CacheAware { worker_urls: Arc::new(RwLock::new(worker_urls)), tree, running_queue, processed_queue, cache_threshold, balance_abs_threshold, balance_rel_threshold, _eviction_thread: Some(eviction_thread), } } } } pub fn get_first(&self) -> Option { match self { Router::RoundRobin { worker_urls, .. } | Router::Random { worker_urls } | Router::CacheAware { worker_urls, .. } => { if worker_urls.read().unwrap().is_empty() { None } else { Some(worker_urls.read().unwrap()[0].clone()) } } } } pub async fn dispatch( &self, client: &reqwest::Client, req: HttpRequest, body: Bytes, route: &str, ) -> HttpResponse { let text = get_text_from_request(&body, route); let worker_url = match self { Router::RoundRobin { worker_urls, current_index, } => { let idx = current_index .fetch_update( std::sync::atomic::Ordering::SeqCst, std::sync::atomic::Ordering::SeqCst, |x| Some((x + 1) % worker_urls.read().unwrap().len()), ) .unwrap(); worker_urls.read().unwrap()[idx].clone() } Router::Random { worker_urls } => worker_urls.read().unwrap() [rand::random::() % worker_urls.read().unwrap().len()] .clone(), Router::CacheAware { worker_urls, tree, running_queue, processed_queue, cache_threshold, balance_abs_threshold, balance_rel_threshold, .. } => { // TODO: delay scheduling if cache hit rate is high because it may cause imbalance. prioritize low hit rate ones let tree = tree.lock().unwrap(); let mut running_queue = running_queue.lock().unwrap(); // Get current load statistics let max_load = *running_queue.values().max().unwrap_or(&0); let min_load = *running_queue.values().min().unwrap_or(&0); // Load is considered imbalanced if: // 1. (max - min) > abs_threshold AND // 2. max > rel_threshold * min let is_imbalanced = max_load.saturating_sub(min_load) > *balance_abs_threshold && (max_load as f32) > (min_load as f32 * balance_rel_threshold); let selected_url = if is_imbalanced { // Log load balancing trigger and current queue state info!( "Load balancing triggered due to workload imbalance:\n\ Max load: {}, Min load: {}\n\ Current running queue: {:?}", max_load, min_load, running_queue ); // Use shortest queue routing when load is imbalanced running_queue .iter() .min_by_key(|(_url, &count)| count) .map(|(url, _)| url.clone()) .unwrap_or_else(|| worker_urls.read().unwrap()[0].clone()) } else { // Use cache-aware routing when load is balanced let (matched_text, matched_worker) = tree.prefix_match(&text); let matched_rate = matched_text.chars().count() as f32 / text.chars().count() as f32; if matched_rate > *cache_threshold { matched_worker.to_string() } else { tree.get_smallest_tenant() } }; // Update queues and tree *running_queue.get_mut(&selected_url).unwrap() += 1; *processed_queue .lock() .unwrap() .get_mut(&selected_url) .unwrap() += 1; tree.insert(&text, &selected_url); selected_url } }; let is_stream = serde_json::from_slice::(&body) .map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false)) .unwrap_or(false); let res = match client .post(format!("{}/{}", worker_url.clone(), route)) .header( "Content-Type", req.headers() .get("Content-Type") .and_then(|h| h.to_str().ok()) .unwrap_or("application/json"), ) .body(body.to_vec()) .send() .await { Ok(res) => res, Err(_) => return HttpResponse::InternalServerError().finish(), }; let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); if !is_stream { // For non-streaming requests, get response first let response = match res.bytes().await { Ok(body) => HttpResponse::build(status).body(body.to_vec()), Err(e) => { let error_msg = format!("Failed to get response body: {}", e); HttpResponse::InternalServerError().body(error_msg) } }; // Then decrement running queue counter if using CacheAware if let Router::CacheAware { running_queue, .. } = self { if let Ok(mut queue) = running_queue.lock() { if let Some(count) = queue.get_mut(&worker_url) { *count = count.saturating_sub(1); } } } response } else if let Router::CacheAware { running_queue, .. } = self { let running_queue = Arc::clone(running_queue); let worker_url = worker_url.clone(); HttpResponse::build(status) .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) .streaming( res.bytes_stream() .map_err(|_| { actix_web::error::ErrorInternalServerError("Failed to read stream") }) .inspect(move |bytes| { let bytes = bytes.as_ref().unwrap(); if bytes .as_ref() .windows(12) .any(|window| window == b"data: [DONE]") { let mut locked_queue = running_queue.lock().unwrap(); let count = locked_queue.get_mut(&worker_url).unwrap(); *count = count.saturating_sub(1); debug!("streaming is done!!") } }), ) } else { HttpResponse::build(status) .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) .streaming(res.bytes_stream().map_err(|_| { actix_web::error::ErrorInternalServerError("Failed to read stream") })) } } pub fn add_worker(&self, worker_url: String) { match self { Router::RoundRobin { worker_urls, .. } | Router::Random { worker_urls } | Router::CacheAware { worker_urls, .. } => { let mut urls = worker_urls.write().unwrap(); info!("Added worker: {}", worker_url); urls.push(worker_url); } } } pub fn remove_worker(&self, worker_url: String) { match self { Router::RoundRobin { worker_urls, .. } | Router::Random { worker_urls } | Router::CacheAware { worker_urls, .. } => { let mut urls = worker_urls.write().unwrap(); let index = urls.iter().position(|url| url == &worker_url).unwrap(); urls.remove(index); info!("Removed worker: {}", worker_url); } } // if cache aware, remove the worker from the tree if let Router::CacheAware { tree, .. } = self { tree.lock().unwrap().remove_tenant(&worker_url); info!("Removed worker from tree: {}", worker_url); } } }