Unverified Commit 1d1ce624 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] refactor router and worker management 2.5/n (#10677)

parent 60e2a7ce
...@@ -9,6 +9,7 @@ use super::{ ...@@ -9,6 +9,7 @@ use super::{
RoundRobinPolicy, RoundRobinPolicy,
}; };
use crate::config::types::PolicyConfig; use crate::config::types::PolicyConfig;
use crate::core::Worker;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
...@@ -255,6 +256,81 @@ impl PolicyRegistry { ...@@ -255,6 +256,81 @@ impl PolicyRegistry {
.map(Arc::clone) .map(Arc::clone)
.unwrap_or_else(|| self.get_default_policy()) .unwrap_or_else(|| self.get_default_policy())
} }
/// Initialize cache-aware policy with workers if applicable
/// This should be called after workers are registered for a model
pub fn init_cache_aware_policy(&self, model_id: &str, workers: &[Arc<dyn Worker>]) {
// Get the policy for this model
if let Some(policy) = self.get_policy(model_id) {
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy.as_any().downcast_ref::<CacheAwarePolicy>() {
debug!(
"Initializing cache-aware policy with {} workers for model {}",
workers.len(),
model_id
);
cache_aware.init_workers(workers);
}
}
}
}
/// Remove a worker from cache-aware policy if applicable
/// This should be called when a worker is being removed
pub fn remove_worker_from_cache_aware(&self, model_id: &str, worker_url: &str) {
// Get the policy for this model
if let Some(policy) = self.get_policy(model_id) {
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy.as_any().downcast_ref::<CacheAwarePolicy>() {
cache_aware.remove_worker_by_url(worker_url);
debug!(
"Removed worker {} from cache-aware policy for model {}",
worker_url, model_id
);
}
}
}
}
/// Initialize cache-aware policies for PD mode (prefill and decode)
pub fn init_pd_cache_aware_policies(
&self,
prefill_workers: &[Arc<dyn Worker>],
decode_workers: &[Arc<dyn Worker>],
) {
// Initialize prefill policy if it's cache-aware
if let Some(prefill_policy) = self.prefill_policy.read().unwrap().as_ref() {
if prefill_policy.name() == "cache_aware" {
if let Some(cache_aware) =
prefill_policy.as_any().downcast_ref::<CacheAwarePolicy>()
{
if !prefill_workers.is_empty() {
debug!(
"Initializing prefill cache-aware policy with {} workers",
prefill_workers.len()
);
cache_aware.init_workers(prefill_workers);
}
}
}
}
// Initialize decode policy if it's cache-aware
if let Some(decode_policy) = self.decode_policy.read().unwrap().as_ref() {
if decode_policy.name() == "cache_aware" {
if let Some(cache_aware) = decode_policy.as_any().downcast_ref::<CacheAwarePolicy>()
{
if !decode_workers.is_empty() {
debug!(
"Initializing decode cache-aware policy with {} workers",
decode_workers.len()
);
cache_aware.init_workers(decode_workers);
}
}
}
}
}
} }
impl std::fmt::Debug for PolicyRegistry { impl std::fmt::Debug for PolicyRegistry {
......
...@@ -232,18 +232,12 @@ impl PDRouter { ...@@ -232,18 +232,12 @@ impl PDRouter {
// Notify PolicyRegistry about the new worker // Notify PolicyRegistry about the new worker
let model_id = worker_arc.model_id(); let model_id = worker_arc.model_id();
let policy = self.policy_registry.on_worker_added(model_id, None); self.policy_registry.on_worker_added(model_id, None);
// If this is a cache-aware policy, update it with all workers for this model // Initialize cache-aware policy if applicable
if policy.name() == "cache_aware" { let model_workers = self.worker_registry.get_by_model_fast(model_id);
if let Some(cache_aware) = policy self.policy_registry
.as_any() .init_cache_aware_policy(model_id, &model_workers);
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
let model_workers = self.worker_registry.get_by_model_fast(model_id);
cache_aware.init_workers(&model_workers);
}
}
info!("Added prefill server: {}", url); info!("Added prefill server: {}", url);
Ok(format!("Successfully added prefill server: {}", url)) Ok(format!("Successfully added prefill server: {}", url))
...@@ -272,18 +266,12 @@ impl PDRouter { ...@@ -272,18 +266,12 @@ impl PDRouter {
// Notify PolicyRegistry about the new worker // Notify PolicyRegistry about the new worker
let model_id = worker_arc.model_id(); let model_id = worker_arc.model_id();
let policy = self.policy_registry.on_worker_added(model_id, None); self.policy_registry.on_worker_added(model_id, None);
// If this is a cache-aware policy, update it with all workers for this model // Initialize cache-aware policy if applicable
if policy.name() == "cache_aware" { let model_workers = self.worker_registry.get_by_model_fast(model_id);
if let Some(cache_aware) = policy self.policy_registry
.as_any() .init_cache_aware_policy(model_id, &model_workers);
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
let model_workers = self.worker_registry.get_by_model_fast(model_id);
cache_aware.init_workers(&model_workers);
}
}
info!("Added decode server: {}", url); info!("Added decode server: {}", url);
Ok(format!("Successfully added decode server: {}", url)) Ok(format!("Successfully added decode server: {}", url))
...@@ -307,17 +295,9 @@ impl PDRouter { ...@@ -307,17 +295,9 @@ impl PDRouter {
// Notify PolicyRegistry about the removed worker // Notify PolicyRegistry about the removed worker
self.policy_registry.on_worker_removed(&model_id); self.policy_registry.on_worker_removed(&model_id);
// Get the policy for this model to update cache-aware if needed // Remove from cache-aware policy if applicable
if let Some(policy) = self.policy_registry.get_policy(&model_id) { self.policy_registry
if policy.name() == "cache_aware" { .remove_worker_from_cache_aware(&model_id, url);
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.remove_worker_by_url(url);
}
}
}
} }
if removed.is_some() { if removed.is_some() {
...@@ -348,17 +328,9 @@ impl PDRouter { ...@@ -348,17 +328,9 @@ impl PDRouter {
// Notify PolicyRegistry about the removed worker // Notify PolicyRegistry about the removed worker
self.policy_registry.on_worker_removed(&model_id); self.policy_registry.on_worker_removed(&model_id);
// Get the policy for this model to update cache-aware if needed // Remove from cache-aware policy if applicable
if let Some(policy) = self.policy_registry.get_policy(&model_id) { self.policy_registry
if policy.name() == "cache_aware" { .remove_worker_from_cache_aware(&model_id, url);
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.remove_worker_by_url(url);
}
}
}
} }
if removed.is_some() { if removed.is_some() {
...@@ -2226,15 +2198,6 @@ mod tests { ...@@ -2226,15 +2198,6 @@ mod tests {
assert_eq!(prefill_workers.len(), 1); assert_eq!(prefill_workers.len(), 1);
} }
// ============= Bootstrap Injection Tests =============
// Note: These tests are commented out as we've moved to the optimized bootstrap injection
// approach that doesn't use the Bootstrap trait on GenerateReqInput anymore.
// TODO: Add new tests for the optimized bootstrap injection approach using
// RequestWithBootstrap and BatchRequestWithBootstrap wrappers
// ============= Worker Selection Tests =============
#[tokio::test] #[tokio::test]
async fn test_select_healthy_prefill_worker() { async fn test_select_healthy_prefill_worker() {
let router = create_test_pd_router(); let router = create_test_pd_router();
......
...@@ -70,21 +70,15 @@ impl Router { ...@@ -70,21 +70,15 @@ impl Router {
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs), window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
}; };
// Initialize cache-aware policy with workers if needed // Cache-aware policies are initialized in WorkerInitializer
let default_policy = ctx.policy_registry.get_default_policy();
if default_policy.name() == "cache_aware" {
if let Some(cache_aware) = default_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.init_workers(&workers);
}
}
// Setup load monitoring for PowerOfTwo policy // Setup load monitoring for PowerOfTwo policy
let (tx, rx) = tokio::sync::watch::channel(HashMap::new()); let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
let worker_loads = Arc::new(rx); let worker_loads = Arc::new(rx);
// Get default policy to check if we need load monitoring
let default_policy = ctx.policy_registry.get_default_policy();
// Check if default policy is power_of_two for load monitoring // Check if default policy is power_of_two for load monitoring
let load_monitor_handle = if default_policy.name() == "power_of_two" { let load_monitor_handle = if default_policy.name() == "power_of_two" {
let monitor_urls = worker_urls.clone(); let monitor_urls = worker_urls.clone();
...@@ -964,19 +958,13 @@ impl Router { ...@@ -964,19 +958,13 @@ impl Router {
// Notify PolicyRegistry about the new worker // Notify PolicyRegistry about the new worker
let model_id = worker_arc.model_id(); let model_id = worker_arc.model_id();
let policy = self.policy_registry.on_worker_added(model_id, None); self.policy_registry.on_worker_added(model_id, None);
// If this is a cache-aware policy, update it with all workers for this model // Initialize cache-aware policy if applicable
if policy.name() == "cache_aware" { let model_workers =
if let Some(cache_aware) = policy self.worker_registry.get_by_model_fast(model_id);
.as_any() self.policy_registry
.downcast_ref::<crate::policies::CacheAwarePolicy>( .init_cache_aware_policy(model_id, &model_workers);
) {
let model_workers =
self.worker_registry.get_by_model_fast(model_id);
cache_aware.init_workers(&model_workers);
}
}
worker_added = true; worker_added = true;
} }
...@@ -1000,20 +988,12 @@ impl Router { ...@@ -1000,20 +988,12 @@ impl Router {
// Notify PolicyRegistry about the new worker // Notify PolicyRegistry about the new worker
let model_id = worker_arc.model_id(); let model_id = worker_arc.model_id();
let policy = self.policy_registry.on_worker_added(model_id, None); self.policy_registry.on_worker_added(model_id, None);
// If this is a cache-aware policy, add this worker to it // Initialize cache-aware policy if applicable
if policy.name() == "cache_aware" { let model_workers = self.worker_registry.get_by_model_fast(model_id);
if let Some(cache_aware) = policy self.policy_registry
.as_any() .init_cache_aware_policy(model_id, &model_workers);
.downcast_ref::<crate::policies::CacheAwarePolicy>(
) {
// Get all workers for this model
let model_workers =
self.worker_registry.get_by_model_fast(model_id);
cache_aware.init_workers(&model_workers);
}
}
} }
RouterMetrics::set_active_workers(self.worker_registry.get_all().len()); RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
...@@ -1084,20 +1064,11 @@ impl Router { ...@@ -1084,20 +1064,11 @@ impl Router {
RouterMetrics::set_active_workers(self.worker_registry.get_all().len()); RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
// If any models are using cache aware policy, remove the workers from the tree
// Check each removed worker's model and get its policy
for dp_url in removed_workers.iter() { for dp_url in removed_workers.iter() {
if let Some(worker) = self.worker_registry.get_by_url(dp_url) { if let Some(worker) = self.worker_registry.get_by_url(dp_url) {
let model_id = worker.model_id(); let model_id = worker.model_id();
if let Some(policy) = self.policy_registry.get_policy(model_id) { self.policy_registry
if let Some(cache_aware) = policy .remove_worker_from_cache_aware(model_id, dp_url);
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.remove_worker_by_url(dp_url);
info!("Removed worker from cache-aware tree: {}", dp_url);
}
}
} }
} }
} else { } else {
...@@ -1118,16 +1089,8 @@ impl Router { ...@@ -1118,16 +1089,8 @@ impl Router {
RouterMetrics::set_active_workers(self.worker_registry.get_all().len()); RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
} }
// If the model is using cache aware policy, remove the worker from the tree self.policy_registry
if let Some(policy) = self.policy_registry.get_policy(&model_id) { .remove_worker_from_cache_aware(&model_id, worker_url);
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.remove_worker_by_url(worker_url);
info!("Removed worker from cache-aware tree: {}", worker_url);
}
}
} }
} }
......
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
use crate::config::types::{ConnectionMode as ConfigConnectionMode, RouterConfig, RoutingMode}; use crate::config::types::{ConnectionMode as ConfigConnectionMode, RouterConfig, RoutingMode};
use crate::core::{ use crate::core::{
BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, HealthConfig, WorkerRegistry, BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, HealthConfig, Worker, WorkerRegistry,
WorkerType, WorkerType,
}; };
use crate::policies::PolicyRegistry;
use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tracing::{info, warn}; use tracing::{info, warn};
...@@ -19,6 +21,7 @@ impl WorkerInitializer { ...@@ -19,6 +21,7 @@ impl WorkerInitializer {
pub async fn initialize_workers( pub async fn initialize_workers(
config: &RouterConfig, config: &RouterConfig,
worker_registry: &Arc<WorkerRegistry>, worker_registry: &Arc<WorkerRegistry>,
policy_registry: Option<&Arc<PolicyRegistry>>,
) -> Result<(), String> { ) -> Result<(), String> {
info!("Initializing workers for routing mode: {:?}", config.mode); info!("Initializing workers for routing mode: {:?}", config.mode);
...@@ -29,6 +32,7 @@ impl WorkerInitializer { ...@@ -29,6 +32,7 @@ impl WorkerInitializer {
&config.connection_mode, &config.connection_mode,
config, config,
worker_registry, worker_registry,
policy_registry,
) )
.await?; .await?;
} }
...@@ -42,6 +46,7 @@ impl WorkerInitializer { ...@@ -42,6 +46,7 @@ impl WorkerInitializer {
&config.connection_mode, &config.connection_mode,
config, config,
worker_registry, worker_registry,
policy_registry,
) )
.await?; .await?;
Self::create_decode_workers( Self::create_decode_workers(
...@@ -49,6 +54,7 @@ impl WorkerInitializer { ...@@ -49,6 +54,7 @@ impl WorkerInitializer {
&config.connection_mode, &config.connection_mode,
config, config,
worker_registry, worker_registry,
policy_registry,
) )
.await?; .await?;
} }
...@@ -76,6 +82,7 @@ impl WorkerInitializer { ...@@ -76,6 +82,7 @@ impl WorkerInitializer {
config_connection_mode: &ConfigConnectionMode, config_connection_mode: &ConfigConnectionMode,
config: &RouterConfig, config: &RouterConfig,
registry: &Arc<WorkerRegistry>, registry: &Arc<WorkerRegistry>,
policy_registry: Option<&Arc<PolicyRegistry>>,
) -> Result<(), String> { ) -> Result<(), String> {
info!("Creating {} regular workers", urls.len()); info!("Creating {} regular workers", urls.len());
...@@ -100,6 +107,8 @@ impl WorkerInitializer { ...@@ -100,6 +107,8 @@ impl WorkerInitializer {
success_threshold: config.health_check.success_threshold, success_threshold: config.health_check.success_threshold,
}; };
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
for url in urls { for url in urls {
// TODO: Add DP-aware support when we have dp_rank/dp_size info // TODO: Add DP-aware support when we have dp_rank/dp_size info
let worker = BasicWorkerBuilder::new(url.clone()) let worker = BasicWorkerBuilder::new(url.clone())
...@@ -109,8 +118,28 @@ impl WorkerInitializer { ...@@ -109,8 +118,28 @@ impl WorkerInitializer {
.health_config(health_config.clone()) .health_config(health_config.clone())
.build(); .build();
let worker_id = registry.register(Arc::new(worker)); let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
let model_id = worker_arc.model_id();
let worker_id = registry.register(Arc::clone(&worker_arc));
info!("Registered regular worker {} with ID {:?}", url, worker_id); info!("Registered regular worker {} with ID {:?}", url, worker_id);
// Track workers by model for cache-aware policy initialization
registered_workers
.entry(model_id.to_string())
.or_default()
.push(Arc::clone(&worker_arc));
// Notify policy registry about the worker
if let Some(policy_reg) = policy_registry {
policy_reg.on_worker_added(model_id, None);
}
}
// Initialize cache-aware policies with all workers for each model
if let Some(policy_reg) = policy_registry {
for (model_id, workers) in registered_workers {
policy_reg.init_cache_aware_policy(&model_id, &workers);
}
} }
Ok(()) Ok(())
...@@ -122,6 +151,7 @@ impl WorkerInitializer { ...@@ -122,6 +151,7 @@ impl WorkerInitializer {
config_connection_mode: &ConfigConnectionMode, config_connection_mode: &ConfigConnectionMode,
config: &RouterConfig, config: &RouterConfig,
registry: &Arc<WorkerRegistry>, registry: &Arc<WorkerRegistry>,
policy_registry: Option<&Arc<PolicyRegistry>>,
) -> Result<(), String> { ) -> Result<(), String> {
info!("Creating {} prefill workers", prefill_entries.len()); info!("Creating {} prefill workers", prefill_entries.len());
...@@ -149,6 +179,8 @@ impl WorkerInitializer { ...@@ -149,6 +179,8 @@ impl WorkerInitializer {
success_threshold: config.health_check.success_threshold, success_threshold: config.health_check.success_threshold,
}; };
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
for (url, bootstrap_port) in prefill_entries { for (url, bootstrap_port) in prefill_entries {
// TODO: Add DP-aware support when we have dp_rank/dp_size info // TODO: Add DP-aware support when we have dp_rank/dp_size info
let worker = BasicWorkerBuilder::new(url.clone()) let worker = BasicWorkerBuilder::new(url.clone())
...@@ -160,8 +192,33 @@ impl WorkerInitializer { ...@@ -160,8 +192,33 @@ impl WorkerInitializer {
.health_config(health_config.clone()) .health_config(health_config.clone())
.build(); .build();
let worker_id = registry.register(Arc::new(worker)); let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
let model_id = worker_arc.model_id();
let worker_id = registry.register(Arc::clone(&worker_arc));
info!("Registered prefill worker {} with ID {:?}", url, worker_id); info!("Registered prefill worker {} with ID {:?}", url, worker_id);
// Track workers by model for cache-aware policy initialization
registered_workers
.entry(model_id.to_string())
.or_default()
.push(Arc::clone(&worker_arc));
// Notify policy registry about the worker
if let Some(policy_reg) = policy_registry {
policy_reg.on_worker_added(model_id, None);
}
}
// Initialize cache-aware policies for PD mode
if let Some(policy_reg) = policy_registry {
// Collect all prefill workers
let all_prefill_workers: Vec<Arc<dyn Worker>> = registered_workers
.values()
.flat_map(|workers| workers.iter().cloned())
.collect();
// Initialize PD policies (will handle both prefill and decode, but we only have prefill here)
policy_reg.init_pd_cache_aware_policies(&all_prefill_workers, &[]);
} }
Ok(()) Ok(())
...@@ -173,6 +230,7 @@ impl WorkerInitializer { ...@@ -173,6 +230,7 @@ impl WorkerInitializer {
config_connection_mode: &ConfigConnectionMode, config_connection_mode: &ConfigConnectionMode,
config: &RouterConfig, config: &RouterConfig,
registry: &Arc<WorkerRegistry>, registry: &Arc<WorkerRegistry>,
policy_registry: Option<&Arc<PolicyRegistry>>,
) -> Result<(), String> { ) -> Result<(), String> {
info!("Creating {} decode workers", urls.len()); info!("Creating {} decode workers", urls.len());
...@@ -197,6 +255,8 @@ impl WorkerInitializer { ...@@ -197,6 +255,8 @@ impl WorkerInitializer {
success_threshold: config.health_check.success_threshold, success_threshold: config.health_check.success_threshold,
}; };
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
for url in urls { for url in urls {
// TODO: Add DP-aware support when we have dp_rank/dp_size info // TODO: Add DP-aware support when we have dp_rank/dp_size info
let worker = BasicWorkerBuilder::new(url.clone()) let worker = BasicWorkerBuilder::new(url.clone())
...@@ -206,8 +266,33 @@ impl WorkerInitializer { ...@@ -206,8 +266,33 @@ impl WorkerInitializer {
.health_config(health_config.clone()) .health_config(health_config.clone())
.build(); .build();
let worker_id = registry.register(Arc::new(worker)); let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
let model_id = worker_arc.model_id();
let worker_id = registry.register(Arc::clone(&worker_arc));
info!("Registered decode worker {} with ID {:?}", url, worker_id); info!("Registered decode worker {} with ID {:?}", url, worker_id);
// Track workers by model for cache-aware policy initialization
registered_workers
.entry(model_id.to_string())
.or_default()
.push(Arc::clone(&worker_arc));
// Notify policy registry about the worker
if let Some(policy_reg) = policy_registry {
policy_reg.on_worker_added(model_id, None);
}
}
// Initialize cache-aware policies for PD mode
if let Some(policy_reg) = policy_registry {
// Collect all decode workers
let all_decode_workers: Vec<Arc<dyn Worker>> = registered_workers
.values()
.flat_map(|workers| workers.iter().cloned())
.collect();
// Initialize PD policies (will handle both prefill and decode, but we only have decode here)
policy_reg.init_pd_cache_aware_policies(&[], &all_decode_workers);
} }
Ok(()) Ok(())
...@@ -281,7 +366,8 @@ impl WorkerInitializer { ...@@ -281,7 +366,8 @@ impl WorkerInitializer {
worker_type: WorkerType, worker_type: WorkerType,
config: &RouterConfig, config: &RouterConfig,
registry: &Arc<WorkerRegistry>, registry: &Arc<WorkerRegistry>,
grpc_clients: &mut std::collections::HashMap<String, crate::grpc::SglangSchedulerClient>, policy_registry: Option<&Arc<PolicyRegistry>>,
grpc_clients: &mut HashMap<String, crate::grpc::SglangSchedulerClient>,
) -> Result<(), String> { ) -> Result<(), String> {
info!( info!(
"Creating {} gRPC workers of type {:?}", "Creating {} gRPC workers of type {:?}",
...@@ -307,6 +393,8 @@ impl WorkerInitializer { ...@@ -307,6 +393,8 @@ impl WorkerInitializer {
success_threshold: config.health_check.success_threshold, success_threshold: config.health_check.success_threshold,
}; };
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
for url in worker_urls { for url in worker_urls {
if let Some(client) = grpc_clients.remove(url) { if let Some(client) = grpc_clients.remove(url) {
let worker = BasicWorkerBuilder::new(url.clone()) let worker = BasicWorkerBuilder::new(url.clone())
...@@ -317,13 +405,33 @@ impl WorkerInitializer { ...@@ -317,13 +405,33 @@ impl WorkerInitializer {
.grpc_client(client) .grpc_client(client)
.build(); .build();
let worker_id = registry.register(Arc::new(worker)); let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
let model_id = worker_arc.model_id();
let worker_id = registry.register(Arc::clone(&worker_arc));
info!("Registered gRPC worker {} with ID {:?}", url, worker_id); info!("Registered gRPC worker {} with ID {:?}", url, worker_id);
// Track workers by model for cache-aware policy initialization
registered_workers
.entry(model_id.to_string())
.or_default()
.push(Arc::clone(&worker_arc));
// Notify policy registry about the worker
if let Some(policy_reg) = policy_registry {
policy_reg.on_worker_added(model_id, None);
}
} else { } else {
warn!("No gRPC client available for worker {}, skipping", url); warn!("No gRPC client available for worker {}, skipping", url);
} }
} }
// Initialize cache-aware policies with all workers for each model
if let Some(policy_reg) = policy_registry {
for (model_id, workers) in registered_workers {
policy_reg.init_cache_aware_policy(&model_id, &workers);
}
}
Ok(()) Ok(())
} }
} }
......
...@@ -595,15 +595,17 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -595,15 +595,17 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
let app_context = Arc::new(app_context); let app_context = Arc::new(app_context);
// Initialize workers before creating routers
// This separates worker lifecycle from router lifecycle
info!( info!(
"Initializing workers for routing mode: {:?}", "Initializing workers for routing mode: {:?}",
config.router_config.mode config.router_config.mode
); );
WorkerInitializer::initialize_workers(&config.router_config, &app_context.worker_registry) WorkerInitializer::initialize_workers(
.await &config.router_config,
.map_err(|e| format!("Failed to initialize workers: {}", e))?; &app_context.worker_registry,
Some(&app_context.policy_registry),
)
.await
.map_err(|e| format!("Failed to initialize workers: {}", e))?;
let worker_stats = app_context.worker_registry.stats(); let worker_stats = app_context.worker_registry.stats();
info!( info!(
......
...@@ -104,7 +104,7 @@ impl TestContext { ...@@ -104,7 +104,7 @@ impl TestContext {
// Initialize workers in the registry before creating router // Initialize workers in the registry before creating router
if !worker_urls.is_empty() { if !worker_urls.is_empty() {
use sglang_router_rs::routers::WorkerInitializer; use sglang_router_rs::routers::WorkerInitializer;
WorkerInitializer::initialize_workers(&config, &app_context.worker_registry) WorkerInitializer::initialize_workers(&config, &app_context.worker_registry, None)
.await .await
.expect("Failed to initialize workers"); .expect("Failed to initialize workers");
} }
......
...@@ -48,7 +48,7 @@ impl TestContext { ...@@ -48,7 +48,7 @@ impl TestContext {
// Initialize workers in the registry before creating router // Initialize workers in the registry before creating router
if !worker_urls.is_empty() { if !worker_urls.is_empty() {
use sglang_router_rs::routers::WorkerInitializer; use sglang_router_rs::routers::WorkerInitializer;
WorkerInitializer::initialize_workers(&config, &app_context.worker_registry) WorkerInitializer::initialize_workers(&config, &app_context.worker_registry, None)
.await .await
.expect("Failed to initialize workers"); .expect("Failed to initialize workers");
} }
......
...@@ -49,7 +49,7 @@ impl TestContext { ...@@ -49,7 +49,7 @@ impl TestContext {
// Initialize workers in the registry before creating router // Initialize workers in the registry before creating router
if !worker_urls.is_empty() { if !worker_urls.is_empty() {
use sglang_router_rs::routers::WorkerInitializer; use sglang_router_rs::routers::WorkerInitializer;
WorkerInitializer::initialize_workers(&config, &app_context.worker_registry) WorkerInitializer::initialize_workers(&config, &app_context.worker_registry, None)
.await .await
.expect("Failed to initialize workers"); .expect("Failed to initialize workers");
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment