use axum::{ extract::Request, extract::State, http::HeaderValue, http::StatusCode, middleware::Next, response::IntoResponse, response::Response, }; use rand::Rng; use std::sync::Arc; use std::time::Duration; use std::time::Instant; use tokio::sync::{mpsc, oneshot}; use tower::{Layer, Service}; use tower_http::trace::{MakeSpan, OnRequest, OnResponse, TraceLayer}; use tracing::{debug, error, field::Empty, info, info_span, warn, Span}; pub use crate::core::token_bucket::TokenBucket; use crate::server::AppState; /// Generate OpenAI-compatible request ID based on endpoint fn generate_request_id(path: &str) -> String { let prefix = if path.contains("/chat/completions") { "chatcmpl-" } else if path.contains("/completions") { "cmpl-" } else if path.contains("/generate") { "gnt-" } else { "req-" }; // Generate a random string similar to OpenAI's format let chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; let mut rng = rand::rng(); let random_part: String = (0..24) .map(|_| { let idx = rng.random_range(0..chars.len()); chars.chars().nth(idx).unwrap() }) .collect(); format!("{}{}", prefix, random_part) } /// Extension type for storing request ID #[derive(Clone, Debug)] pub struct RequestId(pub String); /// Tower Layer for request ID middleware #[derive(Clone)] pub struct RequestIdLayer { headers: Arc>, } impl RequestIdLayer { pub fn new(headers: Vec) -> Self { Self { headers: Arc::new(headers), } } } impl Layer for RequestIdLayer { type Service = RequestIdMiddleware; fn layer(&self, inner: S) -> Self::Service { RequestIdMiddleware { inner, headers: self.headers.clone(), } } } /// Tower Service for request ID middleware #[derive(Clone)] pub struct RequestIdMiddleware { inner: S, headers: Arc>, } impl Service for RequestIdMiddleware where S: Service + Send + 'static, S::Future: Send + 'static, { type Response = S::Response; type Error = S::Error; type Future = std::pin::Pin< Box> + Send>, >; fn poll_ready( &mut self, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { self.inner.poll_ready(cx) } fn call(&mut self, mut req: Request) -> Self::Future { let headers = self.headers.clone(); // Extract request ID from headers or generate new one let mut request_id = None; for header_name in headers.iter() { if let Some(header_value) = req.headers().get(header_name) { if let Ok(value) = header_value.to_str() { request_id = Some(value.to_string()); break; } } } let request_id = request_id.unwrap_or_else(|| generate_request_id(req.uri().path())); // Insert request ID into request extensions req.extensions_mut().insert(RequestId(request_id.clone())); // Create a span with the request ID for this request let span = tracing::info_span!( "http_request", method = %req.method(), uri = %req.uri(), version = ?req.version(), request_id = %request_id ); // Log within the span let _enter = span.enter(); tracing::info!( target: "sglang_router_rs::request", "started processing request" ); drop(_enter); // Capture values we need in the async block let method = req.method().clone(); let uri = req.uri().clone(); let version = req.version(); // Call the inner service let future = self.inner.call(req); Box::pin(async move { let start_time = Instant::now(); let mut response = future.await?; let latency = start_time.elapsed(); // Add request ID to response headers response.headers_mut().insert( "x-request-id", HeaderValue::from_str(&request_id) .unwrap_or_else(|_| HeaderValue::from_static("invalid-request-id")), ); // Log the response with proper request ID in span let status = response.status(); let span = tracing::info_span!( "http_request", method = %method, uri = %uri, version = ?version, request_id = %request_id, status = %status, latency = ?latency ); let _enter = span.enter(); if status.is_server_error() { tracing::error!( target: "sglang_router_rs::response", "request failed with server error" ); } else if status.is_client_error() { tracing::warn!( target: "sglang_router_rs::response", "request failed with client error" ); } else { tracing::info!( target: "sglang_router_rs::response", "finished processing request" ); } Ok(response) }) } } // ============= Logging Middleware ============= /// Custom span maker that includes request ID #[derive(Clone, Debug)] pub struct RequestSpan; impl MakeSpan for RequestSpan { fn make_span(&mut self, request: &Request) -> Span { // Don't try to extract request ID here - it won't be available yet // The RequestIdLayer runs after TraceLayer creates the span info_span!( "http_request", method = %request.method(), uri = %request.uri(), version = ?request.version(), request_id = Empty, // Will be set later status_code = Empty, latency = Empty, error = Empty, ) } } /// Custom on_request handler #[derive(Clone, Debug)] pub struct RequestLogger; impl OnRequest for RequestLogger { fn on_request(&mut self, request: &Request, span: &Span) { let _enter = span.enter(); // Try to get the request ID from extensions // This will work if RequestIdLayer has already run if let Some(request_id) = request.extensions().get::() { span.record("request_id", request_id.0.as_str()); } // Don't log here - we already log in RequestIdService with the proper request_id } } /// Custom on_response handler #[derive(Clone, Debug)] pub struct ResponseLogger { _start_time: Instant, } impl Default for ResponseLogger { fn default() -> Self { Self { _start_time: Instant::now(), } } } impl OnResponse for ResponseLogger { fn on_response(self, response: &Response, latency: std::time::Duration, span: &Span) { let status = response.status(); // Record these in the span for structured logging/observability tools span.record("status_code", status.as_u16()); span.record("latency", format!("{:?}", latency)); // Don't log here - RequestIdService handles all logging with proper request IDs } } /// Create a configured TraceLayer for HTTP logging /// Note: Actual request/response logging with request IDs is done in RequestIdService pub fn create_logging_layer() -> TraceLayer< tower_http::classify::SharedClassifier, RequestSpan, RequestLogger, ResponseLogger, > { TraceLayer::new_for_http() .make_span_with(RequestSpan) .on_request(RequestLogger) .on_response(ResponseLogger::default()) } /// Structured logging data for requests #[derive(Debug, serde::Serialize)] pub struct RequestLogEntry { pub timestamp: String, pub request_id: String, pub method: String, pub uri: String, pub status: u16, pub latency_ms: u64, pub user_agent: Option, pub remote_addr: Option, pub error: Option, } /// Log a request with structured data pub fn log_request(entry: RequestLogEntry) { if entry.status >= 500 { tracing::error!( target: "sglang_router_rs::http", request_id = %entry.request_id, method = %entry.method, uri = %entry.uri, status = entry.status, latency_ms = entry.latency_ms, user_agent = ?entry.user_agent, remote_addr = ?entry.remote_addr, error = ?entry.error, "HTTP request failed" ); } else if entry.status >= 400 { tracing::warn!( target: "sglang_router_rs::http", request_id = %entry.request_id, method = %entry.method, uri = %entry.uri, status = entry.status, latency_ms = entry.latency_ms, user_agent = ?entry.user_agent, remote_addr = ?entry.remote_addr, "HTTP request client error" ); } else { tracing::info!( target: "sglang_router_rs::http", request_id = %entry.request_id, method = %entry.method, uri = %entry.uri, status = entry.status, latency_ms = entry.latency_ms, user_agent = ?entry.user_agent, remote_addr = ?entry.remote_addr, "HTTP request completed" ); } } // ============ Concurrency Limiting with Queue Support ============ /// Request queue entry pub struct QueuedRequest { /// Time when the request was queued queued_at: Instant, /// Channel to send the permit back when acquired permit_tx: oneshot::Sender>, } /// Queue metrics for monitoring #[derive(Debug, Default)] pub struct QueueMetrics { pub total_queued: std::sync::atomic::AtomicU64, pub current_queued: std::sync::atomic::AtomicU64, pub total_timeout: std::sync::atomic::AtomicU64, pub total_rejected: std::sync::atomic::AtomicU64, } /// Queue processor that handles queued requests pub struct QueueProcessor { token_bucket: Arc, queue_rx: mpsc::Receiver, queue_timeout: Duration, } impl QueueProcessor { pub fn new( token_bucket: Arc, queue_rx: mpsc::Receiver, queue_timeout: Duration, ) -> Self { Self { token_bucket, queue_rx, queue_timeout, } } pub async fn run(mut self) { info!("Starting concurrency queue processor"); // Process requests in a single task to reduce overhead while let Some(queued) = self.queue_rx.recv().await { // Check timeout immediately let elapsed = queued.queued_at.elapsed(); if elapsed >= self.queue_timeout { warn!("Request already timed out in queue"); let _ = queued.permit_tx.send(Err(StatusCode::REQUEST_TIMEOUT)); continue; } let remaining_timeout = self.queue_timeout - elapsed; // Try to acquire token for this request if self.token_bucket.try_acquire(1.0).await.is_ok() { // Got token immediately debug!("Queue: acquired token immediately for queued request"); let _ = queued.permit_tx.send(Ok(())); } else { // Need to wait for token let token_bucket = self.token_bucket.clone(); // Spawn task only when we actually need to wait tokio::spawn(async move { if token_bucket .acquire_timeout(1.0, remaining_timeout) .await .is_ok() { debug!("Queue: acquired token after waiting"); let _ = queued.permit_tx.send(Ok(())); } else { warn!("Queue: request timed out waiting for token"); let _ = queued.permit_tx.send(Err(StatusCode::REQUEST_TIMEOUT)); } }); } } warn!("Concurrency queue processor shutting down"); } } /// State for the concurrency limiter pub struct ConcurrencyLimiter { pub queue_tx: Option>, } impl ConcurrencyLimiter { /// Create new concurrency limiter with optional queue pub fn new( token_bucket: Arc, queue_size: usize, queue_timeout: Duration, ) -> (Self, Option) { if queue_size > 0 { let (queue_tx, queue_rx) = mpsc::channel(queue_size); let processor = QueueProcessor::new(token_bucket, queue_rx, queue_timeout); ( Self { queue_tx: Some(queue_tx), }, Some(processor), ) } else { (Self { queue_tx: None }, None) } } } /// Middleware function for concurrency limiting with optional queuing pub async fn concurrency_limit_middleware( State(app_state): State>, request: Request, next: Next, ) -> Response { let token_bucket = app_state.context.rate_limiter.clone(); // Try to acquire token immediately if token_bucket.try_acquire(1.0).await.is_ok() { debug!("Acquired token immediately"); let response = next.run(request).await; // Return the token to the bucket token_bucket.return_tokens(1.0).await; response } else { // No tokens available, try to queue if enabled if let Some(queue_tx) = &app_state.concurrency_queue_tx { debug!("No tokens available, attempting to queue request"); // Create a channel for the token response let (permit_tx, permit_rx) = oneshot::channel(); let queued = QueuedRequest { queued_at: Instant::now(), permit_tx, }; // Try to send to queue match queue_tx.try_send(queued) { Ok(_) => { // Wait for token from queue processor match permit_rx.await { Ok(Ok(())) => { debug!("Acquired token from queue"); let response = next.run(request).await; // Return the token to the bucket token_bucket.return_tokens(1.0).await; response } Ok(Err(status)) => { warn!("Queue returned error status: {}", status); status.into_response() } Err(_) => { error!("Queue response channel closed"); StatusCode::INTERNAL_SERVER_ERROR.into_response() } } } Err(_) => { warn!("Request queue is full, returning 429"); StatusCode::TOO_MANY_REQUESTS.into_response() } } } else { warn!("No tokens available and queuing is disabled, returning 429"); StatusCode::TOO_MANY_REQUESTS.into_response() } } }