Unverified Commit 7b7e5615 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] fix radix tree integration issues in PD router (#8982)

parent 1a8706c8
......@@ -112,7 +112,7 @@ impl CacheAwarePolicy {
}
}
/// Initialize the tree with worker URLs
/// Initialize the tree with worker URLs (used only during initial setup)
pub fn init_workers(&self, workers: &[Box<dyn Worker>]) {
if let Ok(tree) = self.tree.lock() {
for worker in workers {
......@@ -121,6 +121,13 @@ impl CacheAwarePolicy {
}
}
/// Add a single worker to the tree (incremental update)
pub fn add_worker(&self, url: &str) {
if let Ok(tree) = self.tree.lock() {
tree.insert("", url);
}
}
/// Remove a worker from the tree
pub fn remove_worker(&self, url: &str) {
if let Ok(tree) = self.tree.lock() {
......@@ -178,6 +185,13 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
.min_by_key(|&&idx| workers[idx].load())
.copied()?;
// Even in imbalanced mode, update the tree to maintain cache state
if let Some(text) = request_text {
if let Ok(tree) = self.tree.lock() {
tree.insert(text, workers[min_load_idx].url());
}
}
// Increment processed counter
workers[min_load_idx].increment_processed();
RouterMetrics::record_processed_request(workers[min_load_idx].url());
......@@ -206,21 +220,26 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
};
// Find the index of the selected worker
let selected_idx = workers.iter().position(|w| w.url() == selected_url)?;
if let Some(selected_idx) = workers.iter().position(|w| w.url() == selected_url) {
// Only proceed if the worker is healthy
if workers[selected_idx].is_healthy() {
// Update the tree with this request
tree.insert(text, &selected_url);
// Only proceed if the worker is healthy
if !workers[selected_idx].is_healthy() {
return healthy_indices.first().copied();
}
// Update the tree with this request
tree.insert(text, &selected_url);
// Increment processed counter
workers[selected_idx].increment_processed();
RouterMetrics::record_processed_request(&selected_url);
// Increment processed counter
workers[selected_idx].increment_processed();
RouterMetrics::record_processed_request(&selected_url);
return Some(selected_idx);
}
} else {
// Selected worker no longer exists, remove it from tree
tree.remove_tenant(&selected_url);
debug!("Removed stale worker {} from cache tree", selected_url);
}
return Some(selected_idx);
// Fallback to first healthy worker
return healthy_indices.first().copied();
}
// Fallback to first healthy worker if tree operations fail
......
......@@ -7,7 +7,6 @@ use crate::metrics::RouterMetrics;
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::policies::LoadBalancingPolicy;
use crate::routers::{RouterTrait, WorkerManagement};
use crate::tree::Tree;
use async_trait::async_trait;
use axum::{
body::Body,
......@@ -20,7 +19,7 @@ use futures_util::StreamExt;
use reqwest::Client;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, RwLock};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, info, warn};
......@@ -31,8 +30,6 @@ pub struct PDRouter {
pub decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
pub prefill_policy: Arc<dyn LoadBalancingPolicy>,
pub decode_policy: Arc<dyn LoadBalancingPolicy>,
pub prefill_tree: Option<Arc<Mutex<Tree>>>,
pub decode_tree: Option<Arc<Mutex<Tree>>>,
pub timeout_secs: u64,
pub interval_secs: u64,
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
......@@ -91,9 +88,14 @@ impl PDRouter {
workers.push(worker);
// Add to cache tree if using cache-aware policy for prefill
if let Some(ref tree) = self.prefill_tree {
tree.lock().unwrap().insert("", &url);
// Update cache-aware policy if applicable
drop(workers); // Release write lock
if let Some(cache_policy) = self
.prefill_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_policy.add_worker(&url);
}
info!("Added prefill server: {}", url);
......@@ -125,9 +127,14 @@ impl PDRouter {
workers.push(worker);
// Add to cache tree if using cache-aware policy for decode
if let Some(ref tree) = self.decode_tree {
tree.lock().unwrap().insert("", &url);
// Update cache-aware policy if applicable
drop(workers); // Release write lock
if let Some(cache_policy) = self
.decode_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_policy.add_worker(&url);
}
info!("Added decode server: {}", url);
......@@ -152,9 +159,13 @@ impl PDRouter {
});
}
// Remove from cache tree if using cache-aware policy
if let Some(ref tree) = self.prefill_tree {
tree.lock().unwrap().remove_tenant(url);
// Remove from cache-aware policy if applicable
if let Some(cache_policy) = self
.prefill_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_policy.remove_worker(url);
}
info!("Removed prefill server: {}", url);
......@@ -179,9 +190,13 @@ impl PDRouter {
});
}
// Remove from the cache tree if using cache-aware policy for decode
if let Some(ref tree) = self.decode_tree {
tree.lock().unwrap().remove_tenant(url);
// Remove from cache-aware policy if applicable
if let Some(cache_policy) = self
.decode_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_policy.remove_worker(url);
}
info!("Removed decode server: {}", url);
......@@ -238,11 +253,20 @@ impl PDRouter {
)?;
}
// Initialize cache-aware components if needed for prefill policy
let prefill_tree = Self::initialize_radix_tree(&prefill_policy, &prefill_workers)?;
// Initialize cache-aware policies with workers
if let Some(cache_policy) = prefill_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_policy.init_workers(&prefill_workers);
}
// Initialize cache-aware components if needed for decode policy
let decode_tree = Self::initialize_radix_tree(&decode_policy, &decode_workers)?;
if let Some(cache_policy) = decode_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_policy.init_workers(&decode_workers);
}
// Set up background load monitoring for power-of-two selection
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
......@@ -294,8 +318,6 @@ impl PDRouter {
decode_workers,
prefill_policy,
decode_policy,
prefill_tree,
decode_tree,
timeout_secs,
interval_secs,
worker_loads,
......@@ -309,35 +331,6 @@ impl PDRouter {
})
}
// Helper function to initialize radix tree for cache-aware policies
fn initialize_radix_tree(
policy: &Arc<dyn LoadBalancingPolicy>,
workers: &[Box<dyn Worker>],
) -> Result<Option<Arc<Mutex<Tree>>>, String> {
if let Some(cache_policy) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
// Initialize the policy's internal tree with workers
cache_policy.init_workers(workers);
let tree = Arc::new(Mutex::new(Tree::new()));
{
let tree_guard = tree
.lock()
.map_err(|e| format!("Failed to lock tree: {}", e))?;
for worker in workers {
tree_guard.insert("", worker.url());
}
}
Ok(Some(tree))
} else {
Ok(None)
}
}
// Helper to handle server selection errors
fn handle_server_selection_error(error: String) -> Response {
error!("Failed to select PD pair error={}", error);
......@@ -1863,8 +1856,6 @@ mod tests {
decode_workers: Arc::new(RwLock::new(vec![])),
prefill_policy,
decode_policy,
prefill_tree: None,
decode_tree: None,
timeout_secs: 5,
interval_secs: 1,
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
......@@ -2002,105 +1993,6 @@ mod tests {
}
}
// ============= Cache Tree Integration Tests =============
#[tokio::test]
async fn test_cache_tree_operations() {
let cache_policy = Arc::new(CacheAwarePolicy::new());
let mut router = create_test_pd_router();
router.prefill_policy = cache_policy;
// Initialize cache tree
let tree = Arc::new(Mutex::new(Tree::new()));
router.prefill_tree = Some(Arc::clone(&tree));
// Manually add worker and update tree
let worker = create_test_worker(
"http://worker1".to_string(),
WorkerType::Prefill {
bootstrap_port: None,
},
true,
);
router.prefill_workers.write().unwrap().push(worker);
// Update tree
tree.lock().unwrap().insert("", "http://worker1");
// Verify tree contains the worker
let tree_guard = tree.lock().unwrap();
let (_matched_text, tenant) = tree_guard.prefix_match("");
// Since we inserted with empty prefix, we should get a match
assert_eq!(tenant, "http://worker1");
}
#[tokio::test]
async fn test_cache_tree_rebuild_on_remove() {
let cache_policy = Arc::new(CacheAwarePolicy::new());
let mut router = create_test_pd_router();
router.prefill_policy = cache_policy;
// Initialize cache tree
let tree = Arc::new(Mutex::new(Tree::new()));
router.prefill_tree = Some(Arc::clone(&tree));
// Add multiple workers
let worker1 = create_test_worker(
"http://worker1".to_string(),
WorkerType::Prefill {
bootstrap_port: None,
},
true,
);
let worker2 = create_test_worker(
"http://worker2".to_string(),
WorkerType::Prefill {
bootstrap_port: None,
},
true,
);
router.prefill_workers.write().unwrap().push(worker1);
router.prefill_workers.write().unwrap().push(worker2);
// Initialize tree with both workers
{
let tree_guard = tree.lock().unwrap();
tree_guard.insert("", "http://worker1");
tree_guard.insert("", "http://worker2");
}
// Remove one worker
let result = router.remove_prefill_server("http://worker1").await;
assert!(result.is_ok());
// Verify tree only contains remaining worker
let tree_guard = tree.lock().unwrap();
let (_matched_text, tenant) = tree_guard.prefix_match("");
// After rebuild, tree should only have worker2
assert_eq!(tenant, "http://worker2");
}
#[tokio::test]
async fn test_no_cache_tree_operations() {
let router = create_test_pd_router();
assert!(router.prefill_tree.is_none());
// Add a worker without cache tree
let worker = create_test_worker(
"http://worker1".to_string(),
WorkerType::Prefill {
bootstrap_port: None,
},
true,
);
router.prefill_workers.write().unwrap().push(worker);
// Remove should work without tree
let result = router.remove_prefill_server("http://worker1").await;
assert!(result.is_ok());
}
// ============= Bootstrap Injection Tests =============
// Note: These tests are commented out as we've moved to the optimized bootstrap injection
// approach that doesn't use the Bootstrap trait on GenerateReqInput anymore.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment