use crate::{ config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode}, core::{ worker_to_info, Job, JobQueue, JobQueueConfig, LoadMonitor, WorkerManager, WorkerRegistry, WorkerType, }, data_connector::{ MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage, NoOpConversationStorage, NoOpResponseStorage, OracleConversationItemStorage, OracleConversationStorage, OracleResponseStorage, SharedConversationStorage, SharedResponseStorage, }, logging::{self, LoggingConfig}, metrics::{self, PrometheusConfig}, middleware::{self, AuthConfig, QueuedRequest, TokenBucket}, policies::PolicyRegistry, protocols::{ spec::{ ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ResponsesGetParams, ResponsesRequest, V1RerankReqInput, }, validated::ValidatedJson, worker_spec::{WorkerConfigRequest, WorkerErrorResponse, WorkerInfo}, }, reasoning_parser::ParserFactory as ReasoningParserFactory, routers::{router_manager::RouterManager, RouterTrait}, service_discovery::{start_service_discovery, ServiceDiscoveryConfig}, tokenizer::{factory as tokenizer_factory, traits::Tokenizer}, tool_parser::ParserFactory as ToolParserFactory, }; use axum::{ extract::{Path, Query, Request, State}, http::StatusCode, response::{IntoResponse, Response}, routing::{delete, get, post}, serve, Json, Router, }; use reqwest::Client; use serde::Deserialize; use serde_json::{json, Value}; use std::sync::OnceLock; use std::{ sync::atomic::{AtomicBool, Ordering}, sync::Arc, time::Duration, }; use tokio::{net::TcpListener, signal, spawn}; use tracing::{error, info, warn, Level}; // #[derive(Clone)] pub struct AppContext { pub client: Client, pub router_config: RouterConfig, pub rate_limiter: Option>, pub tokenizer: Option>, pub reasoning_parser_factory: Option, pub tool_parser_factory: Option, pub worker_registry: Arc, pub policy_registry: Arc, pub router_manager: Option>, pub response_storage: SharedResponseStorage, pub conversation_storage: SharedConversationStorage, pub conversation_item_storage: crate::data_connector::SharedConversationItemStorage, pub load_monitor: Option>, pub configured_reasoning_parser: Option, pub configured_tool_parser: Option, pub worker_job_queue: Arc>>, } impl AppContext { pub fn new( router_config: RouterConfig, client: Client, max_concurrent_requests: i32, rate_limit_tokens_per_second: Option, ) -> Result { let rate_limiter = match max_concurrent_requests { n if n <= 0 => None, n => { let rate_limit_tokens = rate_limit_tokens_per_second.filter(|&t| t > 0).unwrap_or(n); Some(Arc::new(TokenBucket::new( n as usize, rate_limit_tokens as usize, ))) } }; let (tokenizer, reasoning_parser_factory, tool_parser_factory) = if router_config .connection_mode == ConnectionMode::Grpc { let tokenizer_path = router_config .tokenizer_path .clone() .or_else(|| router_config.model_path.clone()) .ok_or_else(|| { "gRPC mode requires either --tokenizer-path or --model-path to be specified" .to_string() })?; let tokenizer = Some( tokenizer_factory::create_tokenizer_with_chat_template_blocking( &tokenizer_path, router_config.chat_template.as_deref(), ) .map_err(|e| { format!( "Failed to create tokenizer from '{}': {}. \ Ensure the path is valid and points to a tokenizer file (tokenizer.json) \ or a HuggingFace model ID. For directories, ensure they contain tokenizer files.", tokenizer_path, e ) })?, ); let reasoning_parser_factory = Some(crate::reasoning_parser::ParserFactory::new()); let tool_parser_factory = Some(crate::tool_parser::ParserFactory::new()); (tokenizer, reasoning_parser_factory, tool_parser_factory) } else { (None, None, None) }; let worker_registry = Arc::new(WorkerRegistry::new()); let policy_registry = Arc::new(PolicyRegistry::new(router_config.policy.clone())); let router_manager = None; let (response_storage, conversation_storage): ( SharedResponseStorage, SharedConversationStorage, ) = match router_config.history_backend { HistoryBackend::Memory => ( Arc::new(MemoryResponseStorage::new()), Arc::new(MemoryConversationStorage::new()), ), HistoryBackend::None => ( Arc::new(NoOpResponseStorage::new()), Arc::new(NoOpConversationStorage::new()), ), HistoryBackend::Oracle => { let oracle_cfg = router_config.oracle.clone().ok_or_else(|| { "oracle configuration is required when history_backend=oracle".to_string() })?; let response_storage = OracleResponseStorage::new(oracle_cfg.clone()).map_err(|err| { format!("failed to initialize Oracle response storage: {err}") })?; let conversation_storage = OracleConversationStorage::new(oracle_cfg.clone()) .map_err(|err| { format!("failed to initialize Oracle conversation storage: {err}") })?; (Arc::new(response_storage), Arc::new(conversation_storage)) } }; // Conversation items storage (memory-backed for now) let conversation_item_storage: crate::data_connector::SharedConversationItemStorage = match router_config.history_backend { HistoryBackend::Oracle => { let oracle_cfg = router_config.oracle.clone().ok_or_else(|| { "oracle configuration is required when history_backend=oracle".to_string() })?; Arc::new(OracleConversationItemStorage::new(oracle_cfg).map_err(|e| { format!("failed to initialize Oracle conversation item storage: {e}") })?) } _ => Arc::new(MemoryConversationItemStorage::new()), }; let load_monitor = Some(Arc::new(LoadMonitor::new( worker_registry.clone(), policy_registry.clone(), client.clone(), router_config.worker_startup_check_interval_secs, ))); let configured_reasoning_parser = router_config.reasoning_parser.clone(); let configured_tool_parser = router_config.tool_call_parser.clone(); // Create empty OnceLock for worker job queue (will be initialized in startup()) let worker_job_queue = Arc::new(OnceLock::new()); Ok(Self { client, router_config, rate_limiter, tokenizer, reasoning_parser_factory, tool_parser_factory, worker_registry, policy_registry, router_manager, response_storage, conversation_storage, conversation_item_storage, load_monitor, configured_reasoning_parser, configured_tool_parser, worker_job_queue, }) } } #[derive(Clone)] pub struct AppState { pub router: Arc, pub context: Arc, pub concurrency_queue_tx: Option>, pub router_manager: Option>, } async fn sink_handler() -> Response { StatusCode::NOT_FOUND.into_response() } async fn liveness() -> Response { (StatusCode::OK, "OK").into_response() } async fn readiness(State(state): State>) -> Response { let workers = state.context.worker_registry.get_all(); let healthy_workers: Vec<_> = workers.iter().filter(|w| w.is_healthy()).collect(); let is_ready = if state.context.router_config.enable_igw { !healthy_workers.is_empty() } else { match &state.context.router_config.mode { RoutingMode::PrefillDecode { .. } => { let has_prefill = healthy_workers .iter() .any(|w| matches!(w.worker_type(), WorkerType::Prefill { .. })); let has_decode = healthy_workers .iter() .any(|w| matches!(w.worker_type(), WorkerType::Decode)); has_prefill && has_decode } RoutingMode::Regular { .. } => !healthy_workers.is_empty(), RoutingMode::OpenAI { .. } => !healthy_workers.is_empty(), } }; if is_ready { ( StatusCode::OK, Json(json!({ "status": "ready", "healthy_workers": healthy_workers.len(), "total_workers": workers.len() })), ) .into_response() } else { ( StatusCode::SERVICE_UNAVAILABLE, Json(json!({ "status": "not ready", "reason": "insufficient healthy workers" })), ) .into_response() } } async fn health(_state: State>) -> Response { liveness().await } async fn health_generate(State(state): State>, req: Request) -> Response { state.router.health_generate(req).await } async fn get_server_info(State(state): State>, req: Request) -> Response { state.router.get_server_info(req).await } async fn v1_models(State(state): State>, req: Request) -> Response { state.router.get_models(req).await } async fn get_model_info(State(state): State>, req: Request) -> Response { state.router.get_model_info(req).await } async fn generate( State(state): State>, headers: http::HeaderMap, Json(body): Json, ) -> Response { state .router .route_generate(Some(&headers), &body, None) .await } async fn v1_chat_completions( State(state): State>, headers: http::HeaderMap, ValidatedJson(body): ValidatedJson, ) -> Response { state.router.route_chat(Some(&headers), &body, None).await } async fn v1_completions( State(state): State>, headers: http::HeaderMap, Json(body): Json, ) -> Response { state .router .route_completion(Some(&headers), &body, None) .await } async fn rerank( State(state): State>, headers: http::HeaderMap, Json(body): Json, ) -> Response { state.router.route_rerank(Some(&headers), &body, None).await } async fn v1_rerank( State(state): State>, headers: http::HeaderMap, Json(body): Json, ) -> Response { state .router .route_rerank(Some(&headers), &body.into(), None) .await } async fn v1_responses( State(state): State>, headers: http::HeaderMap, Json(body): Json, ) -> Response { state .router .route_responses(Some(&headers), &body, None) .await } async fn v1_embeddings( State(state): State>, headers: http::HeaderMap, Json(body): Json, ) -> Response { state .router .route_embeddings(Some(&headers), &body, None) .await } async fn v1_responses_get( State(state): State>, Path(response_id): Path, headers: http::HeaderMap, Query(params): Query, ) -> Response { state .router .get_response(Some(&headers), &response_id, ¶ms) .await } async fn v1_responses_cancel( State(state): State>, Path(response_id): Path, headers: http::HeaderMap, ) -> Response { state .router .cancel_response(Some(&headers), &response_id) .await } async fn v1_responses_delete( State(state): State>, Path(response_id): Path, headers: http::HeaderMap, ) -> Response { state .router .delete_response(Some(&headers), &response_id) .await } async fn v1_responses_list_input_items( State(state): State>, Path(response_id): Path, headers: http::HeaderMap, ) -> Response { state .router .list_response_input_items(Some(&headers), &response_id) .await } async fn v1_conversations_create( State(state): State>, headers: http::HeaderMap, Json(body): Json, ) -> Response { state .router .create_conversation(Some(&headers), &body) .await } async fn v1_conversations_get( State(state): State>, Path(conversation_id): Path, headers: http::HeaderMap, ) -> Response { state .router .get_conversation(Some(&headers), &conversation_id) .await } async fn v1_conversations_update( State(state): State>, Path(conversation_id): Path, headers: http::HeaderMap, Json(body): Json, ) -> Response { state .router .update_conversation(Some(&headers), &conversation_id, &body) .await } async fn v1_conversations_delete( State(state): State>, Path(conversation_id): Path, headers: http::HeaderMap, ) -> Response { state .router .delete_conversation(Some(&headers), &conversation_id) .await } #[derive(Deserialize, Default)] struct ListItemsQuery { limit: Option, order: Option, after: Option, } async fn v1_conversations_list_items( State(state): State>, Path(conversation_id): Path, Query(ListItemsQuery { limit, order, after, }): Query, headers: http::HeaderMap, ) -> Response { state .router .list_conversation_items(Some(&headers), &conversation_id, limit, order, after) .await } #[derive(Deserialize, Default)] struct GetItemQuery { /// Additional fields to include in response (not yet implemented) include: Option>, } async fn v1_conversations_create_items( State(state): State>, Path(conversation_id): Path, headers: http::HeaderMap, Json(body): Json, ) -> Response { state .router .create_conversation_items(Some(&headers), &conversation_id, &body) .await } async fn v1_conversations_get_item( State(state): State>, Path((conversation_id, item_id)): Path<(String, String)>, Query(query): Query, headers: http::HeaderMap, ) -> Response { state .router .get_conversation_item(Some(&headers), &conversation_id, &item_id, query.include) .await } async fn v1_conversations_delete_item( State(state): State>, Path((conversation_id, item_id)): Path<(String, String)>, headers: http::HeaderMap, ) -> Response { state .router .delete_conversation_item(Some(&headers), &conversation_id, &item_id) .await } #[derive(Deserialize)] struct AddWorkerQuery { url: String, api_key: Option, } async fn add_worker( State(state): State>, Query(AddWorkerQuery { url, api_key }): Query, ) -> Response { // Warn if router has API key but worker is being added without one if state.context.router_config.api_key.is_some() && api_key.is_none() { warn!( "Adding worker {} without API key while router has API key configured. \ Worker will be accessible without authentication. \ If the worker requires the same API key as the router, please specify it explicitly.", url ); } let result = WorkerManager::add_worker(&url, &api_key, &state.context).await; match result { Ok(message) => (StatusCode::OK, message).into_response(), Err(error) => (StatusCode::BAD_REQUEST, error).into_response(), } } async fn list_workers(State(state): State>) -> Response { let worker_list = WorkerManager::get_worker_urls(&state.context.worker_registry); Json(json!({ "urls": worker_list })).into_response() } async fn remove_worker( State(state): State>, Query(AddWorkerQuery { url, .. }): Query, ) -> Response { let result = WorkerManager::remove_worker(&url, &state.context); match result { Ok(message) => (StatusCode::OK, message).into_response(), Err(error) => (StatusCode::BAD_REQUEST, error).into_response(), } } async fn flush_cache(State(state): State>, _req: Request) -> Response { match WorkerManager::flush_cache_all(&state.context.worker_registry, &state.context.client) .await { Ok(result) => { if result.failed.is_empty() { ( StatusCode::OK, Json(json!({ "status": "success", "message": result.message, "workers_flushed": result.successful.len(), "total_http_workers": result.http_workers, "total_workers": result.total_workers })), ) .into_response() } else { ( StatusCode::PARTIAL_CONTENT, Json(json!({ "status": "partial_success", "message": result.message, "successful": result.successful, "failed": result.failed.into_iter().map(|(url, err)| json!({ "worker": url, "error": err })).collect::>(), "total_http_workers": result.http_workers, "total_workers": result.total_workers })), ) .into_response() } } Err(e) => { error!("Failed to flush cache: {}", e); ( StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "status": "error", "message": format!("Failed to flush cache: {}", e) })), ) .into_response() } } } async fn get_loads(State(state): State>, _req: Request) -> Response { let result = WorkerManager::get_all_worker_loads(&state.context.worker_registry, &state.context.client) .await; let loads: Vec = result .loads .iter() .map(|info| { json!({ "worker": &info.worker, "load": info.load }) }) .collect(); ( StatusCode::OK, Json(json!({ "workers": loads })), ) .into_response() } async fn create_worker( State(state): State>, Json(config): Json, ) -> Response { // Warn if router has API key but worker is being added without one if state.context.router_config.api_key.is_some() && config.api_key.is_none() { warn!( "Adding worker {} without API key while router has API key configured. \ Worker will be accessible without authentication. \ If the worker requires the same API key as the router, please specify it explicitly.", config.url ); } // Submit job for async processing let worker_url = config.url.clone(); let job = Job::AddWorker { config: Box::new(config), }; let job_queue = state .context .worker_job_queue .get() .expect("JobQueue not initialized"); match job_queue.submit(job).await { Ok(_) => { let response = json!({ "status": "accepted", "worker_id": worker_url, "message": "Worker addition queued for background processing" }); (StatusCode::ACCEPTED, Json(response)).into_response() } Err(error) => { let error_response = WorkerErrorResponse { error, code: "INTERNAL_SERVER_ERROR".to_string(), }; (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response)).into_response() } } } async fn list_workers_rest(State(state): State>) -> Response { let workers = state.context.worker_registry.get_all(); let worker_infos: Vec = workers.iter().map(worker_to_info).collect(); let response = json!({ "workers": worker_infos, "total": workers.len(), "stats": { "prefill_count": state.context.worker_registry.get_prefill_workers().len(), "decode_count": state.context.worker_registry.get_decode_workers().len(), "regular_count": state.context.worker_registry.get_by_type(&WorkerType::Regular).len(), } }); Json(response).into_response() } async fn get_worker(State(state): State>, Path(url): Path) -> Response { let job_queue = state .context .worker_job_queue .get() .expect("JobQueue not initialized"); if let Some(worker) = state.context.worker_registry.get_by_url(&url) { // Worker exists in registry, get its full info and attach job status if any let mut worker_info = worker_to_info(&worker); if let Some(status) = job_queue.get_status(&url) { worker_info.job_status = Some(status); } return Json(worker_info).into_response(); } // Worker not in registry, check job queue for its status if let Some(status) = job_queue.get_status(&url) { // Create a partial WorkerInfo to report the job status let worker_info = WorkerInfo { id: url.clone(), url: url.clone(), model_id: "unknown".to_string(), priority: 0, cost: 1.0, worker_type: "unknown".to_string(), is_healthy: false, load: 0, connection_mode: "unknown".to_string(), tokenizer_path: None, reasoning_parser: None, tool_parser: None, chat_template: None, bootstrap_port: None, metadata: std::collections::HashMap::new(), job_status: Some(status), }; return Json(worker_info).into_response(); } // Worker not found in registry or job queue let error = WorkerErrorResponse { error: format!("Worker {url} not found"), code: "WORKER_NOT_FOUND".to_string(), }; (StatusCode::NOT_FOUND, Json(error)).into_response() } async fn delete_worker(State(state): State>, Path(url): Path) -> Response { let worker_id = url.clone(); let job = Job::RemoveWorker { url }; let job_queue = state .context .worker_job_queue .get() .expect("JobQueue not initialized"); match job_queue.submit(job).await { Ok(_) => { let response = json!({ "status": "accepted", "worker_id": worker_id, "message": "Worker removal queued for background processing" }); (StatusCode::ACCEPTED, Json(response)).into_response() } Err(error) => { let error_response = WorkerErrorResponse { error, code: "INTERNAL_SERVER_ERROR".to_string(), }; (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response)).into_response() } } } pub struct ServerConfig { pub host: String, pub port: u16, pub router_config: RouterConfig, pub max_payload_size: usize, pub log_dir: Option, pub log_level: Option, pub service_discovery_config: Option, pub prometheus_config: Option, pub request_timeout_secs: u64, pub request_id_headers: Option>, } pub fn build_app( app_state: Arc, auth_config: AuthConfig, max_payload_size: usize, request_id_headers: Vec, cors_allowed_origins: Vec, ) -> Router { let protected_routes = Router::new() .route("/generate", post(generate)) .route("/v1/chat/completions", post(v1_chat_completions)) .route("/v1/completions", post(v1_completions)) .route("/rerank", post(rerank)) .route("/v1/rerank", post(v1_rerank)) .route("/v1/responses", post(v1_responses)) .route("/v1/embeddings", post(v1_embeddings)) .route("/v1/responses/{response_id}", get(v1_responses_get)) .route( "/v1/responses/{response_id}/cancel", post(v1_responses_cancel), ) .route("/v1/responses/{response_id}", delete(v1_responses_delete)) .route( "/v1/responses/{response_id}/input", get(v1_responses_list_input_items), ) .route("/v1/conversations", post(v1_conversations_create)) .route( "/v1/conversations/{conversation_id}", get(v1_conversations_get) .post(v1_conversations_update) .delete(v1_conversations_delete), ) .route( "/v1/conversations/{conversation_id}/items", get(v1_conversations_list_items).post(v1_conversations_create_items), ) .route( "/v1/conversations/{conversation_id}/items/{item_id}", get(v1_conversations_get_item).delete(v1_conversations_delete_item), ) .route_layer(axum::middleware::from_fn_with_state( app_state.clone(), middleware::concurrency_limit_middleware, )) .route_layer(axum::middleware::from_fn_with_state( auth_config.clone(), middleware::auth_middleware, )); let public_routes = Router::new() .route("/liveness", get(liveness)) .route("/readiness", get(readiness)) .route("/health", get(health)) .route("/health_generate", get(health_generate)) .route("/v1/models", get(v1_models)) .route("/get_model_info", get(get_model_info)) .route("/get_server_info", get(get_server_info)); let admin_routes = Router::new() .route("/add_worker", post(add_worker)) .route("/remove_worker", post(remove_worker)) .route("/list_workers", get(list_workers)) .route("/flush_cache", post(flush_cache)) .route("/get_loads", get(get_loads)) .route_layer(axum::middleware::from_fn_with_state( auth_config.clone(), middleware::auth_middleware, )); let worker_routes = Router::new() .route("/workers", post(create_worker)) .route("/workers", get(list_workers_rest)) .route("/workers/{url}", get(get_worker)) .route("/workers/{url}", delete(delete_worker)) .route_layer(axum::middleware::from_fn_with_state( auth_config.clone(), middleware::auth_middleware, )); Router::new() .merge(protected_routes) .merge(public_routes) .merge(admin_routes) .merge(worker_routes) .layer(axum::extract::DefaultBodyLimit::max(max_payload_size)) .layer(tower_http::limit::RequestBodyLimitLayer::new( max_payload_size, )) .layer(middleware::create_logging_layer()) .layer(middleware::RequestIdLayer::new(request_id_headers)) .layer(create_cors_layer(cors_allowed_origins)) .fallback(sink_handler) .with_state(app_state) } pub async fn startup(config: ServerConfig) -> Result<(), Box> { static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false); let _log_guard = if !LOGGING_INITIALIZED.swap(true, Ordering::SeqCst) { Some(logging::init_logging(LoggingConfig { level: config .log_level .as_deref() .and_then(|s| match s.to_uppercase().parse::() { Ok(l) => Some(l), Err(_) => { warn!("Invalid log level string: '{s}'. Defaulting to INFO."); None } }) .unwrap_or(Level::INFO), json_format: false, log_dir: config.log_dir.clone(), colorize: true, log_file_name: "sgl-router".to_string(), log_targets: None, })) } else { None }; if let Some(prometheus_config) = &config.prometheus_config { metrics::start_prometheus(prometheus_config.clone()); } info!( "Starting router on {}:{} | mode: {:?} | policy: {:?} | max_payload: {}MB", config.host, config.port, config.router_config.mode, config.router_config.policy, config.max_payload_size / (1024 * 1024) ); let client = Client::builder() .pool_idle_timeout(Some(Duration::from_secs(50))) .pool_max_idle_per_host(500) .timeout(Duration::from_secs(config.request_timeout_secs)) .connect_timeout(Duration::from_secs(10)) .tcp_nodelay(true) .tcp_keepalive(Some(Duration::from_secs(30))) .build() .expect("Failed to create HTTP client"); let app_context = AppContext::new( config.router_config.clone(), client.clone(), config.router_config.max_concurrent_requests, config.router_config.rate_limit_tokens_per_second, )?; let app_context = Arc::new(app_context); let weak_context = Arc::downgrade(&app_context); let worker_job_queue = JobQueue::new(JobQueueConfig::default(), weak_context); app_context .worker_job_queue .set(worker_job_queue) .expect("JobQueue should only be initialized once"); info!( "Initializing workers for routing mode: {:?}", config.router_config.mode ); WorkerManager::initialize_workers( &config.router_config, &app_context.worker_registry, Some(&app_context.policy_registry), ) .await .map_err(|e| format!("Failed to initialize workers: {}", e))?; let worker_stats = app_context.worker_registry.stats(); info!( "Workers initialized: {} total, {} healthy", worker_stats.total_workers, worker_stats.healthy_workers ); let router_manager = RouterManager::from_config(&config, &app_context).await?; let router: Arc = router_manager.clone(); let _health_checker = app_context .worker_registry .start_health_checker(config.router_config.health_check.check_interval_secs); info!( "Started health checker for workers with {}s interval", config.router_config.health_check.check_interval_secs ); if let Some(ref load_monitor) = app_context.load_monitor { load_monitor.start().await; info!("Started LoadMonitor for PowerOfTwo policies"); } let (limiter, processor) = middleware::ConcurrencyLimiter::new( app_context.rate_limiter.clone(), config.router_config.queue_size, Duration::from_secs(config.router_config.queue_timeout_secs), ); if app_context.rate_limiter.is_none() { info!("Rate limiting is disabled (max_concurrent_requests = -1)"); } match processor { Some(proc) => { spawn(proc.run()); info!( "Started request queue (size: {}, timeout: {}s)", config.router_config.queue_size, config.router_config.queue_timeout_secs ); } None => { info!( "Rate limiting enabled (max_concurrent_requests = {}, queue disabled)", config.router_config.max_concurrent_requests ); } } let app_state = Arc::new(AppState { router, context: app_context.clone(), concurrency_queue_tx: limiter.queue_tx.clone(), router_manager: Some(router_manager), }); if let Some(service_discovery_config) = config.service_discovery_config { if service_discovery_config.enabled { let app_context_arc = Arc::clone(&app_state.context); match start_service_discovery(service_discovery_config, app_context_arc).await { Ok(handle) => { info!("Service discovery started"); spawn(async move { if let Err(e) = handle.await { error!("Service discovery task failed: {:?}", e); } }); } Err(e) => { error!("Failed to start service discovery: {e}"); warn!("Continuing without service discovery"); } } } } info!( "Router ready | workers: {:?}", WorkerManager::get_worker_urls(&app_state.context.worker_registry) ); let request_id_headers = config.request_id_headers.clone().unwrap_or_else(|| { vec![ "x-request-id".to_string(), "x-correlation-id".to_string(), "x-trace-id".to_string(), "request-id".to_string(), ] }); let auth_config = AuthConfig { api_key: config.router_config.api_key.clone(), }; let app = build_app( app_state, auth_config, config.max_payload_size, request_id_headers, config.router_config.cors_allowed_origins.clone(), ); // TcpListener::bind accepts &str and handles IPv4/IPv6 via ToSocketAddrs let bind_addr = format!("{}:{}", config.host, config.port); info!("Starting server on {}", bind_addr); let listener = TcpListener::bind(&bind_addr) .await .map_err(|e| format!("Failed to bind to {}: {}", bind_addr, e))?; serve(listener, app) .with_graceful_shutdown(shutdown_signal()) .await .map_err(|e| Box::new(e) as Box)?; Ok(()) } async fn shutdown_signal() { let ctrl_c = async { signal::ctrl_c() .await .expect("failed to install Ctrl+C handler"); }; #[cfg(unix)] let terminate = async { signal::unix::signal(signal::unix::SignalKind::terminate()) .expect("failed to install signal handler") .recv() .await; }; #[cfg(not(unix))] let terminate = std::future::pending::<()>(); tokio::select! { _ = ctrl_c => { info!("Received Ctrl+C, starting graceful shutdown"); }, _ = terminate => { info!("Received terminate signal, starting graceful shutdown"); }, } } fn create_cors_layer(allowed_origins: Vec) -> tower_http::cors::CorsLayer { use tower_http::cors::Any; let cors = if allowed_origins.is_empty() { tower_http::cors::CorsLayer::new() .allow_origin(Any) .allow_methods(Any) .allow_headers(Any) .expose_headers(Any) } else { let origins: Vec = allowed_origins .into_iter() .filter_map(|origin| origin.parse().ok()) .collect(); tower_http::cors::CorsLayer::new() .allow_origin(origins) .allow_methods([http::Method::GET, http::Method::POST, http::Method::OPTIONS]) .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION]) .expose_headers([http::header::HeaderName::from_static("x-request-id")]) }; cors.max_age(Duration::from_secs(3600)) }