Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
828a4fe9
Unverified
Commit
828a4fe9
authored
Aug 02, 2025
by
Simo Lin
Committed by
GitHub
Aug 02, 2025
Browse files
[router] Implement HTTP Dependency Injection Pattern for Router System (#8714)
parent
8ada1ab6
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
197 additions
and
186 deletions
+197
-186
sgl-router/src/routers/factory.rs
sgl-router/src/routers/factory.rs
+21
-17
sgl-router/src/routers/mod.rs
sgl-router/src/routers/mod.rs
+9
-16
sgl-router/src/routers/pd_router.rs
sgl-router/src/routers/pd_router.rs
+36
-50
sgl-router/src/routers/router.rs
sgl-router/src/routers/router.rs
+45
-50
sgl-router/src/server.rs
sgl-router/src/server.rs
+38
-32
sgl-router/src/service_discovery.rs
sgl-router/src/service_discovery.rs
+2
-2
sgl-router/tests/api_endpoints_test.rs
sgl-router/tests/api_endpoints_test.rs
+7
-4
sgl-router/tests/common/mod.rs
sgl-router/tests/common/mod.rs
+13
-0
sgl-router/tests/common/test_app.rs
sgl-router/tests/common/test_app.rs
+10
-6
sgl-router/tests/request_formats_test.rs
sgl-router/tests/request_formats_test.rs
+6
-4
sgl-router/tests/streaming_tests.rs
sgl-router/tests/streaming_tests.rs
+6
-4
sgl-router/tests/test_pd_routing.rs
sgl-router/tests/test_pd_routing.rs
+4
-1
No files found.
sgl-router/src/routers/factory.rs
View file @
828a4fe9
//! Factory for creating router instances
use
super
::{
pd_router
::
PDRouter
,
router
::
Router
,
RouterTrait
};
use
crate
::
config
::{
PolicyConfig
,
RouterConfig
,
RoutingMode
};
use
crate
::
config
::{
PolicyConfig
,
RoutingMode
};
use
crate
::
policies
::
PolicyFactory
;
use
crate
::
server
::
AppContext
;
use
std
::
sync
::
Arc
;
/// Factory for creating router instances based on configuration
pub
struct
RouterFactory
;
impl
RouterFactory
{
/// Create a router instance from
configuration
pub
fn
create_router
(
c
onfig
:
&
RouterConfig
)
->
Result
<
Box
<
dyn
RouterTrait
>
,
String
>
{
match
&
config
.mode
{
/// Create a router instance from
application context
pub
fn
create_router
(
c
tx
:
&
Arc
<
AppContext
>
)
->
Result
<
Box
<
dyn
RouterTrait
>
,
String
>
{
match
&
ctx
.router_
config.mode
{
RoutingMode
::
Regular
{
worker_urls
}
=>
{
Self
::
create_regular_router
(
worker_urls
,
&
config
.policy
,
c
onfig
)
Self
::
create_regular_router
(
worker_urls
,
&
ctx
.router_
config.policy
,
c
tx
)
}
RoutingMode
::
PrefillDecode
{
prefill_urls
,
...
...
@@ -24,8 +26,8 @@ impl RouterFactory {
decode_urls
,
prefill_policy
.as_ref
(),
decode_policy
.as_ref
(),
&
config
.policy
,
c
onfig
,
&
ctx
.router_
config.policy
,
c
tx
,
),
}
}
...
...
@@ -34,19 +36,20 @@ impl RouterFactory {
fn
create_regular_router
(
worker_urls
:
&
[
String
],
policy_config
:
&
PolicyConfig
,
router_config
:
&
RouterConfig
,
ctx
:
&
Arc
<
AppContext
>
,
)
->
Result
<
Box
<
dyn
RouterTrait
>
,
String
>
{
// Create policy
let
policy
=
PolicyFactory
::
create_from_config
(
policy_config
);
// Create regular router with injected policy
// Create regular router with injected policy
and client
let
router
=
Router
::
new
(
worker_urls
.to_vec
(),
policy
,
router_config
.worker_startup_timeout_secs
,
router_config
.worker_startup_check_interval_secs
,
router_config
.dp_aware
,
router_config
.api_key
.clone
(),
ctx
.client
.clone
(),
ctx
.router_config.worker_startup_timeout_secs
,
ctx
.router_config.worker_startup_check_interval_secs
,
ctx
.router_config.dp_aware
,
ctx
.router_config.api_key
.clone
(),
)
?
;
Ok
(
Box
::
new
(
router
))
...
...
@@ -59,7 +62,7 @@ impl RouterFactory {
prefill_policy_config
:
Option
<&
PolicyConfig
>
,
decode_policy_config
:
Option
<&
PolicyConfig
>
,
main_policy_config
:
&
PolicyConfig
,
router_config
:
&
RouterConfig
,
ctx
:
&
Arc
<
AppContext
>
,
)
->
Result
<
Box
<
dyn
RouterTrait
>
,
String
>
{
// Create policies - use specific policies if provided, otherwise fall back to main policy
let
prefill_policy
=
...
...
@@ -67,14 +70,15 @@ impl RouterFactory {
let
decode_policy
=
PolicyFactory
::
create_from_config
(
decode_policy_config
.unwrap_or
(
main_policy_config
));
// Create PD router with separate policies
// Create PD router with separate policies
and client
let
router
=
PDRouter
::
new
(
prefill_urls
.to_vec
(),
decode_urls
.to_vec
(),
prefill_policy
,
decode_policy
,
router_config
.worker_startup_timeout_secs
,
router_config
.worker_startup_check_interval_secs
,
ctx
.client
.clone
(),
ctx
.router_config.worker_startup_timeout_secs
,
ctx
.router_config.worker_startup_check_interval_secs
,
)
?
;
Ok
(
Box
::
new
(
router
))
...
...
sgl-router/src/routers/mod.rs
View file @
828a4fe9
...
...
@@ -7,7 +7,6 @@ use axum::{
http
::{
HeaderMap
,
StatusCode
},
response
::{
IntoResponse
,
Response
},
};
use
reqwest
::
Client
;
use
std
::
fmt
::
Debug
;
use
crate
::
openai_api_types
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
};
...
...
@@ -46,32 +45,27 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
fn
as_any
(
&
self
)
->
&
dyn
std
::
any
::
Any
;
/// Route a health check request
async
fn
health
(
&
self
,
client
:
&
Client
,
req
:
Request
<
Body
>
)
->
Response
;
async
fn
health
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
;
/// Route a health generate request
async
fn
health_generate
(
&
self
,
client
:
&
Client
,
req
:
Request
<
Body
>
)
->
Response
;
async
fn
health_generate
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
;
/// Get server information
async
fn
get_server_info
(
&
self
,
client
:
&
Client
,
req
:
Request
<
Body
>
)
->
Response
;
async
fn
get_server_info
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
;
/// Get available models
async
fn
get_models
(
&
self
,
client
:
&
Client
,
req
:
Request
<
Body
>
)
->
Response
;
async
fn
get_models
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
;
/// Get model information
async
fn
get_model_info
(
&
self
,
client
:
&
Client
,
req
:
Request
<
Body
>
)
->
Response
;
async
fn
get_model_info
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
;
/// Route a generate request
async
fn
route_generate
(
&
self
,
client
:
&
Client
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
GenerateRequest
,
)
->
Response
;
async
fn
route_generate
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
GenerateRequest
)
->
Response
;
/// Route a chat completion request
async
fn
route_chat
(
&
self
,
client
:
&
Client
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
ChatCompletionRequest
,
)
->
Response
;
...
...
@@ -79,16 +73,15 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
/// Route a completion request
async
fn
route_completion
(
&
self
,
client
:
&
Client
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
CompletionRequest
,
)
->
Response
;
/// Flush cache on all workers
async
fn
flush_cache
(
&
self
,
client
:
&
Client
)
->
Response
;
async
fn
flush_cache
(
&
self
)
->
Response
;
/// Get worker loads (for monitoring)
async
fn
get_worker_loads
(
&
self
,
client
:
&
Client
)
->
Response
;
async
fn
get_worker_loads
(
&
self
)
->
Response
;
/// Get router type name
fn
router_type
(
&
self
)
->
&
'static
str
;
...
...
sgl-router/src/routers/pd_router.rs
View file @
828a4fe9
...
...
@@ -35,7 +35,7 @@ pub struct PDRouter {
pub
interval_secs
:
u64
,
pub
worker_loads
:
Arc
<
tokio
::
sync
::
watch
::
Receiver
<
HashMap
<
String
,
isize
>>>
,
pub
load_monitor_handle
:
Option
<
Arc
<
tokio
::
task
::
JoinHandle
<
()
>>>
,
pub
http_
client
:
Client
,
pub
client
:
Client
,
_
prefill_health_checker
:
Option
<
HealthChecker
>
,
_
decode_health_checker
:
Option
<
HealthChecker
>
,
}
...
...
@@ -177,6 +177,7 @@ impl PDRouter {
decode_urls
:
Vec
<
String
>
,
prefill_policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
decode_policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
client
:
Client
,
timeout_secs
:
u64
,
interval_secs
:
u64
,
)
->
Result
<
Self
,
String
>
{
...
...
@@ -215,17 +216,11 @@ impl PDRouter {
let
(
tx
,
rx
)
=
tokio
::
sync
::
watch
::
channel
(
HashMap
::
new
());
let
worker_loads
=
Arc
::
new
(
rx
);
// Create a shared HTTP client for all operations
let
http_client
=
Client
::
builder
()
.timeout
(
Duration
::
from_secs
(
timeout_secs
))
.build
()
.map_err
(|
e
|
format!
(
"Failed to create HTTP client: {}"
,
e
))
?
;
let
load_monitor_handle
=
if
prefill_policy
.name
()
==
"power_of_two"
||
decode_policy
.name
()
==
"power_of_two"
{
let
monitor_urls
=
all_urls
.clone
();
let
monitor_interval
=
interval_secs
;
let
monitor_client
=
http_
client
.clone
();
let
monitor_client
=
client
.clone
();
let
prefill_policy_clone
=
Arc
::
clone
(
&
prefill_policy
);
let
decode_policy_clone
=
Arc
::
clone
(
&
decode_policy
);
...
...
@@ -264,7 +259,7 @@ impl PDRouter {
interval_secs
,
worker_loads
,
load_monitor_handle
,
http_
client
,
client
,
_
prefill_health_checker
:
Some
(
prefill_health_checker
),
_
decode_health_checker
:
Some
(
decode_health_checker
),
})
...
...
@@ -302,7 +297,6 @@ impl PDRouter {
// Route a typed generate request
pub
async
fn
route_generate
(
&
self
,
client
:
&
Client
,
headers
:
Option
<&
HeaderMap
>
,
mut
typed_req
:
GenerateReqInput
,
route
:
&
str
,
...
...
@@ -371,7 +365,6 @@ impl PDRouter {
// Execute dual dispatch
self
.execute_dual_dispatch
(
client
,
headers
,
json_with_bootstrap
,
route
,
...
...
@@ -387,7 +380,6 @@ impl PDRouter {
// Route a typed chat request
pub
async
fn
route_chat
(
&
self
,
client
:
&
Client
,
headers
:
Option
<&
HeaderMap
>
,
mut
typed_req
:
ChatReqInput
,
route
:
&
str
,
...
...
@@ -459,7 +451,6 @@ impl PDRouter {
// Execute dual dispatch
self
.execute_dual_dispatch
(
client
,
headers
,
json_with_bootstrap
,
route
,
...
...
@@ -475,7 +466,6 @@ impl PDRouter {
// Route a completion request while preserving OpenAI format
pub
async
fn
route_completion
(
&
self
,
client
:
&
Client
,
headers
:
Option
<&
HeaderMap
>
,
mut
typed_req
:
CompletionRequest
,
route
:
&
str
,
...
...
@@ -540,7 +530,6 @@ impl PDRouter {
// Execute dual dispatch
self
.execute_dual_dispatch
(
client
,
headers
,
json_with_bootstrap
,
route
,
...
...
@@ -554,10 +543,8 @@ impl PDRouter {
}
// Execute the dual dispatch to prefill and decode servers
#[allow(clippy::too_many_arguments)]
async
fn
execute_dual_dispatch
(
&
self
,
client
:
&
Client
,
headers
:
Option
<&
HeaderMap
>
,
json_request
:
Value
,
route
:
&
str
,
...
...
@@ -571,11 +558,13 @@ impl PDRouter {
let
_
guard
=
WorkerLoadGuard
::
new_multi
(
vec!
[
prefill
,
decode
]);
// Build requests using .json() method
let
mut
prefill_request
=
client
let
mut
prefill_request
=
self
.client
.post
(
api_path
(
prefill
.url
(),
route
))
.json
(
&
json_request
);
let
mut
decode_request
=
client
let
mut
decode_request
=
self
.client
.post
(
api_path
(
decode
.url
(),
route
))
.json
(
&
json_request
);
...
...
@@ -987,7 +976,7 @@ async fn get_worker_load(client: &reqwest::Client, worker_url: &str) -> Option<i
// PD-specific endpoints
impl
PDRouter
{
pub
async
fn
health_generate
(
&
self
,
client
:
&
reqwest
::
Client
)
->
Response
{
pub
async
fn
health_generate
(
&
self
)
->
Response
{
// Test model generation capability by selecting a random pair and testing them
// Note: This endpoint actually causes the model to generate tokens, so we only test one pair
...
...
@@ -1005,11 +994,11 @@ impl PDRouter {
// Test prefill server's health_generate
let
prefill_url
=
format!
(
"{}/health_generate"
,
prefill
.url
());
let
prefill_result
=
client
.get
(
&
prefill_url
)
.send
()
.await
;
let
prefill_result
=
self
.
client
.get
(
&
prefill_url
)
.send
()
.await
;
// Test decode server's health_generate
let
decode_url
=
format!
(
"{}/health_generate"
,
decode
.url
());
let
decode_result
=
client
.get
(
&
decode_url
)
.send
()
.await
;
let
decode_result
=
self
.
client
.get
(
&
decode_url
)
.send
()
.await
;
// Check results
let
mut
errors
=
Vec
::
new
();
...
...
@@ -1068,7 +1057,7 @@ impl PDRouter {
}
}
pub
async
fn
get_server_info
(
&
self
,
client
:
&
reqwest
::
Client
)
->
Response
{
pub
async
fn
get_server_info
(
&
self
)
->
Response
{
// Get info from the first decode server to match sglang's server info format
let
first_decode_url
=
if
let
Ok
(
workers
)
=
self
.decode_workers
.read
()
{
workers
.first
()
.map
(|
w
|
w
.url
()
.to_string
())
...
...
@@ -1081,7 +1070,8 @@ impl PDRouter {
};
if
let
Some
(
worker_url
)
=
first_decode_url
{
match
client
match
self
.client
.get
(
format!
(
"{}/get_server_info"
,
worker_url
))
.send
()
.await
...
...
@@ -1130,7 +1120,7 @@ impl PDRouter {
}
}
pub
async
fn
get_models
(
&
self
,
client
:
&
reqwest
::
Client
,
req
:
Request
<
Body
>
)
->
Response
{
pub
async
fn
get_models
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
{
// Extract headers first to avoid Send issues
let
headers
=
crate
::
routers
::
router
::
copy_request_headers
(
&
req
);
...
...
@@ -1147,7 +1137,7 @@ impl PDRouter {
if
let
Some
(
worker_url
)
=
first_worker_url
{
// Send request directly without going through Router
let
mut
request_builder
=
client
.get
(
format!
(
"{}/v1/models"
,
worker_url
));
let
mut
request_builder
=
self
.
client
.get
(
format!
(
"{}/v1/models"
,
worker_url
));
for
(
name
,
value
)
in
headers
{
if
name
.to_lowercase
()
!=
"content-type"
&&
name
.to_lowercase
()
!=
"content-length"
{
...
...
@@ -1224,7 +1214,7 @@ impl PDRouter {
.into_response
()
}
pub
async
fn
get_model_info
(
&
self
,
client
:
&
reqwest
::
Client
,
req
:
Request
<
Body
>
)
->
Response
{
pub
async
fn
get_model_info
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
{
// Extract headers first to avoid Send issues
let
headers
=
crate
::
routers
::
router
::
copy_request_headers
(
&
req
);
...
...
@@ -1241,7 +1231,7 @@ impl PDRouter {
};
if
let
Some
(
worker_url
)
=
first_worker_url
{
let
mut
request_builder
=
client
.get
(
format!
(
"{}/get_model_info"
,
worker_url
));
let
mut
request_builder
=
self
.
client
.get
(
format!
(
"{}/get_model_info"
,
worker_url
));
for
(
name
,
value
)
in
headers
{
if
name
.to_lowercase
()
!=
"content-type"
&&
name
.to_lowercase
()
!=
"content-length"
{
...
...
@@ -1384,7 +1374,7 @@ impl RouterTrait for PDRouter {
self
}
async
fn
health
(
&
self
,
_
client
:
&
Client
,
_
req
:
Request
<
Body
>
)
->
Response
{
async
fn
health
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
// This is a server readiness check - checking if we have healthy workers
// Workers handle their own health checks in the background
let
mut
all_healthy
=
true
;
...
...
@@ -1417,68 +1407,65 @@ impl RouterTrait for PDRouter {
}
}
async
fn
health_generate
(
&
self
,
client
:
&
Client
,
_
req
:
Request
<
Body
>
)
->
Response
{
async
fn
health_generate
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
// Use the existing PDRouter health_generate method
PDRouter
::
health_generate
(
self
,
client
)
.await
PDRouter
::
health_generate
(
self
)
.await
}
async
fn
get_server_info
(
&
self
,
client
:
&
Client
,
_
req
:
Request
<
Body
>
)
->
Response
{
async
fn
get_server_info
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
// Use the existing PDRouter get_server_info method
PDRouter
::
get_server_info
(
self
,
client
)
.await
PDRouter
::
get_server_info
(
self
)
.await
}
async
fn
get_models
(
&
self
,
client
:
&
Client
,
req
:
Request
<
Body
>
)
->
Response
{
async
fn
get_models
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
{
// Use the existing PDRouter get_models method
PDRouter
::
get_models
(
self
,
client
,
req
)
.await
PDRouter
::
get_models
(
self
,
req
)
.await
}
async
fn
get_model_info
(
&
self
,
client
:
&
Client
,
req
:
Request
<
Body
>
)
->
Response
{
async
fn
get_model_info
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
{
// Use the existing PDRouter get_model_info method
PDRouter
::
get_model_info
(
self
,
client
,
req
)
.await
PDRouter
::
get_model_info
(
self
,
req
)
.await
}
async
fn
route_generate
(
&
self
,
client
:
&
Client
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
GenerateRequest
,
)
->
Response
{
// Convert OpenAI format to PD format
let
pd_req
=
body
.clone
()
.to_pd_request
();
PDRouter
::
route_generate
(
self
,
client
,
headers
,
pd_req
,
"/generate"
)
.await
PDRouter
::
route_generate
(
self
,
headers
,
pd_req
,
"/generate"
)
.await
}
async
fn
route_chat
(
&
self
,
client
:
&
Client
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
ChatCompletionRequest
,
)
->
Response
{
// Convert OpenAI format to PD format
let
pd_req
=
body
.clone
()
.to_pd_request
();
PDRouter
::
route_chat
(
self
,
client
,
headers
,
pd_req
,
"/v1/chat/completions"
)
.await
PDRouter
::
route_chat
(
self
,
headers
,
pd_req
,
"/v1/chat/completions"
)
.await
}
async
fn
route_completion
(
&
self
,
client
:
&
Client
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
CompletionRequest
,
)
->
Response
{
// Use the new method that preserves OpenAI format
PDRouter
::
route_completion
(
self
,
client
,
headers
,
body
.clone
(),
"/v1/completions"
)
.await
PDRouter
::
route_completion
(
self
,
headers
,
body
.clone
(),
"/v1/completions"
)
.await
}
async
fn
flush_cache
(
&
self
,
client
:
&
Client
)
->
Response
{
async
fn
flush_cache
(
&
self
)
->
Response
{
// Use the existing PDRouter flush_cache method
PDRouter
::
flush_cache
(
self
,
client
)
.await
PDRouter
::
flush_cache
(
self
,
&
self
.
client
)
.await
}
async
fn
get_worker_loads
(
&
self
,
client
:
&
Client
)
->
Response
{
async
fn
get_worker_loads
(
&
self
)
->
Response
{
// Use the existing PDRouter get_loads method
PDRouter
::
get_loads
(
self
,
client
)
.await
PDRouter
::
get_loads
(
self
,
&
self
.
client
)
.await
}
fn
router_type
(
&
self
)
->
&
'static
str
{
...
...
@@ -1570,7 +1557,7 @@ mod tests {
interval_secs
:
1
,
worker_loads
:
Arc
::
new
(
tokio
::
sync
::
watch
::
channel
(
HashMap
::
new
())
.1
),
load_monitor_handle
:
None
,
http_
client
:
reqwest
::
Client
::
new
(),
client
:
Client
::
new
(),
_
prefill_health_checker
:
None
,
_
decode_health_checker
:
None
,
}
...
...
@@ -1959,11 +1946,10 @@ mod tests {
router
.decode_workers
.write
()
.unwrap
()
.push
(
decode_worker
);
// Test health endpoint
let
client
=
reqwest
::
Client
::
new
();
let
http_req
=
axum
::
http
::
Request
::
builder
()
.body
(
axum
::
body
::
Body
::
empty
())
.unwrap
();
let
response
=
router
.health
(
&
client
,
http_req
)
.await
;
let
response
=
router
.health
(
http_req
)
.await
;
assert_eq!
(
response
.status
(),
200
);
...
...
sgl-router/src/routers/router.rs
View file @
828a4fe9
...
...
@@ -34,6 +34,7 @@ pub fn copy_request_headers(req: &Request<Body>) -> Vec<(String, String)> {
pub
struct
Router
{
workers
:
Arc
<
RwLock
<
Vec
<
Box
<
dyn
Worker
>>>>
,
policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
client
:
Client
,
timeout_secs
:
u64
,
interval_secs
:
u64
,
dp_aware
:
bool
,
...
...
@@ -44,10 +45,11 @@ pub struct Router {
}
impl
Router
{
/// Create a new router with injected policy
/// Create a new router with injected policy
and client
pub
fn
new
(
worker_urls
:
Vec
<
String
>
,
policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
client
:
Client
,
timeout_secs
:
u64
,
interval_secs
:
u64
,
dp_aware
:
bool
,
...
...
@@ -94,9 +96,17 @@ impl Router {
let
monitor_urls
=
worker_urls
.clone
();
let
monitor_interval
=
interval_secs
;
let
policy_clone
=
Arc
::
clone
(
&
policy
);
let
client_clone
=
client
.clone
();
Some
(
Arc
::
new
(
tokio
::
spawn
(
async
move
{
Self
::
monitor_worker_loads
(
monitor_urls
,
tx
,
monitor_interval
,
policy_clone
)
.await
;
Self
::
monitor_worker_loads
(
monitor_urls
,
tx
,
monitor_interval
,
policy_clone
,
client_clone
,
)
.await
;
})))
}
else
{
None
...
...
@@ -105,6 +115,7 @@ impl Router {
Ok
(
Router
{
workers
,
policy
,
client
,
timeout_secs
,
interval_secs
,
dp_aware
,
...
...
@@ -245,7 +256,7 @@ impl Router {
}
}
pub
async
fn
send_health_check
(
&
self
,
client
:
&
Client
,
worker_url
:
&
str
)
->
Response
{
pub
async
fn
send_health_check
(
&
self
,
worker_url
:
&
str
)
->
Response
{
let
health_url
=
if
self
.dp_aware
{
// Need to extract the URL from "http://host:port@dp_rank"
match
Self
::
extract_dp_rank
(
worker_url
)
{
...
...
@@ -263,7 +274,7 @@ impl Router {
worker_url
};
let
request_builder
=
client
.get
(
format!
(
"{}/health"
,
health_url
));
let
request_builder
=
self
.
client
.get
(
format!
(
"{}/health"
,
health_url
));
let
response
=
match
request_builder
.send
()
.await
{
Ok
(
res
)
=>
{
...
...
@@ -305,17 +316,12 @@ impl Router {
}
// Helper method to proxy GET requests to the first available worker
async
fn
proxy_get_request
(
&
self
,
client
:
&
Client
,
req
:
Request
<
Body
>
,
endpoint
:
&
str
,
)
->
Response
{
async
fn
proxy_get_request
(
&
self
,
req
:
Request
<
Body
>
,
endpoint
:
&
str
)
->
Response
{
let
headers
=
copy_request_headers
(
&
req
);
match
self
.select_first_worker
()
{
Ok
(
worker_url
)
=>
{
let
mut
request_builder
=
client
.get
(
format!
(
"{}/{}"
,
worker_url
,
endpoint
));
let
mut
request_builder
=
self
.
client
.get
(
format!
(
"{}/{}"
,
worker_url
,
endpoint
));
for
(
name
,
value
)
in
headers
{
if
name
.to_lowercase
()
!=
"content-type"
&&
name
.to_lowercase
()
!=
"content-length"
...
...
@@ -353,7 +359,6 @@ impl Router {
T
:
crate
::
openai_api_types
::
GenerationRequest
+
serde
::
Serialize
+
Clone
,
>
(
&
self
,
client
:
&
reqwest
::
Client
,
headers
:
Option
<&
HeaderMap
>
,
typed_req
:
&
T
,
route
:
&
str
,
...
...
@@ -397,7 +402,6 @@ impl Router {
// Send typed request directly
let
response
=
self
.send_typed_request
(
client
,
headers
,
typed_req
,
route
,
...
...
@@ -413,7 +417,7 @@ impl Router {
return
response
;
}
else
{
// if the worker is healthy, it means the request is bad, so return the error response
let
health_response
=
self
.send_health_check
(
client
,
&
worker_url
)
.await
;
let
health_response
=
self
.send_health_check
(
&
worker_url
)
.await
;
if
health_response
.status
()
.is_success
()
{
RouterMetrics
::
record_request_error
(
route
,
"request_failed"
);
return
response
;
...
...
@@ -483,7 +487,6 @@ impl Router {
// Send typed request directly without conversion
async
fn
send_typed_request
<
T
:
serde
::
Serialize
>
(
&
self
,
client
:
&
reqwest
::
Client
,
headers
:
Option
<&
HeaderMap
>
,
typed_req
:
&
T
,
route
:
&
str
,
...
...
@@ -536,11 +539,11 @@ impl Router {
.into_response
();
}
client
self
.
client
.post
(
format!
(
"{}{}"
,
worker_url_prefix
,
route
))
.json
(
&
json_val
)
}
else
{
client
self
.
client
.post
(
format!
(
"{}{}"
,
worker_url
,
route
))
.json
(
typed_req
)
// Use json() directly with typed request
};
...
...
@@ -866,7 +869,7 @@ impl Router {
}
}
async
fn
get_worker_load
(
&
self
,
client
:
&
reqwest
::
Client
,
worker_url
:
&
str
)
->
Option
<
isize
>
{
async
fn
get_worker_load
(
&
self
,
worker_url
:
&
str
)
->
Option
<
isize
>
{
let
worker_url
=
if
self
.dp_aware
{
// Need to extract the URL from "http://host:port@dp_rank"
let
(
worker_url_prefix
,
_
dp_rank
)
=
match
Self
::
extract_dp_rank
(
worker_url
)
{
...
...
@@ -881,7 +884,12 @@ impl Router {
worker_url
};
match
client
.get
(
&
format!
(
"{}/get_load"
,
worker_url
))
.send
()
.await
{
match
self
.client
.get
(
&
format!
(
"{}/get_load"
,
worker_url
))
.send
()
.await
{
Ok
(
res
)
if
res
.status
()
.is_success
()
=>
match
res
.bytes
()
.await
{
Ok
(
bytes
)
=>
match
serde_json
::
from_slice
::
<
serde_json
::
Value
>
(
&
bytes
)
{
Ok
(
data
)
=>
data
...
...
@@ -919,18 +927,8 @@ impl Router {
tx
:
tokio
::
sync
::
watch
::
Sender
<
HashMap
<
String
,
isize
>>
,
interval_secs
:
u64
,
policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
client
:
Client
,
)
{
let
client
=
match
reqwest
::
Client
::
builder
()
.timeout
(
Duration
::
from_secs
(
5
))
.build
()
{
Ok
(
c
)
=>
c
,
Err
(
e
)
=>
{
error!
(
"Failed to create HTTP client for load monitoring: {}"
,
e
);
return
;
}
};
let
mut
interval
=
tokio
::
time
::
interval
(
Duration
::
from_secs
(
interval_secs
));
loop
{
...
...
@@ -1028,7 +1026,7 @@ impl RouterTrait for Router {
self
}
async
fn
health
(
&
self
,
_
client
:
&
Client
,
_
req
:
Request
<
Body
>
)
->
Response
{
async
fn
health
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
let
workers
=
self
.workers
.read
()
.unwrap
();
let
unhealthy_servers
:
Vec
<
_
>
=
workers
.iter
()
...
...
@@ -1047,53 +1045,49 @@ impl RouterTrait for Router {
}
}
async
fn
health_generate
(
&
self
,
client
:
&
Client
,
req
:
Request
<
Body
>
)
->
Response
{
self
.proxy_get_request
(
client
,
req
,
"health_generate"
)
.await
async
fn
health_generate
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
{
self
.proxy_get_request
(
req
,
"health_generate"
)
.await
}
async
fn
get_server_info
(
&
self
,
client
:
&
Client
,
req
:
Request
<
Body
>
)
->
Response
{
self
.proxy_get_request
(
client
,
req
,
"get_server_info"
)
.await
async
fn
get_server_info
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
{
self
.proxy_get_request
(
req
,
"get_server_info"
)
.await
}
async
fn
get_models
(
&
self
,
client
:
&
Client
,
req
:
Request
<
Body
>
)
->
Response
{
self
.proxy_get_request
(
client
,
req
,
"v1/models"
)
.await
async
fn
get_models
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
{
self
.proxy_get_request
(
req
,
"v1/models"
)
.await
}
async
fn
get_model_info
(
&
self
,
client
:
&
Client
,
req
:
Request
<
Body
>
)
->
Response
{
self
.proxy_get_request
(
client
,
req
,
"get_model_info"
)
.await
async
fn
get_model_info
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
{
self
.proxy_get_request
(
req
,
"get_model_info"
)
.await
}
async
fn
route_generate
(
&
self
,
client
:
&
Client
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
GenerateRequest
,
)
->
Response
{
self
.route_typed_request
(
client
,
headers
,
body
,
"/generate"
)
.await
self
.route_typed_request
(
headers
,
body
,
"/generate"
)
.await
}
async
fn
route_chat
(
&
self
,
client
:
&
Client
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
ChatCompletionRequest
,
)
->
Response
{
self
.route_typed_request
(
client
,
headers
,
body
,
"/v1/chat/completions"
)
self
.route_typed_request
(
headers
,
body
,
"/v1/chat/completions"
)
.await
}
async
fn
route_completion
(
&
self
,
client
:
&
Client
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
CompletionRequest
,
)
->
Response
{
self
.route_typed_request
(
client
,
headers
,
body
,
"/v1/completions"
)
self
.route_typed_request
(
headers
,
body
,
"/v1/completions"
)
.await
}
async
fn
flush_cache
(
&
self
,
client
:
&
Client
)
->
Response
{
async
fn
flush_cache
(
&
self
)
->
Response
{
// Get all worker URLs
let
worker_urls
=
self
.get_worker_urls
();
...
...
@@ -1117,7 +1111,7 @@ impl RouterTrait for Router {
}
else
{
worker_url
};
let
request_builder
=
client
.post
(
format!
(
"{}/flush_cache"
,
worker_url
));
let
request_builder
=
self
.
client
.post
(
format!
(
"{}/flush_cache"
,
worker_url
));
tasks
.push
(
request_builder
.send
());
}
...
...
@@ -1142,13 +1136,13 @@ impl RouterTrait for Router {
}
}
async
fn
get_worker_loads
(
&
self
,
client
:
&
Client
)
->
Response
{
async
fn
get_worker_loads
(
&
self
)
->
Response
{
let
urls
=
self
.get_worker_urls
();
let
mut
loads
=
Vec
::
new
();
// Get loads from all workers
for
url
in
&
urls
{
let
load
=
self
.get_worker_load
(
client
,
url
)
.await
.unwrap_or
(
-
1
);
let
load
=
self
.get_worker_load
(
url
)
.await
.unwrap_or
(
-
1
);
loads
.push
(
serde_json
::
json!
({
"worker"
:
url
,
"load"
:
load
...
...
@@ -1215,6 +1209,7 @@ mod tests {
interval_secs
:
1
,
dp_aware
:
false
,
api_key
:
None
,
client
:
Client
::
new
(),
_
worker_loads
:
Arc
::
new
(
rx
),
_
load_monitor_handle
:
None
,
_
health_checker
:
None
,
...
...
sgl-router/src/server.rs
View file @
828a4fe9
...
...
@@ -22,29 +22,34 @@ use tokio::spawn;
use
tracing
::{
error
,
info
,
warn
,
Level
};
#[derive(Clone)]
pub
struct
AppState
{
pub
router
:
Arc
<
dyn
RouterTrait
>
,
pub
struct
AppContext
{
pub
client
:
Client
,
pub
_
concurrency_limiter
:
Arc
<
tokio
::
sync
::
Semaphore
>
,
pub
router_config
:
RouterConfig
,
pub
concurrency_limiter
:
Arc
<
tokio
::
sync
::
Semaphore
>
,
// Future dependencies can be added here
}
impl
App
State
{
impl
App
Context
{
pub
fn
new
(
router_config
:
RouterConfig
,
client
:
Client
,
max_concurrent_requests
:
usize
,
)
->
Result
<
Self
,
String
>
{
let
router
=
RouterFactory
::
create_router
(
&
router_config
)
?
;
let
router
=
Arc
::
from
(
router
);
)
->
Self
{
let
concurrency_limiter
=
Arc
::
new
(
tokio
::
sync
::
Semaphore
::
new
(
max_concurrent_requests
));
Ok
(
Self
{
router
,
Self
{
client
,
_
concurrency_limiter
:
concurrency_limiter
,
})
router_config
,
concurrency_limiter
,
}
}
}
#[derive(Clone)]
pub
struct
AppState
{
pub
router
:
Arc
<
dyn
RouterTrait
>
,
pub
context
:
Arc
<
AppContext
>
,
}
// Fallback handler for unmatched routes
async
fn
sink_handler
()
->
Response
{
StatusCode
::
NOT_FOUND
.into_response
()
...
...
@@ -60,23 +65,23 @@ async fn readiness(State(state): State<Arc<AppState>>) -> Response {
}
async
fn
health
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
req
:
Request
)
->
Response
{
state
.router
.health
(
&
state
.client
,
req
)
.await
state
.router
.health
(
req
)
.await
}
async
fn
health_generate
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
req
:
Request
)
->
Response
{
state
.router
.health_generate
(
&
state
.client
,
req
)
.await
state
.router
.health_generate
(
req
)
.await
}
async
fn
get_server_info
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
req
:
Request
)
->
Response
{
state
.router
.get_server_info
(
&
state
.client
,
req
)
.await
state
.router
.get_server_info
(
req
)
.await
}
async
fn
v1_models
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
req
:
Request
)
->
Response
{
state
.router
.get_models
(
&
state
.client
,
req
)
.await
state
.router
.get_models
(
req
)
.await
}
async
fn
get_model_info
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
req
:
Request
)
->
Response
{
state
.router
.get_model_info
(
&
state
.client
,
req
)
.await
state
.router
.get_model_info
(
req
)
.await
}
// Generation endpoints
...
...
@@ -86,10 +91,7 @@ async fn generate(
headers
:
http
::
HeaderMap
,
Json
(
body
):
Json
<
GenerateRequest
>
,
)
->
Response
{
state
.router
.route_generate
(
&
state
.client
,
Some
(
&
headers
),
&
body
)
.await
state
.router
.route_generate
(
Some
(
&
headers
),
&
body
)
.await
}
async
fn
v1_chat_completions
(
...
...
@@ -97,10 +99,7 @@ async fn v1_chat_completions(
headers
:
http
::
HeaderMap
,
Json
(
body
):
Json
<
ChatCompletionRequest
>
,
)
->
Response
{
state
.router
.route_chat
(
&
state
.client
,
Some
(
&
headers
),
&
body
)
.await
state
.router
.route_chat
(
Some
(
&
headers
),
&
body
)
.await
}
async
fn
v1_completions
(
...
...
@@ -108,10 +107,7 @@ async fn v1_completions(
headers
:
http
::
HeaderMap
,
Json
(
body
):
Json
<
CompletionRequest
>
,
)
->
Response
{
state
.router
.route_completion
(
&
state
.client
,
Some
(
&
headers
),
&
body
)
.await
state
.router
.route_completion
(
Some
(
&
headers
),
&
body
)
.await
}
// Worker management endpoints
...
...
@@ -159,11 +155,11 @@ async fn remove_worker(
}
async
fn
flush_cache
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
_
req
:
Request
)
->
Response
{
state
.router
.flush_cache
(
&
state
.client
)
.await
state
.router
.flush_cache
()
.await
}
async
fn
get_loads
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
_
req
:
Request
)
->
Response
{
state
.router
.get_worker_loads
(
&
state
.client
)
.await
state
.router
.get_worker_loads
()
.await
}
pub
struct
ServerConfig
{
...
...
@@ -281,11 +277,21 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
.build
()
.expect
(
"Failed to create HTTP client"
);
let
app_state
=
Arc
::
new
(
AppState
::
new
(
// Create the application context with all dependencies
let
app_context
=
Arc
::
new
(
AppContext
::
new
(
config
.router_config
.clone
(),
client
.clone
(),
config
.router_config.max_concurrent_requests
,
)
?
);
));
// Create router with the context
let
router
=
RouterFactory
::
create_router
(
&
app_context
)
?
;
// Create app state with router and context
let
app_state
=
Arc
::
new
(
AppState
{
router
:
Arc
::
from
(
router
),
context
:
app_context
.clone
(),
});
let
router_arc
=
Arc
::
clone
(
&
app_state
.router
);
// Start the service discovery if enabled
...
...
sgl-router/src/service_discovery.rs
View file @
828a4fe9
...
...
@@ -40,7 +40,6 @@ impl Default for ServiceDiscoveryConfig {
check_interval
:
Duration
::
from_secs
(
60
),
port
:
8000
,
// Standard port for modern services
namespace
:
None
,
// None means watch all namespaces
// PD mode defaults
pd_mode
:
false
,
prefill_selector
:
HashMap
::
new
(),
decode_selector
:
HashMap
::
new
(),
...
...
@@ -581,7 +580,8 @@ mod tests {
use
crate
::
routers
::
router
::
Router
;
let
policy
=
PolicyFactory
::
create_from_config
(
&
PolicyConfig
::
Random
);
let
router
=
Router
::
new
(
vec!
[],
policy
,
5
,
1
,
false
,
None
)
.unwrap
();
let
router
=
Router
::
new
(
vec!
[],
policy
,
reqwest
::
Client
::
new
(),
5
,
1
,
false
,
None
)
.unwrap
();
Arc
::
new
(
router
)
as
Arc
<
dyn
RouterTrait
>
}
...
...
sgl-router/tests/api_endpoints_test.rs
View file @
828a4fe9
...
...
@@ -83,12 +83,12 @@ impl TestContext {
.build
()
.unwrap
();
// C
lone config for the closure
let
config
_c
l
on
e
=
config
.clone
();
// C
reate app context
let
app
_con
text
=
common
::
create_test_context
(
config
.clone
()
)
;
// Create router using sync factory in a blocking context
let
router
=
tokio
::
task
::
spawn_blocking
(
move
||
RouterFactory
::
create_router
(
&
config
_c
l
on
e
))
tokio
::
task
::
spawn_blocking
(
move
||
RouterFactory
::
create_router
(
&
app
_con
text
))
.await
.unwrap
()
.unwrap
();
...
...
@@ -1433,9 +1433,12 @@ mod pd_mode_tests {
cors_allowed_origins
:
vec!
[],
};
// Create app context
let
app_context
=
common
::
create_test_context
(
config
);
// Create router - this might fail due to health check issues
let
router_result
=
tokio
::
task
::
spawn_blocking
(
move
||
RouterFactory
::
create_router
(
&
config
))
tokio
::
task
::
spawn_blocking
(
move
||
RouterFactory
::
create_router
(
&
app_context
))
.await
.unwrap
();
...
...
sgl-router/tests/common/mod.rs
View file @
828a4fe9
pub
mod
mock_worker
;
pub
mod
test_app
;
use
sglang_router_rs
::
config
::
RouterConfig
;
use
sglang_router_rs
::
server
::
AppContext
;
use
std
::
sync
::
Arc
;
/// Helper function to create AppContext for tests
pub
fn
create_test_context
(
config
:
RouterConfig
)
->
Arc
<
AppContext
>
{
Arc
::
new
(
AppContext
::
new
(
config
.clone
(),
reqwest
::
Client
::
new
(),
config
.max_concurrent_requests
,
))
}
sgl-router/tests/common/test_app.rs
View file @
828a4fe9
...
...
@@ -3,7 +3,7 @@ use reqwest::Client;
use
sglang_router_rs
::{
config
::
RouterConfig
,
routers
::
RouterTrait
,
server
::{
build_app
,
AppState
},
server
::{
build_app
,
AppContext
,
AppState
},
};
use
std
::
sync
::
Arc
;
...
...
@@ -13,13 +13,17 @@ pub fn create_test_app(
client
:
Client
,
router_config
:
&
RouterConfig
,
)
->
Router
{
// Create AppState with the test router
// Create AppContext
let
app_context
=
Arc
::
new
(
AppContext
::
new
(
router_config
.clone
(),
client
,
router_config
.max_concurrent_requests
,
));
// Create AppState with the test router and context
let
app_state
=
Arc
::
new
(
AppState
{
router
,
client
,
_
concurrency_limiter
:
Arc
::
new
(
tokio
::
sync
::
Semaphore
::
new
(
router_config
.max_concurrent_requests
,
)),
context
:
app_context
,
});
// Configure request ID headers (use defaults if not specified)
...
...
sgl-router/tests/request_formats_test.rs
View file @
828a4fe9
...
...
@@ -53,10 +53,12 @@ impl TestContext {
config
.mode
=
RoutingMode
::
Regular
{
worker_urls
};
let
router
=
tokio
::
task
::
spawn_blocking
(
move
||
RouterFactory
::
create_router
(
&
config
))
.await
.unwrap
()
.unwrap
();
let
app_context
=
common
::
create_test_context
(
config
);
let
router
=
tokio
::
task
::
spawn_blocking
(
move
||
RouterFactory
::
create_router
(
&
app_context
))
.await
.unwrap
()
.unwrap
();
let
router
=
Arc
::
from
(
router
);
if
!
workers
.is_empty
()
{
...
...
sgl-router/tests/streaming_tests.rs
View file @
828a4fe9
...
...
@@ -54,10 +54,12 @@ impl TestContext {
config
.mode
=
RoutingMode
::
Regular
{
worker_urls
};
let
router
=
tokio
::
task
::
spawn_blocking
(
move
||
RouterFactory
::
create_router
(
&
config
))
.await
.unwrap
()
.unwrap
();
let
app_context
=
common
::
create_test_context
(
config
);
let
router
=
tokio
::
task
::
spawn_blocking
(
move
||
RouterFactory
::
create_router
(
&
app_context
))
.await
.unwrap
()
.unwrap
();
let
router
=
Arc
::
from
(
router
);
if
!
workers
.is_empty
()
{
...
...
sgl-router/tests/test_pd_routing.rs
View file @
828a4fe9
...
...
@@ -181,7 +181,10 @@ mod test_pd_routing {
};
// Router creation will fail due to health checks, but config should be valid
let
result
=
RouterFactory
::
create_router
(
&
config
);
let
app_context
=
sglang_router_rs
::
server
::
AppContext
::
new
(
config
,
reqwest
::
Client
::
new
(),
64
);
let
app_context
=
std
::
sync
::
Arc
::
new
(
app_context
);
let
result
=
RouterFactory
::
create_router
(
&
app_context
);
assert
!
(
result
.is_err
());
let
error_msg
=
result
.unwrap_err
();
// Error should be about health/timeout, not configuration
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment