Unverified Commit b93acd70 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router] minor code clean up in server startup (#10470)

parent 86a32bb5
use crate::config::RouterConfig; use crate::{
use crate::core::WorkerRegistry; config::{ConnectionMode, RouterConfig},
use crate::logging::{self, LoggingConfig}; core::{WorkerRegistry, WorkerType},
use crate::metrics::{self, PrometheusConfig}; logging::{self, LoggingConfig},
use crate::middleware::TokenBucket; metrics::{self, PrometheusConfig},
use crate::policies::PolicyRegistry; middleware::{self, QueuedRequest, TokenBucket},
use crate::protocols::spec::{ policies::PolicyRegistry,
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, protocols::{
ResponsesRequest, V1RerankReqInput, spec::{
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest,
RerankRequest, ResponsesRequest, V1RerankReqInput,
},
worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse},
},
reasoning_parser::ParserFactory,
routers::{
router_manager::{RouterId, RouterManager},
RouterFactory, RouterTrait,
},
service_discovery::{start_service_discovery, ServiceDiscoveryConfig},
tokenizer::{factory as tokenizer_factory, traits::Tokenizer},
tool_parser::ParserRegistry,
}; };
use crate::protocols::worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse};
use crate::reasoning_parser::ParserFactory;
use crate::routers::router_manager::{RouterId, RouterManager};
use crate::routers::{RouterFactory, RouterTrait};
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
use crate::tokenizer::{factory as tokenizer_factory, traits::Tokenizer};
use crate::tool_parser::ParserRegistry;
use axum::{ use axum::{
extract::{Path, Query, Request, State}, extract::{Path, Query, Request, State},
http::StatusCode, http::StatusCode,
response::{IntoResponse, Response}, response::{IntoResponse, Response},
routing::{delete, get, post}, routing::{delete, get, post},
Json, Router, serve, Json, Router,
}; };
use reqwest::Client; use reqwest::Client;
use std::collections::HashMap; use serde::Deserialize;
use std::sync::atomic::{AtomicBool, Ordering}; use serde_json::json;
use std::sync::Arc; use std::{
use std::time::Duration; sync::atomic::{AtomicBool, Ordering},
use tokio::net::TcpListener; sync::Arc,
use tokio::signal; time::Duration,
use tokio::spawn; };
use tokio::{net::TcpListener, signal, spawn};
use tracing::{error, info, warn, Level}; use tracing::{error, info, warn, Level};
#[derive(Clone)] #[derive(Clone)]
...@@ -40,9 +47,9 @@ pub struct AppContext { ...@@ -40,9 +47,9 @@ pub struct AppContext {
pub tokenizer: Option<Arc<dyn Tokenizer>>, pub tokenizer: Option<Arc<dyn Tokenizer>>,
pub reasoning_parser_factory: Option<ParserFactory>, pub reasoning_parser_factory: Option<ParserFactory>,
pub tool_parser_registry: Option<&'static ParserRegistry>, pub tool_parser_registry: Option<&'static ParserRegistry>,
pub worker_registry: Arc<WorkerRegistry>, // Shared worker registry pub worker_registry: Arc<WorkerRegistry>,
pub policy_registry: Arc<PolicyRegistry>, // Shared policy registry pub policy_registry: Arc<PolicyRegistry>,
pub router_manager: Option<Arc<RouterManager>>, // Only present when enable_igw=true pub router_manager: Option<Arc<RouterManager>>,
} }
impl AppContext { impl AppContext {
...@@ -57,7 +64,7 @@ impl AppContext { ...@@ -57,7 +64,7 @@ impl AppContext {
// Initialize gRPC-specific components only when in gRPC mode // Initialize gRPC-specific components only when in gRPC mode
let (tokenizer, reasoning_parser_factory, tool_parser_registry) = let (tokenizer, reasoning_parser_factory, tool_parser_registry) =
if router_config.connection_mode == crate::config::ConnectionMode::Grpc { if router_config.connection_mode == ConnectionMode::Grpc {
// Get tokenizer path (required for gRPC mode) // Get tokenizer path (required for gRPC mode)
let tokenizer_path = router_config let tokenizer_path = router_config
.tokenizer_path .tokenizer_path
...@@ -71,7 +78,7 @@ impl AppContext { ...@@ -71,7 +78,7 @@ impl AppContext {
// Initialize all gRPC components // Initialize all gRPC components
let tokenizer = Some( let tokenizer = Some(
tokenizer_factory::create_tokenizer(&tokenizer_path) tokenizer_factory::create_tokenizer(&tokenizer_path)
.map_err(|e| format!("Failed to create tokenizer: {}", e))?, .map_err(|e| format!("Failed to create tokenizer: {e}"))?,
); );
let reasoning_parser_factory = Some(ParserFactory::new()); let reasoning_parser_factory = Some(ParserFactory::new());
let tool_parser_registry = Some(ParserRegistry::new()); let tool_parser_registry = Some(ParserRegistry::new());
...@@ -82,14 +89,10 @@ impl AppContext { ...@@ -82,14 +89,10 @@ impl AppContext {
(None, None, None) (None, None, None)
}; };
// Initialize shared registries
let worker_registry = Arc::new(WorkerRegistry::new()); let worker_registry = Arc::new(WorkerRegistry::new());
let policy_registry = Arc::new(PolicyRegistry::new( let policy_registry = Arc::new(PolicyRegistry::new(router_config.policy.clone()));
router_config.policy.clone(), // Use default policy from config
));
// Initialize RouterManager only when enable_igw is true let router_manager = None;
let router_manager = None; // Will be initialized in startup() based on config
Ok(Self { Ok(Self {
client, client,
...@@ -109,7 +112,7 @@ impl AppContext { ...@@ -109,7 +112,7 @@ impl AppContext {
pub struct AppState { pub struct AppState {
pub router: Arc<dyn RouterTrait>, pub router: Arc<dyn RouterTrait>,
pub context: Arc<AppContext>, pub context: Arc<AppContext>,
pub concurrency_queue_tx: Option<tokio::sync::mpsc::Sender<crate::middleware::QueuedRequest>>, pub concurrency_queue_tx: Option<tokio::sync::mpsc::Sender<QueuedRequest>>,
} }
// Fallback handler for unmatched routes // Fallback handler for unmatched routes
...@@ -265,23 +268,18 @@ async fn v1_responses_list_input_items( ...@@ -265,23 +268,18 @@ async fn v1_responses_list_input_items(
.await .await
} }
// Worker management endpoints // ---------- Worker management endpoints (Legacy) ----------
#[derive(Deserialize)]
struct UrlQuery {
url: String,
}
async fn add_worker( async fn add_worker(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Query(params): Query<HashMap<String, String>>, Query(UrlQuery { url }): Query<UrlQuery>,
) -> Response { ) -> Response {
let worker_url = match params.get("url") { match state.router.add_worker(&url).await {
Some(url) => url.to_string(),
None => {
return (
StatusCode::BAD_REQUEST,
"Worker URL required. Provide 'url' query parameter",
)
.into_response();
}
};
match state.router.add_worker(&worker_url).await {
Ok(message) => (StatusCode::OK, message).into_response(), Ok(message) => (StatusCode::OK, message).into_response(),
Err(error) => (StatusCode::BAD_REQUEST, error).into_response(), Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
} }
...@@ -294,17 +292,12 @@ async fn list_workers(State(state): State<Arc<AppState>>) -> Response { ...@@ -294,17 +292,12 @@ async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
async fn remove_worker( async fn remove_worker(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Query(params): Query<HashMap<String, String>>, Query(UrlQuery { url }): Query<UrlQuery>,
) -> Response { ) -> Response {
let worker_url = match params.get("url") { state.router.remove_worker(&url);
Some(url) => url.to_string(),
None => return StatusCode::BAD_REQUEST.into_response(),
};
state.router.remove_worker(&worker_url);
( (
StatusCode::OK, StatusCode::OK,
format!("Successfully removed worker: {}", worker_url), format!("Successfully removed worker: {url}"),
) )
.into_response() .into_response()
} }
...@@ -317,7 +310,7 @@ async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Respons ...@@ -317,7 +310,7 @@ async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Respons
state.router.get_worker_loads().await state.router.get_worker_loads().await
} }
// New RESTful worker management endpoints (when enable_igw=true) // ---------- Worker management endpoints (RESTful) ----------
/// POST /workers - Add a new worker with full configuration /// POST /workers - Add a new worker with full configuration
async fn create_worker( async fn create_worker(
...@@ -374,7 +367,7 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response { ...@@ -374,7 +367,7 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
}); });
// Add bootstrap_port for Prefill workers // Add bootstrap_port for Prefill workers
if let crate::core::WorkerType::Prefill { bootstrap_port } = worker.worker_type() { if let WorkerType::Prefill { bootstrap_port } = worker.worker_type() {
worker_info["bootstrap_port"] = serde_json::json!(bootstrap_port); worker_info["bootstrap_port"] = serde_json::json!(bootstrap_port);
} }
...@@ -384,7 +377,7 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response { ...@@ -384,7 +377,7 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
"stats": { "stats": {
"prefill_count": state.context.worker_registry.get_prefill_workers().len(), "prefill_count": state.context.worker_registry.get_prefill_workers().len(),
"decode_count": state.context.worker_registry.get_decode_workers().len(), "decode_count": state.context.worker_registry.get_decode_workers().len(),
"regular_count": state.context.worker_registry.get_by_type(&crate::core::WorkerType::Regular).len(), "regular_count": state.context.worker_registry.get_by_type(&WorkerType::Regular).len(),
} }
}); });
Json(response).into_response() Json(response).into_response()
...@@ -392,33 +385,29 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response { ...@@ -392,33 +385,29 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
} }
/// GET /workers/{url} - Get specific worker info /// GET /workers/{url} - Get specific worker info
async fn get_worker( async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
State(state): State<Arc<AppState>>,
axum::extract::Path(url): axum::extract::Path<String>,
) -> Response {
if let Some(router_manager) = &state.context.router_manager { if let Some(router_manager) = &state.context.router_manager {
if let Some(worker) = router_manager.get_worker(&url) { if let Some(worker) = router_manager.get_worker(&url) {
Json(worker).into_response() Json(worker).into_response()
} else { } else {
let error = WorkerErrorResponse { let error = WorkerErrorResponse {
error: format!("Worker {} not found", url), error: format!("Worker {url} not found"),
code: "WORKER_NOT_FOUND".to_string(), code: "WORKER_NOT_FOUND".to_string(),
}; };
(StatusCode::NOT_FOUND, Json(error)).into_response() (StatusCode::NOT_FOUND, Json(error)).into_response()
} }
} else { } else {
// In single router mode, check if worker exists
let workers = state.router.get_worker_urls(); let workers = state.router.get_worker_urls();
if workers.contains(&url) { if workers.contains(&url) {
let worker_info = serde_json::json!({ Json(json!({
"url": url, "url": url,
"model_id": "unknown", "model_id": "unknown",
"is_healthy": true "is_healthy": true
}); }))
Json(worker_info).into_response() .into_response()
} else { } else {
let error = WorkerErrorResponse { let error = WorkerErrorResponse {
error: format!("Worker {} not found", url), error: format!("Worker {url} not found"),
code: "WORKER_NOT_FOUND".to_string(), code: "WORKER_NOT_FOUND".to_string(),
}; };
(StatusCode::NOT_FOUND, Json(error)).into_response() (StatusCode::NOT_FOUND, Json(error)).into_response()
...@@ -427,10 +416,7 @@ async fn get_worker( ...@@ -427,10 +416,7 @@ async fn get_worker(
} }
/// DELETE /workers/{url} - Remove a worker /// DELETE /workers/{url} - Remove a worker
async fn delete_worker( async fn delete_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
State(state): State<Arc<AppState>>,
axum::extract::Path(url): axum::extract::Path<String>,
) -> Response {
if let Some(router_manager) = &state.context.router_manager { if let Some(router_manager) = &state.context.router_manager {
match router_manager.remove_worker_from_registry(&url) { match router_manager.remove_worker_from_registry(&url) {
Ok(response) => (StatusCode::OK, Json(response)).into_response(), Ok(response) => (StatusCode::OK, Json(response)).into_response(),
...@@ -441,7 +427,7 @@ async fn delete_worker( ...@@ -441,7 +427,7 @@ async fn delete_worker(
state.router.remove_worker(&url); state.router.remove_worker(&url);
let response = WorkerApiResponse { let response = WorkerApiResponse {
success: true, success: true,
message: format!("Worker {} removed successfully", url), message: format!("Worker {url} removed successfully"),
worker: None, worker: None,
}; };
(StatusCode::OK, Json(response)).into_response() (StatusCode::OK, Json(response)).into_response()
...@@ -489,7 +475,7 @@ pub fn build_app( ...@@ -489,7 +475,7 @@ pub fn build_app(
) )
.route_layer(axum::middleware::from_fn_with_state( .route_layer(axum::middleware::from_fn_with_state(
app_state.clone(), app_state.clone(),
crate::middleware::concurrency_limit_middleware, middleware::concurrency_limit_middleware,
)); ));
let public_routes = Router::new() let public_routes = Router::new()
...@@ -513,7 +499,7 @@ pub fn build_app( ...@@ -513,7 +499,7 @@ pub fn build_app(
.route("/workers", post(create_worker)) .route("/workers", post(create_worker))
.route("/workers", get(list_workers_rest)) .route("/workers", get(list_workers_rest))
.route("/workers/{url}", get(get_worker)) .route("/workers/{url}", get(get_worker))
.route("/workers/{url}", axum::routing::delete(delete_worker)); .route("/workers/{url}", delete(delete_worker));
// Build app with all routes and middleware // Build app with all routes and middleware
Router::new() Router::new()
...@@ -525,17 +511,10 @@ pub fn build_app( ...@@ -525,17 +511,10 @@ pub fn build_app(
.layer(tower_http::limit::RequestBodyLimitLayer::new( .layer(tower_http::limit::RequestBodyLimitLayer::new(
max_payload_size, max_payload_size,
)) ))
// Logging layer - must be added BEFORE request ID layer in the code .layer(middleware::create_logging_layer())
// so it executes AFTER request ID layer at runtime (layers execute bottom-up) .layer(middleware::RequestIdLayer::new(request_id_headers))
// This way the TraceLayer can see the request ID that was added to extensions
.layer(crate::middleware::create_logging_layer())
// Request ID layer - adds request ID to extensions first
.layer(crate::middleware::RequestIdLayer::new(request_id_headers))
// CORS (should be outermost)
.layer(create_cors_layer(cors_allowed_origins)) .layer(create_cors_layer(cors_allowed_origins))
// Fallback
.fallback(sink_handler) .fallback(sink_handler)
// State - apply last to get Router<Arc<AppState>>
.with_state(app_state) .with_state(app_state)
} }
...@@ -551,7 +530,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -551,7 +530,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
.and_then(|s| match s.to_uppercase().parse::<Level>() { .and_then(|s| match s.to_uppercase().parse::<Level>() {
Ok(l) => Some(l), Ok(l) => Some(l),
Err(_) => { Err(_) => {
warn!("Invalid log level string: '{}'. Defaulting to INFO.", s); warn!("Invalid log level string: '{s}'. Defaulting to INFO.");
None None
} }
}) })
...@@ -582,11 +561,11 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -582,11 +561,11 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
let client = Client::builder() let client = Client::builder()
.pool_idle_timeout(Some(Duration::from_secs(50))) .pool_idle_timeout(Some(Duration::from_secs(50)))
.pool_max_idle_per_host(500) // Increase to 500 connections per host .pool_max_idle_per_host(500)
.timeout(Duration::from_secs(config.request_timeout_secs)) .timeout(Duration::from_secs(config.request_timeout_secs))
.connect_timeout(Duration::from_secs(10)) // Separate connection timeout .connect_timeout(Duration::from_secs(10))
.tcp_nodelay(true) .tcp_nodelay(true)
.tcp_keepalive(Some(Duration::from_secs(30))) // Keep connections alive .tcp_keepalive(Some(Duration::from_secs(30)))
.build() .build()
.expect("Failed to create HTTP client"); .expect("Failed to create HTTP client");
...@@ -612,9 +591,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -612,9 +591,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
app_context.policy_registry.clone(), app_context.policy_registry.clone(),
); );
// Create HTTP routers at startup (with empty worker lists)
// Workers will be added to these routers dynamically via RouterManager's worker registry
// 1. HTTP Regular Router // 1. HTTP Regular Router
match RouterFactory::create_regular_router( match RouterFactory::create_regular_router(
&[], // Empty worker list - workers added later &[], // Empty worker list - workers added later
...@@ -631,16 +607,16 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -631,16 +607,16 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
); );
} }
Err(e) => { Err(e) => {
warn!("Failed to create HTTP Regular router: {}", e); warn!("Failed to create HTTP Regular router: {e}");
} }
} }
// 2. HTTP PD Router // 2. HTTP PD Router
match RouterFactory::create_pd_router( match RouterFactory::create_pd_router(
&[], // Empty prefill URLs &[],
&[], // Empty decode URLs &[],
None, // Use default prefill policy None,
None, // Use default decode policy None,
&config.router_config.policy, &config.router_config.policy,
&app_context, &app_context,
) )
...@@ -655,16 +631,11 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -655,16 +631,11 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
); );
} }
Err(e) => { Err(e) => {
warn!("Failed to create HTTP PD router: {}", e); warn!("Failed to create HTTP PD router: {e}");
} }
} }
// TODO: Add gRPC routers once we have dynamic tokenizer loading // TODO: Add gRPC routers once we have dynamic tokenizer loading
// Currently gRPC routers require tokenizer to be initialized first,
// but each model needs its own tokenizer. Once we implement dynamic
// tokenizer loading per model, we can enable gRPC routers here:
// - RouterType::GrpcRegular (RouterId: "grpc-regular")
// - RouterType::GrpcPd (RouterId: "grpc-pd")
info!( info!(
"RouterManager initialized with {} routers", "RouterManager initialized with {} routers",
...@@ -687,7 +658,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -687,7 +658,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
); );
// Set up concurrency limiter with queue if configured // Set up concurrency limiter with queue if configured
let (limiter, processor) = crate::middleware::ConcurrencyLimiter::new( let (limiter, processor) = middleware::ConcurrencyLimiter::new(
app_context.rate_limiter.clone(), app_context.rate_limiter.clone(),
config.router_config.queue_size, config.router_config.queue_size,
Duration::from_secs(config.router_config.queue_timeout_secs), Duration::from_secs(config.router_config.queue_timeout_secs),
...@@ -724,7 +695,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -724,7 +695,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
}); });
} }
Err(e) => { Err(e) => {
error!("Failed to start service discovery: {}", e); error!("Failed to start service discovery: {e}");
warn!("Continuing without service discovery"); warn!("Continuing without service discovery");
} }
} }
...@@ -736,7 +707,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -736,7 +707,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
app_state.router.get_worker_urls() app_state.router.get_worker_urls()
); );
// Configure request ID headers
let request_id_headers = config.request_id_headers.clone().unwrap_or_else(|| { let request_id_headers = config.request_id_headers.clone().unwrap_or_else(|| {
vec![ vec![
"x-request-id".to_string(), "x-request-id".to_string(),
...@@ -754,15 +724,10 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -754,15 +724,10 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
config.router_config.cors_allowed_origins.clone(), config.router_config.cors_allowed_origins.clone(),
); );
// Create TCP listener - use the configured host
let addr = format!("{}:{}", config.host, config.port); let addr = format!("{}:{}", config.host, config.port);
let listener = TcpListener::bind(&addr).await?; let listener = TcpListener::bind(&addr).await?;
// Start server with graceful shutdown
info!("Starting server on {}", addr); info!("Starting server on {}", addr);
serve(listener, app)
// Serve the application with graceful shutdown
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal()) .with_graceful_shutdown(shutdown_signal())
.await .await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?; .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment