Unverified Commit 89971c4c authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] refactor router and worker management 4/n (#10756)


Co-authored-by: default avatarChang Su <chang.s.su@oracle.com>
parent 113f8f65
...@@ -41,6 +41,7 @@ pub struct PDRouter { ...@@ -41,6 +41,7 @@ pub struct PDRouter {
pub prefill_client: Client, pub prefill_client: Client,
pub retry_config: RetryConfig, pub retry_config: RetryConfig,
pub api_key: Option<String>, pub api_key: Option<String>,
pub enable_igw: bool,
prefill_drain_tx: mpsc::Sender<reqwest::Response>, prefill_drain_tx: mpsc::Sender<reqwest::Response>,
} }
...@@ -317,6 +318,7 @@ impl PDRouter { ...@@ -317,6 +318,7 @@ impl PDRouter {
prefill_drain_tx, prefill_drain_tx,
retry_config: ctx.router_config.effective_retry_config(), retry_config: ctx.router_config.effective_retry_config(),
api_key: ctx.router_config.api_key.clone(), api_key: ctx.router_config.api_key.clone(),
enable_igw: ctx.router_config.enable_igw,
}) })
} }
...@@ -849,7 +851,14 @@ impl PDRouter { ...@@ -849,7 +851,14 @@ impl PDRouter {
request_text: Option<&str>, request_text: Option<&str>,
model_id: Option<&str>, model_id: Option<&str>,
) -> Result<(Arc<dyn Worker>, Arc<dyn Worker>), String> { ) -> Result<(Arc<dyn Worker>, Arc<dyn Worker>), String> {
let prefill_workers = if let Some(model) = model_id { let effective_model_id = if !self.enable_igw { None } else { model_id };
debug!(
"Selecting PD pair: enable_igw={}, model_id={:?}, effective_model_id={:?}",
self.enable_igw, model_id, effective_model_id
);
let prefill_workers = if let Some(model) = effective_model_id {
self.worker_registry self.worker_registry
.get_by_model_fast(model) .get_by_model_fast(model)
.into_iter() .into_iter()
...@@ -859,7 +868,7 @@ impl PDRouter { ...@@ -859,7 +868,7 @@ impl PDRouter {
self.worker_registry.get_prefill_workers() self.worker_registry.get_prefill_workers()
}; };
let decode_workers = if let Some(model) = model_id { let decode_workers = if let Some(model) = effective_model_id {
self.worker_registry self.worker_registry
.get_by_model_fast(model) .get_by_model_fast(model)
.into_iter() .into_iter()
...@@ -1797,6 +1806,7 @@ mod tests { ...@@ -1797,6 +1806,7 @@ mod tests {
prefill_drain_tx: mpsc::channel(100).0, prefill_drain_tx: mpsc::channel(100).0,
retry_config: RetryConfig::default(), retry_config: RetryConfig::default(),
api_key: Some("test_api_key".to_string()), api_key: Some("test_api_key".to_string()),
enable_igw: false,
} }
} }
......
...@@ -35,6 +35,7 @@ pub struct Router { ...@@ -35,6 +35,7 @@ pub struct Router {
policy_registry: Arc<PolicyRegistry>, policy_registry: Arc<PolicyRegistry>,
client: Client, client: Client,
dp_aware: bool, dp_aware: bool,
enable_igw: bool,
retry_config: RetryConfig, retry_config: RetryConfig,
_worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>, _worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
_load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>, _load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
...@@ -93,6 +94,7 @@ impl Router { ...@@ -93,6 +94,7 @@ impl Router {
policy_registry: ctx.policy_registry.clone(), policy_registry: ctx.policy_registry.clone(),
client: ctx.client.clone(), client: ctx.client.clone(),
dp_aware: ctx.router_config.dp_aware, dp_aware: ctx.router_config.dp_aware,
enable_igw: ctx.router_config.enable_igw,
retry_config: ctx.router_config.effective_retry_config(), retry_config: ctx.router_config.effective_retry_config(),
_worker_loads: worker_loads, _worker_loads: worker_loads,
_load_monitor_handle: load_monitor_handle, _load_monitor_handle: load_monitor_handle,
...@@ -162,9 +164,11 @@ impl Router { ...@@ -162,9 +164,11 @@ impl Router {
model_id: Option<&str>, model_id: Option<&str>,
text: Option<&str>, text: Option<&str>,
) -> Option<Arc<dyn Worker>> { ) -> Option<Arc<dyn Worker>> {
let effective_model_id = if !self.enable_igw { None } else { model_id };
// Get workers for the specified model O(1), filtered by connection mode // Get workers for the specified model O(1), filtered by connection mode
let workers = self.worker_registry.get_workers_filtered( let workers = self.worker_registry.get_workers_filtered(
model_id, effective_model_id,
Some(WorkerType::Regular), Some(WorkerType::Regular),
Some(ConnectionMode::Http), Some(ConnectionMode::Http),
false, // get all workers, we'll filter by is_available() next false, // get all workers, we'll filter by is_available() next
...@@ -1106,6 +1110,7 @@ mod tests { ...@@ -1106,6 +1110,7 @@ mod tests {
retry_config: RetryConfig::default(), retry_config: RetryConfig::default(),
_worker_loads: Arc::new(rx), _worker_loads: Arc::new(rx),
_load_monitor_handle: None, _load_monitor_handle: None,
enable_igw: false,
} }
} }
......
...@@ -4,12 +4,14 @@ ...@@ -4,12 +4,14 @@
//! - Single Router Mode (enable_igw=false): Router owns workers directly //! - Single Router Mode (enable_igw=false): Router owns workers directly
//! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything //! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything
use crate::core::{Worker, WorkerRegistry, WorkerType}; use crate::config::{ConnectionMode, RoutingMode};
use crate::core::{WorkerRegistry, WorkerType};
use crate::protocols::spec::{ use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
ResponsesRequest, ResponsesRequest,
}; };
use crate::routers::RouterTrait; use crate::routers::RouterTrait;
use crate::server::{AppContext, ServerConfig};
use async_trait::async_trait; use async_trait::async_trait;
use axum::{ use axum::{
body::Body, body::Body,
...@@ -19,9 +21,8 @@ use axum::{ ...@@ -19,9 +21,8 @@ use axum::{
}; };
use dashmap::DashMap; use dashmap::DashMap;
use std::sync::Arc; use std::sync::Arc;
use tracing::info; use tracing::{debug, info, warn};
/// Router identifier
#[derive(Debug, Clone, Hash, Eq, PartialEq)] #[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct RouterId(String); pub struct RouterId(String);
...@@ -35,30 +36,120 @@ impl RouterId { ...@@ -35,30 +36,120 @@ impl RouterId {
} }
} }
/// Router Manager - Central coordinator for routers and workers
pub struct RouterManager { pub struct RouterManager {
/// Worker registry (single source of truth in multi-router mode)
worker_registry: Arc<WorkerRegistry>, worker_registry: Arc<WorkerRegistry>,
/// All routers managed by this manager
/// RouterId examples: "http-regular", "http-pd", "grpc-regular", "grpc-pd"
routers: Arc<DashMap<RouterId, Arc<dyn RouterTrait>>>, routers: Arc<DashMap<RouterId, Arc<dyn RouterTrait>>>,
/// Default router for requests without specific routing
default_router: Arc<std::sync::RwLock<Option<RouterId>>>, default_router: Arc<std::sync::RwLock<Option<RouterId>>>,
enable_igw: bool,
} }
impl RouterManager { impl RouterManager {
/// Create a new router manager with shared registries
pub fn new(worker_registry: Arc<WorkerRegistry>) -> Self { pub fn new(worker_registry: Arc<WorkerRegistry>) -> Self {
Self { Self {
worker_registry, worker_registry,
routers: Arc::new(DashMap::new()), routers: Arc::new(DashMap::new()),
default_router: Arc::new(std::sync::RwLock::new(None)), default_router: Arc::new(std::sync::RwLock::new(None)),
enable_igw: false, // Will be set properly in from_config
}
}
pub async fn from_config(
config: &ServerConfig,
app_context: &Arc<AppContext>,
) -> Result<Arc<Self>, String> {
use crate::routers::RouterFactory;
let mut manager = Self::new(app_context.worker_registry.clone());
manager.enable_igw = config.router_config.enable_igw;
let manager = Arc::new(manager);
if config.router_config.enable_igw {
info!("Initializing RouterManager in multi-router mode (IGW)");
match RouterFactory::create_regular_router(app_context).await {
Ok(http_regular) => {
info!("Created HTTP Regular router");
manager.register_router(
RouterId::new("http-regular".to_string()),
Arc::from(http_regular),
);
}
Err(e) => {
warn!("Failed to create HTTP Regular router: {e}");
}
}
match RouterFactory::create_pd_router(
None,
None,
&config.router_config.policy,
app_context,
)
.await
{
Ok(http_pd) => {
info!("Created HTTP PD router");
manager
.register_router(RouterId::new("http-pd".to_string()), Arc::from(http_pd));
}
Err(e) => {
warn!("Failed to create HTTP PD router: {e}");
}
}
// TODO: Add gRPC routers once we have dynamic tokenizer loading
info!(
"RouterManager initialized with {} routers for multi-router mode",
manager.router_count()
);
} else {
info!("Initializing RouterManager in single-router mode");
let single_router = Arc::from(RouterFactory::create_router(app_context).await?);
let router_id = Self::determine_router_id(
&config.router_config.mode,
&config.router_config.connection_mode,
);
info!("Created single router with ID: {}", router_id.as_str());
manager.register_router(router_id.clone(), single_router);
manager.set_default_router(router_id);
}
if manager.router_count() == 0 {
return Err("No routers could be initialized".to_string());
}
Ok(manager)
}
pub fn determine_router_id(
routing_mode: &RoutingMode,
connection_mode: &ConnectionMode,
) -> RouterId {
match (connection_mode, routing_mode) {
(ConnectionMode::Http, RoutingMode::Regular { .. }) => {
RouterId::new("http-regular".to_string())
}
(ConnectionMode::Http, RoutingMode::PrefillDecode { .. }) => {
RouterId::new("http-pd".to_string())
}
(ConnectionMode::Http, RoutingMode::OpenAI { .. }) => {
RouterId::new("http-openai".to_string())
}
(ConnectionMode::Grpc, RoutingMode::Regular { .. }) => {
RouterId::new("grpc-regular".to_string())
}
(ConnectionMode::Grpc, RoutingMode::PrefillDecode { .. }) => {
RouterId::new("grpc-pd".to_string())
}
(ConnectionMode::Grpc, RoutingMode::OpenAI { .. }) => {
RouterId::new("grpc-regular".to_string())
}
} }
} }
/// Register a router with the manager
pub fn register_router(&self, id: RouterId, router: Arc<dyn RouterTrait>) { pub fn register_router(&self, id: RouterId, router: Arc<dyn RouterTrait>) {
self.routers.insert(id.clone(), router); self.routers.insert(id.clone(), router);
...@@ -69,18 +160,15 @@ impl RouterManager { ...@@ -69,18 +160,15 @@ impl RouterManager {
} }
} }
/// Set the default router
pub fn set_default_router(&self, id: RouterId) { pub fn set_default_router(&self, id: RouterId) {
let mut default_router = self.default_router.write().unwrap(); let mut default_router = self.default_router.write().unwrap();
*default_router = Some(id); *default_router = Some(id);
} }
/// Get the number of registered routers
pub fn router_count(&self) -> usize { pub fn router_count(&self) -> usize {
self.routers.len() self.routers.len()
} }
/// Get router for a specific model based on worker types
pub fn get_router_for_model(&self, model_id: &str) -> Option<Arc<dyn RouterTrait>> { pub fn get_router_for_model(&self, model_id: &str) -> Option<Arc<dyn RouterTrait>> {
let workers = self.worker_registry.get_by_model(model_id); let workers = self.worker_registry.get_by_model(model_id);
...@@ -111,21 +199,25 @@ impl RouterManager { ...@@ -111,21 +199,25 @@ impl RouterManager {
} }
} }
/// Get workers for routing decision
pub fn get_workers_for_request(&self, model_id: Option<&str>) -> Vec<Arc<dyn Worker>> {
if let Some(model) = model_id {
self.worker_registry.get_by_model(model)
} else {
self.worker_registry.get_all()
}
}
/// Get the appropriate router for a request based on headers and request content
pub fn select_router_for_request( pub fn select_router_for_request(
&self, &self,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
model_id: Option<&str>, model_id: Option<&str>,
) -> Option<Arc<dyn RouterTrait>> { ) -> Option<Arc<dyn RouterTrait>> {
// In single-router mode (enable_igw=false), always use the default router
if !self.enable_igw {
let default_router = self.default_router.read().unwrap();
if let Some(ref default_id) = *default_router {
debug!(
"Single-router mode: using default router {} for model {:?}",
default_id.as_str(),
model_id
);
return self.routers.get(default_id).map(|r| r.clone());
}
}
// Multi-router mode logic follows
let _priority_threshold = headers.and_then(|h| { let _priority_threshold = headers.and_then(|h| {
h.get("x-worker-priority") h.get("x-worker-priority")
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
...@@ -176,10 +268,6 @@ impl RouterManager { ...@@ -176,10 +268,6 @@ impl RouterManager {
score += 1.0; score += 1.0;
} }
// Get workers for this router and evaluate based on priority/cost
// Note: This would require routers to expose their workers or stats
// For now, we'll use a simple selection based on router type
// TODO: Once routers expose worker stats, we can evaluate: // TODO: Once routers expose worker stats, we can evaluate:
// - Average worker priority vs priority_threshold // - Average worker priority vs priority_threshold
// - Average worker cost vs max_cost // - Average worker cost vs max_cost
...@@ -201,16 +289,11 @@ impl RouterTrait for RouterManager { ...@@ -201,16 +289,11 @@ impl RouterTrait for RouterManager {
self self
} }
/// Health check - return 503 if no routers available
async fn health(&self, _req: Request<Body>) -> Response { async fn health(&self, _req: Request<Body>) -> Response {
// Health check should succeed if RouterManager exists, even without routers
// Individual router health can be checked via specific endpoints
(StatusCode::OK, "RouterManager is healthy").into_response() (StatusCode::OK, "RouterManager is healthy").into_response()
} }
/// Health generate - check if any router can handle generate requests
async fn health_generate(&self, _req: Request<Body>) -> Response { async fn health_generate(&self, _req: Request<Body>) -> Response {
// Return 503 since we have no routers with workers
// TODO: Should check if any router has healthy workers // TODO: Should check if any router has healthy workers
( (
StatusCode::SERVICE_UNAVAILABLE, StatusCode::SERVICE_UNAVAILABLE,
...@@ -219,10 +302,8 @@ impl RouterTrait for RouterManager { ...@@ -219,10 +302,8 @@ impl RouterTrait for RouterManager {
.into_response() .into_response()
} }
/// Get server information - aggregate from all routers
async fn get_server_info(&self, _req: Request<Body>) -> Response { async fn get_server_info(&self, _req: Request<Body>) -> Response {
// TODO: Aggregate info from all routers with healthy workers // TODO: Aggregate info from all routers with healthy workers
// For now, return basic info about the RouterManager
( (
StatusCode::OK, StatusCode::OK,
serde_json::json!({ serde_json::json!({
...@@ -235,9 +316,7 @@ impl RouterTrait for RouterManager { ...@@ -235,9 +316,7 @@ impl RouterTrait for RouterManager {
.into_response() .into_response()
} }
/// Get available models - query from worker registry
async fn get_models(&self, _req: Request<Body>) -> Response { async fn get_models(&self, _req: Request<Body>) -> Response {
// Get models from worker registry
let models = self.worker_registry.get_models(); let models = self.worker_registry.get_models();
if models.is_empty() { if models.is_empty() {
...@@ -254,10 +333,8 @@ impl RouterTrait for RouterManager { ...@@ -254,10 +333,8 @@ impl RouterTrait for RouterManager {
} }
} }
/// Get model information
async fn get_model_info(&self, _req: Request<Body>) -> Response { async fn get_model_info(&self, _req: Request<Body>) -> Response {
// TODO: Extract model from request and route to appropriate router // TODO: Extract model from request and route to appropriate router
// For now, return not implemented
( (
StatusCode::NOT_IMPLEMENTED, StatusCode::NOT_IMPLEMENTED,
"Model info endpoint not yet implemented in RouterManager", "Model info endpoint not yet implemented in RouterManager",
...@@ -265,22 +342,17 @@ impl RouterTrait for RouterManager { ...@@ -265,22 +342,17 @@ impl RouterTrait for RouterManager {
.into_response() .into_response()
} }
/// Route a generate request
async fn route_generate( async fn route_generate(
&self, &self,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
body: &GenerateRequest, body: &GenerateRequest,
_model_id: Option<&str>, _model_id: Option<&str>,
) -> Response { ) -> Response {
// Select router based on headers
// GenerateRequest doesn't have a model field
let router = self.select_router_for_request(headers, None); let router = self.select_router_for_request(headers, None);
if let Some(router) = router { if let Some(router) = router {
// In multi-model mode, pass None since GenerateRequest doesn't have model field
router.route_generate(headers, body, None).await router.route_generate(headers, body, None).await
} else { } else {
// Return 404 when no router is available for the request
( (
StatusCode::NOT_FOUND, StatusCode::NOT_FOUND,
"No router available for this request", "No router available for this request",
...@@ -289,7 +361,6 @@ impl RouterTrait for RouterManager { ...@@ -289,7 +361,6 @@ impl RouterTrait for RouterManager {
} }
} }
/// Route a chat completion request
async fn route_chat( async fn route_chat(
&self, &self,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
...@@ -299,10 +370,8 @@ impl RouterTrait for RouterManager { ...@@ -299,10 +370,8 @@ impl RouterTrait for RouterManager {
let router = self.select_router_for_request(headers, Some(&body.model)); let router = self.select_router_for_request(headers, Some(&body.model));
if let Some(router) = router { if let Some(router) = router {
// In multi-model mode, pass the model_id to the router
router.route_chat(headers, body, Some(&body.model)).await router.route_chat(headers, body, Some(&body.model)).await
} else { } else {
// Return 404 when the specified model is not found
( (
StatusCode::NOT_FOUND, StatusCode::NOT_FOUND,
format!("Model '{}' not found or no router available", body.model), format!("Model '{}' not found or no router available", body.model),
...@@ -311,7 +380,6 @@ impl RouterTrait for RouterManager { ...@@ -311,7 +380,6 @@ impl RouterTrait for RouterManager {
} }
} }
/// Route a completion request
async fn route_completion( async fn route_completion(
&self, &self,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
...@@ -321,12 +389,10 @@ impl RouterTrait for RouterManager { ...@@ -321,12 +389,10 @@ impl RouterTrait for RouterManager {
let router = self.select_router_for_request(headers, Some(&body.model)); let router = self.select_router_for_request(headers, Some(&body.model));
if let Some(router) = router { if let Some(router) = router {
// In multi-model mode, pass the model_id to the router
router router
.route_completion(headers, body, Some(&body.model)) .route_completion(headers, body, Some(&body.model))
.await .await
} else { } else {
// Return 404 when the specified model is not found
( (
StatusCode::NOT_FOUND, StatusCode::NOT_FOUND,
format!("Model '{}' not found or no router available", body.model), format!("Model '{}' not found or no router available", body.model),
...@@ -348,26 +414,6 @@ impl RouterTrait for RouterManager { ...@@ -348,26 +414,6 @@ impl RouterTrait for RouterManager {
.into_response() .into_response()
} }
async fn delete_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"responses api not yet implemented in inference gateway mode",
)
.into_response()
}
async fn list_response_input_items(
&self,
_headers: Option<&HeaderMap>,
_response_id: &str,
) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"responses api not yet implemented in inference gateway mode",
)
.into_response()
}
async fn get_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response { async fn get_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
let router = self.select_router_for_request(headers, None); let router = self.select_router_for_request(headers, None);
if let Some(router) = router { if let Some(router) = router {
...@@ -394,7 +440,26 @@ impl RouterTrait for RouterManager { ...@@ -394,7 +440,26 @@ impl RouterTrait for RouterManager {
} }
} }
/// Route embeddings request async fn delete_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"responses api not yet implemented in inference gateway mode",
)
.into_response()
}
async fn list_response_input_items(
&self,
_headers: Option<&HeaderMap>,
_response_id: &str,
) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"responses api not yet implemented in inference gateway mode",
)
.into_response()
}
async fn route_embeddings( async fn route_embeddings(
&self, &self,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
...@@ -408,7 +473,6 @@ impl RouterTrait for RouterManager { ...@@ -408,7 +473,6 @@ impl RouterTrait for RouterManager {
.route_embeddings(headers, body, Some(&body.model)) .route_embeddings(headers, body, Some(&body.model))
.await .await
} else { } else {
// Return 404 when the specified model is not found
( (
StatusCode::NOT_FOUND, StatusCode::NOT_FOUND,
format!("Model '{}' not found or no router available", body.model), format!("Model '{}' not found or no router available", body.model),
...@@ -417,14 +481,12 @@ impl RouterTrait for RouterManager { ...@@ -417,14 +481,12 @@ impl RouterTrait for RouterManager {
} }
} }
/// Route rerank request
async fn route_rerank( async fn route_rerank(
&self, &self,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
body: &RerankRequest, body: &RerankRequest,
model_id: Option<&str>, model_id: Option<&str>,
) -> Response { ) -> Response {
// Try to select a router based on headers
let router = self.select_router_for_request(headers, None); let router = self.select_router_for_request(headers, None);
if let Some(router) = router { if let Some(router) = router {
...@@ -438,10 +500,8 @@ impl RouterTrait for RouterManager { ...@@ -438,10 +500,8 @@ impl RouterTrait for RouterManager {
} }
} }
/// Flush cache on all routers and workers
async fn flush_cache(&self) -> Response { async fn flush_cache(&self) -> Response {
// TODO: Call flush_cache on all routers that have workers // TODO: Call flush_cache on all routers that have workers
// For now, return success if we have any routers
if self.routers.is_empty() { if self.routers.is_empty() {
(StatusCode::SERVICE_UNAVAILABLE, "No routers configured").into_response() (StatusCode::SERVICE_UNAVAILABLE, "No routers configured").into_response()
} else { } else {
...@@ -450,9 +510,7 @@ impl RouterTrait for RouterManager { ...@@ -450,9 +510,7 @@ impl RouterTrait for RouterManager {
} }
} }
/// Get worker loads from all routers
async fn get_worker_loads(&self) -> Response { async fn get_worker_loads(&self) -> Response {
// Return worker loads from the registry
let workers = self.worker_registry.get_all(); let workers = self.worker_registry.get_all();
let loads: Vec<serde_json::Value> = workers let loads: Vec<serde_json::Value> = workers
.iter() .iter()
...@@ -476,12 +534,10 @@ impl RouterTrait for RouterManager { ...@@ -476,12 +534,10 @@ impl RouterTrait for RouterManager {
.into_response() .into_response()
} }
/// Get router type name
fn router_type(&self) -> &'static str { fn router_type(&self) -> &'static str {
"manager" "manager"
} }
/// Server readiness check - check if any router is ready
fn readiness(&self) -> Response { fn readiness(&self) -> Response {
if self.routers.is_empty() { if self.routers.is_empty() {
(StatusCode::SERVICE_UNAVAILABLE, "No routers configured").into_response() (StatusCode::SERVICE_UNAVAILABLE, "No routers configured").into_response()
...@@ -492,9 +548,6 @@ impl RouterTrait for RouterManager { ...@@ -492,9 +548,6 @@ impl RouterTrait for RouterManager {
} }
} }
// Note: get_first_available_router removed - we now properly handle
// router selection based on model and worker availability
impl std::fmt::Debug for RouterManager { impl std::fmt::Debug for RouterManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RouterManager") f.debug_struct("RouterManager")
......
...@@ -14,10 +14,7 @@ use crate::{ ...@@ -14,10 +14,7 @@ use crate::{
worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse}, worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse},
}, },
reasoning_parser::ParserFactory, reasoning_parser::ParserFactory,
routers::{ routers::{router_manager::RouterManager, RouterTrait},
router_manager::{RouterId, RouterManager},
RouterFactory, RouterTrait,
},
service_discovery::{start_service_discovery, ServiceDiscoveryConfig}, service_discovery::{start_service_discovery, ServiceDiscoveryConfig},
tokenizer::{factory as tokenizer_factory, traits::Tokenizer}, tokenizer::{factory as tokenizer_factory, traits::Tokenizer},
tool_parser::ParserRegistry, tool_parser::ParserRegistry,
...@@ -64,10 +61,8 @@ impl AppContext { ...@@ -64,10 +61,8 @@ impl AppContext {
let rate_limit_tokens = rate_limit_tokens_per_second.unwrap_or(max_concurrent_requests); let rate_limit_tokens = rate_limit_tokens_per_second.unwrap_or(max_concurrent_requests);
let rate_limiter = Arc::new(TokenBucket::new(max_concurrent_requests, rate_limit_tokens)); let rate_limiter = Arc::new(TokenBucket::new(max_concurrent_requests, rate_limit_tokens));
// Initialize gRPC-specific components only when in gRPC mode
let (tokenizer, reasoning_parser_factory, tool_parser_registry) = let (tokenizer, reasoning_parser_factory, tool_parser_registry) =
if router_config.connection_mode == ConnectionMode::Grpc { if router_config.connection_mode == ConnectionMode::Grpc {
// Get tokenizer path (required for gRPC mode)
let tokenizer_path = router_config let tokenizer_path = router_config
.tokenizer_path .tokenizer_path
.clone() .clone()
...@@ -77,7 +72,6 @@ impl AppContext { ...@@ -77,7 +72,6 @@ impl AppContext {
.to_string() .to_string()
})?; })?;
// Initialize all gRPC components
let tokenizer = Some( let tokenizer = Some(
tokenizer_factory::create_tokenizer(&tokenizer_path) tokenizer_factory::create_tokenizer(&tokenizer_path)
.map_err(|e| format!("Failed to create tokenizer: {e}"))?, .map_err(|e| format!("Failed to create tokenizer: {e}"))?,
...@@ -87,7 +81,6 @@ impl AppContext { ...@@ -87,7 +81,6 @@ impl AppContext {
(tokenizer, reasoning_parser_factory, tool_parser_registry) (tokenizer, reasoning_parser_factory, tool_parser_registry)
} else { } else {
// HTTP mode doesn't need these components
(None, None, None) (None, None, None)
}; };
...@@ -96,7 +89,6 @@ impl AppContext { ...@@ -96,7 +89,6 @@ impl AppContext {
let router_manager = None; let router_manager = None;
// Initialize response storage based on configuration
let response_storage: SharedResponseStorage = match router_config.history_backend { let response_storage: SharedResponseStorage = match router_config.history_backend {
HistoryBackend::Memory => Arc::new(MemoryResponseStorage::new()), HistoryBackend::Memory => Arc::new(MemoryResponseStorage::new()),
HistoryBackend::None => Arc::new(NoOpResponseStorage::new()), HistoryBackend::None => Arc::new(NoOpResponseStorage::new()),
...@@ -125,12 +117,10 @@ pub struct AppState { ...@@ -125,12 +117,10 @@ pub struct AppState {
pub router_manager: Option<Arc<RouterManager>>, pub router_manager: Option<Arc<RouterManager>>,
} }
// Fallback handler for unmatched routes
async fn sink_handler() -> Response { async fn sink_handler() -> Response {
StatusCode::NOT_FOUND.into_response() StatusCode::NOT_FOUND.into_response()
} }
// Health check endpoints
async fn liveness(State(state): State<Arc<AppState>>) -> Response { async fn liveness(State(state): State<Arc<AppState>>) -> Response {
state.router.liveness() state.router.liveness()
} }
...@@ -257,7 +247,6 @@ async fn v1_responses_delete( ...@@ -257,7 +247,6 @@ async fn v1_responses_delete(
Path(response_id): Path<String>, Path(response_id): Path<String>,
headers: http::HeaderMap, headers: http::HeaderMap,
) -> Response { ) -> Response {
// Python server does not support this yet
state state
.router .router
.delete_response(Some(&headers), &response_id) .delete_response(Some(&headers), &response_id)
...@@ -269,15 +258,12 @@ async fn v1_responses_list_input_items( ...@@ -269,15 +258,12 @@ async fn v1_responses_list_input_items(
Path(response_id): Path<String>, Path(response_id): Path<String>,
headers: http::HeaderMap, headers: http::HeaderMap,
) -> Response { ) -> Response {
// Python server does not support this yet
state state
.router .router
.list_response_input_items(Some(&headers), &response_id) .list_response_input_items(Some(&headers), &response_id)
.await .await
} }
// ---------- Worker management endpoints (Legacy) ----------
#[derive(Deserialize)] #[derive(Deserialize)]
struct AddWorkerQuery { struct AddWorkerQuery {
url: String, url: String,
...@@ -288,7 +274,6 @@ async fn add_worker( ...@@ -288,7 +274,6 @@ async fn add_worker(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Query(AddWorkerQuery { url, api_key }): Query<AddWorkerQuery>, Query(AddWorkerQuery { url, api_key }): Query<AddWorkerQuery>,
) -> Response { ) -> Response {
// Use centralized WorkerManager with full context
let result = WorkerManager::add_worker(&url, &api_key, &state.context).await; let result = WorkerManager::add_worker(&url, &api_key, &state.context).await;
match result { match result {
...@@ -298,7 +283,6 @@ async fn add_worker( ...@@ -298,7 +283,6 @@ async fn add_worker(
} }
async fn list_workers(State(state): State<Arc<AppState>>) -> Response { async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
// Use centralized WorkerManager instead of router's get_worker_urls
let worker_list = WorkerManager::get_worker_urls(&state.context.worker_registry); let worker_list = WorkerManager::get_worker_urls(&state.context.worker_registry);
Json(json!({ "urls": worker_list })).into_response() Json(json!({ "urls": worker_list })).into_response()
} }
...@@ -307,7 +291,6 @@ async fn remove_worker( ...@@ -307,7 +291,6 @@ async fn remove_worker(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Query(AddWorkerQuery { url, .. }): Query<AddWorkerQuery>, Query(AddWorkerQuery { url, .. }): Query<AddWorkerQuery>,
) -> Response { ) -> Response {
// Use centralized WorkerManager with full context
let result = WorkerManager::remove_worker(&url, &state.context); let result = WorkerManager::remove_worker(&url, &state.context);
match result { match result {
...@@ -324,14 +307,10 @@ async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Respons ...@@ -324,14 +307,10 @@ async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Respons
state.router.get_worker_loads().await state.router.get_worker_loads().await
} }
// ---------- Worker management endpoints (RESTful) ----------
/// POST /workers - Add a new worker with full configuration
async fn create_worker( async fn create_worker(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Json(config): Json<WorkerConfigRequest>, Json(config): Json<WorkerConfigRequest>,
) -> Response { ) -> Response {
// In single router mode, use centralized WorkerManager with full context
let result = WorkerManager::add_worker_from_config(&config, &state.context).await; let result = WorkerManager::add_worker_from_config(&config, &state.context).await;
match result { match result {
...@@ -353,9 +332,7 @@ async fn create_worker( ...@@ -353,9 +332,7 @@ async fn create_worker(
} }
} }
/// GET /workers - List all workers with details
async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response { async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
// In single router mode, get detailed worker info from registry
let workers = state.context.worker_registry.get_all(); let workers = state.context.worker_registry.get_all();
let response = serde_json::json!({ let response = serde_json::json!({
"workers": workers.iter().map(|worker| { "workers": workers.iter().map(|worker| {
...@@ -374,7 +351,6 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response { ...@@ -374,7 +351,6 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
"cost": worker.cost(), "cost": worker.cost(),
}); });
// Add bootstrap_port for Prefill workers
if let WorkerType::Prefill { bootstrap_port } = worker.worker_type() { if let WorkerType::Prefill { bootstrap_port } = worker.worker_type() {
worker_info["bootstrap_port"] = serde_json::json!(bootstrap_port); worker_info["bootstrap_port"] = serde_json::json!(bootstrap_port);
} }
...@@ -391,7 +367,6 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response { ...@@ -391,7 +367,6 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
Json(response).into_response() Json(response).into_response()
} }
/// GET /workers/{url} - Get specific worker info
async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response { async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
let workers = WorkerManager::get_worker_urls(&state.context.worker_registry); let workers = WorkerManager::get_worker_urls(&state.context.worker_registry);
if workers.contains(&url) { if workers.contains(&url) {
...@@ -410,9 +385,7 @@ async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) ...@@ -410,9 +385,7 @@ async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>)
} }
} }
/// DELETE /workers/{url} - Remove a worker
async fn delete_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response { async fn delete_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
// In single router mode, use centralized WorkerManager with full context
let result = WorkerManager::remove_worker(&url, &state.context); let result = WorkerManager::remove_worker(&url, &state.context);
match result { match result {
...@@ -447,14 +420,12 @@ pub struct ServerConfig { ...@@ -447,14 +420,12 @@ pub struct ServerConfig {
pub request_id_headers: Option<Vec<String>>, pub request_id_headers: Option<Vec<String>>,
} }
/// Build the Axum application with all routes and middleware
pub fn build_app( pub fn build_app(
app_state: Arc<AppState>, app_state: Arc<AppState>,
max_payload_size: usize, max_payload_size: usize,
request_id_headers: Vec<String>, request_id_headers: Vec<String>,
cors_allowed_origins: Vec<String>, cors_allowed_origins: Vec<String>,
) -> Router { ) -> Router {
// Create routes
let protected_routes = Router::new() let protected_routes = Router::new()
.route("/generate", post(generate)) .route("/generate", post(generate))
.route("/v1/chat/completions", post(v1_chat_completions)) .route("/v1/chat/completions", post(v1_chat_completions))
...@@ -494,20 +465,17 @@ pub fn build_app( ...@@ -494,20 +465,17 @@ pub fn build_app(
.route("/flush_cache", post(flush_cache)) .route("/flush_cache", post(flush_cache))
.route("/get_loads", get(get_loads)); .route("/get_loads", get(get_loads));
// Worker management routes
let worker_routes = Router::new() let worker_routes = Router::new()
.route("/workers", post(create_worker)) .route("/workers", post(create_worker))
.route("/workers", get(list_workers_rest)) .route("/workers", get(list_workers_rest))
.route("/workers/{url}", get(get_worker)) .route("/workers/{url}", get(get_worker))
.route("/workers/{url}", delete(delete_worker)); .route("/workers/{url}", delete(delete_worker));
// Build app with all routes and middleware
Router::new() Router::new()
.merge(protected_routes) .merge(protected_routes)
.merge(public_routes) .merge(public_routes)
.merge(admin_routes) .merge(admin_routes)
.merge(worker_routes) .merge(worker_routes)
// Request body size limiting
.layer(tower_http::limit::RequestBodyLimitLayer::new( .layer(tower_http::limit::RequestBodyLimitLayer::new(
max_payload_size, max_payload_size,
)) ))
...@@ -519,7 +487,6 @@ pub fn build_app( ...@@ -519,7 +487,6 @@ pub fn build_app(
} }
pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Error>> { pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Error>> {
// Only initialize logging if not already done (for Python bindings support)
static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false); static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);
let _log_guard = if !LOGGING_INITIALIZED.swap(true, Ordering::SeqCst) { let _log_guard = if !LOGGING_INITIALIZED.swap(true, Ordering::SeqCst) {
...@@ -545,9 +512,8 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -545,9 +512,8 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
None None
}; };
// Initialize prometheus metrics exporter if let Some(prometheus_config) = &config.prometheus_config {
if let Some(prometheus_config) = config.prometheus_config { metrics::start_prometheus(prometheus_config.clone());
metrics::start_prometheus(prometheus_config);
} }
info!( info!(
...@@ -569,7 +535,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -569,7 +535,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
.build() .build()
.expect("Failed to create HTTP client"); .expect("Failed to create HTTP client");
// Create the application context with all dependencies
let app_context = AppContext::new( let app_context = AppContext::new(
config.router_config.clone(), config.router_config.clone(),
client.clone(), client.clone(),
...@@ -597,67 +562,9 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -597,67 +562,9 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
worker_stats.total_workers, worker_stats.healthy_workers worker_stats.total_workers, worker_stats.healthy_workers
); );
// Create the appropriate router based on enable_igw flag let router_manager = RouterManager::from_config(&config, &app_context).await?;
let (router, router_manager): (Arc<dyn RouterTrait>, Option<Arc<RouterManager>>) = let router: Arc<dyn RouterTrait> = router_manager.clone();
if config.router_config.enable_igw {
info!("Multi-router mode enabled (enable_igw=true)");
// Create RouterManager with shared registries from AppContext
let router_manager = Arc::new(RouterManager::new(app_context.worker_registry.clone()));
// 1. HTTP Regular Router
match RouterFactory::create_regular_router(&app_context).await {
Ok(http_regular) => {
info!("Created HTTP Regular router");
router_manager.register_router(
RouterId::new("http-regular".to_string()),
Arc::from(http_regular),
);
}
Err(e) => {
warn!("Failed to create HTTP Regular router: {e}");
}
}
// 2. HTTP PD Router
match RouterFactory::create_pd_router(
None,
None,
&config.router_config.policy,
&app_context,
)
.await
{
Ok(http_pd) => {
info!("Created HTTP PD router");
router_manager
.register_router(RouterId::new("http-pd".to_string()), Arc::from(http_pd));
}
Err(e) => {
warn!("Failed to create HTTP PD router: {e}");
}
}
// TODO: Add gRPC routers once we have dynamic tokenizer loading
info!(
"RouterManager initialized with {} routers",
router_manager.router_count()
);
(
router_manager.clone() as Arc<dyn RouterTrait>,
Some(router_manager),
)
} else {
info!("Single router mode (enable_igw=false)");
// Create single router with the context
(
Arc::from(RouterFactory::create_router(&app_context).await?),
None,
)
};
// Start health checker for all workers in the registry
let _health_checker = app_context let _health_checker = app_context
.worker_registry .worker_registry
.start_health_checker(config.router_config.health_check.check_interval_secs); .start_health_checker(config.router_config.health_check.check_interval_secs);
...@@ -666,14 +573,12 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -666,14 +573,12 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
config.router_config.health_check.check_interval_secs config.router_config.health_check.check_interval_secs
); );
// Set up concurrency limiter with queue if configured
let (limiter, processor) = middleware::ConcurrencyLimiter::new( let (limiter, processor) = middleware::ConcurrencyLimiter::new(
app_context.rate_limiter.clone(), app_context.rate_limiter.clone(),
config.router_config.queue_size, config.router_config.queue_size,
Duration::from_secs(config.router_config.queue_timeout_secs), Duration::from_secs(config.router_config.queue_timeout_secs),
); );
// Start queue processor if enabled
if let Some(processor) = processor { if let Some(processor) = processor {
spawn(processor.run()); spawn(processor.run());
info!( info!(
...@@ -682,21 +587,18 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -682,21 +587,18 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
); );
} }
// Create app state with router and context
let app_state = Arc::new(AppState { let app_state = Arc::new(AppState {
router, router,
context: app_context.clone(), context: app_context.clone(),
concurrency_queue_tx: limiter.queue_tx.clone(), concurrency_queue_tx: limiter.queue_tx.clone(),
router_manager, router_manager: Some(router_manager),
}); });
// Start the service discovery if enabled
if let Some(service_discovery_config) = config.service_discovery_config { if let Some(service_discovery_config) = config.service_discovery_config {
if service_discovery_config.enabled { if service_discovery_config.enabled {
let app_context_arc = Arc::clone(&app_state.context); let app_context_arc = Arc::clone(&app_state.context);
match start_service_discovery(service_discovery_config, app_context_arc).await { match start_service_discovery(service_discovery_config, app_context_arc).await {
Ok(handle) => { Ok(handle) => {
info!("Service discovery started"); info!("Service discovery started");
// Spawn a task to handle the service discovery thread
spawn(async move { spawn(async move {
if let Err(e) = handle.await { if let Err(e) = handle.await {
error!("Service discovery task failed: {:?}", e); error!("Service discovery task failed: {:?}", e);
...@@ -725,7 +627,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -725,7 +627,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
] ]
}); });
// Build the application
let app = build_app( let app = build_app(
app_state, app_state,
config.max_payload_size, config.max_payload_size,
...@@ -744,7 +645,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -744,7 +645,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
Ok(()) Ok(())
} }
// Graceful shutdown handler
async fn shutdown_signal() { async fn shutdown_signal() {
let ctrl_c = async { let ctrl_c = async {
signal::ctrl_c() signal::ctrl_c()
...@@ -773,19 +673,16 @@ async fn shutdown_signal() { ...@@ -773,19 +673,16 @@ async fn shutdown_signal() {
} }
} }
// CORS Layer Creation
fn create_cors_layer(allowed_origins: Vec<String>) -> tower_http::cors::CorsLayer { fn create_cors_layer(allowed_origins: Vec<String>) -> tower_http::cors::CorsLayer {
use tower_http::cors::Any; use tower_http::cors::Any;
let cors = if allowed_origins.is_empty() { let cors = if allowed_origins.is_empty() {
// Allow all origins if none specified
tower_http::cors::CorsLayer::new() tower_http::cors::CorsLayer::new()
.allow_origin(Any) .allow_origin(Any)
.allow_methods(Any) .allow_methods(Any)
.allow_headers(Any) .allow_headers(Any)
.expose_headers(Any) .expose_headers(Any)
} else { } else {
// Restrict to specific origins
let origins: Vec<http::HeaderValue> = allowed_origins let origins: Vec<http::HeaderValue> = allowed_origins
.into_iter() .into_iter()
.filter_map(|origin| origin.parse().ok()) .filter_map(|origin| origin.parse().ok())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment