Unverified Commit da53e13c authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] preserve original worker response header in router (#9236)

parent d7e38b2f
use axum::body::Body;
use axum::extract::Request;
use axum::http::{HeaderMap, HeaderName, HeaderValue};
/// Copy request headers to a Vec of name-value string pairs
/// Used for forwarding headers to backend workers
pub fn copy_request_headers(req: &Request<Body>) -> Vec<(String, String)> {
req.headers()
.iter()
.filter_map(|(name, value)| {
// Convert header value to string, skipping non-UTF8 headers
value
.to_str()
.ok()
.map(|v| (name.to_string(), v.to_string()))
})
.collect()
}
/// Convert headers from reqwest Response to axum HeaderMap
/// Filters out hop-by-hop headers that shouldn't be forwarded
pub fn preserve_response_headers(reqwest_headers: &HeaderMap) -> HeaderMap {
let mut headers = HeaderMap::new();
for (name, value) in reqwest_headers.iter() {
// Skip hop-by-hop headers that shouldn't be forwarded
let name_str = name.as_str().to_lowercase();
if should_forward_header(&name_str) {
// The original name and value are already valid, so we can just clone them
headers.insert(name.clone(), value.clone());
}
}
headers
}
/// Determine if a header should be forwarded from backend to client
fn should_forward_header(name: &str) -> bool {
// List of headers that should NOT be forwarded (hop-by-hop headers)
!matches!(
name,
"connection" |
"keep-alive" |
"proxy-authenticate" |
"proxy-authorization" |
"te" |
"trailers" |
"transfer-encoding" |
"upgrade" |
"content-encoding" | // Let axum/hyper handle encoding
"host" // Should not forward the backend's host header
)
}
...@@ -12,6 +12,7 @@ use std::fmt::Debug; ...@@ -12,6 +12,7 @@ use std::fmt::Debug;
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
pub mod factory; pub mod factory;
pub mod header_utils;
pub mod pd_router; pub mod pd_router;
pub mod pd_types; pub mod pd_types;
pub mod router; pub mod router;
......
// PD (Prefill-Decode) Router Implementation // PD (Prefill-Decode) Router Implementation
// This module handles routing for disaggregated prefill-decode systems // This module handles routing for disaggregated prefill-decode systems
use super::header_utils;
use super::pd_types::{api_path, PDRouterError}; use super::pd_types::{api_path, PDRouterError};
use crate::config::types::{ use crate::config::types::{
CircuitBreakerConfig as ConfigCircuitBreakerConfig, CircuitBreakerConfig as ConfigCircuitBreakerConfig,
...@@ -170,17 +171,26 @@ impl PDRouter { ...@@ -170,17 +171,26 @@ impl PDRouter {
} }
match request_builder.send().await { match request_builder.send().await {
Ok(res) if res.status().is_success() => match res.bytes().await { Ok(res) if res.status().is_success() => {
Ok(body) => (StatusCode::OK, body).into_response(), let response_headers = header_utils::preserve_response_headers(res.headers());
Err(e) => {
error!("Failed to read response body: {}", e); match res.bytes().await {
( Ok(body) => {
StatusCode::INTERNAL_SERVER_ERROR, let mut response = Response::new(axum::body::Body::from(body));
format!("Failed to read response body: {}", e), *response.status_mut() = StatusCode::OK;
) *response.headers_mut() = response_headers;
.into_response() response
}
Err(e) => {
error!("Failed to read response body: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read response body: {}", e),
)
.into_response()
}
} }
}, }
Ok(res) => { Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16()) let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
...@@ -822,12 +832,16 @@ impl PDRouter { ...@@ -822,12 +832,16 @@ impl PDRouter {
json.pointer("/meta_info/input_token_logprobs").cloned() json.pointer("/meta_info/input_token_logprobs").cloned()
}); });
let response_headers =
header_utils::preserve_response_headers(res.headers());
Self::create_streaming_response( Self::create_streaming_response(
res.bytes_stream(), res.bytes_stream(),
status, status,
prefill_logprobs, prefill_logprobs,
return_logprob, return_logprob,
None, None,
Some(response_headers),
) )
} else { } else {
// Non-streaming response with logprobs // Non-streaming response with logprobs
...@@ -918,17 +932,30 @@ impl PDRouter { ...@@ -918,17 +932,30 @@ impl PDRouter {
} else if is_stream { } else if is_stream {
// Streaming response without logprobs - direct passthrough // Streaming response without logprobs - direct passthrough
let decode_url = decode.url().to_string(); let decode_url = decode.url().to_string();
let response_headers =
header_utils::preserve_response_headers(res.headers());
Self::create_streaming_response( Self::create_streaming_response(
res.bytes_stream(), res.bytes_stream(),
status, status,
None, None,
false, false,
Some(decode_url), Some(decode_url),
Some(response_headers),
) )
} else { } else {
// Non-streaming response without logprobs - direct passthrough like fast version // Non-streaming response without logprobs - direct passthrough like fast version
let response_headers =
header_utils::preserve_response_headers(res.headers());
match res.bytes().await { match res.bytes().await {
Ok(decode_body) => (status, decode_body).into_response(), Ok(decode_body) => {
let mut response =
Response::new(axum::body::Body::from(decode_body));
*response.status_mut() = status;
*response.headers_mut() = response_headers;
response
}
Err(e) => { Err(e) => {
error!("Failed to read decode response: {}", e); error!("Failed to read decode response: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response") (StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response")
...@@ -1081,6 +1108,7 @@ impl PDRouter { ...@@ -1081,6 +1108,7 @@ impl PDRouter {
prefill_logprobs: Option<Value>, prefill_logprobs: Option<Value>,
return_logprob: bool, return_logprob: bool,
decode_url: Option<String>, decode_url: Option<String>,
headers: Option<HeaderMap>,
) -> Response { ) -> Response {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
...@@ -1118,9 +1146,12 @@ impl PDRouter { ...@@ -1118,9 +1146,12 @@ impl PDRouter {
let mut response = Response::new(body); let mut response = Response::new(body);
*response.status_mut() = status; *response.status_mut() = status;
response
.headers_mut() // Use provided headers or create new ones, then ensure content-type is set for streaming
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); let mut headers = headers.unwrap_or_else(HeaderMap::new);
headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
*response.headers_mut() = headers;
response response
} }
...@@ -1556,7 +1587,7 @@ impl RouterTrait for PDRouter { ...@@ -1556,7 +1587,7 @@ impl RouterTrait for PDRouter {
async fn get_models(&self, req: Request<Body>) -> Response { async fn get_models(&self, req: Request<Body>) -> Response {
// Extract headers first to avoid Send issues // Extract headers first to avoid Send issues
let headers = crate::routers::router::copy_request_headers(&req); let headers = header_utils::copy_request_headers(&req);
// Proxy to first prefill worker // Proxy to first prefill worker
self.proxy_to_first_worker(&self.prefill_workers, "v1/models", "prefill", Some(headers)) self.proxy_to_first_worker(&self.prefill_workers, "v1/models", "prefill", Some(headers))
...@@ -1565,7 +1596,7 @@ impl RouterTrait for PDRouter { ...@@ -1565,7 +1596,7 @@ impl RouterTrait for PDRouter {
async fn get_model_info(&self, req: Request<Body>) -> Response { async fn get_model_info(&self, req: Request<Body>) -> Response {
// Extract headers first to avoid Send issues // Extract headers first to avoid Send issues
let headers = crate::routers::router::copy_request_headers(&req); let headers = header_utils::copy_request_headers(&req);
// Proxy to first prefill worker // Proxy to first prefill worker
self.proxy_to_first_worker( self.proxy_to_first_worker(
......
use super::header_utils;
use crate::config::types::{ use crate::config::types::{
CircuitBreakerConfig as ConfigCircuitBreakerConfig, CircuitBreakerConfig as ConfigCircuitBreakerConfig,
HealthCheckConfig as ConfigHealthCheckConfig, RetryConfig, HealthCheckConfig as ConfigHealthCheckConfig, RetryConfig,
...@@ -24,17 +25,6 @@ use std::sync::{Arc, RwLock}; ...@@ -24,17 +25,6 @@ use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
pub fn copy_request_headers(req: &Request<Body>) -> Vec<(String, String)> {
req.headers()
.iter()
.filter_map(|(name, value)| {
value
.to_str()
.ok()
.map(|v| (name.to_string(), v.to_string()))
})
.collect()
}
/// Regular router that uses injected load balancing policies /// Regular router that uses injected load balancing policies
#[derive(Debug)] #[derive(Debug)]
...@@ -400,7 +390,7 @@ impl Router { ...@@ -400,7 +390,7 @@ impl Router {
// Helper method to proxy GET requests to the first available worker // Helper method to proxy GET requests to the first available worker
async fn proxy_get_request(&self, req: Request<Body>, endpoint: &str) -> Response { async fn proxy_get_request(&self, req: Request<Body>, endpoint: &str) -> Response {
let headers = copy_request_headers(&req); let headers = super::header_utils::copy_request_headers(&req);
match self.select_first_worker() { match self.select_first_worker() {
Ok(worker_url) => { Ok(worker_url) => {
...@@ -416,8 +406,18 @@ impl Router { ...@@ -416,8 +406,18 @@ impl Router {
Ok(res) => { Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16()) let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
// Preserve headers from backend
let response_headers =
header_utils::preserve_response_headers(res.headers());
match res.bytes().await { match res.bytes().await {
Ok(body) => (status, body).into_response(), Ok(body) => {
let mut response = Response::new(axum::body::Body::from(body));
*response.status_mut() = status;
*response.headers_mut() = response_headers;
response
}
Err(e) => ( Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read response: {}", e), format!("Failed to read response: {}", e),
...@@ -645,9 +645,16 @@ impl Router { ...@@ -645,9 +645,16 @@ impl Router {
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
if !is_stream { if !is_stream {
// For non-streaming requests, get response first // For non-streaming requests, preserve headers
let response_headers = super::header_utils::preserve_response_headers(res.headers());
let response = match res.bytes().await { let response = match res.bytes().await {
Ok(body) => (status, body).into_response(), Ok(body) => {
let mut response = Response::new(axum::body::Body::from(body));
*response.status_mut() = status;
*response.headers_mut() = response_headers;
response
}
Err(e) => { Err(e) => {
let error_msg = format!("Failed to get response body: {}", e); let error_msg = format!("Failed to get response body: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, error_msg).into_response() (StatusCode::INTERNAL_SERVER_ERROR, error_msg).into_response()
...@@ -670,6 +677,11 @@ impl Router { ...@@ -670,6 +677,11 @@ impl Router {
let workers = Arc::clone(&self.workers); let workers = Arc::clone(&self.workers);
let worker_url = worker_url.to_string(); let worker_url = worker_url.to_string();
// Preserve headers for streaming response
let mut response_headers = header_utils::preserve_response_headers(res.headers());
// Ensure we set the correct content-type for SSE
response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
let stream = res.bytes_stream(); let stream = res.bytes_stream();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
...@@ -724,12 +736,15 @@ impl Router { ...@@ -724,12 +736,15 @@ impl Router {
let mut response = Response::new(body); let mut response = Response::new(body);
*response.status_mut() = status; *response.status_mut() = status;
response *response.headers_mut() = response_headers;
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
response response
} else { } else {
// For requests without load tracking, just stream // For requests without load tracking, just stream
// Preserve headers for streaming response
let mut response_headers = header_utils::preserve_response_headers(res.headers());
// Ensure we set the correct content-type for SSE
response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
let stream = res.bytes_stream(); let stream = res.bytes_stream();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
...@@ -756,9 +771,7 @@ impl Router { ...@@ -756,9 +771,7 @@ impl Router {
let mut response = Response::new(body); let mut response = Response::new(body);
*response.status_mut() = status; *response.status_mut() = status;
response *response.headers_mut() = response_headers;
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
response response
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment