Unverified Commit 4f8a982d authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] clean up dependency injector to use ctx (#10000)

parent d966b902
...@@ -83,20 +83,8 @@ impl RouterFactory { ...@@ -83,20 +83,8 @@ impl RouterFactory {
// Create policy // Create policy
let policy = PolicyFactory::create_from_config(policy_config); let policy = PolicyFactory::create_from_config(policy_config);
// Create regular router with injected policy and client // Create regular router with injected policy and context
let router = Router::new( let router = Router::new(worker_urls.to_vec(), policy, ctx).await?;
worker_urls.to_vec(),
policy,
ctx.client.clone(),
ctx.router_config.worker_startup_timeout_secs,
ctx.router_config.worker_startup_check_interval_secs,
ctx.router_config.dp_aware,
ctx.router_config.api_key.clone(),
ctx.router_config.retry.clone(),
ctx.router_config.circuit_breaker.clone(),
ctx.router_config.health_check.clone(),
)
.await?;
Ok(Box::new(router)) Ok(Box::new(router))
} }
...@@ -116,19 +104,13 @@ impl RouterFactory { ...@@ -116,19 +104,13 @@ impl RouterFactory {
let decode_policy = let decode_policy =
PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config)); PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
// Create PD router with separate policies and client // Create PD router with separate policies and context
let router = PDRouter::new( let router = PDRouter::new(
prefill_urls.to_vec(), prefill_urls.to_vec(),
decode_urls.to_vec(), decode_urls.to_vec(),
prefill_policy, prefill_policy,
decode_policy, decode_policy,
ctx.client.clone(), ctx,
ctx.router_config.request_timeout_secs,
ctx.router_config.worker_startup_timeout_secs,
ctx.router_config.worker_startup_check_interval_secs,
ctx.router_config.retry.clone(),
ctx.router_config.circuit_breaker.clone(),
ctx.router_config.health_check.clone(),
) )
.await?; .await?;
...@@ -146,46 +128,8 @@ impl RouterFactory { ...@@ -146,46 +128,8 @@ impl RouterFactory {
// Create policy // Create policy
let policy = PolicyFactory::create_from_config(policy_config); let policy = PolicyFactory::create_from_config(policy_config);
// Get tokenizer from context // Create gRPC router with context
let tokenizer = ctx let router = GrpcRouter::new(worker_urls.to_vec(), policy, ctx).await?;
.tokenizer
.as_ref()
.ok_or_else(|| {
"gRPC router requires tokenizer to be initialized in AppContext".to_string()
})?
.clone();
// Get reasoning parser factory from context
let reasoning_parser_factory = ctx
.reasoning_parser_factory
.as_ref()
.ok_or_else(|| {
"gRPC router requires reasoning parser factory to be initialized in AppContext"
.to_string()
})?
.clone();
// Get tool parser registry from context
let tool_parser_registry = ctx.tool_parser_registry.ok_or_else(|| {
"gRPC router requires tool parser registry to be initialized in AppContext".to_string()
})?;
// Create gRPC router
let router = GrpcRouter::new(
worker_urls.to_vec(),
policy,
ctx.router_config.worker_startup_timeout_secs,
ctx.router_config.worker_startup_check_interval_secs,
ctx.router_config.dp_aware,
ctx.router_config.api_key.clone(),
ctx.router_config.effective_retry_config(),
ctx.router_config.effective_circuit_breaker_config(),
ctx.router_config.health_check.clone(),
tokenizer,
reasoning_parser_factory,
tool_parser_registry,
)
.await?;
Ok(Box::new(router)) Ok(Box::new(router))
} }
...@@ -207,47 +151,13 @@ impl RouterFactory { ...@@ -207,47 +151,13 @@ impl RouterFactory {
let decode_policy = let decode_policy =
PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config)); PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
// Get tokenizer from context // Create gRPC PD router with context
let tokenizer = ctx
.tokenizer
.as_ref()
.ok_or_else(|| {
"gRPC PD router requires tokenizer to be initialized in AppContext".to_string()
})?
.clone();
// Get reasoning parser factory from context
let reasoning_parser_factory = ctx
.reasoning_parser_factory
.as_ref()
.ok_or_else(|| {
"gRPC PD router requires reasoning parser factory to be initialized in AppContext"
.to_string()
})?
.clone();
// Get tool parser registry from context
let tool_parser_registry = ctx.tool_parser_registry.ok_or_else(|| {
"gRPC PD router requires tool parser registry to be initialized in AppContext"
.to_string()
})?;
// Create gRPC PD router
let router = GrpcPDRouter::new( let router = GrpcPDRouter::new(
prefill_urls.to_vec(), prefill_urls.to_vec(),
decode_urls.to_vec(), decode_urls.to_vec(),
prefill_policy, prefill_policy,
decode_policy, decode_policy,
ctx.router_config.worker_startup_timeout_secs, ctx,
ctx.router_config.worker_startup_check_interval_secs,
ctx.router_config.dp_aware,
ctx.router_config.api_key.clone(),
ctx.router_config.effective_retry_config(),
ctx.router_config.effective_circuit_breaker_config(),
ctx.router_config.health_check.clone(),
tokenizer,
reasoning_parser_factory,
tool_parser_registry,
) )
.await?; .await?;
......
// PD (Prefill-Decode) gRPC Router Implementation // PD (Prefill-Decode) gRPC Router Implementation
use crate::config::types::{ use crate::config::types::RetryConfig;
CircuitBreakerConfig as ConfigCircuitBreakerConfig,
HealthCheckConfig as ConfigHealthCheckConfig, RetryConfig,
};
use crate::core::{ use crate::core::{
BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType, BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType,
}; };
...@@ -61,27 +58,33 @@ pub struct GrpcPDRouter { ...@@ -61,27 +58,33 @@ pub struct GrpcPDRouter {
impl GrpcPDRouter { impl GrpcPDRouter {
/// Create a new gRPC PD router /// Create a new gRPC PD router
#[allow(clippy::too_many_arguments)]
pub async fn new( pub async fn new(
prefill_urls: Vec<(String, Option<u16>)>, prefill_urls: Vec<(String, Option<u16>)>,
decode_urls: Vec<String>, decode_urls: Vec<String>,
prefill_policy: Arc<dyn LoadBalancingPolicy>, prefill_policy: Arc<dyn LoadBalancingPolicy>,
decode_policy: Arc<dyn LoadBalancingPolicy>, decode_policy: Arc<dyn LoadBalancingPolicy>,
timeout_secs: u64, ctx: &Arc<crate::server::AppContext>,
interval_secs: u64,
dp_aware: bool,
api_key: Option<String>,
retry_config: RetryConfig,
circuit_breaker_config: ConfigCircuitBreakerConfig,
health_check_config: ConfigHealthCheckConfig,
tokenizer: Arc<dyn Tokenizer>,
reasoning_parser_factory: ParserFactory,
tool_parser_registry: &'static ParserRegistry,
) -> Result<Self, String> { ) -> Result<Self, String> {
// Update metrics // Update metrics
RouterMetrics::set_active_workers(prefill_urls.len() + decode_urls.len()); RouterMetrics::set_active_workers(prefill_urls.len() + decode_urls.len());
// Extract necessary components from context
let tokenizer = ctx
.tokenizer
.as_ref()
.ok_or_else(|| "gRPC PD router requires tokenizer".to_string())?
.clone();
let reasoning_parser_factory = ctx
.reasoning_parser_factory
.as_ref()
.ok_or_else(|| "gRPC PD router requires reasoning parser factory".to_string())?
.clone();
let tool_parser_registry = ctx
.tool_parser_registry
.ok_or_else(|| "gRPC PD router requires tool parser registry".to_string())?;
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig // Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
let core_cb_config = CircuitBreakerConfig { let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold, failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold, success_threshold: circuit_breaker_config.success_threshold,
...@@ -138,11 +141,11 @@ impl GrpcPDRouter { ...@@ -138,11 +141,11 @@ impl GrpcPDRouter {
) )
.with_circuit_breaker_config(core_cb_config.clone()) .with_circuit_breaker_config(core_cb_config.clone())
.with_health_config(HealthConfig { .with_health_config(HealthConfig {
timeout_secs: health_check_config.timeout_secs, timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: health_check_config.check_interval_secs, check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: health_check_config.endpoint.clone(), endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: health_check_config.failure_threshold, failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: health_check_config.success_threshold, success_threshold: ctx.router_config.health_check.success_threshold,
}); });
Box::new(worker) as Box<dyn Worker> Box::new(worker) as Box<dyn Worker>
}) })
...@@ -159,11 +162,11 @@ impl GrpcPDRouter { ...@@ -159,11 +162,11 @@ impl GrpcPDRouter {
) )
.with_circuit_breaker_config(core_cb_config.clone()) .with_circuit_breaker_config(core_cb_config.clone())
.with_health_config(HealthConfig { .with_health_config(HealthConfig {
timeout_secs: health_check_config.timeout_secs, timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: health_check_config.check_interval_secs, check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: health_check_config.endpoint.clone(), endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: health_check_config.failure_threshold, failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: health_check_config.success_threshold, success_threshold: ctx.router_config.health_check.success_threshold,
}); });
Box::new(worker) as Box<dyn Worker> Box::new(worker) as Box<dyn Worker>
}) })
...@@ -187,10 +190,14 @@ impl GrpcPDRouter { ...@@ -187,10 +190,14 @@ impl GrpcPDRouter {
let prefill_workers = Arc::new(RwLock::new(prefill_workers)); let prefill_workers = Arc::new(RwLock::new(prefill_workers));
let decode_workers = Arc::new(RwLock::new(decode_workers)); let decode_workers = Arc::new(RwLock::new(decode_workers));
let prefill_health_checker = let prefill_health_checker = crate::core::start_health_checker(
crate::core::start_health_checker(Arc::clone(&prefill_workers), interval_secs); Arc::clone(&prefill_workers),
let decode_health_checker = ctx.router_config.worker_startup_check_interval_secs,
crate::core::start_health_checker(Arc::clone(&decode_workers), interval_secs); );
let decode_health_checker = crate::core::start_health_checker(
Arc::clone(&decode_workers),
ctx.router_config.worker_startup_check_interval_secs,
);
Ok(GrpcPDRouter { Ok(GrpcPDRouter {
prefill_workers, prefill_workers,
...@@ -204,11 +211,11 @@ impl GrpcPDRouter { ...@@ -204,11 +211,11 @@ impl GrpcPDRouter {
tool_parser_registry, tool_parser_registry,
_prefill_health_checker: Some(prefill_health_checker), _prefill_health_checker: Some(prefill_health_checker),
_decode_health_checker: Some(decode_health_checker), _decode_health_checker: Some(decode_health_checker),
timeout_secs, timeout_secs: ctx.router_config.worker_startup_timeout_secs,
interval_secs, interval_secs: ctx.router_config.worker_startup_check_interval_secs,
dp_aware, dp_aware: ctx.router_config.dp_aware,
api_key, api_key: ctx.router_config.api_key.clone(),
retry_config, retry_config: ctx.router_config.effective_retry_config(),
circuit_breaker_config: core_cb_config, circuit_breaker_config: core_cb_config,
}) })
} }
......
// gRPC Router Implementation // gRPC Router Implementation
use crate::config::types::{ use crate::config::types::RetryConfig;
CircuitBreakerConfig as ConfigCircuitBreakerConfig,
HealthCheckConfig as ConfigHealthCheckConfig, RetryConfig,
};
use crate::core::{ use crate::core::{
BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType, BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType,
}; };
...@@ -54,25 +51,31 @@ pub struct GrpcRouter { ...@@ -54,25 +51,31 @@ pub struct GrpcRouter {
impl GrpcRouter { impl GrpcRouter {
/// Create a new gRPC router /// Create a new gRPC router
#[allow(clippy::too_many_arguments)]
pub async fn new( pub async fn new(
worker_urls: Vec<String>, worker_urls: Vec<String>,
policy: Arc<dyn LoadBalancingPolicy>, policy: Arc<dyn LoadBalancingPolicy>,
timeout_secs: u64, ctx: &Arc<crate::server::AppContext>,
interval_secs: u64,
dp_aware: bool,
api_key: Option<String>,
retry_config: RetryConfig,
circuit_breaker_config: ConfigCircuitBreakerConfig,
health_check_config: ConfigHealthCheckConfig,
tokenizer: Arc<dyn Tokenizer>,
reasoning_parser_factory: ParserFactory,
tool_parser_registry: &'static ParserRegistry,
) -> Result<Self, String> { ) -> Result<Self, String> {
// Update metrics // Update metrics
RouterMetrics::set_active_workers(worker_urls.len()); RouterMetrics::set_active_workers(worker_urls.len());
// Extract necessary components from context
let tokenizer = ctx
.tokenizer
.as_ref()
.ok_or_else(|| "gRPC router requires tokenizer".to_string())?
.clone();
let reasoning_parser_factory = ctx
.reasoning_parser_factory
.as_ref()
.ok_or_else(|| "gRPC router requires reasoning parser factory".to_string())?
.clone();
let tool_parser_registry = ctx
.tool_parser_registry
.ok_or_else(|| "gRPC router requires tool parser registry".to_string())?;
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig // Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
let core_cb_config = CircuitBreakerConfig { let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold, failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold, success_threshold: circuit_breaker_config.success_threshold,
...@@ -112,11 +115,11 @@ impl GrpcRouter { ...@@ -112,11 +115,11 @@ impl GrpcRouter {
) )
.with_circuit_breaker_config(core_cb_config.clone()) .with_circuit_breaker_config(core_cb_config.clone())
.with_health_config(HealthConfig { .with_health_config(HealthConfig {
timeout_secs: health_check_config.timeout_secs, timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: health_check_config.check_interval_secs, check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: health_check_config.endpoint.clone(), endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: health_check_config.failure_threshold, failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: health_check_config.success_threshold, success_threshold: ctx.router_config.health_check.success_threshold,
}) })
.with_grpc_client(client); .with_grpc_client(client);
...@@ -135,7 +138,10 @@ impl GrpcRouter { ...@@ -135,7 +138,10 @@ impl GrpcRouter {
} }
let workers = Arc::new(RwLock::new(workers)); let workers = Arc::new(RwLock::new(workers));
let health_checker = crate::core::start_health_checker(Arc::clone(&workers), interval_secs); let health_checker = crate::core::start_health_checker(
Arc::clone(&workers),
ctx.router_config.worker_startup_check_interval_secs,
);
Ok(GrpcRouter { Ok(GrpcRouter {
workers, workers,
...@@ -145,11 +151,11 @@ impl GrpcRouter { ...@@ -145,11 +151,11 @@ impl GrpcRouter {
reasoning_parser_factory, reasoning_parser_factory,
tool_parser_registry, tool_parser_registry,
_health_checker: Some(health_checker), _health_checker: Some(health_checker),
timeout_secs, timeout_secs: ctx.router_config.worker_startup_timeout_secs,
interval_secs, interval_secs: ctx.router_config.worker_startup_check_interval_secs,
dp_aware, dp_aware: ctx.router_config.dp_aware,
api_key, api_key: ctx.router_config.api_key.clone(),
retry_config, retry_config: ctx.router_config.effective_retry_config(),
circuit_breaker_config: core_cb_config, circuit_breaker_config: core_cb_config,
}) })
} }
......
// PD (Prefill-Decode) Router Implementation // PD (Prefill-Decode) Router Implementation
// This module handles routing for disaggregated prefill-decode systems // This module handles routing for disaggregated prefill-decode systems
use super::pd_types::{api_path, PDRouterError}; use super::pd_types::{api_path, PDRouterError};
use crate::config::types::{ use crate::config::types::RetryConfig;
CircuitBreakerConfig as ConfigCircuitBreakerConfig,
HealthCheckConfig as ConfigHealthCheckConfig, RetryConfig,
};
use crate::core::{ use crate::core::{
is_retryable_status, BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, is_retryable_status, BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig,
RetryExecutor, Worker, WorkerFactory, WorkerLoadGuard, WorkerType, RetryExecutor, Worker, WorkerFactory, WorkerLoadGuard, WorkerType,
...@@ -375,15 +372,10 @@ impl PDRouter { ...@@ -375,15 +372,10 @@ impl PDRouter {
decode_urls: Vec<String>, decode_urls: Vec<String>,
prefill_policy: Arc<dyn LoadBalancingPolicy>, prefill_policy: Arc<dyn LoadBalancingPolicy>,
decode_policy: Arc<dyn LoadBalancingPolicy>, decode_policy: Arc<dyn LoadBalancingPolicy>,
client: Client, ctx: &Arc<crate::server::AppContext>,
prefill_request_timeout_secs: u64,
worker_startup_timeout_secs: u64,
worker_startup_check_interval_secs: u64,
retry_config: RetryConfig,
circuit_breaker_config: ConfigCircuitBreakerConfig,
health_check_config: ConfigHealthCheckConfig,
) -> Result<Self, String> { ) -> Result<Self, String> {
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig // Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
let core_cb_config = CircuitBreakerConfig { let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold, failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold, success_threshold: circuit_breaker_config.success_threshold,
...@@ -403,11 +395,11 @@ impl PDRouter { ...@@ -403,11 +395,11 @@ impl PDRouter {
) )
.with_circuit_breaker_config(core_cb_config.clone()) .with_circuit_breaker_config(core_cb_config.clone())
.with_health_config(HealthConfig { .with_health_config(HealthConfig {
timeout_secs: health_check_config.timeout_secs, timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: health_check_config.check_interval_secs, check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: health_check_config.endpoint.clone(), endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: health_check_config.failure_threshold, failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: health_check_config.success_threshold, success_threshold: ctx.router_config.health_check.success_threshold,
}); });
Box::new(worker) as Box<dyn Worker> Box::new(worker) as Box<dyn Worker>
}) })
...@@ -419,11 +411,11 @@ impl PDRouter { ...@@ -419,11 +411,11 @@ impl PDRouter {
let worker = BasicWorker::new(url, WorkerType::Decode) let worker = BasicWorker::new(url, WorkerType::Decode)
.with_circuit_breaker_config(core_cb_config.clone()) .with_circuit_breaker_config(core_cb_config.clone())
.with_health_config(HealthConfig { .with_health_config(HealthConfig {
timeout_secs: health_check_config.timeout_secs, timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: health_check_config.check_interval_secs, check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: health_check_config.endpoint.clone(), endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: health_check_config.failure_threshold, failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: health_check_config.success_threshold, success_threshold: ctx.router_config.health_check.success_threshold,
}); });
Box::new(worker) as Box<dyn Worker> Box::new(worker) as Box<dyn Worker>
}) })
...@@ -438,8 +430,8 @@ impl PDRouter { ...@@ -438,8 +430,8 @@ impl PDRouter {
if !all_urls.is_empty() { if !all_urls.is_empty() {
crate::routers::http::router::Router::wait_for_healthy_workers( crate::routers::http::router::Router::wait_for_healthy_workers(
&all_urls, &all_urls,
worker_startup_timeout_secs, ctx.router_config.worker_startup_timeout_secs,
worker_startup_check_interval_secs, ctx.router_config.worker_startup_check_interval_secs,
) )
.await?; .await?;
} }
...@@ -466,8 +458,8 @@ impl PDRouter { ...@@ -466,8 +458,8 @@ impl PDRouter {
let load_monitor_handle = let load_monitor_handle =
if prefill_policy.name() == "power_of_two" || decode_policy.name() == "power_of_two" { if prefill_policy.name() == "power_of_two" || decode_policy.name() == "power_of_two" {
let monitor_urls = all_urls.clone(); let monitor_urls = all_urls.clone();
let monitor_interval = worker_startup_check_interval_secs; let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
let monitor_client = client.clone(); let monitor_client = ctx.client.clone();
let prefill_policy_clone = Arc::clone(&prefill_policy); let prefill_policy_clone = Arc::clone(&prefill_policy);
let decode_policy_clone = Arc::clone(&decode_policy); let decode_policy_clone = Arc::clone(&decode_policy);
...@@ -492,11 +484,11 @@ impl PDRouter { ...@@ -492,11 +484,11 @@ impl PDRouter {
// Start health checkers for both worker pools // Start health checkers for both worker pools
let prefill_health_checker = crate::core::start_health_checker( let prefill_health_checker = crate::core::start_health_checker(
Arc::clone(&prefill_workers), Arc::clone(&prefill_workers),
health_check_config.check_interval_secs, ctx.router_config.health_check.check_interval_secs,
); );
let decode_health_checker = crate::core::start_health_checker( let decode_health_checker = crate::core::start_health_checker(
Arc::clone(&decode_workers), Arc::clone(&decode_workers),
health_check_config.check_interval_secs, ctx.router_config.health_check.check_interval_secs,
); );
// Build a dedicated prefill client for fire-and-forget semantics // Build a dedicated prefill client for fire-and-forget semantics
...@@ -504,7 +496,7 @@ impl PDRouter { ...@@ -504,7 +496,7 @@ impl PDRouter {
.pool_max_idle_per_host(0) .pool_max_idle_per_host(0)
.http1_only() .http1_only()
.connect_timeout(Duration::from_millis(300)) .connect_timeout(Duration::from_millis(300))
.timeout(Duration::from_secs(prefill_request_timeout_secs)) .timeout(Duration::from_secs(ctx.router_config.request_timeout_secs))
.build() .build()
.map_err(|e| format!("Failed to build prefill client: {}", e))?; .map_err(|e| format!("Failed to build prefill client: {}", e))?;
...@@ -582,14 +574,16 @@ impl PDRouter { ...@@ -582,14 +574,16 @@ impl PDRouter {
decode_workers, decode_workers,
prefill_policy, prefill_policy,
decode_policy, decode_policy,
worker_startup_timeout_secs, worker_startup_timeout_secs: ctx.router_config.worker_startup_timeout_secs,
worker_startup_check_interval_secs, worker_startup_check_interval_secs: ctx
.router_config
.worker_startup_check_interval_secs,
worker_loads, worker_loads,
load_monitor_handle, load_monitor_handle,
client, client: ctx.client.clone(),
prefill_client, prefill_client,
prefill_drain_tx, prefill_drain_tx,
retry_config, retry_config: ctx.router_config.effective_retry_config(),
circuit_breaker_config: core_cb_config, circuit_breaker_config: core_cb_config,
_prefill_health_checker: Some(prefill_health_checker), _prefill_health_checker: Some(prefill_health_checker),
_decode_health_checker: Some(decode_health_checker), _decode_health_checker: Some(decode_health_checker),
......
use crate::config::types::{ use crate::config::types::RetryConfig;
CircuitBreakerConfig as ConfigCircuitBreakerConfig,
HealthCheckConfig as ConfigHealthCheckConfig, RetryConfig,
};
use crate::core::{ use crate::core::{
is_retryable_status, BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, is_retryable_status, BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig,
RetryExecutor, Worker, WorkerFactory, WorkerType, RetryExecutor, Worker, WorkerFactory, WorkerType,
...@@ -51,14 +48,7 @@ impl Router { ...@@ -51,14 +48,7 @@ impl Router {
pub async fn new( pub async fn new(
worker_urls: Vec<String>, worker_urls: Vec<String>,
policy: Arc<dyn LoadBalancingPolicy>, policy: Arc<dyn LoadBalancingPolicy>,
client: Client, ctx: &Arc<crate::server::AppContext>,
worker_startup_timeout_secs: u64,
worker_startup_check_interval_secs: u64,
dp_aware: bool,
api_key: Option<String>,
retry_config: RetryConfig,
circuit_breaker_config: ConfigCircuitBreakerConfig,
health_check_config: ConfigHealthCheckConfig,
) -> Result<Self, String> { ) -> Result<Self, String> {
// Update active workers gauge // Update active workers gauge
RouterMetrics::set_active_workers(worker_urls.len()); RouterMetrics::set_active_workers(worker_urls.len());
...@@ -67,21 +57,22 @@ impl Router { ...@@ -67,21 +57,22 @@ impl Router {
if !worker_urls.is_empty() { if !worker_urls.is_empty() {
Self::wait_for_healthy_workers( Self::wait_for_healthy_workers(
&worker_urls, &worker_urls,
worker_startup_timeout_secs, ctx.router_config.worker_startup_timeout_secs,
worker_startup_check_interval_secs, ctx.router_config.worker_startup_check_interval_secs,
) )
.await?; .await?;
} }
let worker_urls = if dp_aware { let worker_urls = if ctx.router_config.dp_aware {
// worker address now in the format of "http://host:port@dp_rank" // worker address now in the format of "http://host:port@dp_rank"
Self::get_dp_aware_workers(&worker_urls, &api_key) Self::get_dp_aware_workers(&worker_urls, &ctx.router_config.api_key)
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))? .map_err(|e| format!("Failed to get dp-aware workers: {}", e))?
} else { } else {
worker_urls worker_urls
}; };
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig // Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
let core_cb_config = CircuitBreakerConfig { let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold, failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold, success_threshold: circuit_breaker_config.success_threshold,
...@@ -96,11 +87,11 @@ impl Router { ...@@ -96,11 +87,11 @@ impl Router {
let worker = BasicWorker::new(url.clone(), WorkerType::Regular) let worker = BasicWorker::new(url.clone(), WorkerType::Regular)
.with_circuit_breaker_config(core_cb_config.clone()) .with_circuit_breaker_config(core_cb_config.clone())
.with_health_config(HealthConfig { .with_health_config(HealthConfig {
timeout_secs: health_check_config.timeout_secs, timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: health_check_config.check_interval_secs, check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: health_check_config.endpoint.clone(), endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: health_check_config.failure_threshold, failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: health_check_config.success_threshold, success_threshold: ctx.router_config.health_check.success_threshold,
}); });
Box::new(worker) as Box<dyn Worker> Box::new(worker) as Box<dyn Worker>
}) })
...@@ -117,7 +108,7 @@ impl Router { ...@@ -117,7 +108,7 @@ impl Router {
let workers = Arc::new(RwLock::new(workers)); let workers = Arc::new(RwLock::new(workers));
let health_checker = crate::core::start_health_checker( let health_checker = crate::core::start_health_checker(
Arc::clone(&workers), Arc::clone(&workers),
worker_startup_check_interval_secs, ctx.router_config.worker_startup_check_interval_secs,
); );
// Setup load monitoring for PowerOfTwo policy // Setup load monitoring for PowerOfTwo policy
...@@ -126,9 +117,9 @@ impl Router { ...@@ -126,9 +117,9 @@ impl Router {
let load_monitor_handle = if policy.name() == "power_of_two" { let load_monitor_handle = if policy.name() == "power_of_two" {
let monitor_urls = worker_urls.clone(); let monitor_urls = worker_urls.clone();
let monitor_interval = worker_startup_check_interval_secs; let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
let policy_clone = Arc::clone(&policy); let policy_clone = Arc::clone(&policy);
let client_clone = client.clone(); let client_clone = ctx.client.clone();
Some(Arc::new(tokio::spawn(async move { Some(Arc::new(tokio::spawn(async move {
Self::monitor_worker_loads( Self::monitor_worker_loads(
...@@ -147,12 +138,14 @@ impl Router { ...@@ -147,12 +138,14 @@ impl Router {
Ok(Router { Ok(Router {
workers, workers,
policy, policy,
client, client: ctx.client.clone(),
worker_startup_timeout_secs, worker_startup_timeout_secs: ctx.router_config.worker_startup_timeout_secs,
worker_startup_check_interval_secs, worker_startup_check_interval_secs: ctx
dp_aware, .router_config
api_key, .worker_startup_check_interval_secs,
retry_config, dp_aware: ctx.router_config.dp_aware,
api_key: ctx.router_config.api_key.clone(),
retry_config: ctx.router_config.effective_retry_config(),
circuit_breaker_config: core_cb_config, circuit_breaker_config: core_cb_config,
_worker_loads: worker_loads, _worker_loads: worker_loads,
_load_monitor_handle: load_monitor_handle, _load_monitor_handle: load_monitor_handle,
......
...@@ -579,25 +579,27 @@ mod tests { ...@@ -579,25 +579,27 @@ mod tests {
// Helper to create a Router instance for testing event handlers // Helper to create a Router instance for testing event handlers
async fn create_test_router() -> Arc<dyn RouterTrait> { async fn create_test_router() -> Arc<dyn RouterTrait> {
use crate::config::PolicyConfig; use crate::config::{PolicyConfig, RouterConfig};
use crate::middleware::TokenBucket;
use crate::policies::PolicyFactory; use crate::policies::PolicyFactory;
use crate::routers::http::router::Router; use crate::routers::http::router::Router;
use crate::server::AppContext;
// Create a minimal RouterConfig for testing
let router_config = RouterConfig::default();
// Create AppContext with minimal components
let app_context = Arc::new(AppContext {
client: reqwest::Client::new(),
router_config,
rate_limiter: Arc::new(TokenBucket::new(1000, 1000)),
tokenizer: None, // HTTP mode doesn't need tokenizer
reasoning_parser_factory: None, // HTTP mode doesn't need reasoning parser
tool_parser_registry: None, // HTTP mode doesn't need tool parser
});
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random); let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
let router = Router::new( let router = Router::new(vec![], policy, &app_context).await.unwrap();
vec![],
policy,
reqwest::Client::new(),
5,
1,
false,
None,
crate::config::types::RetryConfig::default(),
crate::config::types::CircuitBreakerConfig::default(),
crate::config::types::HealthCheckConfig::default(),
)
.await
.unwrap();
Arc::new(router) as Arc<dyn RouterTrait> Arc::new(router) as Arc<dyn RouterTrait>
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment