// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use std::collections::HashMap; use std::env::var; use std::path::PathBuf; use std::sync::Arc; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering; use std::time::Duration; use axum::body::Body; use axum::http::Response; use super::Metrics; use super::RouteDoc; use super::metrics; use super::metrics::register_worker_timing_metrics; use crate::discovery::{ModelManager, register_worker_load_metrics}; use crate::endpoint_type::EndpointType; use crate::request_template::RequestTemplate; use anyhow::Result; use axum_server::tls_rustls::RustlsConfig; use derive_builder::Builder; use dynamo_runtime::config::environment_names::llm as env_llm; use dynamo_runtime::discovery::{Discovery, KVStoreDiscovery}; use dynamo_runtime::logging::make_request_span; use dynamo_runtime::storage::kv; use std::net::SocketAddr; use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use tower_http::trace::TraceLayer; /// HTTP service shared state pub struct State { metrics: Arc, manager: Arc, store: kv::Manager, discovery_client: Arc, flags: StateFlags, cancel_token: CancellationToken, } #[derive(Default, Debug)] struct StateFlags { chat_endpoints_enabled: AtomicBool, cmpl_endpoints_enabled: AtomicBool, embeddings_endpoints_enabled: AtomicBool, images_endpoints_enabled: AtomicBool, responses_endpoints_enabled: AtomicBool, } impl StateFlags { pub fn get(&self, endpoint_type: &EndpointType) -> bool { match endpoint_type { EndpointType::Chat => self.chat_endpoints_enabled.load(Ordering::Relaxed), EndpointType::Completion => self.cmpl_endpoints_enabled.load(Ordering::Relaxed), EndpointType::Embedding => self.embeddings_endpoints_enabled.load(Ordering::Relaxed), EndpointType::Images => self.images_endpoints_enabled.load(Ordering::Relaxed), EndpointType::Responses => self.responses_endpoints_enabled.load(Ordering::Relaxed), } } pub fn set(&self, endpoint_type: &EndpointType, enabled: bool) { match endpoint_type { EndpointType::Chat => self .chat_endpoints_enabled .store(enabled, Ordering::Relaxed), EndpointType::Completion => self .cmpl_endpoints_enabled .store(enabled, Ordering::Relaxed), EndpointType::Embedding => self .embeddings_endpoints_enabled .store(enabled, Ordering::Relaxed), EndpointType::Images => self .images_endpoints_enabled .store(enabled, Ordering::Relaxed), EndpointType::Responses => self .responses_endpoints_enabled .store(enabled, Ordering::Relaxed), } } } impl State { pub fn new( manager: Arc, store: kv::Manager, cancel_token: CancellationToken, ) -> Self { // Initialize discovery backed by KV store // Create a cancellation token for the discovery's watch streams let discovery_client = { let discovery_cancel_token = cancel_token.child_token(); Arc::new(KVStoreDiscovery::new(store.clone(), discovery_cancel_token)) as Arc }; Self { manager, metrics: Arc::new(Metrics::default()), store, discovery_client, flags: StateFlags { chat_endpoints_enabled: AtomicBool::new(false), cmpl_endpoints_enabled: AtomicBool::new(false), embeddings_endpoints_enabled: AtomicBool::new(false), images_endpoints_enabled: AtomicBool::new(false), responses_endpoints_enabled: AtomicBool::new(false), }, cancel_token, } } /// Get the Prometheus [`Metrics`] object which tracks request counts and inflight requests pub fn metrics_clone(&self) -> Arc { self.metrics.clone() } pub fn manager(&self) -> &ModelManager { Arc::as_ref(&self.manager) } pub fn manager_clone(&self) -> Arc { self.manager.clone() } pub fn store(&self) -> &kv::Manager { &self.store } pub fn discovery(&self) -> Arc { self.discovery_client.clone() } /// Check if the service is shutting down pub fn is_cancelled(&self) -> bool { self.cancel_token.is_cancelled() } /// Get the cancellation token pub fn cancel_token(&self) -> &CancellationToken { &self.cancel_token } // TODO pub fn sse_keep_alive(&self) -> Option { None } } #[derive(Clone)] pub struct HttpService { // The state we share with every request handler state: Arc, router: axum::Router, port: u16, host: String, enable_tls: bool, tls_cert_path: Option, tls_key_path: Option, route_docs: Vec, } #[derive(Clone, Builder)] #[builder(pattern = "owned", build_fn(private, name = "build_internal"))] pub struct HttpServiceConfig { #[builder(default = "8787")] port: u16, #[builder(setter(into), default = "String::from(\"0.0.0.0\")")] host: String, #[builder(default = "false")] enable_tls: bool, #[builder(default = "None")] tls_cert_path: Option, #[builder(default = "None")] tls_key_path: Option, // #[builder(default)] // custom: Vec #[builder(default = "false")] enable_chat_endpoints: bool, #[builder(default = "false")] enable_cmpl_endpoints: bool, #[builder(default = "true")] enable_embeddings_endpoints: bool, #[builder(default = "true")] enable_responses_endpoints: bool, #[builder(default = "None")] request_template: Option, #[builder(default)] store: kv::Manager, } impl HttpService { pub fn builder() -> HttpServiceConfigBuilder { HttpServiceConfigBuilder::default() } pub fn state_clone(&self) -> Arc { self.state.clone() } pub fn state(&self) -> &State { Arc::as_ref(&self.state) } pub fn model_manager(&self) -> &ModelManager { self.state().manager() } pub async fn spawn(&self, cancel_token: CancellationToken) -> JoinHandle> { let this = self.clone(); tokio::spawn(async move { this.run(cancel_token).await }) } pub async fn run(&self, cancel_token: CancellationToken) -> Result<()> { let address = format!("{}:{}", self.host, self.port); let protocol = if self.enable_tls { "HTTPS" } else { "HTTP" }; tracing::info!(protocol, address, "Starting HTTP(S) service"); let router = self.router.clone(); let observer = cancel_token.child_token(); let state_cancel = self.state.cancel_token().clone(); let addr: SocketAddr = address .parse() .map_err(|e| anyhow::anyhow!("Invalid address '{}': {}", address, e))?; if self.enable_tls { let cert_path = self .tls_cert_path .as_ref() .ok_or_else(|| anyhow::anyhow!("TLS certificate path not provided"))?; let key_path = self .tls_key_path .as_ref() .ok_or_else(|| anyhow::anyhow!("TLS private key path not provided"))?; // aws_lc_rs is the default but other crates pull in `ring` also, // so rustls doesn't know which one to use. Tell it. if let Err(e) = rustls::crypto::aws_lc_rs::default_provider().install_default() { tracing::debug!("TLS crypto provider already installed: {e:?}"); } let config = RustlsConfig::from_pem_file(cert_path, key_path) .await .map_err(|e| anyhow::anyhow!("Failed to create TLS config: {}", e))?; let handle = axum_server::Handle::new(); let server = axum_server::bind_rustls(addr, config) .handle(handle.clone()) .serve(router.into_make_service()); tokio::select! { result = server => { result.map_err(|e| anyhow::anyhow!("HTTPS server error: {}", e))?; } _ = observer.cancelled() => { state_cancel.cancel(); tracing::info!("HTTPS server shutdown requested"); // accepting requests for 5 more seconds, to allow incorrectly routed requests to arrive handle.graceful_shutdown(Some(Duration::from_secs(get_graceful_shutdown_timeout() as u64))); // no longer accepting requests, draining all existing connections } } } else { let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| { tracing::error!( protocol = %protocol, address = %address, error = %e, "Failed to bind server to address" ); match e.kind() { std::io::ErrorKind::AddrInUse => anyhow::anyhow!( "Failed to start {} server: port {} already in use. Use --http-port to specify a different port.", protocol, self.port ), _ => anyhow::anyhow!( "Failed to start {} server on {}: {}", protocol, address, e ), } })?; axum::serve(listener, router) .with_graceful_shutdown(async move { observer.cancelled_owned().await; state_cancel.cancel(); tracing::info!("HTTP server shutdown requested"); // accepting requests for 5 more seconds, to allow incorrectly routed requests to arrive tokio::time::sleep(Duration::from_secs(get_graceful_shutdown_timeout() as u64)) .await; // no longer accepting requests, draining all existing connections }) .await .inspect_err(|_| cancel_token.cancel())?; } Ok(()) } /// Documentation of exposed HTTP endpoints pub fn route_docs(&self) -> &[RouteDoc] { &self.route_docs } pub fn enable_model_endpoint(&self, endpoint_type: EndpointType, enable: bool) { self.state.flags.set(&endpoint_type, enable); tracing::info!( "{} endpoints {}", endpoint_type.as_str(), if enable { "enabled" } else { "disabled" } ); } } fn get_graceful_shutdown_timeout() -> usize { std::env::var(env_llm::DYN_HTTP_GRACEFUL_SHUTDOWN_TIMEOUT_SECS) .ok() .and_then(|s| s.parse::().ok()) .unwrap_or(5) } /// Environment variable to set the metrics endpoint path (default: `/metrics`) static HTTP_SVC_METRICS_PATH_ENV: &str = "DYN_HTTP_SVC_METRICS_PATH"; /// Environment variable to set the models endpoint path (default: `/v1/models`) static HTTP_SVC_MODELS_PATH_ENV: &str = "DYN_HTTP_SVC_MODELS_PATH"; /// Environment variable to set the health endpoint path (default: `/health`) static HTTP_SVC_HEALTH_PATH_ENV: &str = "DYN_HTTP_SVC_HEALTH_PATH"; /// Environment variable to set the live endpoint path (default: `/live`) static HTTP_SVC_LIVE_PATH_ENV: &str = "DYN_HTTP_SVC_LIVE_PATH"; /// Environment variable to set the chat completions endpoint path (default: `/v1/chat/completions`) static HTTP_SVC_CHAT_PATH_ENV: &str = "DYN_HTTP_SVC_CHAT_PATH"; /// Environment variable to set the completions endpoint path (default: `/v1/completions`) static HTTP_SVC_CMP_PATH_ENV: &str = "DYN_HTTP_SVC_CMP_PATH"; /// Environment variable to set the embeddings endpoint path (default: `/v1/embeddings`) static HTTP_SVC_EMB_PATH_ENV: &str = "DYN_HTTP_SVC_EMB_PATH"; /// Environment variable to set the responses endpoint path (default: `/v1/responses`) static HTTP_SVC_RESPONSES_PATH_ENV: &str = "DYN_HTTP_SVC_RESPONSES_PATH"; impl HttpServiceConfigBuilder { pub fn build(self) -> Result { let config: HttpServiceConfig = self.build_internal()?; let model_manager = Arc::new(ModelManager::new()); // Create a temporary cancel token for building - will be replaced in spawn/run let temp_cancel_token = CancellationToken::new(); let state = Arc::new(State::new(model_manager, config.store, temp_cancel_token)); state .flags .set(&EndpointType::Chat, config.enable_chat_endpoints); state .flags .set(&EndpointType::Completion, config.enable_cmpl_endpoints); state .flags .set(&EndpointType::Embedding, config.enable_embeddings_endpoints); state .flags .set(&EndpointType::Responses, config.enable_responses_endpoints); // enable prometheus metrics let registry = metrics::Registry::new(); state.metrics_clone().register(®istry)?; // Register worker load metrics (active_decode_blocks, active_prefill_tokens per worker) // These are updated by KvWorkerMonitor when receiving ActiveLoad events if let Err(e) = register_worker_load_metrics(®istry) { tracing::warn!("Failed to register worker load metrics: {}", e); } // Register worker timing metrics (last_ttft, last_itl per worker) // These are updated by ResponseMetricCollector when observing TTFT/ITL if let Err(e) = register_worker_timing_metrics(®istry) { tracing::warn!("Failed to register worker timing metrics: {}", e); } let mut router = axum::Router::new(); let mut all_docs = Vec::new(); let mut routes = vec![ metrics::router(registry, var(HTTP_SVC_METRICS_PATH_ENV).ok()), super::openai::list_models_router(state.clone(), var(HTTP_SVC_MODELS_PATH_ENV).ok()), super::health::health_check_router(state.clone(), var(HTTP_SVC_HEALTH_PATH_ENV).ok()), super::health::live_check_router(state.clone(), var(HTTP_SVC_LIVE_PATH_ENV).ok()), super::busy_threshold::busy_threshold_router(state.clone(), None), ]; let endpoint_routes = HttpServiceConfigBuilder::get_endpoints_router(state.clone(), &config.request_template); routes.extend(endpoint_routes); for (route_docs, route) in routes { router = router.merge(route); all_docs.extend(route_docs); } // Add OpenAPI documentation routes (must be after all other routes so it can document them) // Note: The path parameter is currently unused as SwaggerUi requires static paths let (openapi_docs, openapi_route) = super::openapi_docs::openapi_router(all_docs.clone(), None); router = router.merge(openapi_route); all_docs.extend(openapi_docs); // Add span for tracing // Add on_response callback for logging response status code router = router.layer( TraceLayer::new_for_http() .make_span_with(make_request_span) .on_response( |response: &Response, latency: Duration, _span: &tracing::Span| { let status = response.status(); let latency_ms = latency.as_millis(); if status.is_server_error() { tracing::error!( status = %status.as_u16(), latency_ms = %latency_ms, "request completed with server error" ); } else if status.is_client_error() { tracing::warn!( status = %status.as_u16(), latency_ms = %latency_ms, "request completed with client request error" ); } else { tracing::debug!( status = %status.as_u16(), latency_ms = %latency_ms, "request completed" ); } }, ), ); Ok(HttpService { state, router, port: config.port, host: config.host, enable_tls: config.enable_tls, tls_cert_path: config.tls_cert_path, tls_key_path: config.tls_key_path, route_docs: all_docs, }) } pub fn with_request_template(mut self, request_template: Option) -> Self { self.request_template = Some(request_template); self } fn get_endpoints_router( state: Arc, request_template: &Option, ) -> Vec<(Vec, axum::Router)> { let mut routes = Vec::new(); // Add chat completions route with conditional middleware let (chat_docs, chat_route) = super::openai::chat_completions_router( state.clone(), request_template.clone(), var(HTTP_SVC_CHAT_PATH_ENV).ok(), ); let (cmpl_docs, cmpl_route) = super::openai::completions_router(state.clone(), var(HTTP_SVC_CMP_PATH_ENV).ok()); let (embed_docs, embed_route) = super::openai::embeddings_router(state.clone(), var(HTTP_SVC_EMB_PATH_ENV).ok()); let (images_docs, images_route) = super::openai::images_router(state.clone(), None); let (responses_docs, responses_route) = super::openai::responses_router( state.clone(), request_template.clone(), var(HTTP_SVC_RESPONSES_PATH_ENV).ok(), ); let mut endpoint_routes = HashMap::new(); endpoint_routes.insert(EndpointType::Chat, (chat_docs, chat_route)); endpoint_routes.insert(EndpointType::Completion, (cmpl_docs, cmpl_route)); endpoint_routes.insert(EndpointType::Embedding, (embed_docs, embed_route)); endpoint_routes.insert(EndpointType::Images, (images_docs, images_route)); endpoint_routes.insert(EndpointType::Responses, (responses_docs, responses_route)); for endpoint_type in EndpointType::all() { let state_route = state.clone(); if !endpoint_routes.contains_key(&endpoint_type) { tracing::debug!("{} endpoints are disabled", endpoint_type.as_str()); continue; } let (docs, route) = endpoint_routes.get(&endpoint_type).cloned().unwrap(); let route = route.route_layer(axum::middleware::from_fn( move |req: axum::http::Request, next: axum::middleware::Next| { let state: Arc = state_route.clone(); async move { // Check if the endpoint is enabled let enabled = state.flags.get(&endpoint_type); if enabled { Ok(next.run(req).await) } else { tracing::debug!("{} endpoints are disabled", endpoint_type.as_str()); Err(axum::http::StatusCode::NOT_FOUND) } } }, )); routes.push((docs, route)); } routes } } #[cfg(test)] mod tests { use super::*; use serial_test::serial; use std::sync::Arc; use tokio_util::sync::CancellationToken; #[tokio::test] #[serial] async fn test_liveness_endpoint_reflects_cancellation() { // 1. Setup service & token let cancel_token = Arc::new(CancellationToken::new()); let service = HttpService::builder().build().unwrap(); let port = service.port; // 2. Spawn service with shared token let service_token = cancel_token.clone(); let handle = tokio::spawn(async move { service.run((*service_token).clone()).await.unwrap(); }); tokio::time::sleep(std::time::Duration::from_millis(1)).await; // 3. Cancel the token cancel_token.cancel(); // 4. Wait a tiny bit for propagation tokio::time::sleep(std::time::Duration::from_millis(20)).await; // 5. Hit the endpoint let client = reqwest::Client::new(); let resp = client .get(format!("http://localhost:{}/live", port)) .send() .await .expect("Request failed"); // 6. ASSERTION: Should be 503 Service Unavailable assert_eq!(resp.status(), reqwest::StatusCode::SERVICE_UNAVAILABLE); // Clean up handle.abort(); } }