Unverified Commit 00eb5eb7 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] refactor router and worker management 2/n (#10666)

parent dab4663b
......@@ -47,18 +47,17 @@ impl RouterFactory {
ConnectionMode::Http => {
// Route to HTTP implementation based on routing mode
match &ctx.router_config.mode {
RoutingMode::Regular { worker_urls } => {
Self::create_regular_router(worker_urls, ctx).await
RoutingMode::Regular { .. } => {
// Workers already initialized in registry
Self::create_regular_router(ctx).await
}
RoutingMode::PrefillDecode {
prefill_urls,
decode_urls,
prefill_policy,
decode_policy,
..
} => {
// Workers already initialized in registry
Self::create_pd_router(
prefill_urls,
decode_urls,
prefill_policy.as_ref(),
decode_policy.as_ref(),
&ctx.router_config.policy,
......@@ -76,19 +75,17 @@ impl RouterFactory {
/// Create a regular router
pub async fn create_regular_router(
worker_urls: &[String],
ctx: &Arc<AppContext>,
) -> Result<Box<dyn RouterTrait>, String> {
// Create regular router with context
let router = Router::new(worker_urls.to_vec(), ctx).await?;
// Workers should already be initialized in the registry
let router = Router::new(ctx).await?;
Ok(Box::new(router))
}
/// Create a PD router with injected policy
pub async fn create_pd_router(
prefill_urls: &[(String, Option<u16>)],
decode_urls: &[String],
prefill_policy_config: Option<&PolicyConfig>,
decode_policy_config: Option<&PolicyConfig>,
main_policy_config: &PolicyConfig,
......@@ -105,7 +102,8 @@ impl RouterFactory {
ctx.policy_registry.set_decode_policy(decode_policy);
// Create PD router with context (policies are in PolicyRegistry)
let router = PDRouter::new(prefill_urls.to_vec(), decode_urls.to_vec(), ctx).await?;
// Workers should already be initialized in the registry
let router = PDRouter::new(ctx).await?;
Ok(Box::new(router))
}
......
......@@ -3,7 +3,7 @@
use super::pd_types::{api_path, PDRouterError};
use crate::config::types::RetryConfig;
use crate::core::{
is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, RetryExecutor,
is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, RetryExecutor,
Worker, WorkerLoadGuard, WorkerRegistry, WorkerType,
};
use crate::metrics::RouterMetrics;
......@@ -371,12 +371,30 @@ impl PDRouter {
}
}
#[allow(clippy::too_many_arguments)]
pub async fn new(
prefill_urls: Vec<(String, Option<u16>)>,
decode_urls: Vec<String>,
ctx: &Arc<crate::server::AppContext>,
) -> Result<Self, String> {
pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
let prefill_workers = ctx.worker_registry.get_workers_filtered(
None, // any model
Some(WorkerType::Prefill {
bootstrap_port: None,
}),
Some(ConnectionMode::Http),
false, // include all workers
);
let decode_workers = ctx.worker_registry.get_workers_filtered(
None, // any model
Some(WorkerType::Decode),
Some(ConnectionMode::Http),
false, // include all workers
);
// Get all worker URLs for monitoring
let all_urls: Vec<String> = prefill_workers
.iter()
.chain(decode_workers.iter())
.map(|w| w.url().to_string())
.collect();
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
let core_cb_config = CircuitBreakerConfig {
......@@ -386,60 +404,6 @@ impl PDRouter {
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Register prefill workers in the registry
for (url, port) in prefill_urls {
let worker = BasicWorkerBuilder::new(url)
.worker_type(WorkerType::Prefill {
bootstrap_port: port,
})
.circuit_breaker_config(core_cb_config.clone())
.health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
})
.build();
ctx.worker_registry.register(Arc::new(worker));
}
// Register decode workers in the registry
for url in decode_urls {
let worker = BasicWorkerBuilder::new(url)
.worker_type(WorkerType::Decode)
.circuit_breaker_config(core_cb_config.clone())
.health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
})
.build();
ctx.worker_registry.register(Arc::new(worker));
}
// Get all workers from registry for health check
let all_workers = ctx.worker_registry.get_all();
let all_urls: Vec<String> = all_workers
.iter()
.map(|worker| worker.url().to_string())
.collect();
if !all_urls.is_empty() {
crate::routers::http::router::Router::wait_for_healthy_workers(
&all_urls,
ctx.router_config.worker_startup_timeout_secs,
ctx.router_config.worker_startup_check_interval_secs,
)
.await?;
}
// Initialize cache-aware policies with workers from registry
// Note: We need to get workers by type and convert to Box<dyn Worker> for CacheAwarePolicy
// This is a temporary workaround until CacheAwarePolicy is updated to work with Arc<dyn Worker>
// TODO: Update CacheAwarePolicy to accept Arc<dyn Worker> instead of Box<dyn Worker>
// Set up background load monitoring for power-of-two selection
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
let worker_loads = Arc::new(rx);
......@@ -471,11 +435,8 @@ impl PDRouter {
None
};
// Note: Health checking is now handled centrally by RouterManager
// Individual routers no longer need to manage health checkers
// Build a dedicated prefill client for fire-and-forget semantics
let prefill_client = reqwest::Client::builder()
let prefill_client = Client::builder()
.pool_max_idle_per_host(0)
.http1_only()
.connect_timeout(Duration::from_millis(300))
......@@ -489,6 +450,7 @@ impl PDRouter {
// Spawn a coordinator with limited concurrent drain tasks
// This prevents unbounded task spawning under extreme load
// TODO reevaluate a simpler approach (e.g. do we really need to deal with fire and forget)
tokio::spawn(async move {
info!("Prefill drain coordinator started");
......@@ -513,7 +475,7 @@ impl PDRouter {
// Drain the response body efficiently
// Use streaming to avoid loading entire body into memory
let start = std::time::Instant::now();
let start = Instant::now();
let mut stream = response.bytes_stream();
let mut bytes_drained = 0;
......
use crate::config::types::RetryConfig;
use crate::core::{
is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, RetryExecutor,
is_retryable_status, BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, RetryExecutor,
Worker, WorkerRegistry, WorkerType,
};
use crate::metrics::RouterMetrics;
......@@ -47,31 +47,19 @@ pub struct Router {
impl Router {
/// Create a new router with injected policy and client
#[allow(clippy::too_many_arguments)]
pub async fn new(
worker_urls: Vec<String>,
ctx: &Arc<crate::server::AppContext>,
) -> Result<Self, String> {
pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
let workers = ctx.worker_registry.get_workers_filtered(
None, // any model
Some(WorkerType::Regular),
Some(ConnectionMode::Http),
false, // include all workers
);
// Update active workers gauge
RouterMetrics::set_active_workers(worker_urls.len());
// Wait for workers to be healthy (skip if empty - for service discovery mode)
if !worker_urls.is_empty() {
Self::wait_for_healthy_workers(
&worker_urls,
ctx.router_config.worker_startup_timeout_secs,
ctx.router_config.worker_startup_check_interval_secs,
)
.await?;
}
RouterMetrics::set_active_workers(workers.len());
let worker_urls = if ctx.router_config.dp_aware {
// worker address now in the format of "http://host:port@dp_rank"
Self::get_dp_aware_workers(&worker_urls, &ctx.router_config.api_key)
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?
} else {
worker_urls
};
// Get worker URLs for monitoring
let worker_urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
......@@ -82,40 +70,14 @@ impl Router {
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Register workers in the registry
// In IGW mode, we need to fetch model info from workers
for url in &worker_urls {
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
// For now, create worker without model_id
let worker = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Regular)
.circuit_breaker_config(core_cb_config.clone())
.health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
})
.build();
let worker_arc = Arc::new(worker);
ctx.worker_registry.register(worker_arc.clone());
// Notify PolicyRegistry about the new worker
let model_id = worker_arc.model_id();
let policy = ctx.policy_registry.on_worker_added(model_id, None);
// If this is a cache-aware policy and it's the first worker for this model,
// initialize it with the worker
if policy.name() == "cache_aware" {
if let Some(cache_aware) = policy
// Initialize cache-aware policy with workers if needed
let default_policy = ctx.policy_registry.get_default_policy();
if default_policy.name() == "cache_aware" {
if let Some(cache_aware) = default_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
let worker_dyn: Arc<dyn Worker> = worker_arc.clone();
cache_aware.init_workers(std::slice::from_ref(&worker_dyn));
}
cache_aware.init_workers(&workers);
}
}
......@@ -124,7 +86,6 @@ impl Router {
let worker_loads = Arc::new(rx);
// Check if default policy is power_of_two for load monitoring
let default_policy = ctx.policy_registry.get_default_policy();
let load_monitor_handle = if default_policy.name() == "power_of_two" {
let monitor_urls = worker_urls.clone();
let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
......
......@@ -19,8 +19,10 @@ pub mod grpc;
pub mod header_utils;
pub mod http;
pub mod router_manager;
pub mod worker_initializer;
pub use factory::RouterFactory;
pub use worker_initializer::WorkerInitializer;
// Re-export HTTP routers for convenience (keeps routers::openai_router path working)
pub use http::{openai_router, pd_router, pd_types, router};
......
// Worker Initialization Module
// Separates worker lifecycle management from router construction
use crate::config::types::{ConnectionMode as ConfigConnectionMode, RouterConfig, RoutingMode};
use crate::core::{
BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, HealthConfig, WorkerRegistry,
WorkerType,
};
use std::sync::Arc;
use std::time::Duration;
use tracing::{info, warn};
/// WorkerInitializer handles the creation and registration of workers
/// based on routing configuration, separating this concern from router constructors
pub struct WorkerInitializer;
impl WorkerInitializer {
/// Initialize workers based on configuration and register them in the WorkerRegistry
pub async fn initialize_workers(
config: &RouterConfig,
worker_registry: &Arc<WorkerRegistry>,
) -> Result<(), String> {
info!("Initializing workers for routing mode: {:?}", config.mode);
match &config.mode {
RoutingMode::Regular { worker_urls } => {
Self::create_regular_workers(
worker_urls,
&config.connection_mode,
config,
worker_registry,
)
.await?;
}
RoutingMode::PrefillDecode {
prefill_urls,
decode_urls,
..
} => {
Self::create_prefill_workers(
prefill_urls,
&config.connection_mode,
config,
worker_registry,
)
.await?;
Self::create_decode_workers(
decode_urls,
&config.connection_mode,
config,
worker_registry,
)
.await?;
}
RoutingMode::OpenAI { .. } => {
info!("OpenAI routing mode - no local workers to initialize");
}
}
// Wait for workers to be healthy if any were registered
if worker_registry.stats().total_workers > 0 {
Self::wait_for_healthy_workers(
worker_registry,
config.worker_startup_timeout_secs,
config.worker_startup_check_interval_secs,
)
.await?;
}
Ok(())
}
/// Create regular workers for standard routing mode
async fn create_regular_workers(
urls: &[String],
config_connection_mode: &ConfigConnectionMode,
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
) -> Result<(), String> {
info!("Creating {} regular workers", urls.len());
// Convert config connection mode to core connection mode
let connection_mode = Self::convert_connection_mode(config_connection_mode, urls.first());
// Convert circuit breaker config
let circuit_breaker_config = config.effective_circuit_breaker_config();
let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold,
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Convert health check config
let health_config = HealthConfig {
timeout_secs: config.health_check.timeout_secs,
check_interval_secs: config.health_check.check_interval_secs,
endpoint: config.health_check.endpoint.clone(),
failure_threshold: config.health_check.failure_threshold,
success_threshold: config.health_check.success_threshold,
};
for url in urls {
// TODO: Add DP-aware support when we have dp_rank/dp_size info
let worker = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Regular)
.connection_mode(connection_mode.clone())
.circuit_breaker_config(core_cb_config.clone())
.health_config(health_config.clone())
.build();
let worker_id = registry.register(Arc::new(worker));
info!("Registered regular worker {} with ID {:?}", url, worker_id);
}
Ok(())
}
/// Create prefill workers for disaggregated routing mode
async fn create_prefill_workers(
prefill_entries: &[(String, Option<u16>)],
config_connection_mode: &ConfigConnectionMode,
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
) -> Result<(), String> {
info!("Creating {} prefill workers", prefill_entries.len());
// Convert config connection mode to core connection mode
let connection_mode = Self::convert_connection_mode(
config_connection_mode,
prefill_entries.first().map(|(url, _)| url),
);
// Convert circuit breaker config
let circuit_breaker_config = config.effective_circuit_breaker_config();
let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold,
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Convert health check config
let health_config = HealthConfig {
timeout_secs: config.health_check.timeout_secs,
check_interval_secs: config.health_check.check_interval_secs,
endpoint: config.health_check.endpoint.clone(),
failure_threshold: config.health_check.failure_threshold,
success_threshold: config.health_check.success_threshold,
};
for (url, bootstrap_port) in prefill_entries {
// TODO: Add DP-aware support when we have dp_rank/dp_size info
let worker = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Prefill {
bootstrap_port: *bootstrap_port,
})
.connection_mode(connection_mode.clone())
.circuit_breaker_config(core_cb_config.clone())
.health_config(health_config.clone())
.build();
let worker_id = registry.register(Arc::new(worker));
info!("Registered prefill worker {} with ID {:?}", url, worker_id);
}
Ok(())
}
/// Create decode workers for disaggregated routing mode
async fn create_decode_workers(
urls: &[String],
config_connection_mode: &ConfigConnectionMode,
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
) -> Result<(), String> {
info!("Creating {} decode workers", urls.len());
// Convert config connection mode to core connection mode
let connection_mode = Self::convert_connection_mode(config_connection_mode, urls.first());
// Convert circuit breaker config
let circuit_breaker_config = config.effective_circuit_breaker_config();
let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold,
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Convert health check config
let health_config = HealthConfig {
timeout_secs: config.health_check.timeout_secs,
check_interval_secs: config.health_check.check_interval_secs,
endpoint: config.health_check.endpoint.clone(),
failure_threshold: config.health_check.failure_threshold,
success_threshold: config.health_check.success_threshold,
};
for url in urls {
// TODO: Add DP-aware support when we have dp_rank/dp_size info
let worker = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Decode)
.connection_mode(connection_mode.clone())
.circuit_breaker_config(core_cb_config.clone())
.health_config(health_config.clone())
.build();
let worker_id = registry.register(Arc::new(worker));
info!("Registered decode worker {} with ID {:?}", url, worker_id);
}
Ok(())
}
/// Convert config connection mode to core connection mode
fn convert_connection_mode(
config_mode: &ConfigConnectionMode,
_sample_url: Option<&String>,
) -> ConnectionMode {
match config_mode {
ConfigConnectionMode::Http => ConnectionMode::Http,
ConfigConnectionMode::Grpc => ConnectionMode::Grpc { port: None },
}
}
/// Wait for workers to become healthy
async fn wait_for_healthy_workers(
registry: &Arc<WorkerRegistry>,
timeout_secs: u64,
check_interval_secs: u64,
) -> Result<(), String> {
let timeout = Duration::from_secs(timeout_secs);
let check_interval = Duration::from_secs(check_interval_secs);
let start_time = std::time::Instant::now();
info!(
"Waiting for workers to become healthy (timeout: {}s)",
timeout_secs
);
loop {
let stats = registry.stats();
if stats.healthy_workers > 0 {
info!(
"Workers healthy: {}/{} workers are ready",
stats.healthy_workers, stats.total_workers
);
// If we have at least one healthy worker, we can proceed
// This allows partial degradation rather than total failure
return Ok(());
}
if start_time.elapsed() > timeout {
let error_msg = format!(
"Timeout waiting for workers to become healthy after {}s. Total workers: {}, Healthy: {}",
timeout_secs, stats.total_workers, stats.healthy_workers
);
warn!("{}", error_msg);
// If we have workers but none are healthy, it's still a failure
if stats.total_workers > 0 {
return Err(error_msg);
} else {
// No workers at all might be OK for some configurations
warn!("No workers registered, proceeding anyway");
return Ok(());
}
}
tokio::time::sleep(check_interval).await;
}
}
/// Initialize workers for gRPC connections specifically
/// This is used when gRPC clients are pre-connected
pub async fn initialize_grpc_workers(
worker_urls: &[String],
worker_type: WorkerType,
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
grpc_clients: &mut std::collections::HashMap<String, crate::grpc::SglangSchedulerClient>,
) -> Result<(), String> {
info!(
"Creating {} gRPC workers of type {:?}",
worker_urls.len(),
worker_type
);
// Convert circuit breaker config
let circuit_breaker_config = config.effective_circuit_breaker_config();
let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold,
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Convert health check config
let health_config = HealthConfig {
timeout_secs: config.health_check.timeout_secs,
check_interval_secs: config.health_check.check_interval_secs,
endpoint: config.health_check.endpoint.clone(),
failure_threshold: config.health_check.failure_threshold,
success_threshold: config.health_check.success_threshold,
};
for url in worker_urls {
if let Some(client) = grpc_clients.remove(url) {
let worker = BasicWorkerBuilder::new(url.clone())
.worker_type(worker_type.clone())
.connection_mode(ConnectionMode::Grpc { port: None })
.circuit_breaker_config(core_cb_config.clone())
.health_config(health_config.clone())
.grpc_client(client)
.build();
let worker_id = registry.register(Arc::new(worker));
info!("Registered gRPC worker {} with ID {:?}", url, worker_id);
} else {
warn!("No gRPC client available for worker {}, skipping", url);
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_convert_connection_mode() {
// HTTP mode
assert!(matches!(
WorkerInitializer::convert_connection_mode(
&ConfigConnectionMode::Http,
Some(&"http://localhost:8080".to_string())
),
ConnectionMode::Http
));
// gRPC mode
assert!(matches!(
WorkerInitializer::convert_connection_mode(
&ConfigConnectionMode::Grpc,
Some(&"grpc://localhost:50051".to_string())
),
ConnectionMode::Grpc { .. }
));
// No URL provided
assert!(matches!(
WorkerInitializer::convert_connection_mode(&ConfigConnectionMode::Http, None),
ConnectionMode::Http
));
}
}
......@@ -14,6 +14,7 @@ use crate::{
worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse},
},
reasoning_parser::ParserFactory,
routers::WorkerInitializer,
routers::{
router_manager::{RouterId, RouterManager},
RouterFactory, RouterTrait,
......@@ -594,6 +595,22 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
let app_context = Arc::new(app_context);
// Initialize workers before creating routers
// This separates worker lifecycle from router lifecycle
info!(
"Initializing workers for routing mode: {:?}",
config.router_config.mode
);
WorkerInitializer::initialize_workers(&config.router_config, &app_context.worker_registry)
.await
.map_err(|e| format!("Failed to initialize workers: {}", e))?;
let worker_stats = app_context.worker_registry.stats();
info!(
"Workers initialized: {} total, {} healthy",
worker_stats.total_workers, worker_stats.healthy_workers
);
// Create the appropriate router based on enable_igw flag
let (router, router_manager): (Arc<dyn RouterTrait>, Option<Arc<RouterManager>>) =
if config.router_config.enable_igw {
......@@ -608,12 +625,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
));
// 1. HTTP Regular Router
match RouterFactory::create_regular_router(
&[], // Empty worker list - workers added later
&app_context,
)
.await
{
match RouterFactory::create_regular_router(&app_context).await {
Ok(http_regular) => {
info!("Created HTTP Regular router");
router_manager.register_router(
......@@ -628,8 +640,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
// 2. HTTP PD Router
match RouterFactory::create_pd_router(
&[],
&[],
None,
None,
&config.router_config.policy,
......@@ -684,7 +694,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
// Start queue processor if enabled
if let Some(processor) = processor {
tokio::spawn(processor.run());
spawn(processor.run());
info!(
"Started request queue with size: {}, timeout: {}s",
config.router_config.queue_size, config.router_config.queue_timeout_secs
......
......@@ -606,7 +606,7 @@ mod tests {
response_storage: Arc::new(crate::data_connector::MemoryResponseStorage::new()),
});
let router = Router::new(vec![], &app_context).await.unwrap();
let router = Router::new(&app_context).await.unwrap();
Arc::new(router) as Arc<dyn RouterTrait>
}
......
......@@ -101,6 +101,14 @@ impl TestContext {
// Create app context
let app_context = common::create_test_context(config.clone());
// Initialize workers in the registry before creating router
if !worker_urls.is_empty() {
use sglang_router_rs::routers::WorkerInitializer;
WorkerInitializer::initialize_workers(&config, &app_context.worker_registry)
.await
.expect("Failed to initialize workers");
}
// Create router
let router = RouterFactory::create_router(&app_context).await.unwrap();
let router = Arc::from(router);
......
......@@ -39,9 +39,20 @@ impl TestContext {
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
}
config.mode = RoutingMode::Regular { worker_urls };
config.mode = RoutingMode::Regular {
worker_urls: worker_urls.clone(),
};
let app_context = common::create_test_context(config.clone());
// Initialize workers in the registry before creating router
if !worker_urls.is_empty() {
use sglang_router_rs::routers::WorkerInitializer;
WorkerInitializer::initialize_workers(&config, &app_context.worker_registry)
.await
.expect("Failed to initialize workers");
}
let app_context = common::create_test_context(config);
let router = RouterFactory::create_router(&app_context).await.unwrap();
let router = Arc::from(router);
......
......@@ -40,9 +40,20 @@ impl TestContext {
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
}
config.mode = RoutingMode::Regular { worker_urls };
config.mode = RoutingMode::Regular {
worker_urls: worker_urls.clone(),
};
let app_context = common::create_test_context(config.clone());
// Initialize workers in the registry before creating router
if !worker_urls.is_empty() {
use sglang_router_rs::routers::WorkerInitializer;
WorkerInitializer::initialize_workers(&config, &app_context.worker_registry)
.await
.expect("Failed to initialize workers");
}
let app_context = common::create_test_context(config);
let router = RouterFactory::create_router(&app_context).await.unwrap();
let router = Arc::from(router);
......
......@@ -207,19 +207,21 @@ mod test_pd_routing {
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
};
// Router creation will fail due to health checks, but config should be valid
let app_context =
sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64, None)
.expect("Failed to create AppContext");
let app_context = std::sync::Arc::new(app_context);
let result = RouterFactory::create_router(&app_context).await;
assert!(result.is_err());
let error_msg = result.unwrap_err();
// Error should be about health/timeout, not configuration
assert!(
error_msg.contains("healthy") || error_msg.contains("timeout"),
"Unexpected error: {}",
error_msg
result.is_ok(),
"Router creation should succeed with empty worker"
);
// Verify that no workers are registered since we didn't initialize them
let stats = app_context.worker_registry.stats();
assert_eq!(
stats.total_workers, 0,
"No workers should be registered without initialization"
);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment