Unverified Commit c8f31042 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] Refactor router and policy traits with dependency injection (#7987)


Co-authored-by: default avatarJin Pan <jpan236@wisc.edu>
Co-authored-by: default avatarKeru Yang <rukeyang@gmail.com>
Co-authored-by: default avatarYingyi Huang <yingyihuang2000@outlook.com>
Co-authored-by: default avatarPhilip Zhu <phlipzhux@gmail.com>
parent 1f76fc87
use crate::router::Router;
use crate::routers::RouterTrait;
use futures::{StreamExt, TryStreamExt};
use k8s_openapi::api::core::v1::Pod;
......@@ -176,7 +176,7 @@ impl PodInfo {
pub async fn start_service_discovery(
config: ServiceDiscoveryConfig,
router: Arc<Router>,
router: Arc<dyn RouterTrait>,
) -> Result<task::JoinHandle<()>, kube::Error> {
// Don't initialize anything if service discovery is disabled
if !config.enabled {
......@@ -346,7 +346,7 @@ pub async fn start_service_discovery(
async fn handle_pod_event(
pod_info: &PodInfo,
tracked_pods: Arc<Mutex<HashSet<PodInfo>>>,
router: Arc<Router>,
router: Arc<dyn RouterTrait>,
port: u16,
pd_mode: bool,
) {
......@@ -379,17 +379,32 @@ async fn handle_pod_event(
pod_info.name, pod_info.pod_type, worker_url
);
// Handle PD mode with specific pod types
let result = if pd_mode && pod_info.pod_type.is_some() {
// Use PD-aware worker management
if let Some(pod_type) = &pod_info.pod_type {
router
.add_pd_worker(&worker_url, pod_type.clone(), pod_info.bootstrap_port)
.await
// Need to import PDRouter type
use crate::routers::pd_router::PDRouter;
// Try to downcast to PDRouter
if let Some(pd_router) = router.as_any().downcast_ref::<PDRouter>() {
match &pod_info.pod_type {
Some(PodType::Prefill) => pd_router
.add_prefill_server(worker_url.clone(), pod_info.bootstrap_port)
.await
.map_err(|e| e.to_string()),
Some(PodType::Decode) => pd_router
.add_decode_server(worker_url.clone())
.await
.map_err(|e| e.to_string()),
Some(PodType::Regular) | None => {
// Fall back to regular add_worker for regular pods
router.add_worker(&worker_url).await
}
}
} else {
Err("Pod type is None in PD mode".to_string())
Err("PD mode enabled but router is not a PDRouter".to_string())
}
} else {
// Fallback to regular worker management
// Regular mode or no pod type specified
router.add_worker(&worker_url).await
};
......@@ -412,7 +427,7 @@ async fn handle_pod_event(
async fn handle_pod_deletion(
pod_info: &PodInfo,
tracked_pods: Arc<Mutex<HashSet<PodInfo>>>,
router: Arc<Router>,
router: Arc<dyn RouterTrait>,
port: u16,
pd_mode: bool,
) {
......@@ -435,18 +450,34 @@ async fn handle_pod_deletion(
pod_info.name, pod_info.pod_type, worker_url
);
// Handle PD mode removal
if pd_mode && pod_info.pod_type.is_some() {
// Use PD-aware worker removal
if let Some(pod_type) = &pod_info.pod_type {
if let Err(e) = router.remove_pd_worker(&worker_url, pod_type.clone()).await {
error!(
"Failed to remove PD worker {} from router: {}",
worker_url, e
);
use crate::routers::pd_router::PDRouter;
// Try to downcast to PDRouter for PD-specific removal
if let Some(pd_router) = router.as_any().downcast_ref::<PDRouter>() {
match &pod_info.pod_type {
Some(PodType::Prefill) => {
if let Err(e) = pd_router.remove_prefill_server(&worker_url).await {
error!("Failed to remove prefill server {}: {}", worker_url, e);
}
}
Some(PodType::Decode) => {
if let Err(e) = pd_router.remove_decode_server(&worker_url).await {
error!("Failed to remove decode server {}: {}", worker_url, e);
}
}
Some(PodType::Regular) | None => {
// Fall back to regular remove_worker
router.remove_worker(&worker_url);
}
}
} else {
// PD mode but not a PDRouter, use generic removal
router.remove_worker(&worker_url);
}
} else {
// Fallback to regular worker removal
// Regular mode removal
router.remove_worker(&worker_url);
}
} else {
......@@ -462,11 +493,9 @@ async fn handle_pod_deletion(
#[cfg(test)]
mod tests {
use super::*;
use crate::router::Router;
use k8s_openapi::api::core::v1::{Pod, PodCondition, PodSpec, PodStatus};
use k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta;
use k8s_openapi::apimachinery::pkg::apis::meta::v1::Time;
use std::sync::RwLock;
// Helper function to create a Pod for testing PodInfo::from_pod
fn create_k8s_pod(
......@@ -546,14 +575,14 @@ mod tests {
}
// Helper to create a Router instance for testing event handlers
fn create_test_router() -> Arc<Router> {
let workers = Arc::new(RwLock::new(Vec::new()));
Arc::new(Router::Random {
workers,
timeout_secs: 5,
interval_secs: 1,
_health_checker: None,
})
fn create_test_router() -> Arc<dyn RouterTrait> {
use crate::config::PolicyConfig;
use crate::policies::PolicyFactory;
use crate::routers::router::Router;
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
let router = Router::new(vec![], policy, 5, 1).unwrap();
Arc::new(router) as Arc<dyn RouterTrait>
}
// Helper to create a PD config for testing
......
......@@ -6,7 +6,7 @@ use sglang_router_rs::openai_api_types::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
SamplingParams, StringOrArray, UserMessageContent,
};
use sglang_router_rs::request_adapter::{RouteableRequest, ToPdRequest};
use sglang_router_rs::routers::request_adapter::{RouteableRequest, ToPdRequest};
#[test]
fn test_benchmark_request_creation() {
......
......@@ -8,12 +8,18 @@
//! Note: PD mode is enabled via the pd_disaggregation flag, not as a policy type.
//! The policy type (Random, PowerOfTwo, CacheAware) determines the selection algorithm within PD mode.
// TODO: This test file needs to be updated for the new configuration structure
// where RoutingMode and PolicyConfig are separate
#[cfg(test)]
mod test_pd_routing {
use rand::Rng;
use serde_json::json;
use sglang_router_rs::pd_types::PDSelectionPolicy;
use sglang_router_rs::router::{PolicyConfig, Router};
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode};
use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::routers::pd_types::get_hostname;
use sglang_router_rs::routers::pd_types::PDSelectionPolicy;
use sglang_router_rs::routers::RouterFactory;
// Test-only struct to help validate PD request parsing
#[derive(Debug)]
......@@ -116,49 +122,68 @@ mod test_pd_routing {
#[test]
fn test_pd_router_configuration() {
// Test PrefillDecodeConfig creation with various policies
// This config is used when pd_disaggregation=true
let configs = vec![
PolicyConfig::PrefillDecodeConfig {
selection_policy: PDSelectionPolicy::Random,
prefill_urls: vec![
("http://prefill1:8080".to_string(), Some(9000)),
("http://prefill2:8080".to_string(), None),
],
decode_urls: vec![
"http://decode1:8080".to_string(),
"http://decode2:8080".to_string(),
],
timeout_secs: 10,
interval_secs: 1,
},
PolicyConfig::PrefillDecodeConfig {
selection_policy: PDSelectionPolicy::PowerOfTwo,
prefill_urls: vec![("http://prefill:8080".to_string(), Some(9000))],
decode_urls: vec!["http://decode:8080".to_string()],
timeout_secs: 5,
interval_secs: 1,
},
PolicyConfig::PrefillDecodeConfig {
selection_policy: PDSelectionPolicy::CacheAware {
// Test PD router configuration with various policies
// In the new structure, RoutingMode and PolicyConfig are separate
let test_cases = vec![
(
RoutingMode::PrefillDecode {
prefill_urls: vec![
("http://prefill1:8080".to_string(), Some(9000)),
("http://prefill2:8080".to_string(), None),
],
decode_urls: vec![
"http://decode1:8080".to_string(),
"http://decode2:8080".to_string(),
],
},
PolicyConfig::Random,
),
(
RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill:8080".to_string(), Some(9000))],
decode_urls: vec!["http://decode:8080".to_string()],
},
PolicyConfig::PowerOfTwo {
load_check_interval_secs: 5,
},
),
(
RoutingMode::PrefillDecode {
prefill_urls: vec![
("http://p1:8080".to_string(), Some(9000)),
("http://p2:8080".to_string(), Some(9001)),
("http://p3:8080".to_string(), Some(9002)),
],
decode_urls: vec!["http://d1:8080".to_string(), "http://d2:8080".to_string()],
},
PolicyConfig::CacheAware {
cache_threshold: 0.7,
balance_abs_threshold: 20,
balance_rel_threshold: 1.2,
eviction_interval_secs: 60,
max_tree_size: 1000000,
},
prefill_urls: vec![
("http://p1:8080".to_string(), Some(9000)),
("http://p2:8080".to_string(), Some(9001)),
("http://p3:8080".to_string(), Some(9002)),
],
decode_urls: vec!["http://d1:8080".to_string(), "http://d2:8080".to_string()],
timeout_secs: 10,
interval_secs: 2,
},
),
];
for config in configs {
for (mode, policy) in test_cases {
let config = RouterConfig {
mode,
policy,
host: "127.0.0.1".to_string(),
port: 3001,
max_payload_size: 1024 * 1024,
request_timeout_secs: 60,
worker_startup_timeout_secs: 10,
worker_startup_check_interval_secs: 1,
discovery: None,
metrics: None,
log_dir: None,
log_level: None,
};
// Router creation will fail due to health checks, but config should be valid
let result = Router::new(vec![], config);
let result = RouterFactory::create_router(&config);
assert!(result.is_err());
let error_msg = result.unwrap_err();
// Error should be about health/timeout, not configuration
......@@ -225,9 +250,6 @@ mod test_pd_routing {
#[test]
fn test_bootstrap_injection_simulation() {
use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::pd_types::get_hostname;
// Since we can't test the actual inject_bootstrap_fields function here
// (it's private in the router module), we'll test the expected behavior
......@@ -315,8 +337,6 @@ mod test_pd_routing {
#[test]
fn test_hostname_extraction() {
use sglang_router_rs::pd_types::get_hostname;
// Test various URL formats
let test_cases = vec![
("http://localhost:8080", "localhost"),
......@@ -662,7 +682,6 @@ mod test_pd_routing {
#[test]
fn test_bootstrap_injection_with_benchmark_requests() {
use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::pd_types::get_hostname;
// Test bootstrap injection with actual benchmark request patterns
let mut benchmark_request = json!({
......@@ -790,9 +809,6 @@ mod test_pd_routing {
#[test]
fn test_large_batch_bootstrap_injection() {
use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::pd_types::get_hostname;
// Test bootstrap injection performance with very large batches
// This simulates the bench_one_batch_server.py scenario
let large_batch_sizes = vec![1024, 4096, 8192];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment