"...tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "b76ac11c9001f94eea43044174739f15b3bef24f"
Unverified Commit c8f31042 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] Refactor router and policy traits with dependency injection (#7987)


Co-authored-by: default avatarJin Pan <jpan236@wisc.edu>
Co-authored-by: default avatarKeru Yang <rukeyang@gmail.com>
Co-authored-by: default avatarYingyi Huang <yingyihuang2000@outlook.com>
Co-authored-by: default avatarPhilip Zhu <phlipzhux@gmail.com>
parent 1f76fc87
use crate::router::Router; use crate::routers::RouterTrait;
use futures::{StreamExt, TryStreamExt}; use futures::{StreamExt, TryStreamExt};
use k8s_openapi::api::core::v1::Pod; use k8s_openapi::api::core::v1::Pod;
...@@ -176,7 +176,7 @@ impl PodInfo { ...@@ -176,7 +176,7 @@ impl PodInfo {
pub async fn start_service_discovery( pub async fn start_service_discovery(
config: ServiceDiscoveryConfig, config: ServiceDiscoveryConfig,
router: Arc<Router>, router: Arc<dyn RouterTrait>,
) -> Result<task::JoinHandle<()>, kube::Error> { ) -> Result<task::JoinHandle<()>, kube::Error> {
// Don't initialize anything if service discovery is disabled // Don't initialize anything if service discovery is disabled
if !config.enabled { if !config.enabled {
...@@ -346,7 +346,7 @@ pub async fn start_service_discovery( ...@@ -346,7 +346,7 @@ pub async fn start_service_discovery(
async fn handle_pod_event( async fn handle_pod_event(
pod_info: &PodInfo, pod_info: &PodInfo,
tracked_pods: Arc<Mutex<HashSet<PodInfo>>>, tracked_pods: Arc<Mutex<HashSet<PodInfo>>>,
router: Arc<Router>, router: Arc<dyn RouterTrait>,
port: u16, port: u16,
pd_mode: bool, pd_mode: bool,
) { ) {
...@@ -379,17 +379,32 @@ async fn handle_pod_event( ...@@ -379,17 +379,32 @@ async fn handle_pod_event(
pod_info.name, pod_info.pod_type, worker_url pod_info.name, pod_info.pod_type, worker_url
); );
// Handle PD mode with specific pod types
let result = if pd_mode && pod_info.pod_type.is_some() { let result = if pd_mode && pod_info.pod_type.is_some() {
// Use PD-aware worker management // Need to import PDRouter type
if let Some(pod_type) = &pod_info.pod_type { use crate::routers::pd_router::PDRouter;
router
.add_pd_worker(&worker_url, pod_type.clone(), pod_info.bootstrap_port) // Try to downcast to PDRouter
.await if let Some(pd_router) = router.as_any().downcast_ref::<PDRouter>() {
match &pod_info.pod_type {
Some(PodType::Prefill) => pd_router
.add_prefill_server(worker_url.clone(), pod_info.bootstrap_port)
.await
.map_err(|e| e.to_string()),
Some(PodType::Decode) => pd_router
.add_decode_server(worker_url.clone())
.await
.map_err(|e| e.to_string()),
Some(PodType::Regular) | None => {
// Fall back to regular add_worker for regular pods
router.add_worker(&worker_url).await
}
}
} else { } else {
Err("Pod type is None in PD mode".to_string()) Err("PD mode enabled but router is not a PDRouter".to_string())
} }
} else { } else {
// Fallback to regular worker management // Regular mode or no pod type specified
router.add_worker(&worker_url).await router.add_worker(&worker_url).await
}; };
...@@ -412,7 +427,7 @@ async fn handle_pod_event( ...@@ -412,7 +427,7 @@ async fn handle_pod_event(
async fn handle_pod_deletion( async fn handle_pod_deletion(
pod_info: &PodInfo, pod_info: &PodInfo,
tracked_pods: Arc<Mutex<HashSet<PodInfo>>>, tracked_pods: Arc<Mutex<HashSet<PodInfo>>>,
router: Arc<Router>, router: Arc<dyn RouterTrait>,
port: u16, port: u16,
pd_mode: bool, pd_mode: bool,
) { ) {
...@@ -435,18 +450,34 @@ async fn handle_pod_deletion( ...@@ -435,18 +450,34 @@ async fn handle_pod_deletion(
pod_info.name, pod_info.pod_type, worker_url pod_info.name, pod_info.pod_type, worker_url
); );
// Handle PD mode removal
if pd_mode && pod_info.pod_type.is_some() { if pd_mode && pod_info.pod_type.is_some() {
// Use PD-aware worker removal use crate::routers::pd_router::PDRouter;
if let Some(pod_type) = &pod_info.pod_type {
if let Err(e) = router.remove_pd_worker(&worker_url, pod_type.clone()).await { // Try to downcast to PDRouter for PD-specific removal
error!( if let Some(pd_router) = router.as_any().downcast_ref::<PDRouter>() {
"Failed to remove PD worker {} from router: {}", match &pod_info.pod_type {
worker_url, e Some(PodType::Prefill) => {
); if let Err(e) = pd_router.remove_prefill_server(&worker_url).await {
error!("Failed to remove prefill server {}: {}", worker_url, e);
}
}
Some(PodType::Decode) => {
if let Err(e) = pd_router.remove_decode_server(&worker_url).await {
error!("Failed to remove decode server {}: {}", worker_url, e);
}
}
Some(PodType::Regular) | None => {
// Fall back to regular remove_worker
router.remove_worker(&worker_url);
}
} }
} else {
// PD mode but not a PDRouter, use generic removal
router.remove_worker(&worker_url);
} }
} else { } else {
// Fallback to regular worker removal // Regular mode removal
router.remove_worker(&worker_url); router.remove_worker(&worker_url);
} }
} else { } else {
...@@ -462,11 +493,9 @@ async fn handle_pod_deletion( ...@@ -462,11 +493,9 @@ async fn handle_pod_deletion(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::router::Router;
use k8s_openapi::api::core::v1::{Pod, PodCondition, PodSpec, PodStatus}; use k8s_openapi::api::core::v1::{Pod, PodCondition, PodSpec, PodStatus};
use k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta; use k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta;
use k8s_openapi::apimachinery::pkg::apis::meta::v1::Time; use k8s_openapi::apimachinery::pkg::apis::meta::v1::Time;
use std::sync::RwLock;
// Helper function to create a Pod for testing PodInfo::from_pod // Helper function to create a Pod for testing PodInfo::from_pod
fn create_k8s_pod( fn create_k8s_pod(
...@@ -546,14 +575,14 @@ mod tests { ...@@ -546,14 +575,14 @@ mod tests {
} }
// Helper to create a Router instance for testing event handlers // Helper to create a Router instance for testing event handlers
fn create_test_router() -> Arc<Router> { fn create_test_router() -> Arc<dyn RouterTrait> {
let workers = Arc::new(RwLock::new(Vec::new())); use crate::config::PolicyConfig;
Arc::new(Router::Random { use crate::policies::PolicyFactory;
workers, use crate::routers::router::Router;
timeout_secs: 5,
interval_secs: 1, let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
_health_checker: None, let router = Router::new(vec![], policy, 5, 1).unwrap();
}) Arc::new(router) as Arc<dyn RouterTrait>
} }
// Helper to create a PD config for testing // Helper to create a PD config for testing
......
...@@ -6,7 +6,7 @@ use sglang_router_rs::openai_api_types::{ ...@@ -6,7 +6,7 @@ use sglang_router_rs::openai_api_types::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest, ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
SamplingParams, StringOrArray, UserMessageContent, SamplingParams, StringOrArray, UserMessageContent,
}; };
use sglang_router_rs::request_adapter::{RouteableRequest, ToPdRequest}; use sglang_router_rs::routers::request_adapter::{RouteableRequest, ToPdRequest};
#[test] #[test]
fn test_benchmark_request_creation() { fn test_benchmark_request_creation() {
......
...@@ -8,12 +8,18 @@ ...@@ -8,12 +8,18 @@
//! Note: PD mode is enabled via the pd_disaggregation flag, not as a policy type. //! Note: PD mode is enabled via the pd_disaggregation flag, not as a policy type.
//! The policy type (Random, PowerOfTwo, CacheAware) determines the selection algorithm within PD mode. //! The policy type (Random, PowerOfTwo, CacheAware) determines the selection algorithm within PD mode.
// TODO: This test file needs to be updated for the new configuration structure
// where RoutingMode and PolicyConfig are separate
#[cfg(test)] #[cfg(test)]
mod test_pd_routing { mod test_pd_routing {
use rand::Rng; use rand::Rng;
use serde_json::json; use serde_json::json;
use sglang_router_rs::pd_types::PDSelectionPolicy; use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode};
use sglang_router_rs::router::{PolicyConfig, Router}; use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::routers::pd_types::get_hostname;
use sglang_router_rs::routers::pd_types::PDSelectionPolicy;
use sglang_router_rs::routers::RouterFactory;
// Test-only struct to help validate PD request parsing // Test-only struct to help validate PD request parsing
#[derive(Debug)] #[derive(Debug)]
...@@ -116,49 +122,68 @@ mod test_pd_routing { ...@@ -116,49 +122,68 @@ mod test_pd_routing {
#[test] #[test]
fn test_pd_router_configuration() { fn test_pd_router_configuration() {
// Test PrefillDecodeConfig creation with various policies // Test PD router configuration with various policies
// This config is used when pd_disaggregation=true // In the new structure, RoutingMode and PolicyConfig are separate
let configs = vec![ let test_cases = vec![
PolicyConfig::PrefillDecodeConfig { (
selection_policy: PDSelectionPolicy::Random, RoutingMode::PrefillDecode {
prefill_urls: vec![ prefill_urls: vec![
("http://prefill1:8080".to_string(), Some(9000)), ("http://prefill1:8080".to_string(), Some(9000)),
("http://prefill2:8080".to_string(), None), ("http://prefill2:8080".to_string(), None),
], ],
decode_urls: vec![ decode_urls: vec![
"http://decode1:8080".to_string(), "http://decode1:8080".to_string(),
"http://decode2:8080".to_string(), "http://decode2:8080".to_string(),
], ],
timeout_secs: 10, },
interval_secs: 1, PolicyConfig::Random,
}, ),
PolicyConfig::PrefillDecodeConfig { (
selection_policy: PDSelectionPolicy::PowerOfTwo, RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill:8080".to_string(), Some(9000))], prefill_urls: vec![("http://prefill:8080".to_string(), Some(9000))],
decode_urls: vec!["http://decode:8080".to_string()], decode_urls: vec!["http://decode:8080".to_string()],
timeout_secs: 5, },
interval_secs: 1, PolicyConfig::PowerOfTwo {
}, load_check_interval_secs: 5,
PolicyConfig::PrefillDecodeConfig { },
selection_policy: PDSelectionPolicy::CacheAware { ),
(
RoutingMode::PrefillDecode {
prefill_urls: vec![
("http://p1:8080".to_string(), Some(9000)),
("http://p2:8080".to_string(), Some(9001)),
("http://p3:8080".to_string(), Some(9002)),
],
decode_urls: vec!["http://d1:8080".to_string(), "http://d2:8080".to_string()],
},
PolicyConfig::CacheAware {
cache_threshold: 0.7, cache_threshold: 0.7,
balance_abs_threshold: 20, balance_abs_threshold: 20,
balance_rel_threshold: 1.2, balance_rel_threshold: 1.2,
eviction_interval_secs: 60,
max_tree_size: 1000000,
}, },
prefill_urls: vec![ ),
("http://p1:8080".to_string(), Some(9000)),
("http://p2:8080".to_string(), Some(9001)),
("http://p3:8080".to_string(), Some(9002)),
],
decode_urls: vec!["http://d1:8080".to_string(), "http://d2:8080".to_string()],
timeout_secs: 10,
interval_secs: 2,
},
]; ];
for config in configs { for (mode, policy) in test_cases {
let config = RouterConfig {
mode,
policy,
host: "127.0.0.1".to_string(),
port: 3001,
max_payload_size: 1024 * 1024,
request_timeout_secs: 60,
worker_startup_timeout_secs: 10,
worker_startup_check_interval_secs: 1,
discovery: None,
metrics: None,
log_dir: None,
log_level: None,
};
// Router creation will fail due to health checks, but config should be valid // Router creation will fail due to health checks, but config should be valid
let result = Router::new(vec![], config); let result = RouterFactory::create_router(&config);
assert!(result.is_err()); assert!(result.is_err());
let error_msg = result.unwrap_err(); let error_msg = result.unwrap_err();
// Error should be about health/timeout, not configuration // Error should be about health/timeout, not configuration
...@@ -225,9 +250,6 @@ mod test_pd_routing { ...@@ -225,9 +250,6 @@ mod test_pd_routing {
#[test] #[test]
fn test_bootstrap_injection_simulation() { fn test_bootstrap_injection_simulation() {
use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::pd_types::get_hostname;
// Since we can't test the actual inject_bootstrap_fields function here // Since we can't test the actual inject_bootstrap_fields function here
// (it's private in the router module), we'll test the expected behavior // (it's private in the router module), we'll test the expected behavior
...@@ -315,8 +337,6 @@ mod test_pd_routing { ...@@ -315,8 +337,6 @@ mod test_pd_routing {
#[test] #[test]
fn test_hostname_extraction() { fn test_hostname_extraction() {
use sglang_router_rs::pd_types::get_hostname;
// Test various URL formats // Test various URL formats
let test_cases = vec![ let test_cases = vec![
("http://localhost:8080", "localhost"), ("http://localhost:8080", "localhost"),
...@@ -662,7 +682,6 @@ mod test_pd_routing { ...@@ -662,7 +682,6 @@ mod test_pd_routing {
#[test] #[test]
fn test_bootstrap_injection_with_benchmark_requests() { fn test_bootstrap_injection_with_benchmark_requests() {
use sglang_router_rs::core::{WorkerFactory, WorkerType}; use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::pd_types::get_hostname;
// Test bootstrap injection with actual benchmark request patterns // Test bootstrap injection with actual benchmark request patterns
let mut benchmark_request = json!({ let mut benchmark_request = json!({
...@@ -790,9 +809,6 @@ mod test_pd_routing { ...@@ -790,9 +809,6 @@ mod test_pd_routing {
#[test] #[test]
fn test_large_batch_bootstrap_injection() { fn test_large_batch_bootstrap_injection() {
use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::pd_types::get_hostname;
// Test bootstrap injection performance with very large batches // Test bootstrap injection performance with very large batches
// This simulates the bench_one_batch_server.py scenario // This simulates the bench_one_batch_server.py scenario
let large_batch_sizes = vec![1024, 4096, 8192]; let large_batch_sizes = vec![1024, 4096, 8192];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment