Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
da53e13c
Unverified
Commit
da53e13c
authored
Aug 15, 2025
by
Simo Lin
Committed by
GitHub
Aug 15, 2025
Browse files
[router] preserve original worker response header in router (#9236)
parent
d7e38b2f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
135 additions
and
37 deletions
+135
-37
sgl-router/src/routers/header_utils.rs
sgl-router/src/routers/header_utils.rs
+53
-0
sgl-router/src/routers/mod.rs
sgl-router/src/routers/mod.rs
+1
-0
sgl-router/src/routers/pd_router.rs
sgl-router/src/routers/pd_router.rs
+47
-16
sgl-router/src/routers/router.rs
sgl-router/src/routers/router.rs
+34
-21
No files found.
sgl-router/src/routers/header_utils.rs
0 → 100644
View file @
da53e13c
use
axum
::
body
::
Body
;
use
axum
::
extract
::
Request
;
use
axum
::
http
::{
HeaderMap
,
HeaderName
,
HeaderValue
};
/// Copy request headers to a Vec of name-value string pairs
/// Used for forwarding headers to backend workers
pub
fn
copy_request_headers
(
req
:
&
Request
<
Body
>
)
->
Vec
<
(
String
,
String
)
>
{
req
.headers
()
.iter
()
.filter_map
(|(
name
,
value
)|
{
// Convert header value to string, skipping non-UTF8 headers
value
.to_str
()
.ok
()
.map
(|
v
|
(
name
.to_string
(),
v
.to_string
()))
})
.collect
()
}
/// Convert headers from reqwest Response to axum HeaderMap
/// Filters out hop-by-hop headers that shouldn't be forwarded
pub
fn
preserve_response_headers
(
reqwest_headers
:
&
HeaderMap
)
->
HeaderMap
{
let
mut
headers
=
HeaderMap
::
new
();
for
(
name
,
value
)
in
reqwest_headers
.iter
()
{
// Skip hop-by-hop headers that shouldn't be forwarded
let
name_str
=
name
.as_str
()
.to_lowercase
();
if
should_forward_header
(
&
name_str
)
{
// The original name and value are already valid, so we can just clone them
headers
.insert
(
name
.clone
(),
value
.clone
());
}
}
headers
}
/// Determine if a header should be forwarded from backend to client
fn
should_forward_header
(
name
:
&
str
)
->
bool
{
// List of headers that should NOT be forwarded (hop-by-hop headers)
!
matches!
(
name
,
"connection"
|
"keep-alive"
|
"proxy-authenticate"
|
"proxy-authorization"
|
"te"
|
"trailers"
|
"transfer-encoding"
|
"upgrade"
|
"content-encoding"
|
// Let axum/hyper handle encoding
"host"
// Should not forward the backend's host header
)
}
sgl-router/src/routers/mod.rs
View file @
da53e13c
...
@@ -12,6 +12,7 @@ use std::fmt::Debug;
...
@@ -12,6 +12,7 @@ use std::fmt::Debug;
use
crate
::
openai_api_types
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
};
use
crate
::
openai_api_types
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
};
pub
mod
factory
;
pub
mod
factory
;
pub
mod
header_utils
;
pub
mod
pd_router
;
pub
mod
pd_router
;
pub
mod
pd_types
;
pub
mod
pd_types
;
pub
mod
router
;
pub
mod
router
;
...
...
sgl-router/src/routers/pd_router.rs
View file @
da53e13c
// PD (Prefill-Decode) Router Implementation
// PD (Prefill-Decode) Router Implementation
// This module handles routing for disaggregated prefill-decode systems
// This module handles routing for disaggregated prefill-decode systems
use
super
::
header_utils
;
use
super
::
pd_types
::{
api_path
,
PDRouterError
};
use
super
::
pd_types
::{
api_path
,
PDRouterError
};
use
crate
::
config
::
types
::{
use
crate
::
config
::
types
::{
CircuitBreakerConfig
as
ConfigCircuitBreakerConfig
,
CircuitBreakerConfig
as
ConfigCircuitBreakerConfig
,
...
@@ -170,17 +171,26 @@ impl PDRouter {
...
@@ -170,17 +171,26 @@ impl PDRouter {
}
}
match
request_builder
.send
()
.await
{
match
request_builder
.send
()
.await
{
Ok
(
res
)
if
res
.status
()
.is_success
()
=>
match
res
.bytes
()
.await
{
Ok
(
res
)
if
res
.status
()
.is_success
()
=>
{
Ok
(
body
)
=>
(
StatusCode
::
OK
,
body
)
.into_response
(),
let
response_headers
=
header_utils
::
preserve_response_headers
(
res
.headers
());
Err
(
e
)
=>
{
error!
(
"Failed to read response body: {}"
,
e
);
match
res
.bytes
()
.await
{
(
Ok
(
body
)
=>
{
StatusCode
::
INTERNAL_SERVER_ERROR
,
let
mut
response
=
Response
::
new
(
axum
::
body
::
Body
::
from
(
body
));
format!
(
"Failed to read response body: {}"
,
e
),
*
response
.status_mut
()
=
StatusCode
::
OK
;
)
*
response
.headers_mut
()
=
response_headers
;
.into_response
()
response
}
Err
(
e
)
=>
{
error!
(
"Failed to read response body: {}"
,
e
);
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
format!
(
"Failed to read response body: {}"
,
e
),
)
.into_response
()
}
}
}
}
,
}
Ok
(
res
)
=>
{
Ok
(
res
)
=>
{
let
status
=
StatusCode
::
from_u16
(
res
.status
()
.as_u16
())
let
status
=
StatusCode
::
from_u16
(
res
.status
()
.as_u16
())
.unwrap_or
(
StatusCode
::
INTERNAL_SERVER_ERROR
);
.unwrap_or
(
StatusCode
::
INTERNAL_SERVER_ERROR
);
...
@@ -822,12 +832,16 @@ impl PDRouter {
...
@@ -822,12 +832,16 @@ impl PDRouter {
json
.pointer
(
"/meta_info/input_token_logprobs"
)
.cloned
()
json
.pointer
(
"/meta_info/input_token_logprobs"
)
.cloned
()
});
});
let
response_headers
=
header_utils
::
preserve_response_headers
(
res
.headers
());
Self
::
create_streaming_response
(
Self
::
create_streaming_response
(
res
.bytes_stream
(),
res
.bytes_stream
(),
status
,
status
,
prefill_logprobs
,
prefill_logprobs
,
return_logprob
,
return_logprob
,
None
,
None
,
Some
(
response_headers
),
)
)
}
else
{
}
else
{
// Non-streaming response with logprobs
// Non-streaming response with logprobs
...
@@ -918,17 +932,30 @@ impl PDRouter {
...
@@ -918,17 +932,30 @@ impl PDRouter {
}
else
if
is_stream
{
}
else
if
is_stream
{
// Streaming response without logprobs - direct passthrough
// Streaming response without logprobs - direct passthrough
let
decode_url
=
decode
.url
()
.to_string
();
let
decode_url
=
decode
.url
()
.to_string
();
let
response_headers
=
header_utils
::
preserve_response_headers
(
res
.headers
());
Self
::
create_streaming_response
(
Self
::
create_streaming_response
(
res
.bytes_stream
(),
res
.bytes_stream
(),
status
,
status
,
None
,
None
,
false
,
false
,
Some
(
decode_url
),
Some
(
decode_url
),
Some
(
response_headers
),
)
)
}
else
{
}
else
{
// Non-streaming response without logprobs - direct passthrough like fast version
// Non-streaming response without logprobs - direct passthrough like fast version
let
response_headers
=
header_utils
::
preserve_response_headers
(
res
.headers
());
match
res
.bytes
()
.await
{
match
res
.bytes
()
.await
{
Ok
(
decode_body
)
=>
(
status
,
decode_body
)
.into_response
(),
Ok
(
decode_body
)
=>
{
let
mut
response
=
Response
::
new
(
axum
::
body
::
Body
::
from
(
decode_body
));
*
response
.status_mut
()
=
status
;
*
response
.headers_mut
()
=
response_headers
;
response
}
Err
(
e
)
=>
{
Err
(
e
)
=>
{
error!
(
"Failed to read decode response: {}"
,
e
);
error!
(
"Failed to read decode response: {}"
,
e
);
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
"Failed to read response"
)
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
"Failed to read response"
)
...
@@ -1081,6 +1108,7 @@ impl PDRouter {
...
@@ -1081,6 +1108,7 @@ impl PDRouter {
prefill_logprobs
:
Option
<
Value
>
,
prefill_logprobs
:
Option
<
Value
>
,
return_logprob
:
bool
,
return_logprob
:
bool
,
decode_url
:
Option
<
String
>
,
decode_url
:
Option
<
String
>
,
headers
:
Option
<
HeaderMap
>
,
)
->
Response
{
)
->
Response
{
let
(
tx
,
rx
)
=
tokio
::
sync
::
mpsc
::
unbounded_channel
();
let
(
tx
,
rx
)
=
tokio
::
sync
::
mpsc
::
unbounded_channel
();
...
@@ -1118,9 +1146,12 @@ impl PDRouter {
...
@@ -1118,9 +1146,12 @@ impl PDRouter {
let
mut
response
=
Response
::
new
(
body
);
let
mut
response
=
Response
::
new
(
body
);
*
response
.status_mut
()
=
status
;
*
response
.status_mut
()
=
status
;
response
.headers_mut
()
// Use provided headers or create new ones, then ensure content-type is set for streaming
.insert
(
CONTENT_TYPE
,
HeaderValue
::
from_static
(
"text/event-stream"
));
let
mut
headers
=
headers
.unwrap_or_else
(
HeaderMap
::
new
);
headers
.insert
(
CONTENT_TYPE
,
HeaderValue
::
from_static
(
"text/event-stream"
));
*
response
.headers_mut
()
=
headers
;
response
response
}
}
...
@@ -1556,7 +1587,7 @@ impl RouterTrait for PDRouter {
...
@@ -1556,7 +1587,7 @@ impl RouterTrait for PDRouter {
async
fn
get_models
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
{
async
fn
get_models
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
{
// Extract headers first to avoid Send issues
// Extract headers first to avoid Send issues
let
headers
=
crate
::
routers
::
router
::
copy_request_headers
(
&
req
);
let
headers
=
header_utils
::
copy_request_headers
(
&
req
);
// Proxy to first prefill worker
// Proxy to first prefill worker
self
.proxy_to_first_worker
(
&
self
.prefill_workers
,
"v1/models"
,
"prefill"
,
Some
(
headers
))
self
.proxy_to_first_worker
(
&
self
.prefill_workers
,
"v1/models"
,
"prefill"
,
Some
(
headers
))
...
@@ -1565,7 +1596,7 @@ impl RouterTrait for PDRouter {
...
@@ -1565,7 +1596,7 @@ impl RouterTrait for PDRouter {
async
fn
get_model_info
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
{
async
fn
get_model_info
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
{
// Extract headers first to avoid Send issues
// Extract headers first to avoid Send issues
let
headers
=
crate
::
routers
::
router
::
copy_request_headers
(
&
req
);
let
headers
=
header_utils
::
copy_request_headers
(
&
req
);
// Proxy to first prefill worker
// Proxy to first prefill worker
self
.proxy_to_first_worker
(
self
.proxy_to_first_worker
(
...
...
sgl-router/src/routers/router.rs
View file @
da53e13c
use
super
::
header_utils
;
use
crate
::
config
::
types
::{
use
crate
::
config
::
types
::{
CircuitBreakerConfig
as
ConfigCircuitBreakerConfig
,
CircuitBreakerConfig
as
ConfigCircuitBreakerConfig
,
HealthCheckConfig
as
ConfigHealthCheckConfig
,
RetryConfig
,
HealthCheckConfig
as
ConfigHealthCheckConfig
,
RetryConfig
,
...
@@ -24,17 +25,6 @@ use std::sync::{Arc, RwLock};
...
@@ -24,17 +25,6 @@ use std::sync::{Arc, RwLock};
use
std
::
time
::{
Duration
,
Instant
};
use
std
::
time
::{
Duration
,
Instant
};
use
tokio_stream
::
wrappers
::
UnboundedReceiverStream
;
use
tokio_stream
::
wrappers
::
UnboundedReceiverStream
;
use
tracing
::{
debug
,
error
,
info
,
warn
};
use
tracing
::{
debug
,
error
,
info
,
warn
};
pub
fn
copy_request_headers
(
req
:
&
Request
<
Body
>
)
->
Vec
<
(
String
,
String
)
>
{
req
.headers
()
.iter
()
.filter_map
(|(
name
,
value
)|
{
value
.to_str
()
.ok
()
.map
(|
v
|
(
name
.to_string
(),
v
.to_string
()))
})
.collect
()
}
/// Regular router that uses injected load balancing policies
/// Regular router that uses injected load balancing policies
#[derive(Debug)]
#[derive(Debug)]
...
@@ -400,7 +390,7 @@ impl Router {
...
@@ -400,7 +390,7 @@ impl Router {
// Helper method to proxy GET requests to the first available worker
// Helper method to proxy GET requests to the first available worker
async
fn
proxy_get_request
(
&
self
,
req
:
Request
<
Body
>
,
endpoint
:
&
str
)
->
Response
{
async
fn
proxy_get_request
(
&
self
,
req
:
Request
<
Body
>
,
endpoint
:
&
str
)
->
Response
{
let
headers
=
copy_request_headers
(
&
req
);
let
headers
=
super
::
header_utils
::
copy_request_headers
(
&
req
);
match
self
.select_first_worker
()
{
match
self
.select_first_worker
()
{
Ok
(
worker_url
)
=>
{
Ok
(
worker_url
)
=>
{
...
@@ -416,8 +406,18 @@ impl Router {
...
@@ -416,8 +406,18 @@ impl Router {
Ok
(
res
)
=>
{
Ok
(
res
)
=>
{
let
status
=
StatusCode
::
from_u16
(
res
.status
()
.as_u16
())
let
status
=
StatusCode
::
from_u16
(
res
.status
()
.as_u16
())
.unwrap_or
(
StatusCode
::
INTERNAL_SERVER_ERROR
);
.unwrap_or
(
StatusCode
::
INTERNAL_SERVER_ERROR
);
// Preserve headers from backend
let
response_headers
=
header_utils
::
preserve_response_headers
(
res
.headers
());
match
res
.bytes
()
.await
{
match
res
.bytes
()
.await
{
Ok
(
body
)
=>
(
status
,
body
)
.into_response
(),
Ok
(
body
)
=>
{
let
mut
response
=
Response
::
new
(
axum
::
body
::
Body
::
from
(
body
));
*
response
.status_mut
()
=
status
;
*
response
.headers_mut
()
=
response_headers
;
response
}
Err
(
e
)
=>
(
Err
(
e
)
=>
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
StatusCode
::
INTERNAL_SERVER_ERROR
,
format!
(
"Failed to read response: {}"
,
e
),
format!
(
"Failed to read response: {}"
,
e
),
...
@@ -645,9 +645,16 @@ impl Router {
...
@@ -645,9 +645,16 @@ impl Router {
.unwrap_or
(
StatusCode
::
INTERNAL_SERVER_ERROR
);
.unwrap_or
(
StatusCode
::
INTERNAL_SERVER_ERROR
);
if
!
is_stream
{
if
!
is_stream
{
// For non-streaming requests, get response first
// For non-streaming requests, preserve headers
let
response_headers
=
super
::
header_utils
::
preserve_response_headers
(
res
.headers
());
let
response
=
match
res
.bytes
()
.await
{
let
response
=
match
res
.bytes
()
.await
{
Ok
(
body
)
=>
(
status
,
body
)
.into_response
(),
Ok
(
body
)
=>
{
let
mut
response
=
Response
::
new
(
axum
::
body
::
Body
::
from
(
body
));
*
response
.status_mut
()
=
status
;
*
response
.headers_mut
()
=
response_headers
;
response
}
Err
(
e
)
=>
{
Err
(
e
)
=>
{
let
error_msg
=
format!
(
"Failed to get response body: {}"
,
e
);
let
error_msg
=
format!
(
"Failed to get response body: {}"
,
e
);
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
error_msg
)
.into_response
()
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
error_msg
)
.into_response
()
...
@@ -670,6 +677,11 @@ impl Router {
...
@@ -670,6 +677,11 @@ impl Router {
let
workers
=
Arc
::
clone
(
&
self
.workers
);
let
workers
=
Arc
::
clone
(
&
self
.workers
);
let
worker_url
=
worker_url
.to_string
();
let
worker_url
=
worker_url
.to_string
();
// Preserve headers for streaming response
let
mut
response_headers
=
header_utils
::
preserve_response_headers
(
res
.headers
());
// Ensure we set the correct content-type for SSE
response_headers
.insert
(
CONTENT_TYPE
,
HeaderValue
::
from_static
(
"text/event-stream"
));
let
stream
=
res
.bytes_stream
();
let
stream
=
res
.bytes_stream
();
let
(
tx
,
rx
)
=
tokio
::
sync
::
mpsc
::
unbounded_channel
();
let
(
tx
,
rx
)
=
tokio
::
sync
::
mpsc
::
unbounded_channel
();
...
@@ -724,12 +736,15 @@ impl Router {
...
@@ -724,12 +736,15 @@ impl Router {
let
mut
response
=
Response
::
new
(
body
);
let
mut
response
=
Response
::
new
(
body
);
*
response
.status_mut
()
=
status
;
*
response
.status_mut
()
=
status
;
response
*
response
.headers_mut
()
=
response_headers
;
.headers_mut
()
.insert
(
CONTENT_TYPE
,
HeaderValue
::
from_static
(
"text/event-stream"
));
response
response
}
else
{
}
else
{
// For requests without load tracking, just stream
// For requests without load tracking, just stream
// Preserve headers for streaming response
let
mut
response_headers
=
header_utils
::
preserve_response_headers
(
res
.headers
());
// Ensure we set the correct content-type for SSE
response_headers
.insert
(
CONTENT_TYPE
,
HeaderValue
::
from_static
(
"text/event-stream"
));
let
stream
=
res
.bytes_stream
();
let
stream
=
res
.bytes_stream
();
let
(
tx
,
rx
)
=
tokio
::
sync
::
mpsc
::
unbounded_channel
();
let
(
tx
,
rx
)
=
tokio
::
sync
::
mpsc
::
unbounded_channel
();
...
@@ -756,9 +771,7 @@ impl Router {
...
@@ -756,9 +771,7 @@ impl Router {
let
mut
response
=
Response
::
new
(
body
);
let
mut
response
=
Response
::
new
(
body
);
*
response
.status_mut
()
=
status
;
*
response
.status_mut
()
=
status
;
response
*
response
.headers_mut
()
=
response_headers
;
.headers_mut
()
.insert
(
CONTENT_TYPE
,
HeaderValue
::
from_static
(
"text/event-stream"
));
response
response
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment