Unverified Commit 5d4fe1ce authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] add move grpc worker management from router to worker manager (#10960)

parent 1b011e68
...@@ -181,16 +181,31 @@ impl WorkerManager { ...@@ -181,16 +181,31 @@ impl WorkerManager {
) -> Result<(), String> { ) -> Result<(), String> {
info!("Starting worker initialization"); info!("Starting worker initialization");
// Determine connection mode from config
let connection_mode = &config.connection_mode;
match &config.mode { match &config.mode {
RoutingMode::Regular { worker_urls } => { RoutingMode::Regular { worker_urls } => match connection_mode {
Self::initialize_regular_workers(worker_urls, config, registry, policy_registry) ConfigConnectionMode::Http => {
Self::initialize_regular_workers(
worker_urls,
config,
registry,
policy_registry,
)
.await?;
}
ConfigConnectionMode::Grpc => {
Self::initialize_grpc_workers(worker_urls, config, registry, policy_registry)
.await?; .await?;
} }
},
RoutingMode::PrefillDecode { RoutingMode::PrefillDecode {
prefill_urls, prefill_urls,
decode_urls, decode_urls,
.. ..
} => { } => match connection_mode {
ConfigConnectionMode::Http => {
let prefill_entries: Vec<(&String, &Option<u16>)> = let prefill_entries: Vec<(&String, &Option<u16>)> =
prefill_urls.iter().map(|(url, port)| (url, port)).collect(); prefill_urls.iter().map(|(url, port)| (url, port)).collect();
...@@ -204,6 +219,17 @@ impl WorkerManager { ...@@ -204,6 +219,17 @@ impl WorkerManager {
Self::initialize_decode_workers(decode_urls, config, registry, policy_registry) Self::initialize_decode_workers(decode_urls, config, registry, policy_registry)
.await?; .await?;
} }
ConfigConnectionMode::Grpc => {
Self::initialize_grpc_pd_workers(
prefill_urls,
decode_urls,
config,
registry,
policy_registry,
)
.await?;
}
},
RoutingMode::OpenAI { .. } => { RoutingMode::OpenAI { .. } => {
info!("OpenAI routing mode - no workers to initialize"); info!("OpenAI routing mode - no workers to initialize");
} }
...@@ -397,6 +423,133 @@ impl WorkerManager { ...@@ -397,6 +423,133 @@ impl WorkerManager {
Ok(()) Ok(())
} }
/// Initialize gRPC workers for regular mode
async fn initialize_grpc_workers(
urls: &[String],
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
policy_registry: Option<&Arc<PolicyRegistry>>,
) -> Result<(), String> {
info!("Creating {} gRPC regular workers", urls.len());
let circuit_breaker_config =
Self::convert_circuit_breaker_config(&config.effective_circuit_breaker_config());
let health_config = Self::convert_health_config(&config.health_check);
let connection_mode = ConnectionMode::Grpc { port: None };
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
for url in urls {
let worker = Self::create_basic_worker(
url.clone(),
WorkerType::Regular,
connection_mode.clone(),
config.api_key.clone(),
None,
circuit_breaker_config.clone(),
health_config.clone(),
);
Self::register_worker(worker, registry, &mut registered_workers, policy_registry);
info!(
"Registered gRPC worker at {} (will connect on first use)",
url
);
}
Self::initialize_cache_policies(&registered_workers, registry, policy_registry);
Ok(())
}
/// Initialize gRPC PD (Prefill-Decode) workers
async fn initialize_grpc_pd_workers(
prefill_urls: &[(String, Option<u16>)],
decode_urls: &[String],
config: &RouterConfig,
registry: &Arc<WorkerRegistry>,
policy_registry: Option<&Arc<PolicyRegistry>>,
) -> Result<(), String> {
info!(
"Creating {} gRPC prefill workers and {} gRPC decode workers",
prefill_urls.len(),
decode_urls.len()
);
let circuit_breaker_config =
Self::convert_circuit_breaker_config(&config.effective_circuit_breaker_config());
let health_config = Self::convert_health_config(&config.health_check);
let mut registered_prefill_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
let mut registered_decode_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
for (url, bootstrap_port) in prefill_urls {
let worker_type = WorkerType::Prefill {
bootstrap_port: *bootstrap_port,
};
let connection_mode = ConnectionMode::Grpc {
port: *bootstrap_port,
};
let worker = Self::create_basic_worker(
url.clone(),
worker_type,
connection_mode,
config.api_key.clone(),
None,
circuit_breaker_config.clone(),
health_config.clone(),
);
Self::register_worker(
worker,
registry,
&mut registered_prefill_workers,
policy_registry,
);
info!(
"Registered gRPC prefill worker at {} (will connect on first use)",
url
);
}
// Create decode workers
for url in decode_urls {
let connection_mode = ConnectionMode::Grpc { port: None };
let worker = Self::create_basic_worker(
url.clone(),
WorkerType::Decode,
connection_mode,
config.api_key.clone(),
None,
circuit_breaker_config.clone(),
health_config.clone(),
);
Self::register_worker(
worker,
registry,
&mut registered_decode_workers,
policy_registry,
);
info!(
"Registered gRPC decode worker at {} (will connect on first use)",
url
);
}
if let Some(policy_reg) = policy_registry {
let all_prefill_workers: Vec<Arc<dyn Worker>> = registered_prefill_workers
.values()
.flat_map(|workers| workers.iter().cloned())
.collect();
let all_decode_workers: Vec<Arc<dyn Worker>> = registered_decode_workers
.values()
.flat_map(|workers| workers.iter().cloned())
.collect();
policy_reg.init_pd_cache_aware_policies(&all_prefill_workers, &all_decode_workers);
}
Ok(())
}
/// Add a worker from a configuration request /// Add a worker from a configuration request
pub async fn add_worker_from_config( pub async fn add_worker_from_config(
config: &WorkerConfigRequest, config: &WorkerConfigRequest,
......
//! Factory for creating router instances //! Factory for creating router instances
use super::grpc::pd_router::GrpcPDRouter;
use super::grpc::router::GrpcRouter;
use super::{ use super::{
http::{openai_router::OpenAIRouter, pd_router::PDRouter, router::Router}, http::{openai_router::OpenAIRouter, pd_router::PDRouter, router::Router},
RouterTrait, RouterTrait,
...@@ -15,23 +17,15 @@ pub struct RouterFactory; ...@@ -15,23 +17,15 @@ pub struct RouterFactory;
impl RouterFactory { impl RouterFactory {
/// Create a router instance from application context /// Create a router instance from application context
pub async fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> { pub async fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
// Check connection mode and route to appropriate implementation
match ctx.router_config.connection_mode { match ctx.router_config.connection_mode {
ConnectionMode::Grpc => { ConnectionMode::Grpc => match &ctx.router_config.mode {
// Route to gRPC implementation based on routing mode RoutingMode::Regular { .. } => Self::create_grpc_router(ctx).await,
match &ctx.router_config.mode {
RoutingMode::Regular { worker_urls } => {
Self::create_grpc_router(worker_urls, &ctx.router_config.policy, ctx).await
}
RoutingMode::PrefillDecode { RoutingMode::PrefillDecode {
prefill_urls,
decode_urls,
prefill_policy, prefill_policy,
decode_policy, decode_policy,
..
} => { } => {
Self::create_grpc_pd_router( Self::create_grpc_pd_router(
prefill_urls,
decode_urls,
prefill_policy.as_ref(), prefill_policy.as_ref(),
decode_policy.as_ref(), decode_policy.as_ref(),
&ctx.router_config.policy, &ctx.router_config.policy,
...@@ -42,21 +36,14 @@ impl RouterFactory { ...@@ -42,21 +36,14 @@ impl RouterFactory {
RoutingMode::OpenAI { .. } => { RoutingMode::OpenAI { .. } => {
Err("OpenAI mode requires HTTP connection_mode".to_string()) Err("OpenAI mode requires HTTP connection_mode".to_string())
} }
} },
} ConnectionMode::Http => match &ctx.router_config.mode {
ConnectionMode::Http => { RoutingMode::Regular { .. } => Self::create_regular_router(ctx).await,
// Route to HTTP implementation based on routing mode
match &ctx.router_config.mode {
RoutingMode::Regular { .. } => {
// Workers already initialized in registry
Self::create_regular_router(ctx).await
}
RoutingMode::PrefillDecode { RoutingMode::PrefillDecode {
prefill_policy, prefill_policy,
decode_policy, decode_policy,
.. ..
} => { } => {
// Workers already initialized in registry
Self::create_pd_router( Self::create_pd_router(
prefill_policy.as_ref(), prefill_policy.as_ref(),
decode_policy.as_ref(), decode_policy.as_ref(),
...@@ -68,8 +55,7 @@ impl RouterFactory { ...@@ -68,8 +55,7 @@ impl RouterFactory {
RoutingMode::OpenAI { worker_urls, .. } => { RoutingMode::OpenAI { worker_urls, .. } => {
Self::create_openai_router(worker_urls.clone(), ctx).await Self::create_openai_router(worker_urls.clone(), ctx).await
} }
} },
}
} }
} }
...@@ -77,8 +63,6 @@ impl RouterFactory { ...@@ -77,8 +63,6 @@ impl RouterFactory {
pub async fn create_regular_router( pub async fn create_regular_router(
ctx: &Arc<AppContext>, ctx: &Arc<AppContext>,
) -> Result<Box<dyn RouterTrait>, String> { ) -> Result<Box<dyn RouterTrait>, String> {
// Create regular router with context
// Workers should already be initialized in the registry
let router = Router::new(ctx).await?; let router = Router::new(ctx).await?;
Ok(Box::new(router)) Ok(Box::new(router))
...@@ -91,66 +75,41 @@ impl RouterFactory { ...@@ -91,66 +75,41 @@ impl RouterFactory {
main_policy_config: &PolicyConfig, main_policy_config: &PolicyConfig,
ctx: &Arc<AppContext>, ctx: &Arc<AppContext>,
) -> Result<Box<dyn RouterTrait>, String> { ) -> Result<Box<dyn RouterTrait>, String> {
// Initialize policies in PolicyRegistry - use specific policies if provided, otherwise fall back to main policy
let prefill_policy = let prefill_policy =
PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config)); PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config));
let decode_policy = let decode_policy =
PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config)); PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
// Set the prefill and decode policies in the registry
ctx.policy_registry.set_prefill_policy(prefill_policy); ctx.policy_registry.set_prefill_policy(prefill_policy);
ctx.policy_registry.set_decode_policy(decode_policy); ctx.policy_registry.set_decode_policy(decode_policy);
// Create PD router with context (policies are in PolicyRegistry)
// Workers should already be initialized in the registry
let router = PDRouter::new(ctx).await?; let router = PDRouter::new(ctx).await?;
Ok(Box::new(router)) Ok(Box::new(router))
} }
/// Create a gRPC router with injected policy /// Create a gRPC router with injected policy
pub async fn create_grpc_router( pub async fn create_grpc_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
worker_urls: &[String], let router = GrpcRouter::new(ctx).await?;
policy_config: &PolicyConfig,
ctx: &Arc<AppContext>,
) -> Result<Box<dyn RouterTrait>, String> {
use super::grpc::router::GrpcRouter;
// Create policy
let policy = PolicyFactory::create_from_config(policy_config);
// Create gRPC router with context
let router = GrpcRouter::new(worker_urls.to_vec(), policy, ctx).await?;
Ok(Box::new(router)) Ok(Box::new(router))
} }
/// Create a gRPC PD router with tokenizer and worker configuration /// Create a gRPC PD router with tokenizer and worker configuration
pub async fn create_grpc_pd_router( pub async fn create_grpc_pd_router(
prefill_urls: &[(String, Option<u16>)],
decode_urls: &[String],
prefill_policy_config: Option<&PolicyConfig>, prefill_policy_config: Option<&PolicyConfig>,
decode_policy_config: Option<&PolicyConfig>, decode_policy_config: Option<&PolicyConfig>,
main_policy_config: &PolicyConfig, main_policy_config: &PolicyConfig,
ctx: &Arc<AppContext>, ctx: &Arc<AppContext>,
) -> Result<Box<dyn RouterTrait>, String> { ) -> Result<Box<dyn RouterTrait>, String> {
use super::grpc::pd_router::GrpcPDRouter;
// Create policies - use specific policies if provided, otherwise fall back to main policy
let prefill_policy = let prefill_policy =
PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config)); PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config));
let decode_policy = let decode_policy =
PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config)); PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
// Create gRPC PD router with context ctx.policy_registry.set_prefill_policy(prefill_policy);
let router = GrpcPDRouter::new( ctx.policy_registry.set_decode_policy(decode_policy);
prefill_urls.to_vec(), let router = GrpcPDRouter::new(ctx).await?;
decode_urls.to_vec(),
prefill_policy,
decode_policy,
ctx,
)
.await?;
Ok(Box::new(router)) Ok(Box::new(router))
} }
...@@ -160,7 +119,6 @@ impl RouterFactory { ...@@ -160,7 +119,6 @@ impl RouterFactory {
worker_urls: Vec<String>, worker_urls: Vec<String>,
ctx: &Arc<AppContext>, ctx: &Arc<AppContext>,
) -> Result<Box<dyn RouterTrait>, String> { ) -> Result<Box<dyn RouterTrait>, String> {
// Use the first worker URL as the OpenAI-compatible base
let base_url = worker_urls let base_url = worker_urls
.first() .first()
.cloned() .cloned()
......
// PD (Prefill-Decode) gRPC Router Implementation // PD (Prefill-Decode) gRPC Router Implementation
use crate::config::types::RetryConfig; use crate::config::types::RetryConfig;
use crate::core::{ use crate::core::{WorkerRegistry, WorkerType};
BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, WorkerRegistry, WorkerType,
};
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; use crate::policies::PolicyRegistry;
use crate::reasoning_parser::ParserFactory; use crate::reasoning_parser::ParserFactory;
use crate::routers::RouterTrait; use crate::routers::RouterTrait;
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::traits::Tokenizer;
...@@ -18,51 +16,29 @@ use axum::{ ...@@ -18,51 +16,29 @@ use axum::{
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use tracing::info; use tracing::info;
/// gRPC PD (Prefill-Decode) router implementation for SGLang /// gRPC PD (Prefill-Decode) router implementation for SGLang
#[allow(dead_code)] // Fields will be used once implementation is complete #[allow(dead_code)] // Fields will be used once implementation is complete
pub struct GrpcPDRouter { pub struct GrpcPDRouter {
/// Centralized worker registry
worker_registry: Arc<WorkerRegistry>, worker_registry: Arc<WorkerRegistry>,
/// Centralized policy registry
policy_registry: Arc<PolicyRegistry>, policy_registry: Arc<PolicyRegistry>,
/// Load balancing policy for prefill
prefill_policy: Arc<dyn LoadBalancingPolicy>,
/// Load balancing policy for decode
decode_policy: Arc<dyn LoadBalancingPolicy>,
/// Tokenizer for handling text encoding/decoding
tokenizer: Arc<dyn Tokenizer>, tokenizer: Arc<dyn Tokenizer>,
/// Reasoning parser factory for structured reasoning outputs
reasoning_parser_factory: ParserFactory, reasoning_parser_factory: ParserFactory,
/// Tool parser registry for function/tool calls
tool_parser_registry: &'static ParserRegistry, tool_parser_registry: &'static ParserRegistry,
/// Configuration
timeout_secs: u64,
interval_secs: u64,
dp_aware: bool, dp_aware: bool,
api_key: Option<String>, api_key: Option<String>,
retry_config: RetryConfig, retry_config: RetryConfig,
circuit_breaker_config: CircuitBreakerConfig,
} }
impl GrpcPDRouter { impl GrpcPDRouter {
/// Create a new gRPC PD router /// Create a new gRPC PD router
pub async fn new( pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
prefill_urls: Vec<(String, Option<u16>)>,
decode_urls: Vec<String>,
prefill_policy: Arc<dyn LoadBalancingPolicy>,
decode_policy: Arc<dyn LoadBalancingPolicy>,
ctx: &Arc<crate::server::AppContext>,
) -> Result<Self, String> {
// Get registries from context // Get registries from context
let worker_registry = ctx.worker_registry.clone(); let worker_registry = ctx.worker_registry.clone();
let policy_registry = ctx.policy_registry.clone(); let policy_registry = ctx.policy_registry.clone();
// Update metrics
RouterMetrics::set_active_workers(prefill_urls.len() + decode_urls.len());
// Extract necessary components from context // Extract necessary components from context
let tokenizer = ctx let tokenizer = ctx
.tokenizer .tokenizer
...@@ -78,67 +54,7 @@ impl GrpcPDRouter { ...@@ -78,67 +54,7 @@ impl GrpcPDRouter {
.tool_parser_registry .tool_parser_registry
.ok_or_else(|| "gRPC PD router requires tool parser registry".to_string())?; .ok_or_else(|| "gRPC PD router requires tool parser registry".to_string())?;
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig // Get prefill and decode workers from registry - they should have been created by WorkerManager
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold,
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
for (url, bootstrap_port) in &prefill_urls {
let worker = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Prefill {
bootstrap_port: *bootstrap_port,
})
.connection_mode(crate::core::ConnectionMode::Grpc {
port: *bootstrap_port,
})
.circuit_breaker_config(core_cb_config.clone())
.health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
})
// No longer passing pre-initialized client - will be created lazily
.build();
worker_registry.register(Arc::new(worker));
info!(
"Registered gRPC prefill worker at {} (will connect on first use)",
url
);
}
for url in &decode_urls {
let worker = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Decode)
.connection_mode(crate::core::ConnectionMode::Grpc { port: None })
.circuit_breaker_config(core_cb_config.clone())
.health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
})
.build();
worker_registry.register(Arc::new(worker));
info!(
"Registered gRPC decode worker at {} (will connect on first use)",
url
);
}
if prefill_urls.is_empty() && decode_urls.is_empty() {
return Err("No gRPC workers configured".to_string());
}
// Initialize policies with workers if needed - filter for gRPC workers only
let prefill_workers = worker_registry.get_workers_filtered( let prefill_workers = worker_registry.get_workers_filtered(
None, // any model None, // any model
Some(WorkerType::Prefill { Some(WorkerType::Prefill {
...@@ -147,12 +63,6 @@ impl GrpcPDRouter { ...@@ -147,12 +63,6 @@ impl GrpcPDRouter {
Some(crate::core::ConnectionMode::Grpc { port: None }), Some(crate::core::ConnectionMode::Grpc { port: None }),
false, // include unhealthy workers during initialization false, // include unhealthy workers during initialization
); );
if let Some(cache_aware) = prefill_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.init_workers(&prefill_workers);
}
let decode_workers = worker_registry.get_workers_filtered( let decode_workers = worker_registry.get_workers_filtered(
None, // any model None, // any model
...@@ -160,29 +70,26 @@ impl GrpcPDRouter { ...@@ -160,29 +70,26 @@ impl GrpcPDRouter {
Some(crate::core::ConnectionMode::Grpc { port: None }), Some(crate::core::ConnectionMode::Grpc { port: None }),
false, // include unhealthy workers during initialization false, // include unhealthy workers during initialization
); );
if let Some(cache_aware) = decode_policy
.as_any() // Update metrics
.downcast_ref::<crate::policies::CacheAwarePolicy>() RouterMetrics::set_active_workers(prefill_workers.len() + decode_workers.len());
{ info!(
cache_aware.init_workers(&decode_workers); "gRPC PD router found {} prefill and {} decode workers in registry",
} prefill_workers.len(),
decode_workers.len()
);
// No need for local health checkers - WorkerRegistry handles health checking // No need for local health checkers - WorkerRegistry handles health checking
Ok(GrpcPDRouter { Ok(GrpcPDRouter {
worker_registry, worker_registry,
policy_registry, policy_registry,
prefill_policy,
decode_policy,
tokenizer, tokenizer,
reasoning_parser_factory, reasoning_parser_factory,
tool_parser_registry, tool_parser_registry,
timeout_secs: ctx.router_config.worker_startup_timeout_secs,
interval_secs: ctx.router_config.worker_startup_check_interval_secs,
dp_aware: ctx.router_config.dp_aware, dp_aware: ctx.router_config.dp_aware,
api_key: ctx.router_config.api_key.clone(), api_key: ctx.router_config.api_key.clone(),
retry_config: ctx.router_config.effective_retry_config(), retry_config: ctx.router_config.effective_retry_config(),
circuit_breaker_config: core_cb_config,
}) })
} }
} }
...@@ -206,8 +113,6 @@ impl std::fmt::Debug for GrpcPDRouter { ...@@ -206,8 +113,6 @@ impl std::fmt::Debug for GrpcPDRouter {
f.debug_struct("GrpcPDRouter") f.debug_struct("GrpcPDRouter")
.field("prefill_workers_count", &prefill_workers.len()) .field("prefill_workers_count", &prefill_workers.len())
.field("decode_workers_count", &decode_workers.len()) .field("decode_workers_count", &decode_workers.len())
.field("timeout_secs", &self.timeout_secs)
.field("interval_secs", &self.interval_secs)
.field("dp_aware", &self.dp_aware) .field("dp_aware", &self.dp_aware)
.finish() .finish()
} }
......
// gRPC Router Implementation // gRPC Router Implementation
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait; use async_trait::async_trait;
use axum::{ use axum::{
...@@ -13,12 +12,10 @@ use axum::{ ...@@ -13,12 +12,10 @@ use axum::{
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
use crate::config::types::RetryConfig; use crate::config::types::RetryConfig;
use crate::core::{ use crate::core::{WorkerRegistry, WorkerType};
BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, WorkerRegistry, WorkerType,
};
use crate::grpc_client::{proto, SglangSchedulerClient}; use crate::grpc_client::{proto, SglangSchedulerClient};
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; use crate::policies::PolicyRegistry;
use crate::protocols::spec::{ChatCompletionRequest, StringOrArray}; use crate::protocols::spec::{ChatCompletionRequest, StringOrArray};
use crate::reasoning_parser::ParserFactory; use crate::reasoning_parser::ParserFactory;
use crate::routers::RouterTrait; use crate::routers::RouterTrait;
...@@ -38,39 +35,21 @@ pub struct ProcessedMessages { ...@@ -38,39 +35,21 @@ pub struct ProcessedMessages {
} }
/// gRPC router implementation for SGLang /// gRPC router implementation for SGLang
#[allow(dead_code)] // Fields will be used once implementation is complete #[allow(dead_code)]
pub struct GrpcRouter { pub struct GrpcRouter {
/// Centralized worker registry
worker_registry: Arc<WorkerRegistry>, worker_registry: Arc<WorkerRegistry>,
/// Centralized policy registry
policy_registry: Arc<PolicyRegistry>, policy_registry: Arc<PolicyRegistry>,
/// Load balancing policy
policy: Arc<dyn LoadBalancingPolicy>,
/// Tokenizer for handling text encoding/decoding
tokenizer: Arc<dyn Tokenizer>, tokenizer: Arc<dyn Tokenizer>,
/// Reasoning parser factory for structured reasoning outputs
reasoning_parser_factory: ParserFactory, reasoning_parser_factory: ParserFactory,
/// Tool parser registry for function/tool calls
tool_parser_registry: &'static ParserRegistry, tool_parser_registry: &'static ParserRegistry,
/// Configuration
timeout_secs: u64,
interval_secs: u64,
dp_aware: bool, dp_aware: bool,
api_key: Option<String>, api_key: Option<String>,
retry_config: RetryConfig, retry_config: RetryConfig,
circuit_breaker_config: CircuitBreakerConfig,
} }
impl GrpcRouter { impl GrpcRouter {
/// Create a new gRPC router /// Create a new gRPC router
pub async fn new( pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
worker_urls: Vec<String>,
policy: Arc<dyn LoadBalancingPolicy>,
ctx: &Arc<crate::server::AppContext>,
) -> Result<Self, String> {
// Update metrics
RouterMetrics::set_active_workers(worker_urls.len());
// Extract necessary components from context // Extract necessary components from context
let tokenizer = ctx let tokenizer = ctx
.tokenizer .tokenizer
...@@ -86,77 +65,28 @@ impl GrpcRouter { ...@@ -86,77 +65,28 @@ impl GrpcRouter {
.tool_parser_registry .tool_parser_registry
.ok_or_else(|| "gRPC router requires tool parser registry".to_string())?; .ok_or_else(|| "gRPC router requires tool parser registry".to_string())?;
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold,
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Get registries from context
let worker_registry = ctx.worker_registry.clone(); let worker_registry = ctx.worker_registry.clone();
let policy_registry = ctx.policy_registry.clone(); let policy_registry = ctx.policy_registry.clone();
// Create Worker trait objects with gRPC connection mode and register them
// Workers will lazily initialize their gRPC clients on first use
for url in &worker_urls {
let worker = BasicWorkerBuilder::new(url.clone())
.worker_type(WorkerType::Regular)
.connection_mode(crate::core::ConnectionMode::Grpc { port: None })
.circuit_breaker_config(core_cb_config.clone())
.health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
})
.build();
worker_registry.register(Arc::new(worker));
info!(
"Registered gRPC worker at {} (will connect on first use)",
url
);
}
if worker_urls.is_empty() {
return Err("No gRPC workers configured".to_string());
}
// Get only gRPC workers from registry for policy initialization
let workers = worker_registry.get_workers_filtered( let workers = worker_registry.get_workers_filtered(
None, // any model None,
Some(WorkerType::Regular), Some(WorkerType::Regular),
Some(crate::core::ConnectionMode::Grpc { port: None }), Some(crate::core::ConnectionMode::Grpc { port: None }),
false, // include unhealthy workers during initialization false,
); );
// Initialize policy with workers if needed RouterMetrics::set_active_workers(workers.len());
if let Some(cache_aware) = policy info!("gRPC router found {} workers in registry", workers.len());
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.init_workers(&workers);
}
// No need for local health checkers - WorkerRegistry handles health checking
Ok(GrpcRouter { Ok(GrpcRouter {
worker_registry, worker_registry,
policy_registry, policy_registry,
policy,
tokenizer, tokenizer,
reasoning_parser_factory, reasoning_parser_factory,
tool_parser_registry, tool_parser_registry,
timeout_secs: ctx.router_config.worker_startup_timeout_secs,
interval_secs: ctx.router_config.worker_startup_check_interval_secs,
dp_aware: ctx.router_config.dp_aware, dp_aware: ctx.router_config.dp_aware,
api_key: ctx.router_config.api_key.clone(), api_key: ctx.router_config.api_key.clone(),
retry_config: ctx.router_config.effective_retry_config(), retry_config: ctx.router_config.effective_retry_config(),
circuit_breaker_config: core_cb_config,
}) })
} }
...@@ -576,8 +506,6 @@ impl std::fmt::Debug for GrpcRouter { ...@@ -576,8 +506,6 @@ impl std::fmt::Debug for GrpcRouter {
let stats = self.worker_registry.stats(); let stats = self.worker_registry.stats();
f.debug_struct("GrpcRouter") f.debug_struct("GrpcRouter")
.field("workers_count", &stats.total_workers) .field("workers_count", &stats.total_workers)
.field("timeout_secs", &self.timeout_secs)
.field("interval_secs", &self.interval_secs)
.field("dp_aware", &self.dp_aware) .field("dp_aware", &self.dp_aware)
.finish() .finish()
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment