Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2f173ea0
Unverified
Commit
2f173ea0
authored
Sep 12, 2025
by
Simo Lin
Committed by
GitHub
Sep 12, 2025
Browse files
[router] allow one router to support different model families and serving mode (#10244)
parent
321fecab
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1628 additions
and
212 deletions
+1628
-212
sgl-router/src/routers/http/router.rs
sgl-router/src/routers/http/router.rs
+269
-188
sgl-router/src/routers/mod.rs
sgl-router/src/routers/mod.rs
+16
-3
sgl-router/src/routers/router_manager.rs
sgl-router/src/routers/router_manager.rs
+766
-0
sgl-router/src/server.rs
sgl-router/src/server.rs
+262
-10
sgl-router/src/service_discovery.rs
sgl-router/src/service_discovery.rs
+8
-5
sgl-router/tests/cache_aware_backward_compat_test.rs
sgl-router/tests/cache_aware_backward_compat_test.rs
+129
-0
sgl-router/tests/policy_registry_integration.rs
sgl-router/tests/policy_registry_integration.rs
+168
-0
sgl-router/tests/test_openai_routing.rs
sgl-router/tests/test_openai_routing.rs
+10
-6
No files found.
sgl-router/src/routers/http/router.rs
View file @
2f173ea0
use
crate
::
config
::
types
::
RetryConfig
;
use
crate
::
core
::{
is_retryable_status
,
BasicWorker
,
CircuitBreakerConfig
,
HealthC
hecker
,
HealthConfig
,
RetryExecutor
,
Worker
,
WorkerFacto
ry
,
WorkerType
,
is_retryable_status
,
BasicWorker
,
CircuitBreakerConfig
,
HealthC
onfig
,
RetryExecutor
,
Worker
,
WorkerRegist
ry
,
WorkerType
,
};
use
crate
::
metrics
::
RouterMetrics
;
use
crate
::
policies
::
LoadBalancingPolicy
;
use
crate
::
policies
::
{
LoadBalancingPolicy
,
PolicyRegistry
}
;
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
,
GenerationRequest
,
RerankRequest
,
RerankResponse
,
RerankResult
,
ResponsesRequest
,
...
...
@@ -22,7 +22,7 @@ use axum::{
use
futures_util
::
StreamExt
;
use
reqwest
::
Client
;
use
std
::
collections
::
HashMap
;
use
std
::
sync
::
{
Arc
,
RwLock
}
;
use
std
::
sync
::
Arc
;
use
std
::
time
::{
Duration
,
Instant
};
use
tokio_stream
::
wrappers
::
UnboundedReceiverStream
;
use
tracing
::{
debug
,
error
,
info
,
warn
};
...
...
@@ -30,8 +30,8 @@ use tracing::{debug, error, info, warn};
/// Regular router that uses injected load balancing policies
#[derive(Debug)]
pub
struct
Router
{
worker
s
:
Arc
<
RwLock
<
Vec
<
Box
<
dyn
Worker
>>>
>
,
policy
:
Arc
<
dyn
LoadBalancingPolic
y
>
,
worker
_registry
:
Arc
<
WorkerRegistry
>
,
policy
_registry
:
Arc
<
PolicyRegistr
y
>
,
client
:
Client
,
worker_startup_timeout_secs
:
u64
,
worker_startup_check_interval_secs
:
u64
,
...
...
@@ -41,7 +41,6 @@ pub struct Router {
circuit_breaker_config
:
CircuitBreakerConfig
,
_
worker_loads
:
Arc
<
tokio
::
sync
::
watch
::
Receiver
<
HashMap
<
String
,
isize
>>>
,
_
load_monitor_handle
:
Option
<
Arc
<
tokio
::
task
::
JoinHandle
<
()
>>>
,
_
health_checker
:
Option
<
HealthChecker
>
,
}
impl
Router
{
...
...
@@ -49,7 +48,6 @@ impl Router {
#[allow(clippy::too_many_arguments)]
pub
async
fn
new
(
worker_urls
:
Vec
<
String
>
,
policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
ctx
:
&
Arc
<
crate
::
server
::
AppContext
>
,
)
->
Result
<
Self
,
String
>
{
// Update active workers gauge
...
...
@@ -82,45 +80,51 @@ impl Router {
window_duration
:
Duration
::
from_secs
(
circuit_breaker_config
.window_duration_secs
),
};
// Create Worker trait objects from URLs with health check config
let
workers
:
Vec
<
Box
<
dyn
Worker
>>
=
worker_urls
.iter
()
.map
(|
url
|
{
let
worker
=
BasicWorker
::
new
(
url
.clone
(),
WorkerType
::
Regular
)
.with_circuit_breaker_config
(
core_cb_config
.clone
())
.with_health_config
(
HealthConfig
{
timeout_secs
:
ctx
.router_config.health_check.timeout_secs
,
check_interval_secs
:
ctx
.router_config.health_check.check_interval_secs
,
endpoint
:
ctx
.router_config.health_check.endpoint
.clone
(),
failure_threshold
:
ctx
.router_config.health_check.failure_threshold
,
success_threshold
:
ctx
.router_config.health_check.success_threshold
,
});
Box
::
new
(
worker
)
as
Box
<
dyn
Worker
>
})
.collect
();
// Register workers in the registry
// In IGW mode, we need to fetch model info from workers
for
url
in
&
worker_urls
{
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
// For now, create worker without model_id
let
worker
=
BasicWorker
::
new
(
url
.clone
(),
WorkerType
::
Regular
)
.with_circuit_breaker_config
(
core_cb_config
.clone
())
.with_health_config
(
HealthConfig
{
timeout_secs
:
ctx
.router_config.health_check.timeout_secs
,
check_interval_secs
:
ctx
.router_config.health_check.check_interval_secs
,
endpoint
:
ctx
.router_config.health_check.endpoint
.clone
(),
failure_threshold
:
ctx
.router_config.health_check.failure_threshold
,
success_threshold
:
ctx
.router_config.health_check.success_threshold
,
});
// Initialize policy with workers if needed (e.g., for cache-aware)
if
let
Some
(
cache_aware
)
=
policy
.as_any
()
.downcast_ref
::
<
crate
::
policies
::
CacheAwarePolicy
>
()
{
cache_aware
.init_workers
(
&
workers
);
let
worker_arc
=
Arc
::
new
(
worker
);
ctx
.worker_registry
.register
(
worker_arc
.clone
());
// Notify PolicyRegistry about the new worker
let
model_id
=
worker_arc
.model_id
();
let
policy
=
ctx
.policy_registry
.on_worker_added
(
model_id
,
None
);
// If this is a cache-aware policy and it's the first worker for this model,
// initialize it with the worker
if
policy
.name
()
==
"cache_aware"
{
if
let
Some
(
cache_aware
)
=
policy
.as_any
()
.downcast_ref
::
<
crate
::
policies
::
CacheAwarePolicy
>
()
{
let
worker_dyn
:
Arc
<
dyn
Worker
>
=
worker_arc
.clone
();
cache_aware
.init_workers
(
std
::
slice
::
from_ref
(
&
worker_dyn
));
}
}
}
let
workers
=
Arc
::
new
(
RwLock
::
new
(
workers
));
let
health_checker
=
crate
::
core
::
start_health_checker
(
Arc
::
clone
(
&
workers
),
ctx
.router_config.worker_startup_check_interval_secs
,
);
// Setup load monitoring for PowerOfTwo policy
let
(
tx
,
rx
)
=
tokio
::
sync
::
watch
::
channel
(
HashMap
::
new
());
let
worker_loads
=
Arc
::
new
(
rx
);
let
load_monitor_handle
=
if
policy
.name
()
==
"power_of_two"
{
// Check if default policy is power_of_two for load monitoring
let
default_policy
=
ctx
.policy_registry
.get_default_policy
();
let
load_monitor_handle
=
if
default_policy
.name
()
==
"power_of_two"
{
let
monitor_urls
=
worker_urls
.clone
();
let
monitor_interval
=
ctx
.router_config.worker_startup_check_interval_secs
;
let
policy_clone
=
Arc
::
clone
(
&
policy
);
let
policy_clone
=
default_policy
.clone
(
);
let
client_clone
=
ctx
.client
.clone
();
Some
(
Arc
::
new
(
tokio
::
spawn
(
async
move
{
...
...
@@ -138,8 +142,8 @@ impl Router {
};
Ok
(
Router
{
worker
s
,
policy
,
worker
_registry
:
ctx
.worker_registry
.clone
()
,
policy
_registry
:
ctx
.policy_registry
.clone
()
,
client
:
ctx
.client
.clone
(),
worker_startup_timeout_secs
:
ctx
.router_config.worker_startup_timeout_secs
,
worker_startup_check_interval_secs
:
ctx
...
...
@@ -151,18 +155,21 @@ impl Router {
circuit_breaker_config
:
core_cb_config
,
_
worker_loads
:
worker_loads
,
_
load_monitor_handle
:
load_monitor_handle
,
_
health_checker
:
Some
(
health_checker
),
})
}
/// Get the current list of worker URLs
pub
fn
get_worker_urls
(
&
self
)
->
Vec
<
String
>
{
self
.workers
.read
()
.unwrap
()
.iter
()
.map
(|
w
|
w
.url
()
.to_string
())
.collect
()
self
.worker_registry
.get_all_urls
()
}
/// Get worker URLs for a specific model
pub
fn
get_worker_urls_for_model
(
&
self
,
model_id
:
Option
<&
str
>
)
->
Vec
<
String
>
{
let
workers
=
match
model_id
{
Some
(
model
)
=>
self
.worker_registry
.get_by_model_fast
(
model
),
None
=>
self
.worker_registry
.get_all
(),
};
workers
.iter
()
.map
(|
w
|
w
.url
()
.to_string
())
.collect
()
}
pub
async
fn
wait_for_healthy_workers
(
...
...
@@ -332,11 +339,27 @@ impl Router {
}
fn
select_first_worker
(
&
self
)
->
Result
<
String
,
String
>
{
let
workers
_guard
=
self
.worker
s
.read
()
.unwrap
();
if
workers
_guard
.is_empty
()
{
let
workers
=
self
.worker
_registry
.get_all
();
if
workers
.is_empty
()
{
Err
(
"No workers are available"
.to_string
())
}
else
{
Ok
(
workers_guard
[
0
]
.url
()
.to_string
())
Ok
(
workers
[
0
]
.url
()
.to_string
())
}
}
#[allow(dead_code)]
fn
select_first_worker_for_model
(
&
self
,
model_id
:
Option
<&
str
>
)
->
Result
<
String
,
String
>
{
let
workers
=
match
model_id
{
Some
(
model
)
=>
self
.worker_registry
.get_by_model_fast
(
model
),
None
=>
self
.worker_registry
.get_all
(),
};
if
workers
.is_empty
()
{
Err
(
format!
(
"No workers are available for model: {:?}"
,
model_id
))
}
else
{
Ok
(
workers
[
0
]
.url
()
.to_string
())
}
}
...
...
@@ -447,20 +470,35 @@ impl Router {
}
}
// New method to route typed requests directly
/// Select worker considering circuit breaker state
fn
select_worker_with_circuit_breaker
(
&
self
,
text
:
Option
<&
str
>
)
->
Option
<
Box
<
dyn
Worker
>>
{
let
workers
=
self
.workers
.read
()
.ok
()
?
;
let
available
:
Vec
<
Box
<
dyn
Worker
>>
=
workers
/// Select worker for a specific model considering circuit breaker state
fn
select_worker_for_model
(
&
self
,
model_id
:
Option
<&
str
>
,
text
:
Option
<&
str
>
,
)
->
Option
<
Arc
<
dyn
Worker
>>
{
// Get workers for the specified model (O(1) lookup if model_id is provided)
let
workers
=
match
model_id
{
Some
(
model
)
=>
self
.worker_registry
.get_by_model_fast
(
model
),
None
=>
self
.worker_registry
.get_all
(),
};
let
available
:
Vec
<
Arc
<
dyn
Worker
>>
=
workers
.iter
()
.filter
(|
w
|
w
.is_available
())
.
map
(|
w
|
w
.clone_worker
()
)
.
cloned
()
.collect
();
if
available
.is_empty
()
{
return
None
;
}
let
idx
=
self
.policy
.select_worker
(
&
available
,
text
)
?
;
Some
(
available
[
idx
]
.clone_worker
())
// Get the appropriate policy for this model
let
policy
=
match
model_id
{
Some
(
model
)
=>
self
.policy_registry
.get_policy_or_default
(
model
),
None
=>
self
.policy_registry
.get_default_policy
(),
};
let
idx
=
policy
.select_worker
(
&
available
,
text
)
?
;
Some
(
available
[
idx
]
.clone
())
}
pub
async
fn
route_typed_request
<
T
:
GenerationRequest
+
serde
::
Serialize
+
Clone
>
(
...
...
@@ -468,6 +506,7 @@ impl Router {
headers
:
Option
<&
HeaderMap
>
,
typed_req
:
&
T
,
route
:
&
str
,
model_id
:
Option
<&
str
>
,
)
->
Response
{
let
start
=
Instant
::
now
();
let
is_stream
=
typed_req
.is_stream
();
...
...
@@ -477,7 +516,7 @@ impl Router {
&
self
.retry_config
,
// operation per attempt
|
_
:
u32
|
async
{
let
worker
=
match
self
.select_worker_
with_circuit_breaker
(
Some
(
&
text
))
{
let
worker
=
match
self
.select_worker_
for_model
(
model_id
,
Some
(
&
text
))
{
Some
(
w
)
=>
w
,
None
=>
{
RouterMetrics
::
record_request_error
(
route
,
"no_available_workers"
);
...
...
@@ -490,7 +529,13 @@ impl Router {
};
// Optional load tracking for cache-aware policy
let
load_incremented
=
if
self
.policy
.name
()
==
"cache_aware"
{
// Get the policy for this model to check if it's cache-aware
let
policy
=
match
model_id
{
Some
(
model
)
=>
self
.policy_registry
.get_policy_or_default
(
model
),
None
=>
self
.policy_registry
.get_default_policy
(),
};
let
load_incremented
=
if
policy
.name
()
==
"cache_aware"
{
worker
.increment_load
();
RouterMetrics
::
set_running_requests
(
worker
.url
(),
worker
.load
());
true
...
...
@@ -654,11 +699,9 @@ impl Router {
// Decrement load on error if it was incremented
if
load_incremented
{
if
let
Ok
(
workers_guard
)
=
self
.workers
.read
()
{
if
let
Some
(
worker
)
=
workers_guard
.iter
()
.find
(|
w
|
w
.url
()
==
worker_url
)
{
worker
.decrement_load
();
RouterMetrics
::
set_running_requests
(
worker_url
,
worker
.load
());
}
if
let
Some
(
worker
)
=
self
.worker_registry
.get_by_url
(
worker_url
)
{
worker
.decrement_load
();
RouterMetrics
::
set_running_requests
(
worker_url
,
worker
.load
());
}
}
...
...
@@ -687,13 +730,9 @@ impl Router {
Err
(
e
)
=>
{
// IMPORTANT: Decrement load on error before returning
if
load_incremented
{
if
let
Ok
(
workers_guard
)
=
self
.workers
.read
()
{
if
let
Some
(
worker
)
=
workers_guard
.iter
()
.find
(|
w
|
w
.url
()
==
worker_url
)
{
worker
.decrement_load
();
RouterMetrics
::
set_running_requests
(
worker_url
,
worker
.load
());
}
if
let
Some
(
worker
)
=
self
.worker_registry
.get_by_url
(
worker_url
)
{
worker
.decrement_load
();
RouterMetrics
::
set_running_requests
(
worker_url
,
worker
.load
());
}
}
...
...
@@ -704,18 +743,16 @@ impl Router {
// Decrement load counter for non-streaming requests if it was incremented
if
load_incremented
{
if
let
Ok
(
workers_guard
)
=
self
.workers
.read
()
{
if
let
Some
(
worker
)
=
workers_guard
.iter
()
.find
(|
w
|
w
.url
()
==
worker_url
)
{
worker
.decrement_load
();
RouterMetrics
::
set_running_requests
(
worker_url
,
worker
.load
());
}
if
let
Some
(
worker
)
=
self
.worker_registry
.get_by_url
(
worker_url
)
{
worker
.decrement_load
();
RouterMetrics
::
set_running_requests
(
worker_url
,
worker
.load
());
}
}
response
}
else
if
load_incremented
{
// For streaming with load tracking, we need to manually decrement when done
let
workers
=
Arc
::
clone
(
&
self
.worker
s
);
let
registry
=
Arc
::
clone
(
&
self
.worker
_registry
);
let
worker_url
=
worker_url
.to_string
();
// Preserve headers for streaming response
...
...
@@ -739,17 +776,10 @@ impl Router {
.windows
(
12
)
.any
(|
window
|
window
==
b
"data: [DONE]"
)
{
if
let
Ok
(
workers_guard
)
=
workers
.read
()
{
if
let
Some
(
worker
)
=
workers_guard
.iter
()
.find
(|
w
|
w
.url
()
==
worker_url
)
{
worker
.decrement_load
();
RouterMetrics
::
set_running_requests
(
&
worker_url
,
worker
.load
(),
);
decremented
=
true
;
}
if
let
Some
(
worker
)
=
registry
.get_by_url
(
&
worker_url
)
{
worker
.decrement_load
();
RouterMetrics
::
set_running_requests
(
&
worker_url
,
worker
.load
());
decremented
=
true
;
}
}
if
tx
.send
(
Ok
(
bytes
))
.is_err
()
{
...
...
@@ -763,11 +793,9 @@ impl Router {
}
}
if
!
decremented
{
if
let
Ok
(
workers_guard
)
=
workers
.read
()
{
if
let
Some
(
worker
)
=
workers_guard
.iter
()
.find
(|
w
|
w
.url
()
==
worker_url
)
{
worker
.decrement_load
();
RouterMetrics
::
set_running_requests
(
&
worker_url
,
worker
.load
());
}
if
let
Some
(
worker
)
=
registry
.get_by_url
(
&
worker_url
)
{
worker
.decrement_load
();
RouterMetrics
::
set_running_requests
(
&
worker_url
,
worker
.load
());
}
}
});
...
...
@@ -839,7 +867,6 @@ impl Router {
match
client
.get
(
format!
(
"{}/health"
,
worker_url
))
.send
()
.await
{
Ok
(
res
)
=>
{
if
res
.status
()
.is_success
()
{
let
mut
workers_guard
=
self
.workers
.write
()
.unwrap
();
if
self
.dp_aware
{
// Need to contact the worker to extract the dp_size,
// and add them as multiple workers
...
...
@@ -848,47 +875,78 @@ impl Router {
.map_err
(|
e
|
format!
(
"Failed to get dp-aware workers: {}"
,
e
))
?
;
let
mut
worker_added
:
bool
=
false
;
for
dp_url
in
&
dp_url_vec
{
if
worker
s_guard
.iter
()
.any
(|
w
|
w
.url
()
==
dp_url
)
{
if
self
.
worker
_registry
.get_by_url
(
dp_url
)
.is_some
(
)
{
warn!
(
"Worker {} already exists"
,
dp_url
);
continue
;
}
info!
(
"Added worker: {}"
,
dp_url
);
let
new_worker
=
WorkerFactory
::
create_regular_with_config
(
dp_url
.to_string
(),
self
.circuit_breaker_config
.clone
(),
);
workers_guard
.push
(
new_worker
);
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let
new_worker
=
BasicWorker
::
new
(
dp_url
.to_string
(),
WorkerType
::
Regular
)
.with_circuit_breaker_config
(
self
.circuit_breaker_config
.clone
(),
);
let
worker_arc
=
Arc
::
new
(
new_worker
);
self
.worker_registry
.register
(
worker_arc
.clone
());
// Notify PolicyRegistry about the new worker
let
model_id
=
worker_arc
.model_id
();
let
policy
=
self
.policy_registry
.on_worker_added
(
model_id
,
None
);
// If this is a cache-aware policy, update it with all workers for this model
if
policy
.name
()
==
"cache_aware"
{
if
let
Some
(
cache_aware
)
=
policy
.as_any
()
.downcast_ref
::
<
crate
::
policies
::
CacheAwarePolicy
>
(
)
{
let
model_workers
=
self
.worker_registry
.get_by_model_fast
(
model_id
);
cache_aware
.init_workers
(
&
model_workers
);
}
}
worker_added
=
true
;
}
if
!
worker_added
{
return
Err
(
format!
(
"No worker added for {}"
,
worker_url
));
}
}
else
{
if
worker
s_guard
.iter
()
.any
(|
w
|
w
.url
()
==
worker_url
)
{
if
self
.
worker
_registry
.get_by_url
(
worker_url
)
.is_some
(
)
{
return
Err
(
format!
(
"Worker {} already exists"
,
worker_url
));
}
info!
(
"Added worker: {}"
,
worker_url
);
let
new_worker
=
WorkerFactory
::
create_regular_with_config
(
worker_url
.to_string
(),
self
.circuit_breaker_config
.clone
(),
);
workers_guard
.push
(
new_worker
);
}
RouterMetrics
::
set_active_workers
(
workers_guard
.len
());
// If cache aware policy, initialize the worker in the tree
if
let
Some
(
cache_aware
)
=
self
.policy
.as_any
()
.downcast_ref
::
<
crate
::
policies
::
CacheAwarePolicy
>
()
{
// Get updated workers after adding
drop
(
workers_guard
);
let
workers_guard
=
self
.workers
.read
()
.unwrap
();
cache_aware
.init_workers
(
&
workers_guard
);
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let
new_worker
=
BasicWorker
::
new
(
worker_url
.to_string
(),
WorkerType
::
Regular
)
.with_circuit_breaker_config
(
self
.circuit_breaker_config
.clone
(),
);
let
worker_arc
=
Arc
::
new
(
new_worker
);
self
.worker_registry
.register
(
worker_arc
.clone
());
// Notify PolicyRegistry about the new worker
let
model_id
=
worker_arc
.model_id
();
let
policy
=
self
.policy_registry
.on_worker_added
(
model_id
,
None
);
// If this is a cache-aware policy, add this worker to it
if
policy
.name
()
==
"cache_aware"
{
if
let
Some
(
cache_aware
)
=
policy
.as_any
()
.downcast_ref
::
<
crate
::
policies
::
CacheAwarePolicy
>
(
)
{
// Get all workers for this model
let
model_workers
=
self
.worker_registry
.get_by_model_fast
(
model_id
);
cache_aware
.init_workers
(
&
model_workers
);
}
}
}
RouterMetrics
::
set_active_workers
(
self
.worker_registry
.get_all
()
.len
());
return
Ok
(
format!
(
"Successfully added worker: {}"
,
worker_url
));
}
else
{
debug!
(
...
...
@@ -931,66 +989,73 @@ impl Router {
if
self
.dp_aware
{
// remove dp-aware workers in a prefix-matching fashion
// without contacting the remote worker
let
mut
candidate_workers
:
Vec
<
String
>
=
Vec
::
new
();
let
mut
removed_workers
:
Vec
<
String
>
=
Vec
::
new
();
let
worker_url_prefix
=
format!
(
"{}@"
,
worker_url
);
{
// find the candidate workers to be removed
let
workers_guard
=
self
.workers
.read
()
.unwrap
();
for
w
in
workers_guard
.iter
()
{
if
w
.url
()
.starts_with
(
&
worker_url_prefix
)
{
candidate_workers
.push
(
w
.url
()
.to_string
());
}
}
}
// Find and remove all workers with matching prefix
let
all_workers
=
self
.worker_registry
.get_all
();
for
w
in
all_workers
.iter
()
{
if
w
.url
()
.starts_with
(
&
worker_url_prefix
)
{
// Get model_id before removing
let
model_id
=
w
.model_id
()
.to_string
();
if
self
.worker_registry
.remove_by_url
(
w
.url
())
.is_some
()
{
info!
(
"Removed worker: {}"
,
w
.url
());
removed_workers
.push
(
w
.url
()
.to_string
());
{
// do the removing on the worker_urls
let
mut
workers_guard
=
self
.workers
.write
()
.unwrap
();
for
dp_url
in
candidate_workers
.iter
()
{
if
let
Some
(
index
)
=
workers_guard
.iter
()
.position
(|
w
|
w
.url
()
==
dp_url
)
{
workers_guard
.remove
(
index
);
info!
(
"Removed worker: {}"
,
dp_url
);
removed_workers
.push
(
dp_url
.to_string
());
// Notify PolicyRegistry about the removed worker
self
.policy_registry
.on_worker_removed
(
&
model_id
);
}
else
{
warn!
(
"Worker {} not found, skipping removal"
,
dp_url
);
continue
;
warn!
(
"Worker {} not found, skipping removal"
,
w
.url
());
}
}
RouterMetrics
::
set_active_workers
(
workers_guard
.len
());
}
// If cache aware policy, remove the workers from the tree
if
let
Some
(
cache_aware
)
=
self
.policy
.as_any
()
.downcast_ref
::
<
crate
::
policies
::
CacheAwarePolicy
>
()
{
for
dp_url
in
removed_workers
.iter
()
{
cache_aware
.remove_worker
(
dp_url
);
info!
(
"Removed worker from tree: {}"
,
dp_url
);
RouterMetrics
::
set_active_workers
(
self
.worker_registry
.get_all
()
.len
());
// If any models are using cache aware policy, remove the workers from the tree
// Check each removed worker's model and get its policy
for
dp_url
in
removed_workers
.iter
()
{
if
let
Some
(
worker
)
=
self
.worker_registry
.get_by_url
(
dp_url
)
{
let
model_id
=
worker
.model_id
();
if
let
Some
(
policy
)
=
self
.policy_registry
.get_policy
(
model_id
)
{
if
let
Some
(
cache_aware
)
=
policy
.as_any
()
.downcast_ref
::
<
crate
::
policies
::
CacheAwarePolicy
>
()
{
cache_aware
.remove_worker_by_url
(
dp_url
);
info!
(
"Removed worker from cache-aware tree: {}"
,
dp_url
);
}
}
}
}
}
else
{
let
mut
workers_guard
=
self
.workers
.write
()
.unwrap
();
if
let
Some
(
index
)
=
workers_guard
.iter
()
.position
(|
w
|
w
.url
()
==
worker_url
)
{
workers_guard
.remove
(
index
);
info!
(
"Removed worker: {}"
,
worker_url
);
RouterMetrics
::
set_active_workers
(
workers_guard
.len
());
// Get the worker first to extract model_id
let
model_id
=
if
let
Some
(
worker
)
=
self
.worker_registry
.get_by_url
(
worker_url
)
{
worker
.model_id
()
.to_string
()
}
else
{
warn!
(
"Worker {} not found, skipping removal"
,
worker_url
);
return
;
};
if
self
.worker_registry
.remove_by_url
(
worker_url
)
.is_some
()
{
info!
(
"Removed worker: {}"
,
worker_url
);
// Notify PolicyRegistry about the removed worker
self
.policy_registry
.on_worker_removed
(
&
model_id
);
RouterMetrics
::
set_active_workers
(
self
.worker_registry
.get_all
()
.len
());
}
// If cache aware policy, remove the workers from the tree
if
let
Some
(
cache_aware
)
=
self
.policy
.as_any
()
.downcast_ref
::
<
crate
::
policies
::
CacheAwarePolicy
>
()
{
cache_aware
.remove_worker
(
worker_url
);
info!
(
"Removed worker from tree: {}"
,
worker_url
);
// If the model is using cache aware policy, remove the worker from the tree
if
let
Some
(
policy
)
=
self
.policy_registry
.get_policy
(
&
model_id
)
{
if
let
Some
(
cache_aware
)
=
policy
.as_any
()
.downcast_ref
::
<
crate
::
policies
::
CacheAwarePolicy
>
()
{
cache_aware
.remove_worker_by_url
(
worker_url
);
info!
(
"Removed worker from cache-aware tree: {}"
,
worker_url
);
}
}
}
}
...
...
@@ -1171,7 +1236,7 @@ impl RouterTrait for Router {
}
async
fn
health
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
let
workers
=
self
.worker
s
.read
()
.unwrap
();
let
workers
=
self
.worker
_registry
.get_all
();
let
unhealthy_servers
:
Vec
<
_
>
=
workers
.iter
()
.filter
(|
w
|
!
w
.is_healthy
())
...
...
@@ -1209,16 +1274,19 @@ impl RouterTrait for Router {
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
GenerateRequest
,
model_id
:
Option
<&
str
>
,
)
->
Response
{
self
.route_typed_request
(
headers
,
body
,
"/generate"
)
.await
self
.route_typed_request
(
headers
,
body
,
"/generate"
,
model_id
)
.await
}
async
fn
route_chat
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
ChatCompletionRequest
,
model_id
:
Option
<&
str
>
,
)
->
Response
{
self
.route_typed_request
(
headers
,
body
,
"/v1/chat/completions"
)
self
.route_typed_request
(
headers
,
body
,
"/v1/chat/completions"
,
model_id
)
.await
}
...
...
@@ -1226,8 +1294,9 @@ impl RouterTrait for Router {
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
CompletionRequest
,
model_id
:
Option
<&
str
>
,
)
->
Response
{
self
.route_typed_request
(
headers
,
body
,
"/v1/completions"
)
self
.route_typed_request
(
headers
,
body
,
"/v1/completions"
,
model_id
)
.await
}
...
...
@@ -1235,8 +1304,9 @@ impl RouterTrait for Router {
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
ResponsesRequest
,
model_id
:
Option
<&
str
>
,
)
->
Response
{
self
.route_typed_request
(
headers
,
body
,
"/v1/responses"
)
self
.route_typed_request
(
headers
,
body
,
"/v1/responses"
,
model_id
)
.await
}
...
...
@@ -1244,11 +1314,18 @@ impl RouterTrait for Router {
todo!
()
}
async
fn
route_rerank
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
RerankRequest
)
->
Response
{
async
fn
route_rerank
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
RerankRequest
,
model_id
:
Option
<&
str
>
,
)
->
Response
{
if
let
Err
(
e
)
=
body
.validate
()
{
return
(
StatusCode
::
BAD_REQUEST
,
e
)
.into_response
();
}
let
response
=
self
.route_typed_request
(
headers
,
body
,
"/v1/rerank"
)
.await
;
let
response
=
self
.route_typed_request
(
headers
,
body
,
"/v1/rerank"
,
model_id
)
.await
;
if
response
.status
()
.is_success
()
{
match
Self
::
build_rerank_response
(
body
,
response
)
.await
{
Ok
(
rerank_response
)
=>
rerank_response
,
...
...
@@ -1340,19 +1417,15 @@ impl RouterTrait for Router {
fn
readiness
(
&
self
)
->
Response
{
// Regular router is ready if it has at least one healthy worker
let
healthy_count
=
self
.workers
.read
()
.unwrap
()
.iter
()
.filter
(|
w
|
w
.is_healthy
())
.count
();
let
workers
=
self
.worker_registry
.get_all
();
let
healthy_count
=
workers
.iter
()
.filter
(|
w
|
w
.is_healthy
())
.count
();
let
total_workers
=
workers
.len
();
if
healthy_count
>
0
{
Json
(
serde_json
::
json!
({
"status"
:
"ready"
,
"healthy_workers"
:
healthy_count
,
"total_workers"
:
self
.
workers
.read
()
.unwrap
()
.len
()
"total_workers"
:
total_
workers
}))
.into_response
()
}
else
{
...
...
@@ -1361,7 +1434,7 @@ impl RouterTrait for Router {
Json
(
serde_json
::
json!
({
"status"
:
"not_ready"
,
"reason"
:
"no healthy workers available"
,
"total_workers"
:
self
.
workers
.read
()
.unwrap
()
.len
()
"total_workers"
:
total_
workers
})),
)
.into_response
()
...
...
@@ -1372,18 +1445,25 @@ impl RouterTrait for Router {
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
crate
::
policies
::
RandomPolicy
;
use
std
::
collections
::
HashMap
;
fn
create_test_regular_router
()
->
Router
{
let
workers
=
vec!
[
WorkerFactory
::
create_regular
(
"http://worker1:8080"
.to_string
()),
WorkerFactory
::
create_regular
(
"http://worker2:8080"
.to_string
()),
];
// Create registries
let
worker_registry
=
Arc
::
new
(
WorkerRegistry
::
new
());
let
policy_registry
=
Arc
::
new
(
PolicyRegistry
::
new
(
crate
::
config
::
types
::
PolicyConfig
::
RoundRobin
,
));
// Register test workers
let
worker1
=
BasicWorker
::
new
(
"http://worker1:8080"
.to_string
(),
WorkerType
::
Regular
);
let
worker2
=
BasicWorker
::
new
(
"http://worker2:8080"
.to_string
(),
WorkerType
::
Regular
);
worker_registry
.register
(
Arc
::
new
(
worker1
));
worker_registry
.register
(
Arc
::
new
(
worker2
));
let
(
_
,
rx
)
=
tokio
::
sync
::
watch
::
channel
(
HashMap
::
new
());
Router
{
worker
s
:
Arc
::
new
(
RwLock
::
new
(
workers
))
,
policy
:
Arc
::
new
(
RandomPolicy
::
new
())
,
worker
_registry
,
policy
_registry
,
worker_startup_timeout_secs
:
5
,
worker_startup_check_interval_secs
:
1
,
dp_aware
:
false
,
...
...
@@ -1393,7 +1473,6 @@ mod tests {
circuit_breaker_config
:
CircuitBreakerConfig
::
default
(),
_
worker_loads
:
Arc
::
new
(
rx
),
_
load_monitor_handle
:
None
,
_
health_checker
:
None
,
}
}
...
...
@@ -1413,7 +1492,9 @@ mod tests {
let
result
=
router
.select_first_worker
();
assert
!
(
result
.is_ok
());
assert_eq!
(
result
.unwrap
(),
"http://worker1:8080"
);
let
url
=
result
.unwrap
();
// DashMap doesn't guarantee order, so just check we get one of the workers
assert
!
(
url
==
"http://worker1:8080"
||
url
==
"http://worker2:8080"
);
}
#[tokio::test]
...
...
sgl-router/src/routers/mod.rs
View file @
2f173ea0
...
...
@@ -17,6 +17,7 @@ pub mod factory;
pub
mod
grpc
;
pub
mod
header_utils
;
pub
mod
http
;
pub
mod
router_manager
;
pub
use
factory
::
RouterFactory
;
// Re-export HTTP routers for convenience (keeps routers::openai_router path working)
...
...
@@ -63,14 +64,19 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
async
fn
get_model_info
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
;
/// Route a generate request
async
fn
route_generate
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
GenerateRequest
)
->
Response
;
async
fn
route_generate
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
GenerateRequest
,
model_id
:
Option
<&
str
>
,
)
->
Response
;
/// Route a chat completion request
async
fn
route_chat
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
ChatCompletionRequest
,
model_id
:
Option
<&
str
>
,
)
->
Response
;
/// Route a completion request
...
...
@@ -78,6 +84,7 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
CompletionRequest
,
model_id
:
Option
<&
str
>
,
)
->
Response
;
/// Route a responses request
...
...
@@ -85,11 +92,17 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
ResponsesRequest
,
model_id
:
Option
<&
str
>
,
)
->
Response
;
async
fn
route_embeddings
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
Body
)
->
Response
;
async
fn
route_rerank
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
RerankRequest
)
->
Response
;
async
fn
route_rerank
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
RerankRequest
,
model_id
:
Option
<&
str
>
,
)
->
Response
;
/// Flush cache on all workers
async
fn
flush_cache
(
&
self
)
->
Response
;
...
...
sgl-router/src/routers/router_manager.rs
0 → 100644
View file @
2f173ea0
//! Router Manager for coordinating multiple routers and workers
//!
//! Provides centralized management based on enable_igw flag:
//! - Single Router Mode (enable_igw=false): Router owns workers directly
//! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything
use
crate
::
config
::
RouterConfig
;
use
crate
::
core
::{
CircuitBreakerConfig
,
Worker
,
WorkerFactory
,
WorkerRegistry
};
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
,
RerankRequest
,
ResponsesRequest
,
};
use
crate
::
protocols
::
worker_spec
::{
ServerInfo
,
WorkerApiResponse
,
WorkerConfigRequest
,
WorkerErrorResponse
,
WorkerInfo
,
WorkerListResponse
,
WorkerStats
,
WorkerTypeStats
,
};
use
crate
::
routers
::{
RouterTrait
,
WorkerManagement
};
use
async_trait
::
async_trait
;
use
axum
::{
body
::
Body
,
extract
::
Request
,
http
::{
HeaderMap
,
StatusCode
},
response
::{
IntoResponse
,
Response
},
};
use
dashmap
::
DashMap
;
use
std
::
sync
::
Arc
;
use
tracing
::{
info
,
warn
};
/// Router identifier
#[derive(Debug,
Clone,
Hash,
Eq,
PartialEq)]
pub
struct
RouterId
(
String
);
impl
RouterId
{
pub
fn
new
(
id
:
String
)
->
Self
{
Self
(
id
)
}
pub
fn
as_str
(
&
self
)
->
&
str
{
&
self
.0
}
}
/// Router Manager - Central coordinator for routers and workers
/// Only created when enable_igw=true
pub
struct
RouterManager
{
/// Worker registry (single source of truth in multi-router mode)
worker_registry
:
Arc
<
WorkerRegistry
>
,
/// Policy registry for managing model-to-policy mappings
policy_registry
:
Arc
<
crate
::
policies
::
PolicyRegistry
>
,
/// All routers managed by this manager (max 4 routers in Phase 2)
/// RouterId examples: "http-regular", "http-pd", "grpc-regular", "grpc-pd"
routers
:
Arc
<
DashMap
<
RouterId
,
Arc
<
dyn
RouterTrait
>>>
,
/// Default router for requests without specific routing
default_router
:
Option
<
RouterId
>
,
/// Model to router mapping for model-aware routing
/// Multiple models can be served by the same router
model_routers
:
Arc
<
DashMap
<
String
,
Vec
<
RouterId
>>>
,
/// HTTP client for querying worker info
client
:
reqwest
::
Client
,
/// Configuration
#[allow(dead_code)]
// May be used in future enhancements
config
:
RouterConfig
,
}
impl
RouterManager
{
/// Create a new router manager with shared registries
pub
fn
new
(
config
:
RouterConfig
,
client
:
reqwest
::
Client
,
worker_registry
:
Arc
<
WorkerRegistry
>
,
policy_registry
:
Arc
<
crate
::
policies
::
PolicyRegistry
>
,
)
->
Self
{
Self
{
worker_registry
,
policy_registry
,
routers
:
Arc
::
new
(
DashMap
::
new
()),
default_router
:
None
,
model_routers
:
Arc
::
new
(
DashMap
::
new
()),
client
,
config
,
}
}
/// Register a router with the manager
pub
fn
register_router
(
&
mut
self
,
id
:
RouterId
,
router
:
Arc
<
dyn
RouterTrait
>
,
models
:
Vec
<
String
>
,
)
{
// Store router
self
.routers
.insert
(
id
.clone
(),
router
);
// Update model mappings
for
model
in
models
{
self
.model_routers
.entry
(
model
)
.or_default
()
.push
(
id
.clone
());
}
// Set as default if first router
if
self
.default_router
.is_none
()
{
self
.default_router
=
Some
(
id
.clone
());
info!
(
"Set default router to {}"
,
id
.as_str
());
}
}
/// Set the default router
pub
fn
set_default_router
(
&
mut
self
,
id
:
RouterId
)
{
self
.default_router
=
Some
(
id
);
}
/// Get the number of registered routers
pub
fn
router_count
(
&
self
)
->
usize
{
self
.routers
.len
()
}
/// Get router for a specific model
pub
fn
get_router_for_model
(
&
self
,
model_id
:
&
str
)
->
Option
<
Arc
<
dyn
RouterTrait
>>
{
// First try model-specific routers
if
let
Some
(
router_ids
)
=
self
.model_routers
.get
(
model_id
)
{
if
let
Some
(
router_id
)
=
router_ids
.first
()
{
if
let
Some
(
router
)
=
self
.routers
.get
(
router_id
)
{
return
Some
(
router
.clone
());
}
}
}
// Fall back to default router
if
let
Some
(
ref
default_id
)
=
self
.default_router
{
self
.routers
.get
(
default_id
)
.map
(|
r
|
r
.clone
())
}
else
{
None
}
}
/// Get workers for routing decision
pub
fn
get_workers_for_request
(
&
self
,
model_id
:
Option
<&
str
>
)
->
Vec
<
Arc
<
dyn
Worker
>>
{
if
let
Some
(
model
)
=
model_id
{
self
.worker_registry
.get_by_model
(
model
)
}
else
{
self
.worker_registry
.get_all
()
}
}
/// Add a worker to the registry
pub
async
fn
add_worker
(
&
self
,
config
:
WorkerConfigRequest
,
)
->
Result
<
WorkerApiResponse
,
WorkerErrorResponse
>
{
// Build labels from configuration
let
mut
labels
=
config
.labels
.clone
();
// Query server info if model_id not provided
let
model_id
=
if
let
Some
(
model_id
)
=
config
.model_id
{
model_id
}
else
{
match
self
.query_server_info
(
&
config
.url
)
.await
{
Ok
(
info
)
=>
{
// Extract model_id from server info
info
.model_id
.or_else
(||
{
info
.model_path
.as_ref
()
.and_then
(|
path
|
path
.split
(
'/'
)
.next_back
()
.map
(|
s
|
s
.to_string
()))
})
.unwrap_or_else
(||
"unknown"
.to_string
())
}
Err
(
e
)
=>
{
warn!
(
"Failed to query server info from {}: {}"
,
config
.url
,
e
);
"unknown"
.to_string
()
}
}
};
// Add configuration to labels
labels
.insert
(
"model_id"
.to_string
(),
model_id
.clone
());
if
let
Some
(
priority
)
=
config
.priority
{
labels
.insert
(
"priority"
.to_string
(),
priority
.to_string
());
}
if
let
Some
(
cost
)
=
config
.cost
{
labels
.insert
(
"cost"
.to_string
(),
cost
.to_string
());
}
// Add gRPC-specific configuration if provided
if
let
Some
(
tokenizer_path
)
=
config
.tokenizer_path
{
labels
.insert
(
"tokenizer_path"
.to_string
(),
tokenizer_path
);
}
if
let
Some
(
reasoning_parser
)
=
config
.reasoning_parser
{
labels
.insert
(
"reasoning_parser"
.to_string
(),
reasoning_parser
);
}
if
let
Some
(
tool_parser
)
=
config
.tool_parser
{
labels
.insert
(
"tool_parser"
.to_string
(),
tool_parser
);
}
if
let
Some
(
chat_template
)
=
config
.chat_template
{
labels
.insert
(
"chat_template"
.to_string
(),
chat_template
);
}
// Create worker based on type
// Note: For prefill and decode workers, we can't easily add labels after creation
// since they return Box<dyn Worker>. We'll need to enhance WorkerFactory in the future.
let
worker
=
match
config
.worker_type
.as_deref
()
{
Some
(
"prefill"
)
=>
{
// For now, prefill workers won't have custom labels
// TODO: Enhance WorkerFactory to accept labels for prefill workers
WorkerFactory
::
create_prefill
(
config
.url
.clone
(),
config
.bootstrap_port
)
}
Some
(
"decode"
)
=>
{
// For now, decode workers won't have custom labels
// TODO: Enhance WorkerFactory to accept labels for decode workers
WorkerFactory
::
create_decode
(
config
.url
.clone
())
}
_
=>
{
// Regular workers can have labels
WorkerFactory
::
create_regular_with_labels
(
config
.url
.clone
(),
labels
.clone
(),
CircuitBreakerConfig
::
default
(),
)
}
};
// Register worker
let
worker_id
=
self
.worker_registry
.register
(
Arc
::
from
(
worker
));
// Notify PolicyRegistry about the new worker
// Extract policy hint from labels if provided
let
policy_hint
=
labels
.get
(
"policy"
)
.map
(|
s
|
s
.as_str
());
let
policy
=
self
.policy_registry
.on_worker_added
(
&
model_id
,
policy_hint
);
info!
(
"Added worker {} with URL {} for model {} using policy {}"
,
worker_id
.as_str
(),
config
.url
,
model_id
,
policy
.name
()
);
// Return worker info
let
worker_arc
=
self
.worker_registry
.get
(
&
worker_id
)
.unwrap
();
let
worker_info
=
self
.worker_to_info
(
worker_id
.as_str
(),
&
worker_arc
);
Ok
(
WorkerApiResponse
{
success
:
true
,
message
:
format!
(
"Worker {} added successfully"
,
worker_id
.as_str
()),
worker
:
Some
(
worker_info
),
})
}
/// Remove a worker from the registry
pub
fn
remove_worker_from_registry
(
&
self
,
url
:
&
str
,
)
->
Result
<
WorkerApiResponse
,
WorkerErrorResponse
>
{
// Get worker to extract model_id before removing
let
model_id
=
self
.worker_registry
.get_by_url
(
url
)
.map
(|
worker
|
worker
.model_id
()
.to_string
());
if
let
Some
(
_
worker
)
=
self
.worker_registry
.remove_by_url
(
url
)
{
// Notify PolicyRegistry about worker removal
if
let
Some
(
model_id
)
=
model_id
{
self
.policy_registry
.on_worker_removed
(
&
model_id
);
info!
(
"Removed worker with URL {} for model {}"
,
url
,
model_id
);
}
else
{
info!
(
"Removed worker with URL {}"
,
url
);
}
Ok
(
WorkerApiResponse
{
success
:
true
,
message
:
format!
(
"Worker {} removed successfully"
,
url
),
worker
:
None
,
})
}
else
{
Err
(
WorkerErrorResponse
{
error
:
format!
(
"Worker with URL {} not found"
,
url
),
code
:
"WORKER_NOT_FOUND"
.to_string
(),
})
}
}
/// List all workers
pub
fn
list_workers
(
&
self
)
->
WorkerListResponse
{
let
workers
=
self
.worker_registry
.get_all_with_ids
();
let
worker_infos
:
Vec
<
WorkerInfo
>
=
workers
.iter
()
.map
(|(
id
,
w
)|
self
.worker_to_info
(
id
.as_str
(),
w
))
.collect
();
let
total
=
worker_infos
.len
();
// Get stats from the worker registry
let
registry_stats
=
self
.worker_registry
.stats
();
// Convert WorkerRegistryStats to WorkerStats
let
stats
=
WorkerStats
{
total_workers
:
registry_stats
.total_workers
,
healthy_workers
:
registry_stats
.healthy_workers
,
total_models
:
registry_stats
.total_models
,
total_load
:
registry_stats
.total_load
,
by_type
:
WorkerTypeStats
{
regular
:
registry_stats
.regular_workers
,
prefill
:
registry_stats
.prefill_workers
,
decode
:
registry_stats
.decode_workers
,
},
};
WorkerListResponse
{
workers
:
worker_infos
,
total
,
stats
,
}
}
/// Get worker by URL
pub
fn
get_worker
(
&
self
,
url
:
&
str
)
->
Option
<
WorkerInfo
>
{
self
.worker_registry
.get_by_url
(
url
)
.map
(|
w
|
self
.worker_to_info
(
"unknown"
,
&
w
))
}
/// Query server info from a worker URL
async
fn
query_server_info
(
&
self
,
url
:
&
str
)
->
Result
<
ServerInfo
,
String
>
{
let
info_url
=
format!
(
"{}/get_server_info"
,
url
.trim_end_matches
(
'/'
));
match
self
.client
.get
(
&
info_url
)
.send
()
.await
{
Ok
(
response
)
=>
{
if
response
.status
()
.is_success
()
{
response
.json
::
<
ServerInfo
>
()
.await
.map_err
(|
e
|
format!
(
"Failed to parse server info: {}"
,
e
))
}
else
{
Err
(
format!
(
"Server returned status: {}"
,
response
.status
()))
}
}
Err
(
e
)
=>
Err
(
format!
(
"Failed to connect to server: {}"
,
e
)),
}
}
/// Convert Worker to WorkerInfo
fn
worker_to_info
(
&
self
,
id
:
&
str
,
worker
:
&
Arc
<
dyn
Worker
>
)
->
WorkerInfo
{
let
metadata
=
worker
.metadata
();
WorkerInfo
{
id
:
id
.to_string
(),
url
:
worker
.url
()
.to_string
(),
model_id
:
worker
.model_id
()
.to_string
(),
priority
:
worker
.priority
(),
cost
:
worker
.cost
(),
worker_type
:
format!
(
"{:?}"
,
worker
.worker_type
()),
is_healthy
:
worker
.is_healthy
(),
load
:
worker
.load
(),
connection_mode
:
format!
(
"{:?}"
,
worker
.connection_mode
()),
tokenizer_path
:
worker
.tokenizer_path
()
.map
(|
s
|
s
.to_string
()),
reasoning_parser
:
worker
.reasoning_parser
()
.map
(|
s
|
s
.to_string
()),
tool_parser
:
worker
.tool_parser
()
.map
(|
s
|
s
.to_string
()),
chat_template
:
worker
.chat_template
()
.map
(|
s
|
s
.to_string
()),
metadata
:
metadata
.labels
.clone
(),
}
}
// Note: calculate_stats removed - using WorkerRegistry::stats() instead
// === Phase 2: Router Management ===
// Note: Dynamic router creation removed - routers are created and registered externally
/// Get the appropriate router for a request based on headers and request content
pub
fn
select_router_for_request
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
model_id
:
Option
<&
str
>
,
)
->
Option
<
Arc
<
dyn
RouterTrait
>>
{
// Extract priority and cost preferences from headers if available
let
_
priority_threshold
=
headers
.and_then
(|
h
|
{
h
.get
(
"x-worker-priority"
)
.and_then
(|
v
|
v
.to_str
()
.ok
())
.and_then
(|
s
|
s
.parse
::
<
u32
>
()
.ok
())
});
let
_
max_cost
=
headers
.and_then
(|
h
|
{
h
.get
(
"x-max-cost"
)
.and_then
(|
v
|
v
.to_str
()
.ok
())
.and_then
(|
s
|
s
.parse
::
<
f32
>
()
.ok
())
});
// Check if PD (prefill-decode) mode is preferred from headers
let
prefer_pd
=
headers
.and_then
(|
h
|
{
h
.get
(
"x-prefer-pd"
)
.and_then
(|
v
|
v
.to_str
()
.ok
())
.map
(|
s
|
s
==
"true"
||
s
==
"1"
)
})
.unwrap_or
(
false
);
// If model specified, find routers serving that model
let
candidate_routers
=
if
let
Some
(
model
)
=
model_id
{
// Get routers for specific model
if
let
Some
(
router_ids
)
=
self
.model_routers
.get
(
model
)
{
router_ids
.iter
()
.filter_map
(|
id
|
self
.routers
.get
(
id
)
.map
(|
r
|
r
.clone
()))
.collect
::
<
Vec
<
_
>>
()
}
else
{
Vec
::
new
()
}
}
else
{
// No model specified, consider all routers
self
.routers
.iter
()
.map
(|
entry
|
entry
.value
()
.clone
())
.collect
::
<
Vec
<
_
>>
()
};
if
candidate_routers
.is_empty
()
{
// No routers found for the specified model
return
None
;
}
// Score routers based on worker attributes and request preferences
let
mut
best_router
=
None
;
let
mut
best_score
=
0.0
;
for
router
in
candidate_routers
{
let
mut
score
=
1.0
;
// Check if this is a PD router
let
is_pd
=
router
.is_pd_mode
();
if
prefer_pd
&&
is_pd
{
score
+=
2.0
;
// Bonus for matching PD preference
}
else
if
!
prefer_pd
&&
!
is_pd
{
score
+=
1.0
;
// Bonus for matching regular preference
}
// Get workers for this router and evaluate based on priority/cost
// Note: This would require routers to expose their workers or stats
// For now, we'll use a simple selection based on router type
// TODO: Once routers expose worker stats, we can evaluate:
// - Average worker priority vs priority_threshold
// - Average worker cost vs max_cost
// - Current load and health status
if
score
>
best_score
{
best_score
=
score
;
best_router
=
Some
(
router
);
}
}
best_router
}
}
// Note: Default implementation removed as RouterManager now requires AppContext
// which cannot be defaulted. RouterManager must be created with explicit context.
// === Phase 2: RouterManager as RouterTrait ===
/// RouterManager implements RouterTrait to act as a meta-router
/// that delegates requests to the appropriate underlying router
#[async_trait]
impl
WorkerManagement
for
RouterManager
{
/// Add a worker - in multi-router mode, this adds to the registry
async
fn
add_worker
(
&
self
,
worker_url
:
&
str
)
->
Result
<
String
,
String
>
{
// Create a basic worker config request
let
config
=
WorkerConfigRequest
{
url
:
worker_url
.to_string
(),
model_id
:
None
,
worker_type
:
None
,
priority
:
None
,
cost
:
None
,
labels
:
std
::
collections
::
HashMap
::
new
(),
bootstrap_port
:
None
,
tokenizer_path
:
None
,
reasoning_parser
:
None
,
tool_parser
:
None
,
chat_template
:
None
,
};
match
self
.add_worker
(
config
)
.await
{
Ok
(
response
)
=>
Ok
(
response
.message
),
Err
(
e
)
=>
Err
(
e
.error
),
}
}
/// Remove a worker from the registry
fn
remove_worker
(
&
self
,
worker_url
:
&
str
)
{
let
_
=
self
.remove_worker_from_registry
(
worker_url
);
}
/// Get all worker URLs from the registry
fn
get_worker_urls
(
&
self
)
->
Vec
<
String
>
{
self
.worker_registry
.get_all_urls
()
}
}
#[async_trait]
impl
RouterTrait
for
RouterManager
{
fn
as_any
(
&
self
)
->
&
dyn
std
::
any
::
Any
{
self
}
/// Health check - return 503 if no routers available
async
fn
health
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
// Health check should succeed if RouterManager exists, even without routers
// Individual router health can be checked via specific endpoints
(
StatusCode
::
OK
,
"RouterManager is healthy"
)
.into_response
()
}
/// Health generate - check if any router can handle generate requests
async
fn
health_generate
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
// Return 503 since we have no routers with workers
// TODO: Should check if any router has healthy workers
(
StatusCode
::
SERVICE_UNAVAILABLE
,
"No routers with healthy workers available"
,
)
.into_response
()
}
/// Get server information - aggregate from all routers
async
fn
get_server_info
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
// TODO: Aggregate info from all routers with healthy workers
// For now, return basic info about the RouterManager
(
StatusCode
::
OK
,
serde_json
::
json!
({
"router_manager"
:
true
,
"routers_count"
:
self
.routers
.len
(),
"workers_count"
:
self
.worker_registry
.get_all
()
.len
()
})
.to_string
(),
)
.into_response
()
}
/// Get available models - aggregate from all routers
async
fn
get_models
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
// Return models that have registered routers
let
models
=
self
.model_routers
.iter
()
.map
(|
entry
|
entry
.key
()
.clone
())
.collect
::
<
Vec
<
_
>>
();
if
models
.is_empty
()
{
(
StatusCode
::
SERVICE_UNAVAILABLE
,
"No models available"
)
.into_response
()
}
else
{
(
StatusCode
::
OK
,
serde_json
::
json!
({
"models"
:
models
})
.to_string
(),
)
.into_response
()
}
}
/// Get model information
async
fn
get_model_info
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
// TODO: Extract model from request and route to appropriate router
// For now, return not implemented
(
StatusCode
::
NOT_IMPLEMENTED
,
"Model info endpoint not yet implemented in RouterManager"
,
)
.into_response
()
}
/// Route a generate request
async
fn
route_generate
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
GenerateRequest
,
_
model_id
:
Option
<&
str
>
,
)
->
Response
{
// Select router based on headers
// GenerateRequest doesn't have a model field
let
router
=
self
.select_router_for_request
(
headers
,
None
);
if
let
Some
(
router
)
=
router
{
// In multi-model mode, pass None since GenerateRequest doesn't have model field
router
.route_generate
(
headers
,
body
,
None
)
.await
}
else
{
// Return 404 when no router is available for the request
(
StatusCode
::
NOT_FOUND
,
"No router available for this request"
,
)
.into_response
()
}
}
/// Route a chat completion request
async
fn
route_chat
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
ChatCompletionRequest
,
_
model_id
:
Option
<&
str
>
,
)
->
Response
{
// Select router based on headers and model
let
router
=
self
.select_router_for_request
(
headers
,
Some
(
&
body
.model
));
if
let
Some
(
router
)
=
router
{
// In multi-model mode, pass the model_id to the router
router
.route_chat
(
headers
,
body
,
Some
(
&
body
.model
))
.await
}
else
{
// Return 404 when the specified model is not found
(
StatusCode
::
NOT_FOUND
,
format!
(
"Model '{}' not found or no router available"
,
body
.model
),
)
.into_response
()
}
}
/// Route a completion request
async
fn
route_completion
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
CompletionRequest
,
_
model_id
:
Option
<&
str
>
,
)
->
Response
{
// Select router based on headers and model
let
router
=
self
.select_router_for_request
(
headers
,
Some
(
&
body
.model
));
if
let
Some
(
router
)
=
router
{
// In multi-model mode, pass the model_id to the router
router
.route_completion
(
headers
,
body
,
Some
(
&
body
.model
))
.await
}
else
{
// Return 404 when the specified model is not found
(
StatusCode
::
NOT_FOUND
,
format!
(
"Model '{}' not found or no router available"
,
body
.model
),
)
.into_response
()
}
}
async
fn
route_responses
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
&
ResponsesRequest
,
_
model_id
:
Option
<&
str
>
,
)
->
Response
{
todo!
()
}
/// Route embeddings request
async
fn
route_embeddings
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
Body
)
->
Response
{
// Try to select a router based on headers
let
router
=
self
.select_router_for_request
(
headers
,
None
);
if
let
Some
(
router
)
=
router
{
router
.route_embeddings
(
headers
,
body
)
.await
}
else
{
(
StatusCode
::
NOT_FOUND
,
"No router available for embeddings request"
,
)
.into_response
()
}
}
/// Route rerank request
async
fn
route_rerank
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
RerankRequest
,
model_id
:
Option
<&
str
>
,
)
->
Response
{
// Try to select a router based on headers
let
router
=
self
.select_router_for_request
(
headers
,
None
);
if
let
Some
(
router
)
=
router
{
router
.route_rerank
(
headers
,
body
,
model_id
)
.await
}
else
{
(
StatusCode
::
NOT_FOUND
,
"No router available for rerank request"
,
)
.into_response
()
}
}
/// Flush cache on all routers and workers
async
fn
flush_cache
(
&
self
)
->
Response
{
// TODO: Call flush_cache on all routers that have workers
// For now, return success if we have any routers
if
self
.routers
.is_empty
()
{
(
StatusCode
::
SERVICE_UNAVAILABLE
,
"No routers configured"
)
.into_response
()
}
else
{
// TODO: Actually flush cache on all routers
(
StatusCode
::
OK
,
"Cache flush requested"
)
.into_response
()
}
}
/// Get worker loads from all routers
async
fn
get_worker_loads
(
&
self
)
->
Response
{
// Return worker loads from the registry
let
workers
=
self
.worker_registry
.get_all
();
let
loads
:
Vec
<
serde_json
::
Value
>
=
workers
.iter
()
.map
(|
w
|
{
serde_json
::
json!
({
"url"
:
w
.url
(),
"model"
:
w
.model_id
(),
"load"
:
w
.load
(),
"is_healthy"
:
w
.is_healthy
()
})
})
.collect
();
(
StatusCode
::
OK
,
serde_json
::
json!
({
"workers"
:
loads
})
.to_string
(),
)
.into_response
()
}
/// Get router type name
fn
router_type
(
&
self
)
->
&
'static
str
{
"manager"
}
/// Server readiness check - check if any router is ready
fn
readiness
(
&
self
)
->
Response
{
if
self
.routers
.is_empty
()
{
(
StatusCode
::
SERVICE_UNAVAILABLE
,
"No routers configured"
)
.into_response
()
}
else
{
// TODO: Check readiness of all routers
(
StatusCode
::
OK
,
"Ready"
)
.into_response
()
}
}
}
// Note: get_first_available_router removed - we now properly handle
// router selection based on model and worker availability
impl
std
::
fmt
::
Debug
for
RouterManager
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
f
.debug_struct
(
"RouterManager"
)
.field
(
"routers_count"
,
&
self
.routers
.len
())
.field
(
"workers_count"
,
&
self
.worker_registry
.get_all
()
.len
())
.field
(
"default_router"
,
&
self
.default_router
)
.finish
()
}
}
sgl-router/src/server.rs
View file @
2f173ea0
use
crate
::
config
::
RouterConfig
;
use
crate
::
core
::
WorkerRegistry
;
use
crate
::
logging
::{
self
,
LoggingConfig
};
use
crate
::
metrics
::{
self
,
PrometheusConfig
};
use
crate
::
middleware
::
TokenBucket
;
use
crate
::
policies
::
PolicyRegistry
;
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
,
RerankRequest
,
ResponsesRequest
,
V1RerankReqInput
,
};
use
crate
::
protocols
::
worker_spec
::{
WorkerApiResponse
,
WorkerConfigRequest
,
WorkerErrorResponse
};
use
crate
::
reasoning_parser
::
ParserFactory
;
use
crate
::
routers
::
router_manager
::{
RouterId
,
RouterManager
};
use
crate
::
routers
::{
RouterFactory
,
RouterTrait
};
use
crate
::
service_discovery
::{
start_service_discovery
,
ServiceDiscoveryConfig
};
use
crate
::
tokenizer
::{
factory
as
tokenizer_factory
,
traits
::
Tokenizer
};
...
...
@@ -36,6 +40,9 @@ pub struct AppContext {
pub
tokenizer
:
Option
<
Arc
<
dyn
Tokenizer
>>
,
pub
reasoning_parser_factory
:
Option
<
ParserFactory
>
,
pub
tool_parser_registry
:
Option
<&
'static
ParserRegistry
>
,
pub
worker_registry
:
Arc
<
WorkerRegistry
>
,
// Shared worker registry
pub
policy_registry
:
Arc
<
PolicyRegistry
>
,
// Shared policy registry
pub
router_manager
:
Option
<
Arc
<
RouterManager
>>
,
// Only present when enable_igw=true
}
impl
AppContext
{
...
...
@@ -75,6 +82,15 @@ impl AppContext {
(
None
,
None
,
None
)
};
// Initialize shared registries
let
worker_registry
=
Arc
::
new
(
WorkerRegistry
::
new
());
let
policy_registry
=
Arc
::
new
(
PolicyRegistry
::
new
(
router_config
.policy
.clone
(),
// Use default policy from config
));
// Initialize RouterManager only when enable_igw is true
let
router_manager
=
None
;
// Will be initialized in startup() based on config
Ok
(
Self
{
client
,
router_config
,
...
...
@@ -82,6 +98,9 @@ impl AppContext {
tokenizer
,
reasoning_parser_factory
,
tool_parser_registry
,
worker_registry
,
policy_registry
,
router_manager
,
})
}
}
...
...
@@ -134,7 +153,10 @@ async fn generate(
headers
:
http
::
HeaderMap
,
Json
(
body
):
Json
<
GenerateRequest
>
,
)
->
Response
{
state
.router
.route_generate
(
Some
(
&
headers
),
&
body
)
.await
state
.router
.route_generate
(
Some
(
&
headers
),
&
body
,
None
)
.await
}
async
fn
v1_chat_completions
(
...
...
@@ -142,7 +164,7 @@ async fn v1_chat_completions(
headers
:
http
::
HeaderMap
,
Json
(
body
):
Json
<
ChatCompletionRequest
>
,
)
->
Response
{
state
.router
.route_chat
(
Some
(
&
headers
),
&
body
)
.await
state
.router
.route_chat
(
Some
(
&
headers
),
&
body
,
None
)
.await
}
async
fn
v1_completions
(
...
...
@@ -150,7 +172,10 @@ async fn v1_completions(
headers
:
http
::
HeaderMap
,
Json
(
body
):
Json
<
CompletionRequest
>
,
)
->
Response
{
state
.router
.route_completion
(
Some
(
&
headers
),
&
body
)
.await
state
.router
.route_completion
(
Some
(
&
headers
),
&
body
,
None
)
.await
}
async
fn
rerank
(
...
...
@@ -158,7 +183,7 @@ async fn rerank(
headers
:
http
::
HeaderMap
,
Json
(
body
):
Json
<
RerankRequest
>
,
)
->
Response
{
state
.router
.route_rerank
(
Some
(
&
headers
),
&
body
)
.await
state
.router
.route_rerank
(
Some
(
&
headers
),
&
body
,
None
)
.await
}
async
fn
v1_rerank
(
...
...
@@ -168,7 +193,7 @@ async fn v1_rerank(
)
->
Response
{
state
.router
.route_rerank
(
Some
(
&
headers
),
&
body
.into
())
.route_rerank
(
Some
(
&
headers
),
&
body
.into
()
,
None
)
.await
}
...
...
@@ -177,7 +202,10 @@ async fn v1_responses(
headers
:
http
::
HeaderMap
,
Json
(
body
):
Json
<
ResponsesRequest
>
,
)
->
Response
{
state
.router
.route_responses
(
Some
(
&
headers
),
&
body
)
.await
state
.router
.route_responses
(
Some
(
&
headers
),
&
body
,
None
)
.await
}
// Worker management endpoints
...
...
@@ -232,6 +260,137 @@ async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Respons
state
.router
.get_worker_loads
()
.await
}
// New RESTful worker management endpoints (when enable_igw=true)
/// POST /workers - Add a new worker with full configuration
async
fn
create_worker
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
Json
(
config
):
Json
<
WorkerConfigRequest
>
,
)
->
Response
{
// Check if RouterManager is available (enable_igw=true)
if
let
Some
(
router_manager
)
=
&
state
.context.router_manager
{
match
router_manager
.add_worker
(
config
)
.await
{
Ok
(
response
)
=>
(
StatusCode
::
OK
,
Json
(
response
))
.into_response
(),
Err
(
error
)
=>
(
StatusCode
::
BAD_REQUEST
,
Json
(
error
))
.into_response
(),
}
}
else
{
// In single router mode, use the router's add_worker with basic config
match
state
.router
.add_worker
(
&
config
.url
)
.await
{
Ok
(
message
)
=>
{
let
response
=
WorkerApiResponse
{
success
:
true
,
message
,
worker
:
None
,
};
(
StatusCode
::
OK
,
Json
(
response
))
.into_response
()
}
Err
(
error
)
=>
{
let
error_response
=
WorkerErrorResponse
{
error
,
code
:
"ADD_WORKER_FAILED"
.to_string
(),
};
(
StatusCode
::
BAD_REQUEST
,
Json
(
error_response
))
.into_response
()
}
}
}
}
/// GET /workers - List all workers with details
async
fn
list_workers_rest
(
State
(
state
):
State
<
Arc
<
AppState
>>
)
->
Response
{
if
let
Some
(
router_manager
)
=
&
state
.context.router_manager
{
let
response
=
router_manager
.list_workers
();
Json
(
response
)
.into_response
()
}
else
{
// In single router mode, get detailed worker info from registry
let
workers
=
state
.context.worker_registry
.get_all
();
let
response
=
serde_json
::
json!
({
"workers"
:
workers
.iter
()
.map
(|
worker
|
{
let
mut
worker_info
=
serde_json
::
json!
({
"url"
:
worker
.url
(),
"model_id"
:
worker
.model_id
(),
"worker_type"
:
format!
(
"{:?}"
,
worker
.worker_type
()),
"is_healthy"
:
worker
.is_healthy
(),
"load"
:
worker
.load
(),
"connection_mode"
:
format!
(
"{:?}"
,
worker
.connection_mode
()),
"priority"
:
worker
.priority
(),
"cost"
:
worker
.cost
(),
});
// Add bootstrap_port for Prefill workers
if
let
crate
::
core
::
WorkerType
::
Prefill
{
bootstrap_port
}
=
worker
.worker_type
()
{
worker_info
[
"bootstrap_port"
]
=
serde_json
::
json!
(
bootstrap_port
);
}
worker_info
})
.collect
::
<
Vec
<
_
>>
(),
"total"
:
workers
.len
(),
"stats"
:
{
"prefill_count"
:
state
.context.worker_registry
.get_prefill_workers
()
.len
(),
"decode_count"
:
state
.context.worker_registry
.get_decode_workers
()
.len
(),
"regular_count"
:
state
.context.worker_registry
.get_by_type
(
&
crate
::
core
::
WorkerType
::
Regular
)
.len
(),
}
});
Json
(
response
)
.into_response
()
}
}
/// GET /workers/{url} - Get specific worker info
async
fn
get_worker
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
axum
::
extract
::
Path
(
url
):
axum
::
extract
::
Path
<
String
>
,
)
->
Response
{
if
let
Some
(
router_manager
)
=
&
state
.context.router_manager
{
if
let
Some
(
worker
)
=
router_manager
.get_worker
(
&
url
)
{
Json
(
worker
)
.into_response
()
}
else
{
let
error
=
WorkerErrorResponse
{
error
:
format!
(
"Worker {} not found"
,
url
),
code
:
"WORKER_NOT_FOUND"
.to_string
(),
};
(
StatusCode
::
NOT_FOUND
,
Json
(
error
))
.into_response
()
}
}
else
{
// In single router mode, check if worker exists
let
workers
=
state
.router
.get_worker_urls
();
if
workers
.contains
(
&
url
)
{
let
worker_info
=
serde_json
::
json!
({
"url"
:
url
,
"model_id"
:
"unknown"
,
"is_healthy"
:
true
});
Json
(
worker_info
)
.into_response
()
}
else
{
let
error
=
WorkerErrorResponse
{
error
:
format!
(
"Worker {} not found"
,
url
),
code
:
"WORKER_NOT_FOUND"
.to_string
(),
};
(
StatusCode
::
NOT_FOUND
,
Json
(
error
))
.into_response
()
}
}
}
/// DELETE /workers/{url} - Remove a worker
async
fn
delete_worker
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
axum
::
extract
::
Path
(
url
):
axum
::
extract
::
Path
<
String
>
,
)
->
Response
{
if
let
Some
(
router_manager
)
=
&
state
.context.router_manager
{
match
router_manager
.remove_worker_from_registry
(
&
url
)
{
Ok
(
response
)
=>
(
StatusCode
::
OK
,
Json
(
response
))
.into_response
(),
Err
(
error
)
=>
(
StatusCode
::
BAD_REQUEST
,
Json
(
error
))
.into_response
(),
}
}
else
{
// In single router mode, use router's remove_worker
state
.router
.remove_worker
(
&
url
);
let
response
=
WorkerApiResponse
{
success
:
true
,
message
:
format!
(
"Worker {} removed successfully"
,
url
),
worker
:
None
,
};
(
StatusCode
::
OK
,
Json
(
response
))
.into_response
()
}
}
pub
struct
ServerConfig
{
pub
host
:
String
,
pub
port
:
u16
,
...
...
@@ -281,11 +440,19 @@ pub fn build_app(
.route
(
"/flush_cache"
,
post
(
flush_cache
))
.route
(
"/get_loads"
,
get
(
get_loads
));
// Worker management routes
let
worker_routes
=
Router
::
new
()
.route
(
"/workers"
,
post
(
create_worker
))
.route
(
"/workers"
,
get
(
list_workers_rest
))
.route
(
"/workers/{url}"
,
get
(
get_worker
))
.route
(
"/workers/{url}"
,
axum
::
routing
::
delete
(
delete_worker
));
// Build app with all routes and middleware
Router
::
new
()
.merge
(
protected_routes
)
.merge
(
public_routes
)
.merge
(
admin_routes
)
.merge
(
worker_routes
)
// Request body size limiting
.layer
(
tower_http
::
limit
::
RequestBodyLimitLayer
::
new
(
max_payload_size
,
...
...
@@ -355,15 +522,100 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
.expect
(
"Failed to create HTTP client"
);
// Create the application context with all dependencies
let
app_context
=
Arc
::
new
(
AppContext
::
new
(
let
app_context
=
AppContext
::
new
(
config
.router_config
.clone
(),
client
.clone
(),
config
.router_config.max_concurrent_requests
,
config
.router_config.rate_limit_tokens_per_second
,
)
?
);
)
?
;
let
app_context
=
Arc
::
new
(
app_context
);
// Create the appropriate router based on enable_igw flag
let
router
:
Box
<
dyn
RouterTrait
>
=
if
config
.router_config.enable_igw
{
info!
(
"Multi-router mode enabled (enable_igw=true)"
);
// Create RouterManager with shared registries from AppContext
let
mut
router_manager
=
RouterManager
::
new
(
config
.router_config
.clone
(),
client
.clone
(),
app_context
.worker_registry
.clone
(),
app_context
.policy_registry
.clone
(),
);
// Create HTTP routers at startup (with empty worker lists)
// Workers will be added to these routers dynamically via RouterManager's worker registry
// 1. HTTP Regular Router
match
RouterFactory
::
create_regular_router
(
&
[],
// Empty worker list - workers added later
&
app_context
,
)
.await
{
Ok
(
http_regular
)
=>
{
info!
(
"Created HTTP Regular router"
);
router_manager
.register_router
(
RouterId
::
new
(
"http-regular"
.to_string
()),
Arc
::
from
(
http_regular
),
vec!
[],
// Models will be determined by workers
);
}
Err
(
e
)
=>
{
warn!
(
"Failed to create HTTP Regular router: {}"
,
e
);
}
}
// 2. HTTP PD Router
match
RouterFactory
::
create_pd_router
(
&
[],
// Empty prefill URLs
&
[],
// Empty decode URLs
None
,
// Use default prefill policy
None
,
// Use default decode policy
&
config
.router_config.policy
,
&
app_context
,
)
.await
{
Ok
(
http_pd
)
=>
{
info!
(
"Created HTTP PD router"
);
router_manager
.register_router
(
RouterId
::
new
(
"http-pd"
.to_string
()),
Arc
::
from
(
http_pd
),
vec!
[],
);
}
Err
(
e
)
=>
{
warn!
(
"Failed to create HTTP PD router: {}"
,
e
);
}
}
// TODO: Add gRPC routers once we have dynamic tokenizer loading
// Currently gRPC routers require tokenizer to be initialized first,
// but each model needs its own tokenizer. Once we implement dynamic
// tokenizer loading per model, we can enable gRPC routers here:
// - RouterType::GrpcRegular (RouterId: "grpc-regular")
// - RouterType::GrpcPd (RouterId: "grpc-pd")
// Create router with the context
let
router
=
RouterFactory
::
create_router
(
&
app_context
)
.await
?
;
info!
(
"RouterManager initialized with {} routers"
,
router_manager
.router_count
()
);
Box
::
new
(
router_manager
)
}
else
{
info!
(
"Single router mode (enable_igw=false)"
);
// Create single router with the context
RouterFactory
::
create_router
(
&
app_context
)
.await
?
};
// Start health checker for all workers in the registry
let
_
health_checker
=
app_context
.worker_registry
.start_health_checker
(
config
.router_config.health_check.check_interval_secs
);
info!
(
"Started health checker for workers with {}s interval"
,
config
.router_config.health_check.check_interval_secs
);
// Set up concurrency limiter with queue if configured
let
(
limiter
,
processor
)
=
crate
::
middleware
::
ConcurrencyLimiter
::
new
(
...
...
sgl-router/src/service_discovery.rs
View file @
2f173ea0
...
...
@@ -579,9 +579,8 @@ mod tests {
// Helper to create a Router instance for testing event handlers
async
fn
create_test_router
()
->
Arc
<
dyn
RouterTrait
>
{
use
crate
::
config
::
{
PolicyConfig
,
RouterConfig
}
;
use
crate
::
config
::
RouterConfig
;
use
crate
::
middleware
::
TokenBucket
;
use
crate
::
policies
::
PolicyFactory
;
use
crate
::
routers
::
http
::
router
::
Router
;
use
crate
::
server
::
AppContext
;
...
...
@@ -591,15 +590,19 @@ mod tests {
// Create AppContext with minimal components
let
app_context
=
Arc
::
new
(
AppContext
{
client
:
reqwest
::
Client
::
new
(),
router_config
,
router_config
:
router_config
.clone
()
,
rate_limiter
:
Arc
::
new
(
TokenBucket
::
new
(
1000
,
1000
)),
worker_registry
:
Arc
::
new
(
crate
::
core
::
WorkerRegistry
::
new
()),
policy_registry
:
Arc
::
new
(
crate
::
policies
::
PolicyRegistry
::
new
(
router_config
.policy
.clone
(),
)),
tokenizer
:
None
,
// HTTP mode doesn't need tokenizer
reasoning_parser_factory
:
None
,
// HTTP mode doesn't need reasoning parser
tool_parser_registry
:
None
,
// HTTP mode doesn't need tool parser
router_manager
:
None
,
// Test doesn't need router manager
});
let
policy
=
PolicyFactory
::
create_from_config
(
&
PolicyConfig
::
Random
);
let
router
=
Router
::
new
(
vec!
[],
policy
,
&
app_context
)
.await
.unwrap
();
let
router
=
Router
::
new
(
vec!
[],
&
app_context
)
.await
.unwrap
();
Arc
::
new
(
router
)
as
Arc
<
dyn
RouterTrait
>
}
...
...
sgl-router/tests/cache_aware_backward_compat_test.rs
0 → 100644
View file @
2f173ea0
use
sglang_router_rs
::
core
::{
BasicWorker
,
Worker
,
WorkerType
};
use
sglang_router_rs
::
policies
::{
CacheAwareConfig
,
CacheAwarePolicy
,
LoadBalancingPolicy
};
use
std
::
collections
::
HashMap
;
use
std
::
sync
::
Arc
;
#[test]
fn
test_backward_compatibility_with_empty_model_id
()
{
let
config
=
CacheAwareConfig
{
cache_threshold
:
0.5
,
balance_abs_threshold
:
2
,
balance_rel_threshold
:
1.5
,
eviction_interval_secs
:
0
,
// Disable background eviction for testing
max_tree_size
:
100
,
};
let
policy
=
CacheAwarePolicy
::
with_config
(
config
);
// Create workers with empty model_id (simulating existing routers)
let
worker1
=
BasicWorker
::
new
(
"http://worker1:8080"
.to_string
(),
WorkerType
::
Regular
);
// No model_id label - should default to "unknown"
let
mut
labels2
=
HashMap
::
new
();
labels2
.insert
(
"model_id"
.to_string
(),
"unknown"
.to_string
());
let
worker2
=
BasicWorker
::
new
(
"http://worker2:8080"
.to_string
(),
WorkerType
::
Regular
)
.with_labels
(
labels2
);
// Add workers - should both go to "default" tree
policy
.add_worker
(
&
worker1
);
policy
.add_worker
(
&
worker2
);
// Create worker list
let
workers
:
Vec
<
Arc
<
dyn
Worker
>>
=
vec!
[
Arc
::
new
(
worker1
.clone
()),
Arc
::
new
(
worker2
.clone
())];
// Select worker - should work without errors
let
selected
=
policy
.select_worker
(
&
workers
,
Some
(
"test request"
));
assert
!
(
selected
.is_some
(),
"Should select a worker"
);
// Remove workers - should work without errors
policy
.remove_worker
(
&
worker1
);
policy
.remove_worker
(
&
worker2
);
}
#[test]
fn
test_mixed_model_ids
()
{
let
config
=
CacheAwareConfig
{
cache_threshold
:
0.5
,
balance_abs_threshold
:
2
,
balance_rel_threshold
:
1.5
,
eviction_interval_secs
:
0
,
max_tree_size
:
100
,
};
let
policy
=
CacheAwarePolicy
::
with_config
(
config
);
// Create workers with different model_id scenarios
let
worker1
=
BasicWorker
::
new
(
"http://worker1:8080"
.to_string
(),
WorkerType
::
Regular
);
// No model_id label - defaults to "unknown" which goes to "default" tree
let
mut
labels2
=
HashMap
::
new
();
labels2
.insert
(
"model_id"
.to_string
(),
"llama-3"
.to_string
());
let
worker2
=
BasicWorker
::
new
(
"http://worker2:8080"
.to_string
(),
WorkerType
::
Regular
)
.with_labels
(
labels2
);
let
mut
labels3
=
HashMap
::
new
();
labels3
.insert
(
"model_id"
.to_string
(),
"unknown"
.to_string
());
let
worker3
=
BasicWorker
::
new
(
"http://worker3:8080"
.to_string
(),
WorkerType
::
Regular
)
.with_labels
(
labels3
);
let
mut
labels4
=
HashMap
::
new
();
labels4
.insert
(
"model_id"
.to_string
(),
"llama-3"
.to_string
());
let
worker4
=
BasicWorker
::
new
(
"http://worker4:8080"
.to_string
(),
WorkerType
::
Regular
)
.with_labels
(
labels4
);
// Add all workers
policy
.add_worker
(
&
worker1
);
policy
.add_worker
(
&
worker2
);
policy
.add_worker
(
&
worker3
);
policy
.add_worker
(
&
worker4
);
// Test selection with default workers only
let
default_workers
:
Vec
<
Arc
<
dyn
Worker
>>
=
vec!
[
Arc
::
new
(
worker1
.clone
()),
Arc
::
new
(
worker3
.clone
())];
let
selected
=
policy
.select_worker
(
&
default_workers
,
Some
(
"test request"
));
assert
!
(
selected
.is_some
(),
"Should select from default workers"
);
// Test selection with specific model workers only
let
llama_workers
:
Vec
<
Arc
<
dyn
Worker
>>
=
vec!
[
Arc
::
new
(
worker2
.clone
()),
Arc
::
new
(
worker4
.clone
())];
let
selected
=
policy
.select_worker
(
&
llama_workers
,
Some
(
"test request"
));
assert
!
(
selected
.is_some
(),
"Should select from llama-3 workers"
);
// Test selection with mixed workers
let
all_workers
:
Vec
<
Arc
<
dyn
Worker
>>
=
vec!
[
Arc
::
new
(
worker1
.clone
()),
Arc
::
new
(
worker2
.clone
()),
Arc
::
new
(
worker3
.clone
()),
Arc
::
new
(
worker4
.clone
()),
];
let
selected
=
policy
.select_worker
(
&
all_workers
,
Some
(
"test request"
));
assert
!
(
selected
.is_some
(),
"Should select from all workers"
);
}
#[test]
fn
test_remove_worker_by_url_backward_compat
()
{
let
config
=
CacheAwareConfig
::
default
();
let
policy
=
CacheAwarePolicy
::
with_config
(
config
);
// Create workers with different model_ids
let
mut
labels1
=
HashMap
::
new
();
labels1
.insert
(
"model_id"
.to_string
(),
"llama-3"
.to_string
());
let
worker1
=
BasicWorker
::
new
(
"http://worker1:8080"
.to_string
(),
WorkerType
::
Regular
)
.with_labels
(
labels1
);
let
worker2
=
BasicWorker
::
new
(
"http://worker2:8080"
.to_string
(),
WorkerType
::
Regular
);
// No model_id label - defaults to "unknown"
// Add workers
policy
.add_worker
(
&
worker1
);
policy
.add_worker
(
&
worker2
);
// Remove by URL (backward compatibility method)
// Should remove from all trees since we don't know the model
policy
.remove_worker_by_url
(
"http://worker1:8080"
);
// Verify removal worked
let
workers
:
Vec
<
Arc
<
dyn
Worker
>>
=
vec!
[
Arc
::
new
(
worker2
.clone
())];
let
selected
=
policy
.select_worker
(
&
workers
,
Some
(
"test"
));
assert_eq!
(
selected
,
Some
(
0
),
"Should only have worker2 left"
);
}
sgl-router/tests/policy_registry_integration.rs
0 → 100644
View file @
2f173ea0
//! Integration tests for PolicyRegistry with RouterManager
use
sglang_router_rs
::
config
::{
PolicyConfig
,
RouterConfig
};
use
sglang_router_rs
::
core
::
WorkerRegistry
;
use
sglang_router_rs
::
policies
::
PolicyRegistry
;
use
sglang_router_rs
::
protocols
::
worker_spec
::
WorkerConfigRequest
;
use
sglang_router_rs
::
routers
::
router_manager
::
RouterManager
;
use
std
::
collections
::
HashMap
;
use
std
::
sync
::
Arc
;
#[tokio::test]
async
fn
test_policy_registry_with_router_manager
()
{
// Create RouterConfig
let
config
=
RouterConfig
{
enable_igw
:
true
,
policy
:
PolicyConfig
::
RoundRobin
,
..
Default
::
default
()
};
// Create HTTP client
let
client
=
reqwest
::
Client
::
new
();
// Create shared registries
let
worker_registry
=
Arc
::
new
(
WorkerRegistry
::
new
());
let
policy_registry
=
Arc
::
new
(
PolicyRegistry
::
new
(
PolicyConfig
::
RoundRobin
));
// Create RouterManager with shared registries
let
_
router_manager
=
RouterManager
::
new
(
config
,
client
,
worker_registry
.clone
(),
policy_registry
.clone
(),
);
// Test adding workers with different models and policies
// Add first worker for llama-3 with cache_aware policy hint
let
mut
labels1
=
HashMap
::
new
();
labels1
.insert
(
"policy"
.to_string
(),
"cache_aware"
.to_string
());
let
_
worker1_config
=
WorkerConfigRequest
{
url
:
"http://worker1:8000"
.to_string
(),
model_id
:
Some
(
"llama-3"
.to_string
()),
worker_type
:
None
,
priority
:
None
,
cost
:
None
,
labels
:
labels1
,
bootstrap_port
:
None
,
tokenizer_path
:
None
,
reasoning_parser
:
None
,
tool_parser
:
None
,
chat_template
:
None
,
};
// This would normally connect to a real worker, but for testing we'll just verify the structure
// In a real test, we'd need to mock the worker or use a test server
// Verify PolicyRegistry has the correct policy for llama-3
let
_
llama_policy
=
policy_registry
.get_policy
(
"llama-3"
);
// After first worker is added, llama-3 should have a policy
// Add second worker for llama-3 with different policy hint (should be ignored)
let
mut
labels2
=
HashMap
::
new
();
labels2
.insert
(
"policy"
.to_string
(),
"random"
.to_string
());
let
_
worker2_config
=
WorkerConfigRequest
{
url
:
"http://worker2:8000"
.to_string
(),
model_id
:
Some
(
"llama-3"
.to_string
()),
worker_type
:
None
,
priority
:
None
,
cost
:
None
,
labels
:
labels2
,
bootstrap_port
:
None
,
tokenizer_path
:
None
,
reasoning_parser
:
None
,
tool_parser
:
None
,
chat_template
:
None
,
};
// The second worker should use the same policy as the first (cache_aware)
// Add worker for different model (gpt-4) with random policy
let
mut
labels3
=
HashMap
::
new
();
labels3
.insert
(
"policy"
.to_string
(),
"random"
.to_string
());
let
_
worker3_config
=
WorkerConfigRequest
{
url
:
"http://worker3:8000"
.to_string
(),
model_id
:
Some
(
"gpt-4"
.to_string
()),
worker_type
:
None
,
priority
:
None
,
cost
:
None
,
labels
:
labels3
,
bootstrap_port
:
None
,
tokenizer_path
:
None
,
reasoning_parser
:
None
,
tool_parser
:
None
,
chat_template
:
None
,
};
// Verify gpt-4 has random policy
let
_
gpt_policy
=
policy_registry
.get_policy
(
"gpt-4"
);
// Test removing workers
// When we remove both llama-3 workers, the policy should be cleaned up
println!
(
"PolicyRegistry integration test structure created"
);
println!
(
"Note: This test requires mocking or test servers to fully execute"
);
}
#[test]
fn
test_policy_registry_cleanup
()
{
use
sglang_router_rs
::
config
::
PolicyConfig
;
use
sglang_router_rs
::
policies
::
PolicyRegistry
;
let
registry
=
PolicyRegistry
::
new
(
PolicyConfig
::
RoundRobin
);
// Add workers for a model
let
policy1
=
registry
.on_worker_added
(
"model-1"
,
Some
(
"cache_aware"
));
assert_eq!
(
policy1
.name
(),
"cache_aware"
);
// Second worker uses existing policy
let
policy2
=
registry
.on_worker_added
(
"model-1"
,
Some
(
"random"
));
assert_eq!
(
policy2
.name
(),
"cache_aware"
);
// Should still be cache_aware
// Verify policy exists
assert
!
(
registry
.get_policy
(
"model-1"
)
.is_some
());
// Remove first worker - policy should remain
registry
.on_worker_removed
(
"model-1"
);
assert
!
(
registry
.get_policy
(
"model-1"
)
.is_some
());
// Remove second worker - policy should be cleaned up
registry
.on_worker_removed
(
"model-1"
);
assert
!
(
registry
.get_policy
(
"model-1"
)
.is_none
());
println!
(
"✓ PolicyRegistry cleanup test passed"
);
}
#[test]
fn
test_policy_registry_multiple_models
()
{
use
sglang_router_rs
::
config
::
PolicyConfig
;
use
sglang_router_rs
::
policies
::
PolicyRegistry
;
let
registry
=
PolicyRegistry
::
new
(
PolicyConfig
::
RoundRobin
);
// Add workers for different models with different policies
let
llama_policy
=
registry
.on_worker_added
(
"llama-3"
,
Some
(
"cache_aware"
));
let
gpt_policy
=
registry
.on_worker_added
(
"gpt-4"
,
Some
(
"random"
));
let
mistral_policy
=
registry
.on_worker_added
(
"mistral"
,
None
);
// Uses default
assert_eq!
(
llama_policy
.name
(),
"cache_aware"
);
assert_eq!
(
gpt_policy
.name
(),
"random"
);
assert_eq!
(
mistral_policy
.name
(),
"round_robin"
);
// Default
// Verify all policies are stored
assert
!
(
registry
.get_policy
(
"llama-3"
)
.is_some
());
assert
!
(
registry
.get_policy
(
"gpt-4"
)
.is_some
());
assert
!
(
registry
.get_policy
(
"mistral"
)
.is_some
());
// Get all mappings
let
mappings
=
registry
.get_all_mappings
();
assert_eq!
(
mappings
.len
(),
3
);
assert_eq!
(
mappings
.get
(
"llama-3"
)
.unwrap
(),
"cache_aware"
);
assert_eq!
(
mappings
.get
(
"gpt-4"
)
.unwrap
(),
"random"
);
assert_eq!
(
mappings
.get
(
"mistral"
)
.unwrap
(),
"round_robin"
);
println!
(
"✓ PolicyRegistry multiple models test passed"
);
}
sgl-router/tests/test_openai_routing.rs
View file @
2f173ea0
...
...
@@ -197,12 +197,14 @@ async fn test_unsupported_endpoints() {
rid
:
None
,
};
let
response
=
router
.route_generate
(
None
,
&
generate_request
)
.await
;
let
response
=
router
.route_generate
(
None
,
&
generate_request
,
None
)
.await
;
assert_eq!
(
response
.status
(),
StatusCode
::
NOT_IMPLEMENTED
);
// Test completion endpoint (should also not be supported)
let
completion_request
=
create_minimal_completion_request
();
let
response
=
router
.route_completion
(
None
,
&
completion_request
)
.await
;
let
response
=
router
.route_completion
(
None
,
&
completion_request
,
None
)
.await
;
assert_eq!
(
response
.status
(),
StatusCode
::
NOT_IMPLEMENTED
);
}
...
...
@@ -228,7 +230,7 @@ async fn test_openai_router_chat_completion_with_mock() {
chat_request
.temperature
=
Some
(
0.7
);
// Route the request
let
response
=
router
.route_chat
(
None
,
&
chat_request
)
.await
;
let
response
=
router
.route_chat
(
None
,
&
chat_request
,
None
)
.await
;
// Should get a successful response from mock server
assert_eq!
(
response
.status
(),
StatusCode
::
OK
);
...
...
@@ -269,7 +271,9 @@ async fn test_openai_e2e_with_server() {
let
chat_request
:
ChatCompletionRequest
=
serde_json
::
from_str
(
&
body_str
)
.unwrap
();
router
.route_chat
(
Some
(
&
parts
.headers
),
&
chat_request
)
.await
router
.route_chat
(
Some
(
&
parts
.headers
),
&
chat_request
,
None
)
.await
}
}
}),
...
...
@@ -327,7 +331,7 @@ async fn test_openai_router_chat_streaming_with_mock() {
});
let
chat_request
:
ChatCompletionRequest
=
serde_json
::
from_value
(
val
)
.unwrap
();
let
response
=
router
.route_chat
(
None
,
&
chat_request
)
.await
;
let
response
=
router
.route_chat
(
None
,
&
chat_request
,
None
)
.await
;
assert_eq!
(
response
.status
(),
StatusCode
::
OK
);
// Should be SSE
...
...
@@ -371,7 +375,7 @@ async fn test_openai_router_circuit_breaker() {
// First few requests should fail and record failures
for
_
in
0
..
3
{
let
response
=
router
.route_chat
(
None
,
&
chat_request
)
.await
;
let
response
=
router
.route_chat
(
None
,
&
chat_request
,
None
)
.await
;
// Should get either an error or circuit breaker response
assert
!
(
response
.status
()
==
StatusCode
::
INTERNAL_SERVER_ERROR
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment