use axum::{extract::Request, http::HeaderValue, response::Response}; use std::sync::Arc; use std::time::Instant; use tower::{Layer, Service}; use tower_http::trace::{MakeSpan, OnRequest, OnResponse, TraceLayer}; use tracing::{field::Empty, info_span, Span}; /// Generate OpenAI-compatible request ID based on endpoint fn generate_request_id(path: &str) -> String { let prefix = if path.contains("/chat/completions") { "chatcmpl-" } else if path.contains("/completions") { "cmpl-" } else if path.contains("/generate") { "gnt-" } else { "req-" }; // Generate a random string similar to OpenAI's format let random_part: String = (0..24) .map(|_| { let chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; chars .chars() .nth(rand::random::() % chars.len()) .unwrap() }) .collect(); format!("{}{}", prefix, random_part) } /// Extension type for storing request ID #[derive(Clone, Debug)] pub struct RequestId(pub String); /// Tower Layer for request ID middleware #[derive(Clone)] pub struct RequestIdLayer { headers: Arc>, } impl RequestIdLayer { pub fn new(headers: Vec) -> Self { Self { headers: Arc::new(headers), } } } impl Layer for RequestIdLayer { type Service = RequestIdMiddleware; fn layer(&self, inner: S) -> Self::Service { RequestIdMiddleware { inner, headers: self.headers.clone(), } } } /// Tower Service for request ID middleware #[derive(Clone)] pub struct RequestIdMiddleware { inner: S, headers: Arc>, } impl Service for RequestIdMiddleware where S: Service + Send + 'static, S::Future: Send + 'static, { type Response = S::Response; type Error = S::Error; type Future = std::pin::Pin< Box> + Send>, >; fn poll_ready( &mut self, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { self.inner.poll_ready(cx) } fn call(&mut self, mut req: Request) -> Self::Future { let headers = self.headers.clone(); // Extract request ID from headers or generate new one let mut request_id = None; for header_name in headers.iter() { if let Some(header_value) = req.headers().get(header_name) { if let Ok(value) = header_value.to_str() { request_id = Some(value.to_string()); break; } } } let request_id = request_id.unwrap_or_else(|| generate_request_id(req.uri().path())); // Insert request ID into request extensions req.extensions_mut().insert(RequestId(request_id.clone())); // Create a span with the request ID for this request let span = tracing::info_span!( "http_request", method = %req.method(), uri = %req.uri(), version = ?req.version(), request_id = %request_id ); // Log within the span let _enter = span.enter(); tracing::info!( target: "sglang_router_rs::request", "started processing request" ); drop(_enter); // Capture values we need in the async block let method = req.method().clone(); let uri = req.uri().clone(); let version = req.version(); // Call the inner service let future = self.inner.call(req); Box::pin(async move { let start_time = Instant::now(); let mut response = future.await?; let latency = start_time.elapsed(); // Add request ID to response headers response.headers_mut().insert( "x-request-id", HeaderValue::from_str(&request_id) .unwrap_or_else(|_| HeaderValue::from_static("invalid-request-id")), ); // Log the response with proper request ID in span let status = response.status(); let span = tracing::info_span!( "http_request", method = %method, uri = %uri, version = ?version, request_id = %request_id, status = %status, latency = ?latency ); let _enter = span.enter(); if status.is_server_error() { tracing::error!( target: "sglang_router_rs::response", "request failed with server error" ); } else if status.is_client_error() { tracing::warn!( target: "sglang_router_rs::response", "request failed with client error" ); } else { tracing::info!( target: "sglang_router_rs::response", "finished processing request" ); } Ok(response) }) } } // ============= Logging Middleware ============= /// Custom span maker that includes request ID #[derive(Clone, Debug)] pub struct RequestSpan; impl MakeSpan for RequestSpan { fn make_span(&mut self, request: &Request) -> Span { // Don't try to extract request ID here - it won't be available yet // The RequestIdLayer runs after TraceLayer creates the span info_span!( "http_request", method = %request.method(), uri = %request.uri(), version = ?request.version(), request_id = Empty, // Will be set later status_code = Empty, latency = Empty, error = Empty, ) } } /// Custom on_request handler #[derive(Clone, Debug)] pub struct RequestLogger; impl OnRequest for RequestLogger { fn on_request(&mut self, request: &Request, span: &Span) { let _enter = span.enter(); // Try to get the request ID from extensions // This will work if RequestIdLayer has already run if let Some(request_id) = request.extensions().get::() { span.record("request_id", &request_id.0.as_str()); } // Don't log here - we already log in RequestIdService with the proper request_id } } /// Custom on_response handler #[derive(Clone, Debug)] pub struct ResponseLogger { _start_time: Instant, } impl Default for ResponseLogger { fn default() -> Self { Self { _start_time: Instant::now(), } } } impl OnResponse for ResponseLogger { fn on_response(self, response: &Response, latency: std::time::Duration, span: &Span) { let status = response.status(); // Record these in the span for structured logging/observability tools span.record("status_code", status.as_u16()); span.record("latency", format!("{:?}", latency)); // Don't log here - RequestIdService handles all logging with proper request IDs } } /// Create a configured TraceLayer for HTTP logging /// Note: Actual request/response logging with request IDs is done in RequestIdService pub fn create_logging_layer() -> TraceLayer< tower_http::classify::SharedClassifier, RequestSpan, RequestLogger, ResponseLogger, > { TraceLayer::new_for_http() .make_span_with(RequestSpan) .on_request(RequestLogger) .on_response(ResponseLogger::default()) } /// Structured logging data for requests #[derive(Debug, serde::Serialize)] pub struct RequestLogEntry { pub timestamp: String, pub request_id: String, pub method: String, pub uri: String, pub status: u16, pub latency_ms: u64, pub user_agent: Option, pub remote_addr: Option, pub error: Option, } /// Log a request with structured data pub fn log_request(entry: RequestLogEntry) { if entry.status >= 500 { tracing::error!( target: "sglang_router_rs::http", request_id = %entry.request_id, method = %entry.method, uri = %entry.uri, status = entry.status, latency_ms = entry.latency_ms, user_agent = ?entry.user_agent, remote_addr = ?entry.remote_addr, error = ?entry.error, "HTTP request failed" ); } else if entry.status >= 400 { tracing::warn!( target: "sglang_router_rs::http", request_id = %entry.request_id, method = %entry.method, uri = %entry.uri, status = entry.status, latency_ms = entry.latency_ms, user_agent = ?entry.user_agent, remote_addr = ?entry.remote_addr, "HTTP request client error" ); } else { tracing::info!( target: "sglang_router_rs::http", request_id = %entry.request_id, method = %entry.method, uri = %entry.uri, status = entry.status, latency_ms = entry.latency_ms, user_agent = ?entry.user_agent, remote_addr = ?entry.remote_addr, "HTTP request completed" ); } }