Unverified Commit 00eb5eb7 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] refactor router and worker management 2/n (#10666)

parent dab4663b
...@@ -47,18 +47,17 @@ impl RouterFactory { ...@@ -47,18 +47,17 @@ impl RouterFactory {
ConnectionMode::Http => { ConnectionMode::Http => {
// Route to HTTP implementation based on routing mode // Route to HTTP implementation based on routing mode
match &ctx.router_config.mode { match &ctx.router_config.mode {
RoutingMode::Regular { worker_urls } => { RoutingMode::Regular { .. } => {
Self::create_regular_router(worker_urls, ctx).await // Workers already initialized in registry
Self::create_regular_router(ctx).await
} }
RoutingMode::PrefillDecode { RoutingMode::PrefillDecode {
prefill_urls,
decode_urls,
prefill_policy, prefill_policy,
decode_policy, decode_policy,
..
} => { } => {
// Workers already initialized in registry
Self::create_pd_router( Self::create_pd_router(
prefill_urls,
decode_urls,
prefill_policy.as_ref(), prefill_policy.as_ref(),
decode_policy.as_ref(), decode_policy.as_ref(),
&ctx.router_config.policy, &ctx.router_config.policy,
...@@ -76,19 +75,17 @@ impl RouterFactory { ...@@ -76,19 +75,17 @@ impl RouterFactory {
/// Create a regular router /// Create a regular router
pub async fn create_regular_router( pub async fn create_regular_router(
worker_urls: &[String],
ctx: &Arc<AppContext>, ctx: &Arc<AppContext>,
) -> Result<Box<dyn RouterTrait>, String> { ) -> Result<Box<dyn RouterTrait>, String> {
// Create regular router with context // Create regular router with context
let router = Router::new(worker_urls.to_vec(), ctx).await?; // Workers should already be initialized in the registry
let router = Router::new(ctx).await?;
Ok(Box::new(router)) Ok(Box::new(router))
} }
/// Create a PD router with injected policy /// Create a PD router with injected policy
pub async fn create_pd_router( pub async fn create_pd_router(
prefill_urls: &[(String, Option<u16>)],
decode_urls: &[String],
prefill_policy_config: Option<&PolicyConfig>, prefill_policy_config: Option<&PolicyConfig>,
decode_policy_config: Option<&PolicyConfig>, decode_policy_config: Option<&PolicyConfig>,
main_policy_config: &PolicyConfig, main_policy_config: &PolicyConfig,
...@@ -105,7 +102,8 @@ impl RouterFactory { ...@@ -105,7 +102,8 @@ impl RouterFactory {
ctx.policy_registry.set_decode_policy(decode_policy); ctx.policy_registry.set_decode_policy(decode_policy);
// Create PD router with context (policies are in PolicyRegistry) // Create PD router with context (policies are in PolicyRegistry)
let router = PDRouter::new(prefill_urls.to_vec(), decode_urls.to_vec(), ctx).await?; // Workers should already be initialized in the registry
let router = PDRouter::new(ctx).await?;
Ok(Box::new(router)) Ok(Box::new(router))
} }
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
use super::pd_types::{api_path, PDRouterError}; use super::pd_types::{api_path, PDRouterError};
use crate::config::types::RetryConfig; use crate::config::types::RetryConfig;
use crate::core::{ use crate::core::{
is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, RetryExecutor, is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, RetryExecutor,
Worker, WorkerLoadGuard, WorkerRegistry, WorkerType, Worker, WorkerLoadGuard, WorkerRegistry, WorkerType,
}; };
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
...@@ -371,12 +371,30 @@ impl PDRouter { ...@@ -371,12 +371,30 @@ impl PDRouter {
} }
} }
#[allow(clippy::too_many_arguments)] pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
pub async fn new( let prefill_workers = ctx.worker_registry.get_workers_filtered(
prefill_urls: Vec<(String, Option<u16>)>, None, // any model
decode_urls: Vec<String>, Some(WorkerType::Prefill {
ctx: &Arc<crate::server::AppContext>, bootstrap_port: None,
) -> Result<Self, String> { }),
Some(ConnectionMode::Http),
false, // include all workers
);
let decode_workers = ctx.worker_registry.get_workers_filtered(
None, // any model
Some(WorkerType::Decode),
Some(ConnectionMode::Http),
false, // include all workers
);
// Get all worker URLs for monitoring
let all_urls: Vec<String> = prefill_workers
.iter()
.chain(decode_workers.iter())
.map(|w| w.url().to_string())
.collect();
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig // Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config(); let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
let core_cb_config = CircuitBreakerConfig { let core_cb_config = CircuitBreakerConfig {
...@@ -386,60 +404,6 @@ impl PDRouter { ...@@ -386,60 +404,6 @@ impl PDRouter {
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs), window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
}; };
// Register prefill workers in the registry
for (url, port) in prefill_urls {
let worker = BasicWorkerBuilder::new(url)
.worker_type(WorkerType::Prefill {
bootstrap_port: port,
})
.circuit_breaker_config(core_cb_config.clone())
.health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
})
.build();
ctx.worker_registry.register(Arc::new(worker));
}
// Register decode workers in the registry
for url in decode_urls {
let worker = BasicWorkerBuilder::new(url)
.worker_type(WorkerType::Decode)
.circuit_breaker_config(core_cb_config.clone())
.health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
})
.build();
ctx.worker_registry.register(Arc::new(worker));
}
// Get all workers from registry for health check
let all_workers = ctx.worker_registry.get_all();
let all_urls: Vec<String> = all_workers
.iter()
.map(|worker| worker.url().to_string())
.collect();
if !all_urls.is_empty() {
crate::routers::http::router::Router::wait_for_healthy_workers(
&all_urls,
ctx.router_config.worker_startup_timeout_secs,
ctx.router_config.worker_startup_check_interval_secs,
)
.await?;
}
// Initialize cache-aware policies with workers from registry
// Note: We need to get workers by type and convert to Box<dyn Worker> for CacheAwarePolicy
// This is a temporary workaround until CacheAwarePolicy is updated to work with Arc<dyn Worker>
// TODO: Update CacheAwarePolicy to accept Arc<dyn Worker> instead of Box<dyn Worker>
// Set up background load monitoring for power-of-two selection // Set up background load monitoring for power-of-two selection
let (tx, rx) = tokio::sync::watch::channel(HashMap::new()); let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
let worker_loads = Arc::new(rx); let worker_loads = Arc::new(rx);
...@@ -471,11 +435,8 @@ impl PDRouter { ...@@ -471,11 +435,8 @@ impl PDRouter {
None None
}; };
// Note: Health checking is now handled centrally by RouterManager
// Individual routers no longer need to manage health checkers
// Build a dedicated prefill client for fire-and-forget semantics // Build a dedicated prefill client for fire-and-forget semantics
let prefill_client = reqwest::Client::builder() let prefill_client = Client::builder()
.pool_max_idle_per_host(0) .pool_max_idle_per_host(0)
.http1_only() .http1_only()
.connect_timeout(Duration::from_millis(300)) .connect_timeout(Duration::from_millis(300))
...@@ -489,6 +450,7 @@ impl PDRouter { ...@@ -489,6 +450,7 @@ impl PDRouter {
// Spawn a coordinator with limited concurrent drain tasks // Spawn a coordinator with limited concurrent drain tasks
// This prevents unbounded task spawning under extreme load // This prevents unbounded task spawning under extreme load
// TODO reevaluate a simpler approach (e.g. do we really need to deal with fire and forget)
tokio::spawn(async move { tokio::spawn(async move {
info!("Prefill drain coordinator started"); info!("Prefill drain coordinator started");
...@@ -513,7 +475,7 @@ impl PDRouter { ...@@ -513,7 +475,7 @@ impl PDRouter {
// Drain the response body efficiently // Drain the response body efficiently
// Use streaming to avoid loading entire body into memory // Use streaming to avoid loading entire body into memory
let start = std::time::Instant::now(); let start = Instant::now();
let mut stream = response.bytes_stream(); let mut stream = response.bytes_stream();
let mut bytes_drained = 0; let mut bytes_drained = 0;
......
use crate::config::types::RetryConfig; use crate::config::types::RetryConfig;
use crate::core::{ use crate::core::{
is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, RetryExecutor, is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, RetryExecutor,
Worker, WorkerRegistry, WorkerType, Worker, WorkerRegistry, WorkerType,
}; };
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
...@@ -47,31 +47,19 @@ pub struct Router { ...@@ -47,31 +47,19 @@ pub struct Router {
impl Router { impl Router {
/// Create a new router with injected policy and client /// Create a new router with injected policy and client
#[allow(clippy::too_many_arguments)] pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
pub async fn new( let workers = ctx.worker_registry.get_workers_filtered(
worker_urls: Vec<String>, None, // any model
ctx: &Arc<crate::server::AppContext>, Some(WorkerType::Regular),
) -> Result<Self, String> { Some(ConnectionMode::Http),
false, // include all workers
);
// Update active workers gauge // Update active workers gauge
RouterMetrics::set_active_workers(worker_urls.len()); RouterMetrics::set_active_workers(workers.len());
// Wait for workers to be healthy (skip if empty - for service discovery mode)
if !worker_urls.is_empty() {
Self::wait_for_healthy_workers(
&worker_urls,
ctx.router_config.worker_startup_timeout_secs,
ctx.router_config.worker_startup_check_interval_secs,
)
.await?;
}
let worker_urls = if ctx.router_config.dp_aware { // Get worker URLs for monitoring
// worker address now in the format of "http://host:port@dp_rank" let worker_urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
Self::get_dp_aware_workers(&worker_urls, &ctx.router_config.api_key)
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?
} else {
worker_urls
};
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig // Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config(); let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
...@@ -82,40 +70,14 @@ impl Router { ...@@ -82,40 +70,14 @@ impl Router {
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs), window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
}; };
// Register workers in the registry // Initialize cache-aware policy with workers if needed
// In IGW mode, we need to fetch model info from workers let default_policy = ctx.policy_registry.get_default_policy();
for url in &worker_urls { if default_policy.name() == "cache_aware" {
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint if let Some(cache_aware) = default_policy
// For now, create worker without model_id
let worker = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Regular)
.circuit_breaker_config(core_cb_config.clone())
.health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
})
.build();
let worker_arc = Arc::new(worker);
ctx.worker_registry.register(worker_arc.clone());
// Notify PolicyRegistry about the new worker
let model_id = worker_arc.model_id();
let policy = ctx.policy_registry.on_worker_added(model_id, None);
// If this is a cache-aware policy and it's the first worker for this model,
// initialize it with the worker
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy
.as_any() .as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>() .downcast_ref::<crate::policies::CacheAwarePolicy>()
{ {
let worker_dyn: Arc<dyn Worker> = worker_arc.clone(); cache_aware.init_workers(&workers);
cache_aware.init_workers(std::slice::from_ref(&worker_dyn));
}
} }
} }
...@@ -124,7 +86,6 @@ impl Router { ...@@ -124,7 +86,6 @@ impl Router {
let worker_loads = Arc::new(rx); let worker_loads = Arc::new(rx);
// Check if default policy is power_of_two for load monitoring // Check if default policy is power_of_two for load monitoring
let default_policy = ctx.policy_registry.get_default_policy();
let load_monitor_handle = if default_policy.name() == "power_of_two" { let load_monitor_handle = if default_policy.name() == "power_of_two" {
let monitor_urls = worker_urls.clone(); let monitor_urls = worker_urls.clone();
let monitor_interval = ctx.router_config.worker_startup_check_interval_secs; let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
......
...@@ -19,8 +19,10 @@ pub mod grpc; ...@@ -19,8 +19,10 @@ pub mod grpc;
pub mod header_utils; pub mod header_utils;
pub mod http; pub mod http;
pub mod router_manager; pub mod router_manager;
pub mod worker_initializer;
pub use factory::RouterFactory; pub use factory::RouterFactory;
pub use worker_initializer::WorkerInitializer;
// Re-export HTTP routers for convenience (keeps routers::openai_router path working) // Re-export HTTP routers for convenience (keeps routers::openai_router path working)
pub use http::{openai_router, pd_router, pd_types, router}; pub use http::{openai_router, pd_router, pd_types, router};
......
// Worker Initialization Module
// Separates worker lifecycle management from router construction
use crate::config::types::{ConnectionMode as ConfigConnectionMode, RouterConfig, RoutingMode};
use crate::core::{
BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, HealthConfig, WorkerRegistry,
WorkerType,
};
use std::sync::Arc;
use std::time::Duration;
use tracing::{info, warn};
/// WorkerInitializer handles the creation and registration of workers
/// based on routing configuration, separating this concern from router constructors
pub struct WorkerInitializer;
impl WorkerInitializer {
/// Initialize workers based on configuration and register them in the WorkerRegistry
pub async fn initialize_workers(
config: &RouterConfig,
worker_registry: &Arc<WorkerRegistry>,
) -> Result<(), String> {
info!("Initializing workers for routing mode: {:?}", config.mode);
match &config.mode {
RoutingMode::Regular { worker_urls } => {
Self::create_regular_workers(
worker_urls,
&config.connection_mode,
config,
worker_registry,
)
.await?;
}
RoutingMode::PrefillDecode {
prefill_urls,
decode_urls,
..
} => {
Self::create_prefill_workers(
prefill_urls,
&config.connection_mode,
config,
worker_registry,
)
.await?;
Self::create_decode_workers(
decode_urls,
&config.connection_mode,
config,
worker_registry,
)
.await?;
}
RoutingMode::OpenAI { .. } => {
info!("OpenAI routing mode - no local workers to initialize");
}
}
// Wait for workers to be healthy if any were registered
if worker_registry.stats().total_workers > 0 {
Self::wait_for_healthy_workers(
worker_registry,
config.worker_startup_timeout_secs,
config.worker_startup_check_interval_secs,
)
.await?;
}
Ok(())
}
/// Create regular workers for standard routing mode
async fn create_regular_workers(
urls: &[String],
config_connection_mode: &ConfigConnectionMode,
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
) -> Result<(), String> {
info!("Creating {} regular workers", urls.len());
// Convert config connection mode to core connection mode
let connection_mode = Self::convert_connection_mode(config_connection_mode, urls.first());
// Convert circuit breaker config
let circuit_breaker_config = config.effective_circuit_breaker_config();
let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold,
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Convert health check config
let health_config = HealthConfig {
timeout_secs: config.health_check.timeout_secs,
check_interval_secs: config.health_check.check_interval_secs,
endpoint: config.health_check.endpoint.clone(),
failure_threshold: config.health_check.failure_threshold,
success_threshold: config.health_check.success_threshold,
};
for url in urls {
// TODO: Add DP-aware support when we have dp_rank/dp_size info
let worker = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Regular)
.connection_mode(connection_mode.clone())
.circuit_breaker_config(core_cb_config.clone())
.health_config(health_config.clone())
.build();
let worker_id = registry.register(Arc::new(worker));
info!("Registered regular worker {} with ID {:?}", url, worker_id);
}
Ok(())
}
/// Create prefill workers for disaggregated routing mode
async fn create_prefill_workers(
prefill_entries: &[(String, Option<u16>)],
config_connection_mode: &ConfigConnectionMode,
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
) -> Result<(), String> {
info!("Creating {} prefill workers", prefill_entries.len());
// Convert config connection mode to core connection mode
let connection_mode = Self::convert_connection_mode(
config_connection_mode,
prefill_entries.first().map(|(url, _)| url),
);
// Convert circuit breaker config
let circuit_breaker_config = config.effective_circuit_breaker_config();
let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold,
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Convert health check config
let health_config = HealthConfig {
timeout_secs: config.health_check.timeout_secs,
check_interval_secs: config.health_check.check_interval_secs,
endpoint: config.health_check.endpoint.clone(),
failure_threshold: config.health_check.failure_threshold,
success_threshold: config.health_check.success_threshold,
};
for (url, bootstrap_port) in prefill_entries {
// TODO: Add DP-aware support when we have dp_rank/dp_size info
let worker = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Prefill {
bootstrap_port: *bootstrap_port,
})
.connection_mode(connection_mode.clone())
.circuit_breaker_config(core_cb_config.clone())
.health_config(health_config.clone())
.build();
let worker_id = registry.register(Arc::new(worker));
info!("Registered prefill worker {} with ID {:?}", url, worker_id);
}
Ok(())
}
/// Create decode workers for disaggregated routing mode
async fn create_decode_workers(
urls: &[String],
config_connection_mode: &ConfigConnectionMode,
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
) -> Result<(), String> {
info!("Creating {} decode workers", urls.len());
// Convert config connection mode to core connection mode
let connection_mode = Self::convert_connection_mode(config_connection_mode, urls.first());
// Convert circuit breaker config
let circuit_breaker_config = config.effective_circuit_breaker_config();
let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold,
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Convert health check config
let health_config = HealthConfig {
timeout_secs: config.health_check.timeout_secs,
check_interval_secs: config.health_check.check_interval_secs,
endpoint: config.health_check.endpoint.clone(),
failure_threshold: config.health_check.failure_threshold,
success_threshold: config.health_check.success_threshold,
};
for url in urls {
// TODO: Add DP-aware support when we have dp_rank/dp_size info
let worker = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Decode)
.connection_mode(connection_mode.clone())
.circuit_breaker_config(core_cb_config.clone())
.health_config(health_config.clone())
.build();
let worker_id = registry.register(Arc::new(worker));
info!("Registered decode worker {} with ID {:?}", url, worker_id);
}
Ok(())
}
/// Convert config connection mode to core connection mode
fn convert_connection_mode(
config_mode: &ConfigConnectionMode,
_sample_url: Option<&String>,
) -> ConnectionMode {
match config_mode {
ConfigConnectionMode::Http => ConnectionMode::Http,
ConfigConnectionMode::Grpc => ConnectionMode::Grpc { port: None },
}
}
/// Wait for workers to become healthy
async fn wait_for_healthy_workers(
registry: &Arc<WorkerRegistry>,
timeout_secs: u64,
check_interval_secs: u64,
) -> Result<(), String> {
let timeout = Duration::from_secs(timeout_secs);
let check_interval = Duration::from_secs(check_interval_secs);
let start_time = std::time::Instant::now();
info!(
"Waiting for workers to become healthy (timeout: {}s)",
timeout_secs
);
loop {
let stats = registry.stats();
if stats.healthy_workers > 0 {
info!(
"Workers healthy: {}/{} workers are ready",
stats.healthy_workers, stats.total_workers
);
// If we have at least one healthy worker, we can proceed
// This allows partial degradation rather than total failure
return Ok(());
}
if start_time.elapsed() > timeout {
let error_msg = format!(
"Timeout waiting for workers to become healthy after {}s. Total workers: {}, Healthy: {}",
timeout_secs, stats.total_workers, stats.healthy_workers
);
warn!("{}", error_msg);
// If we have workers but none are healthy, it's still a failure
if stats.total_workers > 0 {
return Err(error_msg);
} else {
// No workers at all might be OK for some configurations
warn!("No workers registered, proceeding anyway");
return Ok(());
}
}
tokio::time::sleep(check_interval).await;
}
}
/// Initialize workers for gRPC connections specifically
/// This is used when gRPC clients are pre-connected
pub async fn initialize_grpc_workers(
worker_urls: &[String],
worker_type: WorkerType,
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
grpc_clients: &mut std::collections::HashMap<String, crate::grpc::SglangSchedulerClient>,
) -> Result<(), String> {
info!(
"Creating {} gRPC workers of type {:?}",
worker_urls.len(),
worker_type
);
// Convert circuit breaker config
let circuit_breaker_config = config.effective_circuit_breaker_config();
let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold,
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Convert health check config
let health_config = HealthConfig {
timeout_secs: config.health_check.timeout_secs,
check_interval_secs: config.health_check.check_interval_secs,
endpoint: config.health_check.endpoint.clone(),
failure_threshold: config.health_check.failure_threshold,
success_threshold: config.health_check.success_threshold,
};
for url in worker_urls {
if let Some(client) = grpc_clients.remove(url) {
let worker = BasicWorkerBuilder::new(url.clone())
.worker_type(worker_type.clone())
.connection_mode(ConnectionMode::Grpc { port: None })
.circuit_breaker_config(core_cb_config.clone())
.health_config(health_config.clone())
.grpc_client(client)
.build();
let worker_id = registry.register(Arc::new(worker));
info!("Registered gRPC worker {} with ID {:?}", url, worker_id);
} else {
warn!("No gRPC client available for worker {}, skipping", url);
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_convert_connection_mode() {
// HTTP mode
assert!(matches!(
WorkerInitializer::convert_connection_mode(
&ConfigConnectionMode::Http,
Some(&"http://localhost:8080".to_string())
),
ConnectionMode::Http
));
// gRPC mode
assert!(matches!(
WorkerInitializer::convert_connection_mode(
&ConfigConnectionMode::Grpc,
Some(&"grpc://localhost:50051".to_string())
),
ConnectionMode::Grpc { .. }
));
// No URL provided
assert!(matches!(
WorkerInitializer::convert_connection_mode(&ConfigConnectionMode::Http, None),
ConnectionMode::Http
));
}
}
...@@ -14,6 +14,7 @@ use crate::{ ...@@ -14,6 +14,7 @@ use crate::{
worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse}, worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse},
}, },
reasoning_parser::ParserFactory, reasoning_parser::ParserFactory,
routers::WorkerInitializer,
routers::{ routers::{
router_manager::{RouterId, RouterManager}, router_manager::{RouterId, RouterManager},
RouterFactory, RouterTrait, RouterFactory, RouterTrait,
...@@ -594,6 +595,22 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -594,6 +595,22 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
let app_context = Arc::new(app_context); let app_context = Arc::new(app_context);
// Initialize workers before creating routers
// This separates worker lifecycle from router lifecycle
info!(
"Initializing workers for routing mode: {:?}",
config.router_config.mode
);
WorkerInitializer::initialize_workers(&config.router_config, &app_context.worker_registry)
.await
.map_err(|e| format!("Failed to initialize workers: {}", e))?;
let worker_stats = app_context.worker_registry.stats();
info!(
"Workers initialized: {} total, {} healthy",
worker_stats.total_workers, worker_stats.healthy_workers
);
// Create the appropriate router based on enable_igw flag // Create the appropriate router based on enable_igw flag
let (router, router_manager): (Arc<dyn RouterTrait>, Option<Arc<RouterManager>>) = let (router, router_manager): (Arc<dyn RouterTrait>, Option<Arc<RouterManager>>) =
if config.router_config.enable_igw { if config.router_config.enable_igw {
...@@ -608,12 +625,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -608,12 +625,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
)); ));
// 1. HTTP Regular Router // 1. HTTP Regular Router
match RouterFactory::create_regular_router( match RouterFactory::create_regular_router(&app_context).await {
&[], // Empty worker list - workers added later
&app_context,
)
.await
{
Ok(http_regular) => { Ok(http_regular) => {
info!("Created HTTP Regular router"); info!("Created HTTP Regular router");
router_manager.register_router( router_manager.register_router(
...@@ -628,8 +640,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -628,8 +640,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
// 2. HTTP PD Router // 2. HTTP PD Router
match RouterFactory::create_pd_router( match RouterFactory::create_pd_router(
&[],
&[],
None, None,
None, None,
&config.router_config.policy, &config.router_config.policy,
...@@ -684,7 +694,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -684,7 +694,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
// Start queue processor if enabled // Start queue processor if enabled
if let Some(processor) = processor { if let Some(processor) = processor {
tokio::spawn(processor.run()); spawn(processor.run());
info!( info!(
"Started request queue with size: {}, timeout: {}s", "Started request queue with size: {}, timeout: {}s",
config.router_config.queue_size, config.router_config.queue_timeout_secs config.router_config.queue_size, config.router_config.queue_timeout_secs
......
...@@ -606,7 +606,7 @@ mod tests { ...@@ -606,7 +606,7 @@ mod tests {
response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()), response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()),
}); });
let router = Router::new(vec![], &app_context).await.unwrap(); let router = Router::new(&app_context).await.unwrap();
Arc::new(router) as Arc<dyn RouterTrait> Arc::new(router) as Arc<dyn RouterTrait>
} }
......
...@@ -101,6 +101,14 @@ impl TestContext { ...@@ -101,6 +101,14 @@ impl TestContext {
// Create app context // Create app context
let app_context = common::create_test_context(config.clone()); let app_context = common::create_test_context(config.clone());
// Initialize workers in the registry before creating router
if !worker_urls.is_empty() {
use sglang_router_rs::routers::WorkerInitializer;
WorkerInitializer::initialize_workers(&config, &app_context.worker_registry)
.await
.expect("Failed to initialize workers");
}
// Create router // Create router
let router = RouterFactory::create_router(&app_context).await.unwrap(); let router = RouterFactory::create_router(&app_context).await.unwrap();
let router = Arc::from(router); let router = Arc::from(router);
......
...@@ -39,9 +39,20 @@ impl TestContext { ...@@ -39,9 +39,20 @@ impl TestContext {
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
} }
config.mode = RoutingMode::Regular { worker_urls }; config.mode = RoutingMode::Regular {
worker_urls: worker_urls.clone(),
};
let app_context = common::create_test_context(config.clone());
// Initialize workers in the registry before creating router
if !worker_urls.is_empty() {
use sglang_router_rs::routers::WorkerInitializer;
WorkerInitializer::initialize_workers(&config, &app_context.worker_registry)
.await
.expect("Failed to initialize workers");
}
let app_context = common::create_test_context(config);
let router = RouterFactory::create_router(&app_context).await.unwrap(); let router = RouterFactory::create_router(&app_context).await.unwrap();
let router = Arc::from(router); let router = Arc::from(router);
......
...@@ -40,9 +40,20 @@ impl TestContext { ...@@ -40,9 +40,20 @@ impl TestContext {
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
} }
config.mode = RoutingMode::Regular { worker_urls }; config.mode = RoutingMode::Regular {
worker_urls: worker_urls.clone(),
};
let app_context = common::create_test_context(config.clone());
// Initialize workers in the registry before creating router
if !worker_urls.is_empty() {
use sglang_router_rs::routers::WorkerInitializer;
WorkerInitializer::initialize_workers(&config, &app_context.worker_registry)
.await
.expect("Failed to initialize workers");
}
let app_context = common::create_test_context(config);
let router = RouterFactory::create_router(&app_context).await.unwrap(); let router = RouterFactory::create_router(&app_context).await.unwrap();
let router = Arc::from(router); let router = Arc::from(router);
......
...@@ -207,19 +207,21 @@ mod test_pd_routing { ...@@ -207,19 +207,21 @@ mod test_pd_routing {
history_backend: sglang_router_rs::config::HistoryBackend::Memory, history_backend: sglang_router_rs::config::HistoryBackend::Memory,
}; };
// Router creation will fail due to health checks, but config should be valid
let app_context = let app_context =
sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64, None) sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64, None)
.expect("Failed to create AppContext"); .expect("Failed to create AppContext");
let app_context = std::sync::Arc::new(app_context); let app_context = std::sync::Arc::new(app_context);
let result = RouterFactory::create_router(&app_context).await; let result = RouterFactory::create_router(&app_context).await;
assert!(result.is_err());
let error_msg = result.unwrap_err();
// Error should be about health/timeout, not configuration
assert!( assert!(
error_msg.contains("healthy") || error_msg.contains("timeout"), result.is_ok(),
"Unexpected error: {}", "Router creation should succeed with empty worker"
error_msg );
// Verify that no workers are registered since we didn't initialize them
let stats = app_context.worker_registry.stats();
assert_eq!(
stats.total_workers, 0,
"No workers should be registered without initialization"
); );
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment