Unverified Commit 36efd5be authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] refactor router and worker management 1/n (#10664)

parent 68cdc189
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
use crate::config::types::RetryConfig; use crate::config::types::RetryConfig;
use crate::core::{ use crate::core::{
BasicWorkerBuilder, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType, BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, WorkerRegistry, WorkerType,
}; };
use crate::grpc::SglangSchedulerClient; use crate::grpc::SglangSchedulerClient;
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
use crate::policies::LoadBalancingPolicy; use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
use crate::reasoning_parser::ParserFactory; use crate::reasoning_parser::ParserFactory;
use crate::routers::{RouterTrait, WorkerManagement}; use crate::routers::{RouterTrait, WorkerManagement};
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::traits::Tokenizer;
...@@ -19,21 +19,17 @@ use axum::{ ...@@ -19,21 +19,17 @@ use axum::{
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, RwLock}; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tracing::{info, warn}; use tracing::{info, warn};
/// gRPC PD (Prefill-Decode) router implementation for SGLang /// gRPC PD (Prefill-Decode) router implementation for SGLang
#[allow(dead_code)] // Fields will be used once implementation is complete #[allow(dead_code)] // Fields will be used once implementation is complete
pub struct GrpcPDRouter { pub struct GrpcPDRouter {
/// Prefill worker connections /// Centralized worker registry
prefill_workers: Arc<RwLock<Vec<Arc<dyn Worker>>>>, worker_registry: Arc<WorkerRegistry>,
/// Decode worker connections /// Centralized policy registry
decode_workers: Arc<RwLock<Vec<Arc<dyn Worker>>>>, policy_registry: Arc<PolicyRegistry>,
/// gRPC clients for prefill workers
prefill_grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
/// gRPC clients for decode workers
decode_grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
/// Load balancing policy for prefill /// Load balancing policy for prefill
prefill_policy: Arc<dyn LoadBalancingPolicy>, prefill_policy: Arc<dyn LoadBalancingPolicy>,
/// Load balancing policy for decode /// Load balancing policy for decode
...@@ -44,9 +40,6 @@ pub struct GrpcPDRouter { ...@@ -44,9 +40,6 @@ pub struct GrpcPDRouter {
reasoning_parser_factory: ParserFactory, reasoning_parser_factory: ParserFactory,
/// Tool parser registry for function/tool calls /// Tool parser registry for function/tool calls
tool_parser_registry: &'static ParserRegistry, tool_parser_registry: &'static ParserRegistry,
/// Worker health checkers
_prefill_health_checker: Option<HealthChecker>,
_decode_health_checker: Option<HealthChecker>,
/// Configuration /// Configuration
timeout_secs: u64, timeout_secs: u64,
interval_secs: u64, interval_secs: u64,
...@@ -65,6 +58,10 @@ impl GrpcPDRouter { ...@@ -65,6 +58,10 @@ impl GrpcPDRouter {
decode_policy: Arc<dyn LoadBalancingPolicy>, decode_policy: Arc<dyn LoadBalancingPolicy>,
ctx: &Arc<crate::server::AppContext>, ctx: &Arc<crate::server::AppContext>,
) -> Result<Self, String> { ) -> Result<Self, String> {
// Get registries from context
let worker_registry = ctx.worker_registry.clone();
let policy_registry = ctx.policy_registry.clone();
// Update metrics // Update metrics
RouterMetrics::set_active_workers(prefill_urls.len() + decode_urls.len()); RouterMetrics::set_active_workers(prefill_urls.len() + decode_urls.len());
...@@ -126,10 +123,9 @@ impl GrpcPDRouter { ...@@ -126,10 +123,9 @@ impl GrpcPDRouter {
return Err("Failed to connect to any gRPC workers".to_string()); return Err("Failed to connect to any gRPC workers".to_string());
} }
// Create Prefill Worker trait objects with gRPC connection mode // Create Prefill Worker trait objects with gRPC connection mode and register them
let prefill_workers: Vec<Arc<dyn Worker>> = prefill_urls for (url, bootstrap_port) in &prefill_urls {
.iter() if let Some(client) = prefill_grpc_clients.remove(url) {
.map(|(url, bootstrap_port)| {
let worker = BasicWorkerBuilder::new(url.clone()) let worker = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Prefill { .worker_type(WorkerType::Prefill {
bootstrap_port: *bootstrap_port, bootstrap_port: *bootstrap_port,
...@@ -145,15 +141,17 @@ impl GrpcPDRouter { ...@@ -145,15 +141,17 @@ impl GrpcPDRouter {
failure_threshold: ctx.router_config.health_check.failure_threshold, failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold, success_threshold: ctx.router_config.health_check.success_threshold,
}) })
.grpc_client(client)
.build(); .build();
Arc::new(worker) as Arc<dyn Worker>
}) // Register worker in the centralized registry
.collect(); worker_registry.register(Arc::new(worker));
}
// Create Decode Worker trait objects with gRPC connection mode }
let decode_workers: Vec<Arc<dyn Worker>> = decode_urls
.iter() // Create Decode Worker trait objects with gRPC connection mode and register them
.map(|url| { for url in &decode_urls {
if let Some(client) = decode_grpc_clients.remove(url) {
let worker = BasicWorkerBuilder::new(url.clone()) let worker = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Decode) .worker_type(WorkerType::Decode)
.connection_mode(crate::core::ConnectionMode::Grpc { port: None }) .connection_mode(crate::core::ConnectionMode::Grpc { port: None })
...@@ -165,12 +163,23 @@ impl GrpcPDRouter { ...@@ -165,12 +163,23 @@ impl GrpcPDRouter {
failure_threshold: ctx.router_config.health_check.failure_threshold, failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold, success_threshold: ctx.router_config.health_check.success_threshold,
}) })
.grpc_client(client)
.build(); .build();
Arc::new(worker) as Arc<dyn Worker>
})
.collect();
// Initialize policies with workers if needed // Register worker in the centralized registry
worker_registry.register(Arc::new(worker));
}
}
// Initialize policies with workers if needed - filter for gRPC workers only
let prefill_workers = worker_registry.get_workers_filtered(
None, // any model
Some(WorkerType::Prefill {
bootstrap_port: None,
}),
Some(crate::core::ConnectionMode::Grpc { port: None }),
false, // include unhealthy workers during initialization
);
if let Some(cache_aware) = prefill_policy if let Some(cache_aware) = prefill_policy
.as_any() .as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>() .downcast_ref::<crate::policies::CacheAwarePolicy>()
...@@ -178,6 +187,12 @@ impl GrpcPDRouter { ...@@ -178,6 +187,12 @@ impl GrpcPDRouter {
cache_aware.init_workers(&prefill_workers); cache_aware.init_workers(&prefill_workers);
} }
let decode_workers = worker_registry.get_workers_filtered(
None, // any model
Some(WorkerType::Decode),
Some(crate::core::ConnectionMode::Grpc { port: None }),
false, // include unhealthy workers during initialization
);
if let Some(cache_aware) = decode_policy if let Some(cache_aware) = decode_policy
.as_any() .as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>() .downcast_ref::<crate::policies::CacheAwarePolicy>()
...@@ -185,30 +200,16 @@ impl GrpcPDRouter { ...@@ -185,30 +200,16 @@ impl GrpcPDRouter {
cache_aware.init_workers(&decode_workers); cache_aware.init_workers(&decode_workers);
} }
let prefill_workers = Arc::new(RwLock::new(prefill_workers)); // No need for local health checkers - WorkerRegistry handles health checking
let decode_workers = Arc::new(RwLock::new(decode_workers));
let prefill_health_checker = crate::core::start_health_checker(
Arc::clone(&prefill_workers),
ctx.router_config.worker_startup_check_interval_secs,
);
let decode_health_checker = crate::core::start_health_checker(
Arc::clone(&decode_workers),
ctx.router_config.worker_startup_check_interval_secs,
);
Ok(GrpcPDRouter { Ok(GrpcPDRouter {
prefill_workers, worker_registry,
decode_workers, policy_registry,
prefill_grpc_clients: Arc::new(RwLock::new(prefill_grpc_clients)),
decode_grpc_clients: Arc::new(RwLock::new(decode_grpc_clients)),
prefill_policy, prefill_policy,
decode_policy, decode_policy,
tokenizer, tokenizer,
reasoning_parser_factory, reasoning_parser_factory,
tool_parser_registry, tool_parser_registry,
_prefill_health_checker: Some(prefill_health_checker),
_decode_health_checker: Some(decode_health_checker),
timeout_secs: ctx.router_config.worker_startup_timeout_secs, timeout_secs: ctx.router_config.worker_startup_timeout_secs,
interval_secs: ctx.router_config.worker_startup_check_interval_secs, interval_secs: ctx.router_config.worker_startup_check_interval_secs,
dp_aware: ctx.router_config.dp_aware, dp_aware: ctx.router_config.dp_aware,
...@@ -221,15 +222,23 @@ impl GrpcPDRouter { ...@@ -221,15 +222,23 @@ impl GrpcPDRouter {
impl std::fmt::Debug for GrpcPDRouter { impl std::fmt::Debug for GrpcPDRouter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let prefill_workers = self.worker_registry.get_workers_filtered(
None,
Some(WorkerType::Prefill {
bootstrap_port: None,
}),
Some(crate::core::ConnectionMode::Grpc { port: None }),
false,
);
let decode_workers = self.worker_registry.get_workers_filtered(
None,
Some(WorkerType::Decode),
Some(crate::core::ConnectionMode::Grpc { port: None }),
false,
);
f.debug_struct("GrpcPDRouter") f.debug_struct("GrpcPDRouter")
.field( .field("prefill_workers_count", &prefill_workers.len())
"prefill_workers_count", .field("decode_workers_count", &decode_workers.len())
&self.prefill_workers.read().unwrap().len(),
)
.field(
"decode_workers_count",
&self.decode_workers.read().unwrap().len(),
)
.field("timeout_secs", &self.timeout_secs) .field("timeout_secs", &self.timeout_secs)
.field("interval_secs", &self.interval_secs) .field("interval_secs", &self.interval_secs)
.field("dp_aware", &self.dp_aware) .field("dp_aware", &self.dp_aware)
...@@ -351,6 +360,28 @@ impl WorkerManagement for GrpcPDRouter { ...@@ -351,6 +360,28 @@ impl WorkerManagement for GrpcPDRouter {
fn remove_worker(&self, _worker_url: &str) {} fn remove_worker(&self, _worker_url: &str) {}
fn get_worker_urls(&self) -> Vec<String> { fn get_worker_urls(&self) -> Vec<String> {
vec![] let mut urls = Vec::new();
// Get gRPC prefill worker URLs only
let prefill_workers = self.worker_registry.get_workers_filtered(
None,
Some(WorkerType::Prefill {
bootstrap_port: None,
}),
Some(crate::core::ConnectionMode::Grpc { port: None }),
false,
);
urls.extend(prefill_workers.iter().map(|w| w.url().to_string()));
// Get gRPC decode worker URLs only
let decode_workers = self.worker_registry.get_workers_filtered(
None,
Some(WorkerType::Decode),
Some(crate::core::ConnectionMode::Grpc { port: None }),
false,
);
urls.extend(decode_workers.iter().map(|w| w.url().to_string()));
urls
} }
} }
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
use crate::config::types::RetryConfig; use crate::config::types::RetryConfig;
use crate::core::{ use crate::core::{
BasicWorkerBuilder, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType, BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, WorkerRegistry, WorkerType,
}; };
use crate::grpc::SglangSchedulerClient; use crate::grpc::SglangSchedulerClient;
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
use crate::policies::LoadBalancingPolicy; use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
use crate::reasoning_parser::ParserFactory; use crate::reasoning_parser::ParserFactory;
use crate::routers::{RouterTrait, WorkerManagement}; use crate::routers::{RouterTrait, WorkerManagement};
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::traits::Tokenizer;
...@@ -19,17 +19,17 @@ use axum::{ ...@@ -19,17 +19,17 @@ use axum::{
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, RwLock}; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tracing::{info, warn}; use tracing::{info, warn};
/// gRPC router implementation for SGLang /// gRPC router implementation for SGLang
#[allow(dead_code)] // Fields will be used once implementation is complete #[allow(dead_code)] // Fields will be used once implementation is complete
pub struct GrpcRouter { pub struct GrpcRouter {
/// Worker connections /// Centralized worker registry
workers: Arc<RwLock<Vec<Arc<dyn Worker>>>>, worker_registry: Arc<WorkerRegistry>,
/// gRPC clients for each worker /// Centralized policy registry
grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>, policy_registry: Arc<PolicyRegistry>,
/// Load balancing policy /// Load balancing policy
policy: Arc<dyn LoadBalancingPolicy>, policy: Arc<dyn LoadBalancingPolicy>,
/// Tokenizer for handling text encoding/decoding /// Tokenizer for handling text encoding/decoding
...@@ -38,8 +38,6 @@ pub struct GrpcRouter { ...@@ -38,8 +38,6 @@ pub struct GrpcRouter {
reasoning_parser_factory: ParserFactory, reasoning_parser_factory: ParserFactory,
/// Tool parser registry for function/tool calls /// Tool parser registry for function/tool calls
tool_parser_registry: &'static ParserRegistry, tool_parser_registry: &'static ParserRegistry,
/// Worker health checker
_health_checker: Option<HealthChecker>,
/// Configuration /// Configuration
timeout_secs: u64, timeout_secs: u64,
interval_secs: u64, interval_secs: u64,
...@@ -102,10 +100,11 @@ impl GrpcRouter { ...@@ -102,10 +100,11 @@ impl GrpcRouter {
return Err("Failed to connect to any gRPC workers".to_string()); return Err("Failed to connect to any gRPC workers".to_string());
} }
// Create Worker trait objects with gRPC connection mode // Get registries from context
let mut workers: Vec<Arc<dyn Worker>> = Vec::new(); let worker_registry = ctx.worker_registry.clone();
let policy_registry = ctx.policy_registry.clone();
// Move clients from the HashMap to the workers // Create Worker trait objects with gRPC connection mode and register them
for url in &worker_urls { for url in &worker_urls {
if let Some(client) = grpc_clients.remove(url) { if let Some(client) = grpc_clients.remove(url) {
let worker = BasicWorkerBuilder::new(url.clone()) let worker = BasicWorkerBuilder::new(url.clone())
...@@ -122,12 +121,21 @@ impl GrpcRouter { ...@@ -122,12 +121,21 @@ impl GrpcRouter {
.grpc_client(client) .grpc_client(client)
.build(); .build();
workers.push(Arc::new(worker) as Arc<dyn Worker>); // Register worker in the centralized registry
worker_registry.register(Arc::new(worker));
} else { } else {
warn!("No gRPC client for worker {}, skipping", url); warn!("No gRPC client for worker {}, skipping", url);
} }
} }
// Get only gRPC workers from registry for policy initialization
let workers = worker_registry.get_workers_filtered(
None, // any model
Some(WorkerType::Regular),
Some(crate::core::ConnectionMode::Grpc { port: None }),
false, // include unhealthy workers during initialization
);
// Initialize policy with workers if needed // Initialize policy with workers if needed
if let Some(cache_aware) = policy if let Some(cache_aware) = policy
.as_any() .as_any()
...@@ -136,20 +144,15 @@ impl GrpcRouter { ...@@ -136,20 +144,15 @@ impl GrpcRouter {
cache_aware.init_workers(&workers); cache_aware.init_workers(&workers);
} }
let workers = Arc::new(RwLock::new(workers)); // No need for local health checkers - WorkerRegistry handles health checking
let health_checker = crate::core::start_health_checker(
Arc::clone(&workers),
ctx.router_config.worker_startup_check_interval_secs,
);
Ok(GrpcRouter { Ok(GrpcRouter {
workers, worker_registry,
grpc_clients: Arc::new(RwLock::new(grpc_clients)), policy_registry,
policy, policy,
tokenizer, tokenizer,
reasoning_parser_factory, reasoning_parser_factory,
tool_parser_registry, tool_parser_registry,
_health_checker: Some(health_checker),
timeout_secs: ctx.router_config.worker_startup_timeout_secs, timeout_secs: ctx.router_config.worker_startup_timeout_secs,
interval_secs: ctx.router_config.worker_startup_check_interval_secs, interval_secs: ctx.router_config.worker_startup_check_interval_secs,
dp_aware: ctx.router_config.dp_aware, dp_aware: ctx.router_config.dp_aware,
...@@ -162,8 +165,9 @@ impl GrpcRouter { ...@@ -162,8 +165,9 @@ impl GrpcRouter {
impl std::fmt::Debug for GrpcRouter { impl std::fmt::Debug for GrpcRouter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let stats = self.worker_registry.stats();
f.debug_struct("GrpcRouter") f.debug_struct("GrpcRouter")
.field("workers_count", &self.workers.read().unwrap().len()) .field("workers_count", &stats.total_workers)
.field("timeout_secs", &self.timeout_secs) .field("timeout_secs", &self.timeout_secs)
.field("interval_secs", &self.interval_secs) .field("interval_secs", &self.interval_secs)
.field("dp_aware", &self.dp_aware) .field("dp_aware", &self.dp_aware)
...@@ -285,9 +289,13 @@ impl WorkerManagement for GrpcRouter { ...@@ -285,9 +289,13 @@ impl WorkerManagement for GrpcRouter {
fn remove_worker(&self, _worker_url: &str) {} fn remove_worker(&self, _worker_url: &str) {}
fn get_worker_urls(&self) -> Vec<String> { fn get_worker_urls(&self) -> Vec<String> {
self.workers self.worker_registry
.read() .get_workers_filtered(
.unwrap() None, // any model
Some(WorkerType::Regular),
Some(crate::core::ConnectionMode::Grpc { port: None }),
false, // include all workers
)
.iter() .iter()
.map(|w| w.url().to_string()) .map(|w| w.url().to_string())
.collect() .collect()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment