Unverified Commit 1d1ce624 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] refactor router and worker management 2.5/n (#10677)

parent 60e2a7ce
......@@ -9,6 +9,7 @@ use super::{
RoundRobinPolicy,
};
use crate::config::types::PolicyConfig;
use crate::core::Worker;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tracing::{debug, info, warn};
......@@ -255,6 +256,81 @@ impl PolicyRegistry {
.map(Arc::clone)
.unwrap_or_else(|| self.get_default_policy())
}
/// Initialize cache-aware policy with workers if applicable
/// This should be called after workers are registered for a model
pub fn init_cache_aware_policy(&self, model_id: &str, workers: &[Arc<dyn Worker>]) {
// Get the policy for this model
if let Some(policy) = self.get_policy(model_id) {
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy.as_any().downcast_ref::<CacheAwarePolicy>() {
debug!(
"Initializing cache-aware policy with {} workers for model {}",
workers.len(),
model_id
);
cache_aware.init_workers(workers);
}
}
}
}
/// Remove a worker from cache-aware policy if applicable
/// This should be called when a worker is being removed
pub fn remove_worker_from_cache_aware(&self, model_id: &str, worker_url: &str) {
// Get the policy for this model
if let Some(policy) = self.get_policy(model_id) {
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy.as_any().downcast_ref::<CacheAwarePolicy>() {
cache_aware.remove_worker_by_url(worker_url);
debug!(
"Removed worker {} from cache-aware policy for model {}",
worker_url, model_id
);
}
}
}
}
/// Initialize cache-aware policies for PD mode (prefill and decode)
pub fn init_pd_cache_aware_policies(
&self,
prefill_workers: &[Arc<dyn Worker>],
decode_workers: &[Arc<dyn Worker>],
) {
// Initialize prefill policy if it's cache-aware
if let Some(prefill_policy) = self.prefill_policy.read().unwrap().as_ref() {
if prefill_policy.name() == "cache_aware" {
if let Some(cache_aware) =
prefill_policy.as_any().downcast_ref::<CacheAwarePolicy>()
{
if !prefill_workers.is_empty() {
debug!(
"Initializing prefill cache-aware policy with {} workers",
prefill_workers.len()
);
cache_aware.init_workers(prefill_workers);
}
}
}
}
// Initialize decode policy if it's cache-aware
if let Some(decode_policy) = self.decode_policy.read().unwrap().as_ref() {
if decode_policy.name() == "cache_aware" {
if let Some(cache_aware) = decode_policy.as_any().downcast_ref::<CacheAwarePolicy>()
{
if !decode_workers.is_empty() {
debug!(
"Initializing decode cache-aware policy with {} workers",
decode_workers.len()
);
cache_aware.init_workers(decode_workers);
}
}
}
}
}
}
impl std::fmt::Debug for PolicyRegistry {
......
......@@ -232,18 +232,12 @@ impl PDRouter {
// Notify PolicyRegistry about the new worker
let model_id = worker_arc.model_id();
let policy = self.policy_registry.on_worker_added(model_id, None);
self.policy_registry.on_worker_added(model_id, None);
// If this is a cache-aware policy, update it with all workers for this model
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
let model_workers = self.worker_registry.get_by_model_fast(model_id);
cache_aware.init_workers(&model_workers);
}
}
// Initialize cache-aware policy if applicable
let model_workers = self.worker_registry.get_by_model_fast(model_id);
self.policy_registry
.init_cache_aware_policy(model_id, &model_workers);
info!("Added prefill server: {}", url);
Ok(format!("Successfully added prefill server: {}", url))
......@@ -272,18 +266,12 @@ impl PDRouter {
// Notify PolicyRegistry about the new worker
let model_id = worker_arc.model_id();
let policy = self.policy_registry.on_worker_added(model_id, None);
self.policy_registry.on_worker_added(model_id, None);
// If this is a cache-aware policy, update it with all workers for this model
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
let model_workers = self.worker_registry.get_by_model_fast(model_id);
cache_aware.init_workers(&model_workers);
}
}
// Initialize cache-aware policy if applicable
let model_workers = self.worker_registry.get_by_model_fast(model_id);
self.policy_registry
.init_cache_aware_policy(model_id, &model_workers);
info!("Added decode server: {}", url);
Ok(format!("Successfully added decode server: {}", url))
......@@ -307,17 +295,9 @@ impl PDRouter {
// Notify PolicyRegistry about the removed worker
self.policy_registry.on_worker_removed(&model_id);
// Get the policy for this model to update cache-aware if needed
if let Some(policy) = self.policy_registry.get_policy(&model_id) {
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.remove_worker_by_url(url);
}
}
}
// Remove from cache-aware policy if applicable
self.policy_registry
.remove_worker_from_cache_aware(&model_id, url);
}
if removed.is_some() {
......@@ -348,17 +328,9 @@ impl PDRouter {
// Notify PolicyRegistry about the removed worker
self.policy_registry.on_worker_removed(&model_id);
// Get the policy for this model to update cache-aware if needed
if let Some(policy) = self.policy_registry.get_policy(&model_id) {
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.remove_worker_by_url(url);
}
}
}
// Remove from cache-aware policy if applicable
self.policy_registry
.remove_worker_from_cache_aware(&model_id, url);
}
if removed.is_some() {
......@@ -2226,15 +2198,6 @@ mod tests {
assert_eq!(prefill_workers.len(), 1);
}
// ============= Bootstrap Injection Tests =============
// Note: These tests are commented out as we've moved to the optimized bootstrap injection
// approach that doesn't use the Bootstrap trait on GenerateReqInput anymore.
// TODO: Add new tests for the optimized bootstrap injection approach using
// RequestWithBootstrap and BatchRequestWithBootstrap wrappers
// ============= Worker Selection Tests =============
#[tokio::test]
async fn test_select_healthy_prefill_worker() {
let router = create_test_pd_router();
......
......@@ -70,21 +70,15 @@ impl Router {
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Initialize cache-aware policy with workers if needed
let default_policy = ctx.policy_registry.get_default_policy();
if default_policy.name() == "cache_aware" {
if let Some(cache_aware) = default_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.init_workers(&workers);
}
}
// Cache-aware policies are initialized in WorkerInitializer
// Setup load monitoring for PowerOfTwo policy
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
let worker_loads = Arc::new(rx);
// Get default policy to check if we need load monitoring
let default_policy = ctx.policy_registry.get_default_policy();
// Check if default policy is power_of_two for load monitoring
let load_monitor_handle = if default_policy.name() == "power_of_two" {
let monitor_urls = worker_urls.clone();
......@@ -964,19 +958,13 @@ impl Router {
// Notify PolicyRegistry about the new worker
let model_id = worker_arc.model_id();
let policy = self.policy_registry.on_worker_added(model_id, None);
// If this is a cache-aware policy, update it with all workers for this model
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>(
) {
let model_workers =
self.worker_registry.get_by_model_fast(model_id);
cache_aware.init_workers(&model_workers);
}
}
self.policy_registry.on_worker_added(model_id, None);
// Initialize cache-aware policy if applicable
let model_workers =
self.worker_registry.get_by_model_fast(model_id);
self.policy_registry
.init_cache_aware_policy(model_id, &model_workers);
worker_added = true;
}
......@@ -1000,20 +988,12 @@ impl Router {
// Notify PolicyRegistry about the new worker
let model_id = worker_arc.model_id();
let policy = self.policy_registry.on_worker_added(model_id, None);
// If this is a cache-aware policy, add this worker to it
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>(
) {
// Get all workers for this model
let model_workers =
self.worker_registry.get_by_model_fast(model_id);
cache_aware.init_workers(&model_workers);
}
}
self.policy_registry.on_worker_added(model_id, None);
// Initialize cache-aware policy if applicable
let model_workers = self.worker_registry.get_by_model_fast(model_id);
self.policy_registry
.init_cache_aware_policy(model_id, &model_workers);
}
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
......@@ -1084,20 +1064,11 @@ impl Router {
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
// If any models are using cache aware policy, remove the workers from the tree
// Check each removed worker's model and get its policy
for dp_url in removed_workers.iter() {
if let Some(worker) = self.worker_registry.get_by_url(dp_url) {
let model_id = worker.model_id();
if let Some(policy) = self.policy_registry.get_policy(model_id) {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.remove_worker_by_url(dp_url);
info!("Removed worker from cache-aware tree: {}", dp_url);
}
}
self.policy_registry
.remove_worker_from_cache_aware(model_id, dp_url);
}
}
} else {
......@@ -1118,16 +1089,8 @@ impl Router {
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
}
// If the model is using cache aware policy, remove the worker from the tree
if let Some(policy) = self.policy_registry.get_policy(&model_id) {
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.remove_worker_by_url(worker_url);
info!("Removed worker from cache-aware tree: {}", worker_url);
}
}
self.policy_registry
.remove_worker_from_cache_aware(&model_id, worker_url);
}
}
......
......@@ -3,9 +3,11 @@
use crate::config::types::{ConnectionMode as ConfigConnectionMode, RouterConfig, RoutingMode};
use crate::core::{
BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, HealthConfig, WorkerRegistry,
BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, HealthConfig, Worker, WorkerRegistry,
WorkerType,
};
use crate::policies::PolicyRegistry;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tracing::{info, warn};
......@@ -19,6 +21,7 @@ impl WorkerInitializer {
pub async fn initialize_workers(
config: &RouterConfig,
worker_registry: &Arc<WorkerRegistry>,
policy_registry: Option<&Arc<PolicyRegistry>>,
) -> Result<(), String> {
info!("Initializing workers for routing mode: {:?}", config.mode);
......@@ -29,6 +32,7 @@ impl WorkerInitializer {
&config.connection_mode,
config,
worker_registry,
policy_registry,
)
.await?;
}
......@@ -42,6 +46,7 @@ impl WorkerInitializer {
&config.connection_mode,
config,
worker_registry,
policy_registry,
)
.await?;
Self::create_decode_workers(
......@@ -49,6 +54,7 @@ impl WorkerInitializer {
&config.connection_mode,
config,
worker_registry,
policy_registry,
)
.await?;
}
......@@ -76,6 +82,7 @@ impl WorkerInitializer {
config_connection_mode: &ConfigConnectionMode,
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
policy_registry: Option<&Arc<PolicyRegistry>>,
) -> Result<(), String> {
info!("Creating {} regular workers", urls.len());
......@@ -100,6 +107,8 @@ impl WorkerInitializer {
success_threshold: config.health_check.success_threshold,
};
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
for url in urls {
// TODO: Add DP-aware support when we have dp_rank/dp_size info
let worker = BasicWorkerBuilder::new(url.clone())
......@@ -109,8 +118,28 @@ impl WorkerInitializer {
.health_config(health_config.clone())
.build();
let worker_id = registry.register(Arc::new(worker));
let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
let model_id = worker_arc.model_id();
let worker_id = registry.register(Arc::clone(&worker_arc));
info!("Registered regular worker {} with ID {:?}", url, worker_id);
// Track workers by model for cache-aware policy initialization
registered_workers
.entry(model_id.to_string())
.or_default()
.push(Arc::clone(&worker_arc));
// Notify policy registry about the worker
if let Some(policy_reg) = policy_registry {
policy_reg.on_worker_added(model_id, None);
}
}
// Initialize cache-aware policies with all workers for each model
if let Some(policy_reg) = policy_registry {
for (model_id, workers) in registered_workers {
policy_reg.init_cache_aware_policy(&model_id, &workers);
}
}
Ok(())
......@@ -122,6 +151,7 @@ impl WorkerInitializer {
config_connection_mode: &ConfigConnectionMode,
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
policy_registry: Option<&Arc<PolicyRegistry>>,
) -> Result<(), String> {
info!("Creating {} prefill workers", prefill_entries.len());
......@@ -149,6 +179,8 @@ impl WorkerInitializer {
success_threshold: config.health_check.success_threshold,
};
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
for (url, bootstrap_port) in prefill_entries {
// TODO: Add DP-aware support when we have dp_rank/dp_size info
let worker = BasicWorkerBuilder::new(url.clone())
......@@ -160,8 +192,33 @@ impl WorkerInitializer {
.health_config(health_config.clone())
.build();
let worker_id = registry.register(Arc::new(worker));
let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
let model_id = worker_arc.model_id();
let worker_id = registry.register(Arc::clone(&worker_arc));
info!("Registered prefill worker {} with ID {:?}", url, worker_id);
// Track workers by model for cache-aware policy initialization
registered_workers
.entry(model_id.to_string())
.or_default()
.push(Arc::clone(&worker_arc));
// Notify policy registry about the worker
if let Some(policy_reg) = policy_registry {
policy_reg.on_worker_added(model_id, None);
}
}
// Initialize cache-aware policies for PD mode
if let Some(policy_reg) = policy_registry {
// Collect all prefill workers
let all_prefill_workers: Vec<Arc<dyn Worker>> = registered_workers
.values()
.flat_map(|workers| workers.iter().cloned())
.collect();
// Initialize PD policies (will handle both prefill and decode, but we only have prefill here)
policy_reg.init_pd_cache_aware_policies(&all_prefill_workers, &[]);
}
Ok(())
......@@ -173,6 +230,7 @@ impl WorkerInitializer {
config_connection_mode: &ConfigConnectionMode,
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
policy_registry: Option<&Arc<PolicyRegistry>>,
) -> Result<(), String> {
info!("Creating {} decode workers", urls.len());
......@@ -197,6 +255,8 @@ impl WorkerInitializer {
success_threshold: config.health_check.success_threshold,
};
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
for url in urls {
// TODO: Add DP-aware support when we have dp_rank/dp_size info
let worker = BasicWorkerBuilder::new(url.clone())
......@@ -206,8 +266,33 @@ impl WorkerInitializer {
.health_config(health_config.clone())
.build();
let worker_id = registry.register(Arc::new(worker));
let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
let model_id = worker_arc.model_id();
let worker_id = registry.register(Arc::clone(&worker_arc));
info!("Registered decode worker {} with ID {:?}", url, worker_id);
// Track workers by model for cache-aware policy initialization
registered_workers
.entry(model_id.to_string())
.or_default()
.push(Arc::clone(&worker_arc));
// Notify policy registry about the worker
if let Some(policy_reg) = policy_registry {
policy_reg.on_worker_added(model_id, None);
}
}
// Initialize cache-aware policies for PD mode
if let Some(policy_reg) = policy_registry {
// Collect all decode workers
let all_decode_workers: Vec<Arc<dyn Worker>> = registered_workers
.values()
.flat_map(|workers| workers.iter().cloned())
.collect();
// Initialize PD policies (will handle both prefill and decode, but we only have decode here)
policy_reg.init_pd_cache_aware_policies(&[], &all_decode_workers);
}
Ok(())
......@@ -281,7 +366,8 @@ impl WorkerInitializer {
worker_type: WorkerType,
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
grpc_clients: &mut std::collections::HashMap<String, crate::grpc::SglangSchedulerClient>,
policy_registry: Option<&Arc<PolicyRegistry>>,
grpc_clients: &mut HashMap<String, crate::grpc::SglangSchedulerClient>,
) -> Result<(), String> {
info!(
"Creating {} gRPC workers of type {:?}",
......@@ -307,6 +393,8 @@ impl WorkerInitializer {
success_threshold: config.health_check.success_threshold,
};
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
for url in worker_urls {
if let Some(client) = grpc_clients.remove(url) {
let worker = BasicWorkerBuilder::new(url.clone())
......@@ -317,13 +405,33 @@ impl WorkerInitializer {
.grpc_client(client)
.build();
let worker_id = registry.register(Arc::new(worker));
let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
let model_id = worker_arc.model_id();
let worker_id = registry.register(Arc::clone(&worker_arc));
info!("Registered gRPC worker {} with ID {:?}", url, worker_id);
// Track workers by model for cache-aware policy initialization
registered_workers
.entry(model_id.to_string())
.or_default()
.push(Arc::clone(&worker_arc));
// Notify policy registry about the worker
if let Some(policy_reg) = policy_registry {
policy_reg.on_worker_added(model_id, None);
}
} else {
warn!("No gRPC client available for worker {}, skipping", url);
}
}
// Initialize cache-aware policies with all workers for each model
if let Some(policy_reg) = policy_registry {
for (model_id, workers) in registered_workers {
policy_reg.init_cache_aware_policy(&model_id, &workers);
}
}
Ok(())
}
}
......
......@@ -595,15 +595,17 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
let app_context = Arc::new(app_context);
// Initialize workers before creating routers
// This separates worker lifecycle from router lifecycle
info!(
"Initializing workers for routing mode: {:?}",
config.router_config.mode
);
WorkerInitializer::initialize_workers(&config.router_config, &app_context.worker_registry)
.await
.map_err(|e| format!("Failed to initialize workers: {}", e))?;
WorkerInitializer::initialize_workers(
&config.router_config,
&app_context.worker_registry,
Some(&app_context.policy_registry),
)
.await
.map_err(|e| format!("Failed to initialize workers: {}", e))?;
let worker_stats = app_context.worker_registry.stats();
info!(
......
......@@ -104,7 +104,7 @@ impl TestContext {
// Initialize workers in the registry before creating router
if !worker_urls.is_empty() {
use sglang_router_rs::routers::WorkerInitializer;
WorkerInitializer::initialize_workers(&config, &app_context.worker_registry)
WorkerInitializer::initialize_workers(&config, &app_context.worker_registry, None)
.await
.expect("Failed to initialize workers");
}
......
......@@ -48,7 +48,7 @@ impl TestContext {
// Initialize workers in the registry before creating router
if !worker_urls.is_empty() {
use sglang_router_rs::routers::WorkerInitializer;
WorkerInitializer::initialize_workers(&config, &app_context.worker_registry)
WorkerInitializer::initialize_workers(&config, &app_context.worker_registry, None)
.await
.expect("Failed to initialize workers");
}
......
......@@ -49,7 +49,7 @@ impl TestContext {
// Initialize workers in the registry before creating router
if !worker_urls.is_empty() {
use sglang_router_rs::routers::WorkerInitializer;
WorkerInitializer::initialize_workers(&config, &app_context.worker_registry)
WorkerInitializer::initialize_workers(&config, &app_context.worker_registry, None)
.await
.expect("Failed to initialize workers");
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment