//! Router Manager for coordinating multiple routers and workers //! //! Provides centralized management based on enable_igw flag: //! - Single Router Mode (enable_igw=false): Router owns workers directly //! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything use std::sync::Arc; use async_trait::async_trait; use axum::{ body::Body, extract::Request, http::{HeaderMap, StatusCode}, response::{IntoResponse, Response}, }; use dashmap::DashMap; use serde_json::Value; use tracing::{debug, info, warn}; use crate::{ config::{ConnectionMode, RoutingMode}, core::{WorkerRegistry, WorkerType}, protocols::{ chat::ChatCompletionRequest, completion::CompletionRequest, embedding::EmbeddingRequest, generate::GenerateRequest, rerank::RerankRequest, responses::{ResponsesGetParams, ResponsesRequest}, }, routers::RouterTrait, server::{AppContext, ServerConfig}, }; #[derive(Debug, Clone, Hash, Eq, PartialEq)] pub struct RouterId(String); impl RouterId { pub fn new(id: String) -> Self { Self(id) } pub fn as_str(&self) -> &str { &self.0 } } pub struct RouterManager { worker_registry: Arc, routers: Arc>>, default_router: Arc>>, enable_igw: bool, } impl RouterManager { pub fn new(worker_registry: Arc) -> Self { Self { worker_registry, routers: Arc::new(DashMap::new()), default_router: Arc::new(std::sync::RwLock::new(None)), enable_igw: false, // Will be set properly in from_config } } pub async fn from_config( config: &ServerConfig, app_context: &Arc, ) -> Result, String> { use crate::routers::RouterFactory; let mut manager = Self::new(app_context.worker_registry.clone()); manager.enable_igw = config.router_config.enable_igw; let manager = Arc::new(manager); if config.router_config.enable_igw { info!("Initializing RouterManager in multi-router mode (IGW)"); match RouterFactory::create_regular_router(app_context).await { Ok(http_regular) => { info!("Created HTTP Regular router"); manager.register_router( RouterId::new("http-regular".to_string()), Arc::from(http_regular), ); } Err(e) => { warn!("Failed to create HTTP Regular router: {e}"); } } match RouterFactory::create_pd_router( None, None, &config.router_config.policy, app_context, ) .await { Ok(http_pd) => { info!("Created HTTP PD router"); manager .register_router(RouterId::new("http-pd".to_string()), Arc::from(http_pd)); } Err(e) => { warn!("Failed to create HTTP PD router: {e}"); } } // TODO: Add gRPC routers once we have dynamic tokenizer loading info!( "RouterManager initialized with {} routers for multi-router mode", manager.router_count() ); } else { info!("Initializing RouterManager in single-router mode"); let single_router = Arc::from(RouterFactory::create_router(app_context).await?); let router_id = Self::determine_router_id( &config.router_config.mode, &config.router_config.connection_mode, ); info!("Created single router with ID: {}", router_id.as_str()); manager.register_router(router_id.clone(), single_router); manager.set_default_router(router_id); } if manager.router_count() == 0 { return Err("No routers could be initialized".to_string()); } Ok(manager) } pub fn determine_router_id( routing_mode: &RoutingMode, connection_mode: &ConnectionMode, ) -> RouterId { match (connection_mode, routing_mode) { (ConnectionMode::Http, RoutingMode::Regular { .. }) => { RouterId::new("http-regular".to_string()) } (ConnectionMode::Http, RoutingMode::PrefillDecode { .. }) => { RouterId::new("http-pd".to_string()) } (ConnectionMode::Http, RoutingMode::OpenAI { .. }) => { RouterId::new("http-openai".to_string()) } (ConnectionMode::Grpc, RoutingMode::Regular { .. }) => { RouterId::new("grpc-regular".to_string()) } (ConnectionMode::Grpc, RoutingMode::PrefillDecode { .. }) => { RouterId::new("grpc-pd".to_string()) } (ConnectionMode::Grpc, RoutingMode::OpenAI { .. }) => { RouterId::new("grpc-regular".to_string()) } } } pub fn register_router(&self, id: RouterId, router: Arc) { self.routers.insert(id.clone(), router); let mut default_router = self.default_router.write().unwrap(); if default_router.is_none() { *default_router = Some(id.clone()); info!("Set default router to {}", id.as_str()); } } pub fn set_default_router(&self, id: RouterId) { let mut default_router = self.default_router.write().unwrap(); *default_router = Some(id); } pub fn router_count(&self) -> usize { self.routers.len() } pub fn get_router_for_model(&self, model_id: &str) -> Option> { let workers = self.worker_registry.get_by_model(model_id); if !workers.is_empty() { let has_pd_workers = workers.iter().any(|w| { matches!( w.worker_type(), WorkerType::Prefill { .. } | WorkerType::Decode ) }); let router_id = if has_pd_workers { RouterId::new("http-pd".to_string()) } else { RouterId::new("http-regular".to_string()) }; if let Some(router) = self.routers.get(&router_id) { return Some(router.clone()); } } let default_router = self.default_router.read().unwrap(); if let Some(ref default_id) = *default_router { self.routers.get(default_id).map(|r| r.clone()) } else { None } } pub fn select_router_for_request( &self, headers: Option<&HeaderMap>, model_id: Option<&str>, ) -> Option> { // In single-router mode (enable_igw=false), always use the default router if !self.enable_igw { let default_router = self.default_router.read().unwrap(); if let Some(ref default_id) = *default_router { debug!( "Single-router mode: using default router {} for model {:?}", default_id.as_str(), model_id ); return self.routers.get(default_id).map(|r| r.clone()); } } // Multi-router mode logic follows let _priority_threshold = headers.and_then(|h| { h.get("x-worker-priority") .and_then(|v| v.to_str().ok()) .and_then(|s| s.parse::().ok()) }); let _max_cost = headers.and_then(|h| { h.get("x-max-cost") .and_then(|v| v.to_str().ok()) .and_then(|s| s.parse::().ok()) }); let prefer_pd = headers .and_then(|h| { h.get("x-prefer-pd") .and_then(|v| v.to_str().ok()) .map(|s| s == "true" || s == "1") }) .unwrap_or(false); let candidate_routers = if let Some(model) = model_id { if let Some(router) = self.get_router_for_model(model) { vec![router] } else { Vec::new() } } else { self.routers .iter() .map(|entry| entry.value().clone()) .collect::>() }; if candidate_routers.is_empty() { return None; } let mut best_router = None; let mut best_score = 0.0; for router in candidate_routers { let mut score = 1.0; let is_pd = router.is_pd_mode(); if prefer_pd && is_pd { score += 2.0; } else if !prefer_pd && !is_pd { score += 1.0; } // TODO: Once routers expose worker stats, we can evaluate: // - Average worker priority vs priority_threshold // - Average worker cost vs max_cost // - Current load and health status if score > best_score { best_score = score; best_router = Some(router); } } best_router } } #[async_trait] impl RouterTrait for RouterManager { fn as_any(&self) -> &dyn std::any::Any { self } async fn health_generate(&self, _req: Request) -> Response { // TODO: Should check if any router has healthy workers ( StatusCode::SERVICE_UNAVAILABLE, "No routers with healthy workers available", ) .into_response() } async fn get_server_info(&self, _req: Request) -> Response { // TODO: Aggregate info from all routers with healthy workers ( StatusCode::OK, serde_json::json!({ "router_manager": true, "routers_count": self.routers.len(), "workers_count": self.worker_registry.get_all().len() }) .to_string(), ) .into_response() } async fn get_models(&self, req: Request) -> Response { // In single-router mode, delegate to the router (especially for OpenAI mode) if !self.enable_igw { let router = self .default_router .read() .expect("Default router lock is poisoned") .as_ref() .and_then(|id| self.routers.get(id).map(|r| r.value().clone())); if let Some(router) = router { return router.get_models(req).await; } } // In multi-router mode, aggregate models from worker registry let models = self.worker_registry.get_models(); if models.is_empty() { (StatusCode::SERVICE_UNAVAILABLE, "No models available").into_response() } else { ( StatusCode::OK, serde_json::json!({ "models": models }) .to_string(), ) .into_response() } } async fn get_model_info(&self, _req: Request) -> Response { // TODO: Extract model from request and route to appropriate router ( StatusCode::NOT_IMPLEMENTED, "Model info endpoint not yet implemented in RouterManager", ) .into_response() } async fn route_generate( &self, headers: Option<&HeaderMap>, body: &GenerateRequest, _model_id: Option<&str>, ) -> Response { let router = self.select_router_for_request(headers, None); if let Some(router) = router { router.route_generate(headers, body, None).await } else { ( StatusCode::NOT_FOUND, "No router available for this request", ) .into_response() } } async fn route_chat( &self, headers: Option<&HeaderMap>, body: &ChatCompletionRequest, _model_id: Option<&str>, ) -> Response { let router = self.select_router_for_request(headers, Some(&body.model)); if let Some(router) = router { router.route_chat(headers, body, Some(&body.model)).await } else { ( StatusCode::NOT_FOUND, format!("Model '{}' not found or no router available", body.model), ) .into_response() } } async fn route_completion( &self, headers: Option<&HeaderMap>, body: &CompletionRequest, _model_id: Option<&str>, ) -> Response { let router = self.select_router_for_request(headers, Some(&body.model)); if let Some(router) = router { router .route_completion(headers, body, Some(&body.model)) .await } else { ( StatusCode::NOT_FOUND, format!("Model '{}' not found or no router available", body.model), ) .into_response() } } async fn route_responses( &self, headers: Option<&HeaderMap>, body: &ResponsesRequest, model_id: Option<&str>, ) -> Response { let selected_model = body.model.as_deref().or(model_id); let router = self.select_router_for_request(headers, selected_model); if let Some(router) = router { router.route_responses(headers, body, selected_model).await } else { ( StatusCode::NOT_FOUND, "No router available to handle responses request", ) .into_response() } } async fn delete_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response { ( StatusCode::NOT_IMPLEMENTED, "responses api not yet implemented in inference gateway mode", ) .into_response() } async fn list_response_input_items( &self, _headers: Option<&HeaderMap>, _response_id: &str, ) -> Response { ( StatusCode::NOT_IMPLEMENTED, "responses api not yet implemented in inference gateway mode", ) .into_response() } async fn get_response( &self, headers: Option<&HeaderMap>, response_id: &str, params: &ResponsesGetParams, ) -> Response { let router = self.select_router_for_request(headers, None); if let Some(router) = router { router.get_response(headers, response_id, params).await } else { ( StatusCode::NOT_FOUND, format!("No router available to get response '{}'", response_id), ) .into_response() } } async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response { let router = self.select_router_for_request(headers, None); if let Some(router) = router { router.cancel_response(headers, response_id).await } else { ( StatusCode::NOT_FOUND, format!("No router available to cancel response '{}'", response_id), ) .into_response() } } async fn route_embeddings( &self, headers: Option<&HeaderMap>, body: &EmbeddingRequest, _model_id: Option<&str>, ) -> Response { let router = self.select_router_for_request(headers, Some(&body.model)); if let Some(router) = router { router .route_embeddings(headers, body, Some(&body.model)) .await } else { ( StatusCode::NOT_FOUND, format!("Model '{}' not found or no router available", body.model), ) .into_response() } } async fn route_rerank( &self, headers: Option<&HeaderMap>, body: &RerankRequest, model_id: Option<&str>, ) -> Response { let router = self.select_router_for_request(headers, None); if let Some(router) = router { router.route_rerank(headers, body, model_id).await } else { ( StatusCode::NOT_FOUND, "No router available for rerank request", ) .into_response() } } fn router_type(&self) -> &'static str { "manager" } // Conversations API delegates async fn create_conversation(&self, headers: Option<&HeaderMap>, body: &Value) -> Response { let router = self.select_router_for_request(headers, None); if let Some(router) = router { router.create_conversation(headers, body).await } else { ( StatusCode::NOT_FOUND, "No router available to create conversation", ) .into_response() } } async fn get_conversation( &self, headers: Option<&HeaderMap>, conversation_id: &str, ) -> Response { let router = self.select_router_for_request(headers, None); if let Some(router) = router { router.get_conversation(headers, conversation_id).await } else { ( StatusCode::NOT_FOUND, format!( "No router available to get conversation '{}'", conversation_id ), ) .into_response() } } async fn update_conversation( &self, headers: Option<&HeaderMap>, conversation_id: &str, body: &Value, ) -> Response { let router = self.select_router_for_request(headers, None); if let Some(router) = router { router .update_conversation(headers, conversation_id, body) .await } else { ( StatusCode::NOT_FOUND, format!( "No router available to update conversation '{}'", conversation_id ), ) .into_response() } } async fn delete_conversation( &self, headers: Option<&HeaderMap>, conversation_id: &str, ) -> Response { let router = self.select_router_for_request(headers, None); if let Some(router) = router { router.delete_conversation(headers, conversation_id).await } else { ( StatusCode::NOT_FOUND, format!( "No router available to delete conversation '{}'", conversation_id ), ) .into_response() } } async fn list_conversation_items( &self, headers: Option<&HeaderMap>, conversation_id: &str, limit: Option, order: Option, after: Option, ) -> Response { let router = self.select_router_for_request(headers, None); if let Some(router) = router { router .list_conversation_items(headers, conversation_id, limit, order, after) .await } else { ( StatusCode::NOT_FOUND, format!( "No router available to list conversation items for '{}'", conversation_id ), ) .into_response() } } async fn create_conversation_items( &self, headers: Option<&HeaderMap>, conversation_id: &str, body: &Value, ) -> Response { let router = self.select_router_for_request(headers, None); if let Some(router) = router { router .create_conversation_items(headers, conversation_id, body) .await } else { ( StatusCode::NOT_FOUND, format!( "No router available to create conversation items for '{}'", conversation_id ), ) .into_response() } } async fn get_conversation_item( &self, headers: Option<&HeaderMap>, conversation_id: &str, item_id: &str, include: Option>, ) -> Response { let router = self.select_router_for_request(headers, None); if let Some(router) = router { router .get_conversation_item(headers, conversation_id, item_id, include) .await } else { ( StatusCode::NOT_FOUND, format!( "No router available to get conversation item '{}' in '{}'", item_id, conversation_id ), ) .into_response() } } async fn delete_conversation_item( &self, headers: Option<&HeaderMap>, conversation_id: &str, item_id: &str, ) -> Response { let router = self.select_router_for_request(headers, None); if let Some(router) = router { router .delete_conversation_item(headers, conversation_id, item_id) .await } else { ( StatusCode::NOT_FOUND, format!( "No router available to delete conversation item '{}' in '{}'", item_id, conversation_id ), ) .into_response() } } } impl std::fmt::Debug for RouterManager { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("RouterManager") .field("routers_count", &self.routers.len()) .field("workers_count", &self.worker_registry.get_all().len()) .field("default_router", &*self.default_router.read().unwrap()) .finish() } }