Unverified Commit 5c8365a0 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] add ut for pd router (#8208)

parent 8430bfe3
......@@ -1393,3 +1393,515 @@ impl RouterTrait for PDRouter {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{BasicWorker, WorkerType};
use crate::policies::{CacheAwarePolicy, RandomPolicy};
use crate::routers::pd_types::SingleOrBatch;
use actix_web::test::TestRequest;
fn create_test_pd_router() -> PDRouter {
let policy = Arc::new(RandomPolicy::new());
PDRouter {
prefill_workers: Arc::new(RwLock::new(vec![])),
decode_workers: Arc::new(RwLock::new(vec![])),
policy,
prefill_tree: None,
timeout_secs: 5,
interval_secs: 1,
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
load_monitor_handle: None,
http_client: reqwest::Client::new(),
_prefill_health_checker: None,
_decode_health_checker: None,
}
}
fn create_test_worker(url: String, worker_type: WorkerType, healthy: bool) -> Box<dyn Worker> {
let worker = BasicWorker::new(url, worker_type);
worker.set_healthy(healthy);
Box::new(worker)
}
// ============= Worker Management Tests =============
#[tokio::test]
async fn test_add_prefill_server_already_exists() {
let router = create_test_pd_router();
// Add a worker first
let worker = create_test_worker(
"http://localhost:8000".to_string(),
WorkerType::Prefill {
bootstrap_port: Some(8080),
},
true,
);
router.prefill_workers.write().unwrap().push(worker);
// Try to add the same URL again - this would fail during health check in real scenario
// For unit test, we test the duplicate check logic
let workers = router.prefill_workers.read().unwrap();
let exists = workers.iter().any(|w| w.url() == "http://localhost:8000");
assert!(exists);
}
#[tokio::test]
async fn test_remove_prefill_server_success() {
let router = create_test_pd_router();
// Add servers first
let worker1 = create_test_worker(
"http://worker1".to_string(),
WorkerType::Prefill {
bootstrap_port: None,
},
true,
);
let worker2 = create_test_worker(
"http://worker2".to_string(),
WorkerType::Prefill {
bootstrap_port: Some(8080),
},
true,
);
router.prefill_workers.write().unwrap().push(worker1);
router.prefill_workers.write().unwrap().push(worker2);
// Remove one
let result = router.remove_prefill_server("http://worker1").await;
assert!(result.is_ok());
assert!(result.unwrap().contains("Successfully removed"));
let workers = router.prefill_workers.read().unwrap();
assert_eq!(workers.len(), 1);
assert_eq!(workers[0].url(), "http://worker2");
}
#[tokio::test]
async fn test_remove_prefill_server_not_found() {
let router = create_test_pd_router();
let result = router.remove_prefill_server("http://nonexistent").await;
assert!(result.is_err());
match result.unwrap_err() {
PDRouterError::WorkerNotFound { url } => {
assert_eq!(url, "http://nonexistent");
}
_ => panic!("Expected WorkerNotFound error"),
}
}
#[tokio::test]
async fn test_remove_decode_server_success() {
let router = create_test_pd_router();
// Add server first
let worker = create_test_worker("http://decode1".to_string(), WorkerType::Decode, true);
router.decode_workers.write().unwrap().push(worker);
let result = router.remove_decode_server("http://decode1").await;
assert!(result.is_ok());
assert!(result.unwrap().contains("Successfully removed"));
let workers = router.decode_workers.read().unwrap();
assert_eq!(workers.len(), 0);
}
// ============= Lock Error Handling Tests =============
#[test]
fn test_lock_operations() {
let router = create_test_pd_router();
// Test read/write locks work correctly
{
let read_guard = router.prefill_workers.read().unwrap();
assert_eq!(read_guard.len(), 0);
}
{
let mut write_guard = router.prefill_workers.write().unwrap();
write_guard.push(create_test_worker(
"http://test".to_string(),
WorkerType::Prefill {
bootstrap_port: None,
},
true,
));
}
{
let read_guard = router.prefill_workers.read().unwrap();
assert_eq!(read_guard.len(), 1);
}
}
// ============= Cache Tree Integration Tests =============
#[tokio::test]
async fn test_cache_tree_operations() {
let policy = Arc::new(CacheAwarePolicy::new());
let mut router = create_test_pd_router();
router.policy = policy;
// Initialize cache tree
let tree = Arc::new(Mutex::new(Tree::new()));
router.prefill_tree = Some(Arc::clone(&tree));
// Manually add worker and update tree
let worker = create_test_worker(
"http://worker1".to_string(),
WorkerType::Prefill {
bootstrap_port: None,
},
true,
);
router.prefill_workers.write().unwrap().push(worker);
// Update tree
tree.lock().unwrap().insert("", "http://worker1");
// Verify tree contains the worker
let tree_guard = tree.lock().unwrap();
let (_matched_text, tenant) = tree_guard.prefix_match("");
// Since we inserted with empty prefix, we should get a match
assert_eq!(tenant, "http://worker1");
}
#[tokio::test]
async fn test_cache_tree_rebuild_on_remove() {
let policy = Arc::new(CacheAwarePolicy::new());
let mut router = create_test_pd_router();
router.policy = policy;
// Initialize cache tree
let tree = Arc::new(Mutex::new(Tree::new()));
router.prefill_tree = Some(Arc::clone(&tree));
// Add multiple workers
let worker1 = create_test_worker(
"http://worker1".to_string(),
WorkerType::Prefill {
bootstrap_port: None,
},
true,
);
let worker2 = create_test_worker(
"http://worker2".to_string(),
WorkerType::Prefill {
bootstrap_port: None,
},
true,
);
router.prefill_workers.write().unwrap().push(worker1);
router.prefill_workers.write().unwrap().push(worker2);
// Initialize tree with both workers
{
let tree_guard = tree.lock().unwrap();
tree_guard.insert("", "http://worker1");
tree_guard.insert("", "http://worker2");
}
// Remove one worker
let result = router.remove_prefill_server("http://worker1").await;
assert!(result.is_ok());
// Verify tree only contains remaining worker
let tree_guard = tree.lock().unwrap();
let (_matched_text, tenant) = tree_guard.prefix_match("");
// After rebuild, tree should only have worker2
assert_eq!(tenant, "http://worker2");
}
#[tokio::test]
async fn test_no_cache_tree_operations() {
let router = create_test_pd_router();
assert!(router.prefill_tree.is_none());
// Add a worker without cache tree
let worker = create_test_worker(
"http://worker1".to_string(),
WorkerType::Prefill {
bootstrap_port: None,
},
true,
);
router.prefill_workers.write().unwrap().push(worker);
// Remove should work without tree
let result = router.remove_prefill_server("http://worker1").await;
assert!(result.is_ok());
}
// ============= Bootstrap Injection Tests =============
#[test]
fn test_bootstrap_injection_with_existing_fields() {
let mut req = GenerateReqInput {
text: Some(SingleOrBatch::Single("Test".to_string())),
input_ids: None,
stream: false,
bootstrap_host: Some(SingleOrBatch::Single("existing-host".to_string())),
bootstrap_port: Some(SingleOrBatch::Single(Some(9999))),
bootstrap_room: Some(SingleOrBatch::Single(12345)),
other: Value::Object(serde_json::Map::new()),
};
let prefill_worker = create_test_worker(
"http://new-host:8000".to_string(),
WorkerType::Prefill {
bootstrap_port: Some(8080),
},
true,
);
// Bootstrap info is added regardless of existing fields
let result = req.add_bootstrap_info(prefill_worker.as_ref());
assert!(result.is_ok());
// Bootstrap info should be updated with new values
assert_eq!(
req.bootstrap_host,
Some(SingleOrBatch::Single("new-host".to_string()))
);
assert_eq!(req.bootstrap_port, Some(SingleOrBatch::Single(Some(8080))));
// Room should be regenerated (different from original)
if let Some(SingleOrBatch::Single(room)) = req.bootstrap_room {
assert_ne!(room, 12345);
} else {
panic!("Expected single room ID");
}
}
#[test]
fn test_bootstrap_room_generation() {
let mut req1 = GenerateReqInput {
text: Some(SingleOrBatch::Single("Test".to_string())),
input_ids: None,
stream: false,
bootstrap_host: None,
bootstrap_port: None,
bootstrap_room: None,
other: Value::Object(serde_json::Map::new()),
};
let mut req2 = GenerateReqInput {
text: Some(SingleOrBatch::Single("Test".to_string())),
input_ids: None,
stream: false,
bootstrap_host: None,
bootstrap_port: None,
bootstrap_room: None,
other: Value::Object(serde_json::Map::new()),
};
let prefill_worker = create_test_worker(
"http://host:8000".to_string(),
WorkerType::Prefill {
bootstrap_port: Some(8080),
},
true,
);
// Add bootstrap info to both requests
let _ = req1.add_bootstrap_info(prefill_worker.as_ref());
let _ = req2.add_bootstrap_info(prefill_worker.as_ref());
// Room IDs should be different
if let (Some(SingleOrBatch::Single(room1)), Some(SingleOrBatch::Single(room2))) =
(req1.bootstrap_room, req2.bootstrap_room)
{
assert_ne!(room1, room2, "Room IDs should be unique");
} else {
panic!("Expected single room IDs");
}
}
// ============= Worker Selection Tests =============
#[tokio::test]
async fn test_select_healthy_prefill_worker() {
let router = create_test_pd_router();
// Add mix of healthy and unhealthy workers
let healthy_worker = create_test_worker(
"http://healthy".to_string(),
WorkerType::Prefill {
bootstrap_port: None,
},
true,
);
let unhealthy_worker = create_test_worker(
"http://unhealthy".to_string(),
WorkerType::Prefill {
bootstrap_port: None,
},
false,
);
let decode_worker =
create_test_worker("http://decode".to_string(), WorkerType::Decode, true);
router
.prefill_workers
.write()
.unwrap()
.push(unhealthy_worker);
router.prefill_workers.write().unwrap().push(healthy_worker);
router.decode_workers.write().unwrap().push(decode_worker);
let client = reqwest::Client::new();
let result = router.select_pd_pair(&client, None).await;
assert!(result.is_ok());
let (prefill, _decode) = result.unwrap();
// Should select the healthy worker
assert_eq!(prefill.url(), "http://healthy");
assert!(prefill.is_healthy());
}
#[tokio::test]
async fn test_empty_worker_lists() {
let router = create_test_pd_router();
let client = reqwest::Client::new();
let result = router.select_pd_pair(&client, None).await;
assert!(result.is_err());
assert!(result.unwrap_err().contains("No prefill workers available"));
}
// ============= Health Endpoints Tests =============
#[tokio::test]
async fn test_health_endpoints() {
let router = create_test_pd_router();
// Add healthy workers
let prefill_worker = create_test_worker(
"http://localhost:8000".to_string(),
WorkerType::Prefill {
bootstrap_port: None,
},
true,
);
let decode_worker = create_test_worker(
"http://localhost:8001".to_string(),
WorkerType::Decode,
true,
);
router.prefill_workers.write().unwrap().push(prefill_worker);
router.decode_workers.write().unwrap().push(decode_worker);
// Test health endpoint
let client = reqwest::Client::new();
let http_req = TestRequest::default().to_http_request();
let response = router.health(&client, &http_req).await;
assert_eq!(response.status(), 200);
// Test readiness endpoint
let response = router.readiness();
assert_eq!(response.status(), 200);
}
// ============= Load Monitoring Tests =============
#[tokio::test]
async fn test_load_monitor_updates() {
let policy = Arc::new(crate::policies::PowerOfTwoPolicy::new());
let mut router = create_test_pd_router();
router.policy = policy;
// Create load channel
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
router.worker_loads = Arc::new(rx);
// Simulate load updates
let mut loads = HashMap::new();
loads.insert("http://worker1".to_string(), 10);
loads.insert("http://worker2".to_string(), 5);
let _ = tx.send(loads.clone());
// Router should receive updates
let received = router.worker_loads.borrow().clone();
assert_eq!(received.get("http://worker1"), Some(&10));
assert_eq!(received.get("http://worker2"), Some(&5));
}
// ============= Worker Load Tests =============
#[test]
fn test_worker_load_metrics() {
let prefill_worker = create_test_worker(
"http://prefill".to_string(),
WorkerType::Prefill {
bootstrap_port: None,
},
true,
);
let decode_worker =
create_test_worker("http://decode".to_string(), WorkerType::Decode, true);
// Create load guard for both workers
let _guard =
WorkerLoadGuard::new_multi(vec![prefill_worker.as_ref(), decode_worker.as_ref()]);
// Load should be incremented
assert_eq!(prefill_worker.load(), 1);
assert_eq!(decode_worker.load(), 1);
// Drop guard - load should decrement
drop(_guard);
assert_eq!(prefill_worker.load(), 0);
assert_eq!(decode_worker.load(), 0);
}
// ============= Concurrent Operations Tests =============
#[tokio::test]
async fn test_concurrent_worker_operations() {
let router = Arc::new(create_test_pd_router());
let mut handles = vec![];
// Spawn tasks to add workers
for i in 0..5 {
let router_clone = Arc::clone(&router);
let url = format!("http://worker{}", i);
let handle = tokio::spawn(async move {
let worker = create_test_worker(
url,
WorkerType::Prefill {
bootstrap_port: None,
},
true,
);
router_clone.prefill_workers.write().unwrap().push(worker);
});
handles.push(handle);
}
// Wait for all tasks
for handle in handles {
let _ = handle.await;
}
// Check final state
let workers = router.prefill_workers.read().unwrap();
assert_eq!(workers.len(), 5);
}
}
//! Comprehensive tests for PrefillDecode (PD) routing functionality
//!
//! This test suite covers:
//! - Phase 1: Basic PD router creation and configuration
//! - Phase 2: Bootstrap injection and request handling
//! - Phase 3: Cache-aware selection (when implemented)
//!
//! Note: PD mode is enabled via the pd_disaggregation flag, not as a policy type.
//! The policy type (Random, PowerOfTwo, CacheAware) determines the selection algorithm within PD mode.
// TODO: This test file needs to be updated for the new configuration structure
// where RoutingMode and PolicyConfig are separate
#[cfg(test)]
mod test_pd_routing {
use rand::Rng;
......@@ -921,14 +908,6 @@ mod test_pd_routing {
#[test]
fn test_policy_type_to_pd_selection_policy_mapping() {
// Document the mapping from PolicyType to PDSelectionPolicy
// This mapping happens in lib.rs when pd_disaggregation=true
// PolicyType::Random -> PDSelectionPolicy::Random
// PolicyType::PowerOfTwo -> PDSelectionPolicy::PowerOfTwo
// PolicyType::CacheAware -> PDSelectionPolicy::CacheAware { ... }
// PolicyType::RoundRobin -> ERROR (not supported in PD mode)
// Test that PDSelectionPolicy doesn't include RoundRobin
let pd_policy_count = 3; // Random, PowerOfTwo, CacheAware
assert_eq!(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment