Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
89971c4c
Unverified
Commit
89971c4c
authored
Sep 22, 2025
by
Simo Lin
Committed by
GitHub
Sep 22, 2025
Browse files
[router] refactor router and worker management 4/n (#10756)
Co-authored-by:
Chang Su
<
chang.s.su@oracle.com
>
parent
113f8f65
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
161 additions
and
196 deletions
+161
-196
sgl-router/src/routers/http/pd_router.rs
sgl-router/src/routers/http/pd_router.rs
+12
-2
sgl-router/src/routers/http/router.rs
sgl-router/src/routers/http/router.rs
+6
-1
sgl-router/src/routers/router_manager.rs
sgl-router/src/routers/router_manager.rs
+137
-84
sgl-router/src/server.rs
sgl-router/src/server.rs
+6
-109
No files found.
sgl-router/src/routers/http/pd_router.rs
View file @
89971c4c
...
@@ -41,6 +41,7 @@ pub struct PDRouter {
...
@@ -41,6 +41,7 @@ pub struct PDRouter {
pub
prefill_client
:
Client
,
pub
prefill_client
:
Client
,
pub
retry_config
:
RetryConfig
,
pub
retry_config
:
RetryConfig
,
pub
api_key
:
Option
<
String
>
,
pub
api_key
:
Option
<
String
>
,
pub
enable_igw
:
bool
,
prefill_drain_tx
:
mpsc
::
Sender
<
reqwest
::
Response
>
,
prefill_drain_tx
:
mpsc
::
Sender
<
reqwest
::
Response
>
,
}
}
...
@@ -317,6 +318,7 @@ impl PDRouter {
...
@@ -317,6 +318,7 @@ impl PDRouter {
prefill_drain_tx
,
prefill_drain_tx
,
retry_config
:
ctx
.router_config
.effective_retry_config
(),
retry_config
:
ctx
.router_config
.effective_retry_config
(),
api_key
:
ctx
.router_config.api_key
.clone
(),
api_key
:
ctx
.router_config.api_key
.clone
(),
enable_igw
:
ctx
.router_config.enable_igw
,
})
})
}
}
...
@@ -849,7 +851,14 @@ impl PDRouter {
...
@@ -849,7 +851,14 @@ impl PDRouter {
request_text
:
Option
<&
str
>
,
request_text
:
Option
<&
str
>
,
model_id
:
Option
<&
str
>
,
model_id
:
Option
<&
str
>
,
)
->
Result
<
(
Arc
<
dyn
Worker
>
,
Arc
<
dyn
Worker
>
),
String
>
{
)
->
Result
<
(
Arc
<
dyn
Worker
>
,
Arc
<
dyn
Worker
>
),
String
>
{
let
prefill_workers
=
if
let
Some
(
model
)
=
model_id
{
let
effective_model_id
=
if
!
self
.enable_igw
{
None
}
else
{
model_id
};
debug!
(
"Selecting PD pair: enable_igw={}, model_id={:?}, effective_model_id={:?}"
,
self
.enable_igw
,
model_id
,
effective_model_id
);
let
prefill_workers
=
if
let
Some
(
model
)
=
effective_model_id
{
self
.worker_registry
self
.worker_registry
.get_by_model_fast
(
model
)
.get_by_model_fast
(
model
)
.into_iter
()
.into_iter
()
...
@@ -859,7 +868,7 @@ impl PDRouter {
...
@@ -859,7 +868,7 @@ impl PDRouter {
self
.worker_registry
.get_prefill_workers
()
self
.worker_registry
.get_prefill_workers
()
};
};
let
decode_workers
=
if
let
Some
(
model
)
=
model_id
{
let
decode_workers
=
if
let
Some
(
model
)
=
effective_
model_id
{
self
.worker_registry
self
.worker_registry
.get_by_model_fast
(
model
)
.get_by_model_fast
(
model
)
.into_iter
()
.into_iter
()
...
@@ -1797,6 +1806,7 @@ mod tests {
...
@@ -1797,6 +1806,7 @@ mod tests {
prefill_drain_tx
:
mpsc
::
channel
(
100
)
.0
,
prefill_drain_tx
:
mpsc
::
channel
(
100
)
.0
,
retry_config
:
RetryConfig
::
default
(),
retry_config
:
RetryConfig
::
default
(),
api_key
:
Some
(
"test_api_key"
.to_string
()),
api_key
:
Some
(
"test_api_key"
.to_string
()),
enable_igw
:
false
,
}
}
}
}
...
...
sgl-router/src/routers/http/router.rs
View file @
89971c4c
...
@@ -35,6 +35,7 @@ pub struct Router {
...
@@ -35,6 +35,7 @@ pub struct Router {
policy_registry
:
Arc
<
PolicyRegistry
>
,
policy_registry
:
Arc
<
PolicyRegistry
>
,
client
:
Client
,
client
:
Client
,
dp_aware
:
bool
,
dp_aware
:
bool
,
enable_igw
:
bool
,
retry_config
:
RetryConfig
,
retry_config
:
RetryConfig
,
_
worker_loads
:
Arc
<
tokio
::
sync
::
watch
::
Receiver
<
HashMap
<
String
,
isize
>>>
,
_
worker_loads
:
Arc
<
tokio
::
sync
::
watch
::
Receiver
<
HashMap
<
String
,
isize
>>>
,
_
load_monitor_handle
:
Option
<
Arc
<
tokio
::
task
::
JoinHandle
<
()
>>>
,
_
load_monitor_handle
:
Option
<
Arc
<
tokio
::
task
::
JoinHandle
<
()
>>>
,
...
@@ -93,6 +94,7 @@ impl Router {
...
@@ -93,6 +94,7 @@ impl Router {
policy_registry
:
ctx
.policy_registry
.clone
(),
policy_registry
:
ctx
.policy_registry
.clone
(),
client
:
ctx
.client
.clone
(),
client
:
ctx
.client
.clone
(),
dp_aware
:
ctx
.router_config.dp_aware
,
dp_aware
:
ctx
.router_config.dp_aware
,
enable_igw
:
ctx
.router_config.enable_igw
,
retry_config
:
ctx
.router_config
.effective_retry_config
(),
retry_config
:
ctx
.router_config
.effective_retry_config
(),
_
worker_loads
:
worker_loads
,
_
worker_loads
:
worker_loads
,
_
load_monitor_handle
:
load_monitor_handle
,
_
load_monitor_handle
:
load_monitor_handle
,
...
@@ -162,9 +164,11 @@ impl Router {
...
@@ -162,9 +164,11 @@ impl Router {
model_id
:
Option
<&
str
>
,
model_id
:
Option
<&
str
>
,
text
:
Option
<&
str
>
,
text
:
Option
<&
str
>
,
)
->
Option
<
Arc
<
dyn
Worker
>>
{
)
->
Option
<
Arc
<
dyn
Worker
>>
{
let
effective_model_id
=
if
!
self
.enable_igw
{
None
}
else
{
model_id
};
// Get workers for the specified model O(1), filtered by connection mode
// Get workers for the specified model O(1), filtered by connection mode
let
workers
=
self
.worker_registry
.get_workers_filtered
(
let
workers
=
self
.worker_registry
.get_workers_filtered
(
model_id
,
effective_
model_id
,
Some
(
WorkerType
::
Regular
),
Some
(
WorkerType
::
Regular
),
Some
(
ConnectionMode
::
Http
),
Some
(
ConnectionMode
::
Http
),
false
,
// get all workers, we'll filter by is_available() next
false
,
// get all workers, we'll filter by is_available() next
...
@@ -1106,6 +1110,7 @@ mod tests {
...
@@ -1106,6 +1110,7 @@ mod tests {
retry_config
:
RetryConfig
::
default
(),
retry_config
:
RetryConfig
::
default
(),
_
worker_loads
:
Arc
::
new
(
rx
),
_
worker_loads
:
Arc
::
new
(
rx
),
_
load_monitor_handle
:
None
,
_
load_monitor_handle
:
None
,
enable_igw
:
false
,
}
}
}
}
...
...
sgl-router/src/routers/router_manager.rs
View file @
89971c4c
...
@@ -4,12 +4,14 @@
...
@@ -4,12 +4,14 @@
//! - Single Router Mode (enable_igw=false): Router owns workers directly
//! - Single Router Mode (enable_igw=false): Router owns workers directly
//! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything
//! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything
use
crate
::
core
::{
Worker
,
WorkerRegistry
,
WorkerType
};
use
crate
::
config
::{
ConnectionMode
,
RoutingMode
};
use
crate
::
core
::{
WorkerRegistry
,
WorkerType
};
use
crate
::
protocols
::
spec
::{
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
CompletionRequest
,
EmbeddingRequest
,
GenerateRequest
,
RerankRequest
,
ChatCompletionRequest
,
CompletionRequest
,
EmbeddingRequest
,
GenerateRequest
,
RerankRequest
,
ResponsesRequest
,
ResponsesRequest
,
};
};
use
crate
::
routers
::
RouterTrait
;
use
crate
::
routers
::
RouterTrait
;
use
crate
::
server
::{
AppContext
,
ServerConfig
};
use
async_trait
::
async_trait
;
use
async_trait
::
async_trait
;
use
axum
::{
use
axum
::{
body
::
Body
,
body
::
Body
,
...
@@ -19,9 +21,8 @@ use axum::{
...
@@ -19,9 +21,8 @@ use axum::{
};
};
use
dashmap
::
DashMap
;
use
dashmap
::
DashMap
;
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
use
tracing
::
info
;
use
tracing
::
{
debug
,
info
,
warn
}
;
/// Router identifier
#[derive(Debug,
Clone,
Hash,
Eq,
PartialEq)]
#[derive(Debug,
Clone,
Hash,
Eq,
PartialEq)]
pub
struct
RouterId
(
String
);
pub
struct
RouterId
(
String
);
...
@@ -35,30 +36,120 @@ impl RouterId {
...
@@ -35,30 +36,120 @@ impl RouterId {
}
}
}
}
/// Router Manager - Central coordinator for routers and workers
pub
struct
RouterManager
{
pub
struct
RouterManager
{
/// Worker registry (single source of truth in multi-router mode)
worker_registry
:
Arc
<
WorkerRegistry
>
,
worker_registry
:
Arc
<
WorkerRegistry
>
,
/// All routers managed by this manager
/// RouterId examples: "http-regular", "http-pd", "grpc-regular", "grpc-pd"
routers
:
Arc
<
DashMap
<
RouterId
,
Arc
<
dyn
RouterTrait
>>>
,
routers
:
Arc
<
DashMap
<
RouterId
,
Arc
<
dyn
RouterTrait
>>>
,
/// Default router for requests without specific routing
default_router
:
Arc
<
std
::
sync
::
RwLock
<
Option
<
RouterId
>>>
,
default_router
:
Arc
<
std
::
sync
::
RwLock
<
Option
<
RouterId
>>>
,
enable_igw
:
bool
,
}
}
impl
RouterManager
{
impl
RouterManager
{
/// Create a new router manager with shared registries
pub
fn
new
(
worker_registry
:
Arc
<
WorkerRegistry
>
)
->
Self
{
pub
fn
new
(
worker_registry
:
Arc
<
WorkerRegistry
>
)
->
Self
{
Self
{
Self
{
worker_registry
,
worker_registry
,
routers
:
Arc
::
new
(
DashMap
::
new
()),
routers
:
Arc
::
new
(
DashMap
::
new
()),
default_router
:
Arc
::
new
(
std
::
sync
::
RwLock
::
new
(
None
)),
default_router
:
Arc
::
new
(
std
::
sync
::
RwLock
::
new
(
None
)),
enable_igw
:
false
,
// Will be set properly in from_config
}
}
pub
async
fn
from_config
(
config
:
&
ServerConfig
,
app_context
:
&
Arc
<
AppContext
>
,
)
->
Result
<
Arc
<
Self
>
,
String
>
{
use
crate
::
routers
::
RouterFactory
;
let
mut
manager
=
Self
::
new
(
app_context
.worker_registry
.clone
());
manager
.enable_igw
=
config
.router_config.enable_igw
;
let
manager
=
Arc
::
new
(
manager
);
if
config
.router_config.enable_igw
{
info!
(
"Initializing RouterManager in multi-router mode (IGW)"
);
match
RouterFactory
::
create_regular_router
(
app_context
)
.await
{
Ok
(
http_regular
)
=>
{
info!
(
"Created HTTP Regular router"
);
manager
.register_router
(
RouterId
::
new
(
"http-regular"
.to_string
()),
Arc
::
from
(
http_regular
),
);
}
Err
(
e
)
=>
{
warn!
(
"Failed to create HTTP Regular router: {e}"
);
}
}
match
RouterFactory
::
create_pd_router
(
None
,
None
,
&
config
.router_config.policy
,
app_context
,
)
.await
{
Ok
(
http_pd
)
=>
{
info!
(
"Created HTTP PD router"
);
manager
.register_router
(
RouterId
::
new
(
"http-pd"
.to_string
()),
Arc
::
from
(
http_pd
));
}
Err
(
e
)
=>
{
warn!
(
"Failed to create HTTP PD router: {e}"
);
}
}
// TODO: Add gRPC routers once we have dynamic tokenizer loading
info!
(
"RouterManager initialized with {} routers for multi-router mode"
,
manager
.router_count
()
);
}
else
{
info!
(
"Initializing RouterManager in single-router mode"
);
let
single_router
=
Arc
::
from
(
RouterFactory
::
create_router
(
app_context
)
.await
?
);
let
router_id
=
Self
::
determine_router_id
(
&
config
.router_config.mode
,
&
config
.router_config.connection_mode
,
);
info!
(
"Created single router with ID: {}"
,
router_id
.as_str
());
manager
.register_router
(
router_id
.clone
(),
single_router
);
manager
.set_default_router
(
router_id
);
}
if
manager
.router_count
()
==
0
{
return
Err
(
"No routers could be initialized"
.to_string
());
}
Ok
(
manager
)
}
pub
fn
determine_router_id
(
routing_mode
:
&
RoutingMode
,
connection_mode
:
&
ConnectionMode
,
)
->
RouterId
{
match
(
connection_mode
,
routing_mode
)
{
(
ConnectionMode
::
Http
,
RoutingMode
::
Regular
{
..
})
=>
{
RouterId
::
new
(
"http-regular"
.to_string
())
}
(
ConnectionMode
::
Http
,
RoutingMode
::
PrefillDecode
{
..
})
=>
{
RouterId
::
new
(
"http-pd"
.to_string
())
}
(
ConnectionMode
::
Http
,
RoutingMode
::
OpenAI
{
..
})
=>
{
RouterId
::
new
(
"http-openai"
.to_string
())
}
(
ConnectionMode
::
Grpc
,
RoutingMode
::
Regular
{
..
})
=>
{
RouterId
::
new
(
"grpc-regular"
.to_string
())
}
(
ConnectionMode
::
Grpc
,
RoutingMode
::
PrefillDecode
{
..
})
=>
{
RouterId
::
new
(
"grpc-pd"
.to_string
())
}
(
ConnectionMode
::
Grpc
,
RoutingMode
::
OpenAI
{
..
})
=>
{
RouterId
::
new
(
"grpc-regular"
.to_string
())
}
}
}
}
}
/// Register a router with the manager
pub
fn
register_router
(
&
self
,
id
:
RouterId
,
router
:
Arc
<
dyn
RouterTrait
>
)
{
pub
fn
register_router
(
&
self
,
id
:
RouterId
,
router
:
Arc
<
dyn
RouterTrait
>
)
{
self
.routers
.insert
(
id
.clone
(),
router
);
self
.routers
.insert
(
id
.clone
(),
router
);
...
@@ -69,18 +160,15 @@ impl RouterManager {
...
@@ -69,18 +160,15 @@ impl RouterManager {
}
}
}
}
/// Set the default router
pub
fn
set_default_router
(
&
self
,
id
:
RouterId
)
{
pub
fn
set_default_router
(
&
self
,
id
:
RouterId
)
{
let
mut
default_router
=
self
.default_router
.write
()
.unwrap
();
let
mut
default_router
=
self
.default_router
.write
()
.unwrap
();
*
default_router
=
Some
(
id
);
*
default_router
=
Some
(
id
);
}
}
/// Get the number of registered routers
pub
fn
router_count
(
&
self
)
->
usize
{
pub
fn
router_count
(
&
self
)
->
usize
{
self
.routers
.len
()
self
.routers
.len
()
}
}
/// Get router for a specific model based on worker types
pub
fn
get_router_for_model
(
&
self
,
model_id
:
&
str
)
->
Option
<
Arc
<
dyn
RouterTrait
>>
{
pub
fn
get_router_for_model
(
&
self
,
model_id
:
&
str
)
->
Option
<
Arc
<
dyn
RouterTrait
>>
{
let
workers
=
self
.worker_registry
.get_by_model
(
model_id
);
let
workers
=
self
.worker_registry
.get_by_model
(
model_id
);
...
@@ -111,21 +199,25 @@ impl RouterManager {
...
@@ -111,21 +199,25 @@ impl RouterManager {
}
}
}
}
/// Get workers for routing decision
pub
fn
get_workers_for_request
(
&
self
,
model_id
:
Option
<&
str
>
)
->
Vec
<
Arc
<
dyn
Worker
>>
{
if
let
Some
(
model
)
=
model_id
{
self
.worker_registry
.get_by_model
(
model
)
}
else
{
self
.worker_registry
.get_all
()
}
}
/// Get the appropriate router for a request based on headers and request content
pub
fn
select_router_for_request
(
pub
fn
select_router_for_request
(
&
self
,
&
self
,
headers
:
Option
<&
HeaderMap
>
,
headers
:
Option
<&
HeaderMap
>
,
model_id
:
Option
<&
str
>
,
model_id
:
Option
<&
str
>
,
)
->
Option
<
Arc
<
dyn
RouterTrait
>>
{
)
->
Option
<
Arc
<
dyn
RouterTrait
>>
{
// In single-router mode (enable_igw=false), always use the default router
if
!
self
.enable_igw
{
let
default_router
=
self
.default_router
.read
()
.unwrap
();
if
let
Some
(
ref
default_id
)
=
*
default_router
{
debug!
(
"Single-router mode: using default router {} for model {:?}"
,
default_id
.as_str
(),
model_id
);
return
self
.routers
.get
(
default_id
)
.map
(|
r
|
r
.clone
());
}
}
// Multi-router mode logic follows
let
_
priority_threshold
=
headers
.and_then
(|
h
|
{
let
_
priority_threshold
=
headers
.and_then
(|
h
|
{
h
.get
(
"x-worker-priority"
)
h
.get
(
"x-worker-priority"
)
.and_then
(|
v
|
v
.to_str
()
.ok
())
.and_then
(|
v
|
v
.to_str
()
.ok
())
...
@@ -176,10 +268,6 @@ impl RouterManager {
...
@@ -176,10 +268,6 @@ impl RouterManager {
score
+=
1.0
;
score
+=
1.0
;
}
}
// Get workers for this router and evaluate based on priority/cost
// Note: This would require routers to expose their workers or stats
// For now, we'll use a simple selection based on router type
// TODO: Once routers expose worker stats, we can evaluate:
// TODO: Once routers expose worker stats, we can evaluate:
// - Average worker priority vs priority_threshold
// - Average worker priority vs priority_threshold
// - Average worker cost vs max_cost
// - Average worker cost vs max_cost
...
@@ -201,16 +289,11 @@ impl RouterTrait for RouterManager {
...
@@ -201,16 +289,11 @@ impl RouterTrait for RouterManager {
self
self
}
}
/// Health check - return 503 if no routers available
async
fn
health
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
async
fn
health
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
// Health check should succeed if RouterManager exists, even without routers
// Individual router health can be checked via specific endpoints
(
StatusCode
::
OK
,
"RouterManager is healthy"
)
.into_response
()
(
StatusCode
::
OK
,
"RouterManager is healthy"
)
.into_response
()
}
}
/// Health generate - check if any router can handle generate requests
async
fn
health_generate
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
async
fn
health_generate
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
// Return 503 since we have no routers with workers
// TODO: Should check if any router has healthy workers
// TODO: Should check if any router has healthy workers
(
(
StatusCode
::
SERVICE_UNAVAILABLE
,
StatusCode
::
SERVICE_UNAVAILABLE
,
...
@@ -219,10 +302,8 @@ impl RouterTrait for RouterManager {
...
@@ -219,10 +302,8 @@ impl RouterTrait for RouterManager {
.into_response
()
.into_response
()
}
}
/// Get server information - aggregate from all routers
async
fn
get_server_info
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
async
fn
get_server_info
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
// TODO: Aggregate info from all routers with healthy workers
// TODO: Aggregate info from all routers with healthy workers
// For now, return basic info about the RouterManager
(
(
StatusCode
::
OK
,
StatusCode
::
OK
,
serde_json
::
json!
({
serde_json
::
json!
({
...
@@ -235,9 +316,7 @@ impl RouterTrait for RouterManager {
...
@@ -235,9 +316,7 @@ impl RouterTrait for RouterManager {
.into_response
()
.into_response
()
}
}
/// Get available models - query from worker registry
async
fn
get_models
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
async
fn
get_models
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
// Get models from worker registry
let
models
=
self
.worker_registry
.get_models
();
let
models
=
self
.worker_registry
.get_models
();
if
models
.is_empty
()
{
if
models
.is_empty
()
{
...
@@ -254,10 +333,8 @@ impl RouterTrait for RouterManager {
...
@@ -254,10 +333,8 @@ impl RouterTrait for RouterManager {
}
}
}
}
/// Get model information
async
fn
get_model_info
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
async
fn
get_model_info
(
&
self
,
_
req
:
Request
<
Body
>
)
->
Response
{
// TODO: Extract model from request and route to appropriate router
// TODO: Extract model from request and route to appropriate router
// For now, return not implemented
(
(
StatusCode
::
NOT_IMPLEMENTED
,
StatusCode
::
NOT_IMPLEMENTED
,
"Model info endpoint not yet implemented in RouterManager"
,
"Model info endpoint not yet implemented in RouterManager"
,
...
@@ -265,22 +342,17 @@ impl RouterTrait for RouterManager {
...
@@ -265,22 +342,17 @@ impl RouterTrait for RouterManager {
.into_response
()
.into_response
()
}
}
/// Route a generate request
async
fn
route_generate
(
async
fn
route_generate
(
&
self
,
&
self
,
headers
:
Option
<&
HeaderMap
>
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
GenerateRequest
,
body
:
&
GenerateRequest
,
_
model_id
:
Option
<&
str
>
,
_
model_id
:
Option
<&
str
>
,
)
->
Response
{
)
->
Response
{
// Select router based on headers
// GenerateRequest doesn't have a model field
let
router
=
self
.select_router_for_request
(
headers
,
None
);
let
router
=
self
.select_router_for_request
(
headers
,
None
);
if
let
Some
(
router
)
=
router
{
if
let
Some
(
router
)
=
router
{
// In multi-model mode, pass None since GenerateRequest doesn't have model field
router
.route_generate
(
headers
,
body
,
None
)
.await
router
.route_generate
(
headers
,
body
,
None
)
.await
}
else
{
}
else
{
// Return 404 when no router is available for the request
(
(
StatusCode
::
NOT_FOUND
,
StatusCode
::
NOT_FOUND
,
"No router available for this request"
,
"No router available for this request"
,
...
@@ -289,7 +361,6 @@ impl RouterTrait for RouterManager {
...
@@ -289,7 +361,6 @@ impl RouterTrait for RouterManager {
}
}
}
}
/// Route a chat completion request
async
fn
route_chat
(
async
fn
route_chat
(
&
self
,
&
self
,
headers
:
Option
<&
HeaderMap
>
,
headers
:
Option
<&
HeaderMap
>
,
...
@@ -299,10 +370,8 @@ impl RouterTrait for RouterManager {
...
@@ -299,10 +370,8 @@ impl RouterTrait for RouterManager {
let
router
=
self
.select_router_for_request
(
headers
,
Some
(
&
body
.model
));
let
router
=
self
.select_router_for_request
(
headers
,
Some
(
&
body
.model
));
if
let
Some
(
router
)
=
router
{
if
let
Some
(
router
)
=
router
{
// In multi-model mode, pass the model_id to the router
router
.route_chat
(
headers
,
body
,
Some
(
&
body
.model
))
.await
router
.route_chat
(
headers
,
body
,
Some
(
&
body
.model
))
.await
}
else
{
}
else
{
// Return 404 when the specified model is not found
(
(
StatusCode
::
NOT_FOUND
,
StatusCode
::
NOT_FOUND
,
format!
(
"Model '{}' not found or no router available"
,
body
.model
),
format!
(
"Model '{}' not found or no router available"
,
body
.model
),
...
@@ -311,7 +380,6 @@ impl RouterTrait for RouterManager {
...
@@ -311,7 +380,6 @@ impl RouterTrait for RouterManager {
}
}
}
}
/// Route a completion request
async
fn
route_completion
(
async
fn
route_completion
(
&
self
,
&
self
,
headers
:
Option
<&
HeaderMap
>
,
headers
:
Option
<&
HeaderMap
>
,
...
@@ -321,12 +389,10 @@ impl RouterTrait for RouterManager {
...
@@ -321,12 +389,10 @@ impl RouterTrait for RouterManager {
let
router
=
self
.select_router_for_request
(
headers
,
Some
(
&
body
.model
));
let
router
=
self
.select_router_for_request
(
headers
,
Some
(
&
body
.model
));
if
let
Some
(
router
)
=
router
{
if
let
Some
(
router
)
=
router
{
// In multi-model mode, pass the model_id to the router
router
router
.route_completion
(
headers
,
body
,
Some
(
&
body
.model
))
.route_completion
(
headers
,
body
,
Some
(
&
body
.model
))
.await
.await
}
else
{
}
else
{
// Return 404 when the specified model is not found
(
(
StatusCode
::
NOT_FOUND
,
StatusCode
::
NOT_FOUND
,
format!
(
"Model '{}' not found or no router available"
,
body
.model
),
format!
(
"Model '{}' not found or no router available"
,
body
.model
),
...
@@ -348,26 +414,6 @@ impl RouterTrait for RouterManager {
...
@@ -348,26 +414,6 @@ impl RouterTrait for RouterManager {
.into_response
()
.into_response
()
}
}
async
fn
delete_response
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
response_id
:
&
str
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
,
"responses api not yet implemented in inference gateway mode"
,
)
.into_response
()
}
async
fn
list_response_input_items
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
response_id
:
&
str
,
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
,
"responses api not yet implemented in inference gateway mode"
,
)
.into_response
()
}
async
fn
get_response
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
response_id
:
&
str
)
->
Response
{
async
fn
get_response
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
response_id
:
&
str
)
->
Response
{
let
router
=
self
.select_router_for_request
(
headers
,
None
);
let
router
=
self
.select_router_for_request
(
headers
,
None
);
if
let
Some
(
router
)
=
router
{
if
let
Some
(
router
)
=
router
{
...
@@ -394,7 +440,26 @@ impl RouterTrait for RouterManager {
...
@@ -394,7 +440,26 @@ impl RouterTrait for RouterManager {
}
}
}
}
/// Route embeddings request
async
fn
delete_response
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
response_id
:
&
str
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
,
"responses api not yet implemented in inference gateway mode"
,
)
.into_response
()
}
async
fn
list_response_input_items
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
response_id
:
&
str
,
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
,
"responses api not yet implemented in inference gateway mode"
,
)
.into_response
()
}
async
fn
route_embeddings
(
async
fn
route_embeddings
(
&
self
,
&
self
,
headers
:
Option
<&
HeaderMap
>
,
headers
:
Option
<&
HeaderMap
>
,
...
@@ -408,7 +473,6 @@ impl RouterTrait for RouterManager {
...
@@ -408,7 +473,6 @@ impl RouterTrait for RouterManager {
.route_embeddings
(
headers
,
body
,
Some
(
&
body
.model
))
.route_embeddings
(
headers
,
body
,
Some
(
&
body
.model
))
.await
.await
}
else
{
}
else
{
// Return 404 when the specified model is not found
(
(
StatusCode
::
NOT_FOUND
,
StatusCode
::
NOT_FOUND
,
format!
(
"Model '{}' not found or no router available"
,
body
.model
),
format!
(
"Model '{}' not found or no router available"
,
body
.model
),
...
@@ -417,14 +481,12 @@ impl RouterTrait for RouterManager {
...
@@ -417,14 +481,12 @@ impl RouterTrait for RouterManager {
}
}
}
}
/// Route rerank request
async
fn
route_rerank
(
async
fn
route_rerank
(
&
self
,
&
self
,
headers
:
Option
<&
HeaderMap
>
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
RerankRequest
,
body
:
&
RerankRequest
,
model_id
:
Option
<&
str
>
,
model_id
:
Option
<&
str
>
,
)
->
Response
{
)
->
Response
{
// Try to select a router based on headers
let
router
=
self
.select_router_for_request
(
headers
,
None
);
let
router
=
self
.select_router_for_request
(
headers
,
None
);
if
let
Some
(
router
)
=
router
{
if
let
Some
(
router
)
=
router
{
...
@@ -438,10 +500,8 @@ impl RouterTrait for RouterManager {
...
@@ -438,10 +500,8 @@ impl RouterTrait for RouterManager {
}
}
}
}
/// Flush cache on all routers and workers
async
fn
flush_cache
(
&
self
)
->
Response
{
async
fn
flush_cache
(
&
self
)
->
Response
{
// TODO: Call flush_cache on all routers that have workers
// TODO: Call flush_cache on all routers that have workers
// For now, return success if we have any routers
if
self
.routers
.is_empty
()
{
if
self
.routers
.is_empty
()
{
(
StatusCode
::
SERVICE_UNAVAILABLE
,
"No routers configured"
)
.into_response
()
(
StatusCode
::
SERVICE_UNAVAILABLE
,
"No routers configured"
)
.into_response
()
}
else
{
}
else
{
...
@@ -450,9 +510,7 @@ impl RouterTrait for RouterManager {
...
@@ -450,9 +510,7 @@ impl RouterTrait for RouterManager {
}
}
}
}
/// Get worker loads from all routers
async
fn
get_worker_loads
(
&
self
)
->
Response
{
async
fn
get_worker_loads
(
&
self
)
->
Response
{
// Return worker loads from the registry
let
workers
=
self
.worker_registry
.get_all
();
let
workers
=
self
.worker_registry
.get_all
();
let
loads
:
Vec
<
serde_json
::
Value
>
=
workers
let
loads
:
Vec
<
serde_json
::
Value
>
=
workers
.iter
()
.iter
()
...
@@ -476,12 +534,10 @@ impl RouterTrait for RouterManager {
...
@@ -476,12 +534,10 @@ impl RouterTrait for RouterManager {
.into_response
()
.into_response
()
}
}
/// Get router type name
fn
router_type
(
&
self
)
->
&
'static
str
{
fn
router_type
(
&
self
)
->
&
'static
str
{
"manager"
"manager"
}
}
/// Server readiness check - check if any router is ready
fn
readiness
(
&
self
)
->
Response
{
fn
readiness
(
&
self
)
->
Response
{
if
self
.routers
.is_empty
()
{
if
self
.routers
.is_empty
()
{
(
StatusCode
::
SERVICE_UNAVAILABLE
,
"No routers configured"
)
.into_response
()
(
StatusCode
::
SERVICE_UNAVAILABLE
,
"No routers configured"
)
.into_response
()
...
@@ -492,9 +548,6 @@ impl RouterTrait for RouterManager {
...
@@ -492,9 +548,6 @@ impl RouterTrait for RouterManager {
}
}
}
}
// Note: get_first_available_router removed - we now properly handle
// router selection based on model and worker availability
impl
std
::
fmt
::
Debug
for
RouterManager
{
impl
std
::
fmt
::
Debug
for
RouterManager
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
f
.debug_struct
(
"RouterManager"
)
f
.debug_struct
(
"RouterManager"
)
...
...
sgl-router/src/server.rs
View file @
89971c4c
...
@@ -14,10 +14,7 @@ use crate::{
...
@@ -14,10 +14,7 @@ use crate::{
worker_spec
::{
WorkerApiResponse
,
WorkerConfigRequest
,
WorkerErrorResponse
},
worker_spec
::{
WorkerApiResponse
,
WorkerConfigRequest
,
WorkerErrorResponse
},
},
},
reasoning_parser
::
ParserFactory
,
reasoning_parser
::
ParserFactory
,
routers
::{
routers
::{
router_manager
::
RouterManager
,
RouterTrait
},
router_manager
::{
RouterId
,
RouterManager
},
RouterFactory
,
RouterTrait
,
},
service_discovery
::{
start_service_discovery
,
ServiceDiscoveryConfig
},
service_discovery
::{
start_service_discovery
,
ServiceDiscoveryConfig
},
tokenizer
::{
factory
as
tokenizer_factory
,
traits
::
Tokenizer
},
tokenizer
::{
factory
as
tokenizer_factory
,
traits
::
Tokenizer
},
tool_parser
::
ParserRegistry
,
tool_parser
::
ParserRegistry
,
...
@@ -64,10 +61,8 @@ impl AppContext {
...
@@ -64,10 +61,8 @@ impl AppContext {
let
rate_limit_tokens
=
rate_limit_tokens_per_second
.unwrap_or
(
max_concurrent_requests
);
let
rate_limit_tokens
=
rate_limit_tokens_per_second
.unwrap_or
(
max_concurrent_requests
);
let
rate_limiter
=
Arc
::
new
(
TokenBucket
::
new
(
max_concurrent_requests
,
rate_limit_tokens
));
let
rate_limiter
=
Arc
::
new
(
TokenBucket
::
new
(
max_concurrent_requests
,
rate_limit_tokens
));
// Initialize gRPC-specific components only when in gRPC mode
let
(
tokenizer
,
reasoning_parser_factory
,
tool_parser_registry
)
=
let
(
tokenizer
,
reasoning_parser_factory
,
tool_parser_registry
)
=
if
router_config
.connection_mode
==
ConnectionMode
::
Grpc
{
if
router_config
.connection_mode
==
ConnectionMode
::
Grpc
{
// Get tokenizer path (required for gRPC mode)
let
tokenizer_path
=
router_config
let
tokenizer_path
=
router_config
.tokenizer_path
.tokenizer_path
.clone
()
.clone
()
...
@@ -77,7 +72,6 @@ impl AppContext {
...
@@ -77,7 +72,6 @@ impl AppContext {
.to_string
()
.to_string
()
})
?
;
})
?
;
// Initialize all gRPC components
let
tokenizer
=
Some
(
let
tokenizer
=
Some
(
tokenizer_factory
::
create_tokenizer
(
&
tokenizer_path
)
tokenizer_factory
::
create_tokenizer
(
&
tokenizer_path
)
.map_err
(|
e
|
format!
(
"Failed to create tokenizer: {e}"
))
?
,
.map_err
(|
e
|
format!
(
"Failed to create tokenizer: {e}"
))
?
,
...
@@ -87,7 +81,6 @@ impl AppContext {
...
@@ -87,7 +81,6 @@ impl AppContext {
(
tokenizer
,
reasoning_parser_factory
,
tool_parser_registry
)
(
tokenizer
,
reasoning_parser_factory
,
tool_parser_registry
)
}
else
{
}
else
{
// HTTP mode doesn't need these components
(
None
,
None
,
None
)
(
None
,
None
,
None
)
};
};
...
@@ -96,7 +89,6 @@ impl AppContext {
...
@@ -96,7 +89,6 @@ impl AppContext {
let
router_manager
=
None
;
let
router_manager
=
None
;
// Initialize response storage based on configuration
let
response_storage
:
SharedResponseStorage
=
match
router_config
.history_backend
{
let
response_storage
:
SharedResponseStorage
=
match
router_config
.history_backend
{
HistoryBackend
::
Memory
=>
Arc
::
new
(
MemoryResponseStorage
::
new
()),
HistoryBackend
::
Memory
=>
Arc
::
new
(
MemoryResponseStorage
::
new
()),
HistoryBackend
::
None
=>
Arc
::
new
(
NoOpResponseStorage
::
new
()),
HistoryBackend
::
None
=>
Arc
::
new
(
NoOpResponseStorage
::
new
()),
...
@@ -125,12 +117,10 @@ pub struct AppState {
...
@@ -125,12 +117,10 @@ pub struct AppState {
pub
router_manager
:
Option
<
Arc
<
RouterManager
>>
,
pub
router_manager
:
Option
<
Arc
<
RouterManager
>>
,
}
}
// Fallback handler for unmatched routes
async
fn
sink_handler
()
->
Response
{
async
fn
sink_handler
()
->
Response
{
StatusCode
::
NOT_FOUND
.into_response
()
StatusCode
::
NOT_FOUND
.into_response
()
}
}
// Health check endpoints
async
fn
liveness
(
State
(
state
):
State
<
Arc
<
AppState
>>
)
->
Response
{
async
fn
liveness
(
State
(
state
):
State
<
Arc
<
AppState
>>
)
->
Response
{
state
.router
.liveness
()
state
.router
.liveness
()
}
}
...
@@ -257,7 +247,6 @@ async fn v1_responses_delete(
...
@@ -257,7 +247,6 @@ async fn v1_responses_delete(
Path
(
response_id
):
Path
<
String
>
,
Path
(
response_id
):
Path
<
String
>
,
headers
:
http
::
HeaderMap
,
headers
:
http
::
HeaderMap
,
)
->
Response
{
)
->
Response
{
// Python server does not support this yet
state
state
.router
.router
.delete_response
(
Some
(
&
headers
),
&
response_id
)
.delete_response
(
Some
(
&
headers
),
&
response_id
)
...
@@ -269,15 +258,12 @@ async fn v1_responses_list_input_items(
...
@@ -269,15 +258,12 @@ async fn v1_responses_list_input_items(
Path
(
response_id
):
Path
<
String
>
,
Path
(
response_id
):
Path
<
String
>
,
headers
:
http
::
HeaderMap
,
headers
:
http
::
HeaderMap
,
)
->
Response
{
)
->
Response
{
// Python server does not support this yet
state
state
.router
.router
.list_response_input_items
(
Some
(
&
headers
),
&
response_id
)
.list_response_input_items
(
Some
(
&
headers
),
&
response_id
)
.await
.await
}
}
// ---------- Worker management endpoints (Legacy) ----------
#[derive(Deserialize)]
#[derive(Deserialize)]
struct
AddWorkerQuery
{
struct
AddWorkerQuery
{
url
:
String
,
url
:
String
,
...
@@ -288,7 +274,6 @@ async fn add_worker(
...
@@ -288,7 +274,6 @@ async fn add_worker(
State
(
state
):
State
<
Arc
<
AppState
>>
,
State
(
state
):
State
<
Arc
<
AppState
>>
,
Query
(
AddWorkerQuery
{
url
,
api_key
}):
Query
<
AddWorkerQuery
>
,
Query
(
AddWorkerQuery
{
url
,
api_key
}):
Query
<
AddWorkerQuery
>
,
)
->
Response
{
)
->
Response
{
// Use centralized WorkerManager with full context
let
result
=
WorkerManager
::
add_worker
(
&
url
,
&
api_key
,
&
state
.context
)
.await
;
let
result
=
WorkerManager
::
add_worker
(
&
url
,
&
api_key
,
&
state
.context
)
.await
;
match
result
{
match
result
{
...
@@ -298,7 +283,6 @@ async fn add_worker(
...
@@ -298,7 +283,6 @@ async fn add_worker(
}
}
async
fn
list_workers
(
State
(
state
):
State
<
Arc
<
AppState
>>
)
->
Response
{
async
fn
list_workers
(
State
(
state
):
State
<
Arc
<
AppState
>>
)
->
Response
{
// Use centralized WorkerManager instead of router's get_worker_urls
let
worker_list
=
WorkerManager
::
get_worker_urls
(
&
state
.context.worker_registry
);
let
worker_list
=
WorkerManager
::
get_worker_urls
(
&
state
.context.worker_registry
);
Json
(
json!
({
"urls"
:
worker_list
}))
.into_response
()
Json
(
json!
({
"urls"
:
worker_list
}))
.into_response
()
}
}
...
@@ -307,7 +291,6 @@ async fn remove_worker(
...
@@ -307,7 +291,6 @@ async fn remove_worker(
State
(
state
):
State
<
Arc
<
AppState
>>
,
State
(
state
):
State
<
Arc
<
AppState
>>
,
Query
(
AddWorkerQuery
{
url
,
..
}):
Query
<
AddWorkerQuery
>
,
Query
(
AddWorkerQuery
{
url
,
..
}):
Query
<
AddWorkerQuery
>
,
)
->
Response
{
)
->
Response
{
// Use centralized WorkerManager with full context
let
result
=
WorkerManager
::
remove_worker
(
&
url
,
&
state
.context
);
let
result
=
WorkerManager
::
remove_worker
(
&
url
,
&
state
.context
);
match
result
{
match
result
{
...
@@ -324,14 +307,10 @@ async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Respons
...
@@ -324,14 +307,10 @@ async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Respons
state
.router
.get_worker_loads
()
.await
state
.router
.get_worker_loads
()
.await
}
}
// ---------- Worker management endpoints (RESTful) ----------
/// POST /workers - Add a new worker with full configuration
async
fn
create_worker
(
async
fn
create_worker
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
State
(
state
):
State
<
Arc
<
AppState
>>
,
Json
(
config
):
Json
<
WorkerConfigRequest
>
,
Json
(
config
):
Json
<
WorkerConfigRequest
>
,
)
->
Response
{
)
->
Response
{
// In single router mode, use centralized WorkerManager with full context
let
result
=
WorkerManager
::
add_worker_from_config
(
&
config
,
&
state
.context
)
.await
;
let
result
=
WorkerManager
::
add_worker_from_config
(
&
config
,
&
state
.context
)
.await
;
match
result
{
match
result
{
...
@@ -353,9 +332,7 @@ async fn create_worker(
...
@@ -353,9 +332,7 @@ async fn create_worker(
}
}
}
}
/// GET /workers - List all workers with details
async
fn
list_workers_rest
(
State
(
state
):
State
<
Arc
<
AppState
>>
)
->
Response
{
async
fn
list_workers_rest
(
State
(
state
):
State
<
Arc
<
AppState
>>
)
->
Response
{
// In single router mode, get detailed worker info from registry
let
workers
=
state
.context.worker_registry
.get_all
();
let
workers
=
state
.context.worker_registry
.get_all
();
let
response
=
serde_json
::
json!
({
let
response
=
serde_json
::
json!
({
"workers"
:
workers
.iter
()
.map
(|
worker
|
{
"workers"
:
workers
.iter
()
.map
(|
worker
|
{
...
@@ -374,7 +351,6 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
...
@@ -374,7 +351,6 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
"cost"
:
worker
.cost
(),
"cost"
:
worker
.cost
(),
});
});
// Add bootstrap_port for Prefill workers
if
let
WorkerType
::
Prefill
{
bootstrap_port
}
=
worker
.worker_type
()
{
if
let
WorkerType
::
Prefill
{
bootstrap_port
}
=
worker
.worker_type
()
{
worker_info
[
"bootstrap_port"
]
=
serde_json
::
json!
(
bootstrap_port
);
worker_info
[
"bootstrap_port"
]
=
serde_json
::
json!
(
bootstrap_port
);
}
}
...
@@ -391,7 +367,6 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
...
@@ -391,7 +367,6 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
Json
(
response
)
.into_response
()
Json
(
response
)
.into_response
()
}
}
/// GET /workers/{url} - Get specific worker info
async
fn
get_worker
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
Path
(
url
):
Path
<
String
>
)
->
Response
{
async
fn
get_worker
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
Path
(
url
):
Path
<
String
>
)
->
Response
{
let
workers
=
WorkerManager
::
get_worker_urls
(
&
state
.context.worker_registry
);
let
workers
=
WorkerManager
::
get_worker_urls
(
&
state
.context.worker_registry
);
if
workers
.contains
(
&
url
)
{
if
workers
.contains
(
&
url
)
{
...
@@ -410,9 +385,7 @@ async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>)
...
@@ -410,9 +385,7 @@ async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>)
}
}
}
}
/// DELETE /workers/{url} - Remove a worker
async
fn
delete_worker
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
Path
(
url
):
Path
<
String
>
)
->
Response
{
async
fn
delete_worker
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
Path
(
url
):
Path
<
String
>
)
->
Response
{
// In single router mode, use centralized WorkerManager with full context
let
result
=
WorkerManager
::
remove_worker
(
&
url
,
&
state
.context
);
let
result
=
WorkerManager
::
remove_worker
(
&
url
,
&
state
.context
);
match
result
{
match
result
{
...
@@ -447,14 +420,12 @@ pub struct ServerConfig {
...
@@ -447,14 +420,12 @@ pub struct ServerConfig {
pub
request_id_headers
:
Option
<
Vec
<
String
>>
,
pub
request_id_headers
:
Option
<
Vec
<
String
>>
,
}
}
/// Build the Axum application with all routes and middleware
pub
fn
build_app
(
pub
fn
build_app
(
app_state
:
Arc
<
AppState
>
,
app_state
:
Arc
<
AppState
>
,
max_payload_size
:
usize
,
max_payload_size
:
usize
,
request_id_headers
:
Vec
<
String
>
,
request_id_headers
:
Vec
<
String
>
,
cors_allowed_origins
:
Vec
<
String
>
,
cors_allowed_origins
:
Vec
<
String
>
,
)
->
Router
{
)
->
Router
{
// Create routes
let
protected_routes
=
Router
::
new
()
let
protected_routes
=
Router
::
new
()
.route
(
"/generate"
,
post
(
generate
))
.route
(
"/generate"
,
post
(
generate
))
.route
(
"/v1/chat/completions"
,
post
(
v1_chat_completions
))
.route
(
"/v1/chat/completions"
,
post
(
v1_chat_completions
))
...
@@ -494,20 +465,17 @@ pub fn build_app(
...
@@ -494,20 +465,17 @@ pub fn build_app(
.route
(
"/flush_cache"
,
post
(
flush_cache
))
.route
(
"/flush_cache"
,
post
(
flush_cache
))
.route
(
"/get_loads"
,
get
(
get_loads
));
.route
(
"/get_loads"
,
get
(
get_loads
));
// Worker management routes
let
worker_routes
=
Router
::
new
()
let
worker_routes
=
Router
::
new
()
.route
(
"/workers"
,
post
(
create_worker
))
.route
(
"/workers"
,
post
(
create_worker
))
.route
(
"/workers"
,
get
(
list_workers_rest
))
.route
(
"/workers"
,
get
(
list_workers_rest
))
.route
(
"/workers/{url}"
,
get
(
get_worker
))
.route
(
"/workers/{url}"
,
get
(
get_worker
))
.route
(
"/workers/{url}"
,
delete
(
delete_worker
));
.route
(
"/workers/{url}"
,
delete
(
delete_worker
));
// Build app with all routes and middleware
Router
::
new
()
Router
::
new
()
.merge
(
protected_routes
)
.merge
(
protected_routes
)
.merge
(
public_routes
)
.merge
(
public_routes
)
.merge
(
admin_routes
)
.merge
(
admin_routes
)
.merge
(
worker_routes
)
.merge
(
worker_routes
)
// Request body size limiting
.layer
(
tower_http
::
limit
::
RequestBodyLimitLayer
::
new
(
.layer
(
tower_http
::
limit
::
RequestBodyLimitLayer
::
new
(
max_payload_size
,
max_payload_size
,
))
))
...
@@ -519,7 +487,6 @@ pub fn build_app(
...
@@ -519,7 +487,6 @@ pub fn build_app(
}
}
pub
async
fn
startup
(
config
:
ServerConfig
)
->
Result
<
(),
Box
<
dyn
std
::
error
::
Error
>>
{
pub
async
fn
startup
(
config
:
ServerConfig
)
->
Result
<
(),
Box
<
dyn
std
::
error
::
Error
>>
{
// Only initialize logging if not already done (for Python bindings support)
static
LOGGING_INITIALIZED
:
AtomicBool
=
AtomicBool
::
new
(
false
);
static
LOGGING_INITIALIZED
:
AtomicBool
=
AtomicBool
::
new
(
false
);
let
_
log_guard
=
if
!
LOGGING_INITIALIZED
.swap
(
true
,
Ordering
::
SeqCst
)
{
let
_
log_guard
=
if
!
LOGGING_INITIALIZED
.swap
(
true
,
Ordering
::
SeqCst
)
{
...
@@ -545,9 +512,8 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
...
@@ -545,9 +512,8 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
None
None
};
};
// Initialize prometheus metrics exporter
if
let
Some
(
prometheus_config
)
=
&
config
.prometheus_config
{
if
let
Some
(
prometheus_config
)
=
config
.prometheus_config
{
metrics
::
start_prometheus
(
prometheus_config
.clone
());
metrics
::
start_prometheus
(
prometheus_config
);
}
}
info!
(
info!
(
...
@@ -569,7 +535,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
...
@@ -569,7 +535,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
.build
()
.build
()
.expect
(
"Failed to create HTTP client"
);
.expect
(
"Failed to create HTTP client"
);
// Create the application context with all dependencies
let
app_context
=
AppContext
::
new
(
let
app_context
=
AppContext
::
new
(
config
.router_config
.clone
(),
config
.router_config
.clone
(),
client
.clone
(),
client
.clone
(),
...
@@ -597,67 +562,9 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
...
@@ -597,67 +562,9 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
worker_stats
.total_workers
,
worker_stats
.healthy_workers
worker_stats
.total_workers
,
worker_stats
.healthy_workers
);
);
// Create the appropriate router based on enable_igw flag
let
router_manager
=
RouterManager
::
from_config
(
&
config
,
&
app_context
)
.await
?
;
let
(
router
,
router_manager
):
(
Arc
<
dyn
RouterTrait
>
,
Option
<
Arc
<
RouterManager
>>
)
=
let
router
:
Arc
<
dyn
RouterTrait
>
=
router_manager
.clone
();
if
config
.router_config.enable_igw
{
info!
(
"Multi-router mode enabled (enable_igw=true)"
);
// Create RouterManager with shared registries from AppContext
let
router_manager
=
Arc
::
new
(
RouterManager
::
new
(
app_context
.worker_registry
.clone
()));
// 1. HTTP Regular Router
match
RouterFactory
::
create_regular_router
(
&
app_context
)
.await
{
Ok
(
http_regular
)
=>
{
info!
(
"Created HTTP Regular router"
);
router_manager
.register_router
(
RouterId
::
new
(
"http-regular"
.to_string
()),
Arc
::
from
(
http_regular
),
);
}
Err
(
e
)
=>
{
warn!
(
"Failed to create HTTP Regular router: {e}"
);
}
}
// 2. HTTP PD Router
match
RouterFactory
::
create_pd_router
(
None
,
None
,
&
config
.router_config.policy
,
&
app_context
,
)
.await
{
Ok
(
http_pd
)
=>
{
info!
(
"Created HTTP PD router"
);
router_manager
.register_router
(
RouterId
::
new
(
"http-pd"
.to_string
()),
Arc
::
from
(
http_pd
));
}
Err
(
e
)
=>
{
warn!
(
"Failed to create HTTP PD router: {e}"
);
}
}
// TODO: Add gRPC routers once we have dynamic tokenizer loading
info!
(
"RouterManager initialized with {} routers"
,
router_manager
.router_count
()
);
(
router_manager
.clone
()
as
Arc
<
dyn
RouterTrait
>
,
Some
(
router_manager
),
)
}
else
{
info!
(
"Single router mode (enable_igw=false)"
);
// Create single router with the context
(
Arc
::
from
(
RouterFactory
::
create_router
(
&
app_context
)
.await
?
),
None
,
)
};
// Start health checker for all workers in the registry
let
_
health_checker
=
app_context
let
_
health_checker
=
app_context
.worker_registry
.worker_registry
.start_health_checker
(
config
.router_config.health_check.check_interval_secs
);
.start_health_checker
(
config
.router_config.health_check.check_interval_secs
);
...
@@ -666,14 +573,12 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
...
@@ -666,14 +573,12 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
config
.router_config.health_check.check_interval_secs
config
.router_config.health_check.check_interval_secs
);
);
// Set up concurrency limiter with queue if configured
let
(
limiter
,
processor
)
=
middleware
::
ConcurrencyLimiter
::
new
(
let
(
limiter
,
processor
)
=
middleware
::
ConcurrencyLimiter
::
new
(
app_context
.rate_limiter
.clone
(),
app_context
.rate_limiter
.clone
(),
config
.router_config.queue_size
,
config
.router_config.queue_size
,
Duration
::
from_secs
(
config
.router_config.queue_timeout_secs
),
Duration
::
from_secs
(
config
.router_config.queue_timeout_secs
),
);
);
// Start queue processor if enabled
if
let
Some
(
processor
)
=
processor
{
if
let
Some
(
processor
)
=
processor
{
spawn
(
processor
.run
());
spawn
(
processor
.run
());
info!
(
info!
(
...
@@ -682,21 +587,18 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
...
@@ -682,21 +587,18 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
);
);
}
}
// Create app state with router and context
let
app_state
=
Arc
::
new
(
AppState
{
let
app_state
=
Arc
::
new
(
AppState
{
router
,
router
,
context
:
app_context
.clone
(),
context
:
app_context
.clone
(),
concurrency_queue_tx
:
limiter
.queue_tx
.clone
(),
concurrency_queue_tx
:
limiter
.queue_tx
.clone
(),
router_manager
,
router_manager
:
Some
(
router_manager
)
,
});
});
// Start the service discovery if enabled
if
let
Some
(
service_discovery_config
)
=
config
.service_discovery_config
{
if
let
Some
(
service_discovery_config
)
=
config
.service_discovery_config
{
if
service_discovery_config
.enabled
{
if
service_discovery_config
.enabled
{
let
app_context_arc
=
Arc
::
clone
(
&
app_state
.context
);
let
app_context_arc
=
Arc
::
clone
(
&
app_state
.context
);
match
start_service_discovery
(
service_discovery_config
,
app_context_arc
)
.await
{
match
start_service_discovery
(
service_discovery_config
,
app_context_arc
)
.await
{
Ok
(
handle
)
=>
{
Ok
(
handle
)
=>
{
info!
(
"Service discovery started"
);
info!
(
"Service discovery started"
);
// Spawn a task to handle the service discovery thread
spawn
(
async
move
{
spawn
(
async
move
{
if
let
Err
(
e
)
=
handle
.await
{
if
let
Err
(
e
)
=
handle
.await
{
error!
(
"Service discovery task failed: {:?}"
,
e
);
error!
(
"Service discovery task failed: {:?}"
,
e
);
...
@@ -725,7 +627,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
...
@@ -725,7 +627,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
]
]
});
});
// Build the application
let
app
=
build_app
(
let
app
=
build_app
(
app_state
,
app_state
,
config
.max_payload_size
,
config
.max_payload_size
,
...
@@ -744,7 +645,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
...
@@ -744,7 +645,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
Ok
(())
Ok
(())
}
}
// Graceful shutdown handler
async
fn
shutdown_signal
()
{
async
fn
shutdown_signal
()
{
let
ctrl_c
=
async
{
let
ctrl_c
=
async
{
signal
::
ctrl_c
()
signal
::
ctrl_c
()
...
@@ -773,19 +673,16 @@ async fn shutdown_signal() {
...
@@ -773,19 +673,16 @@ async fn shutdown_signal() {
}
}
}
}
// CORS Layer Creation
fn
create_cors_layer
(
allowed_origins
:
Vec
<
String
>
)
->
tower_http
::
cors
::
CorsLayer
{
fn
create_cors_layer
(
allowed_origins
:
Vec
<
String
>
)
->
tower_http
::
cors
::
CorsLayer
{
use
tower_http
::
cors
::
Any
;
use
tower_http
::
cors
::
Any
;
let
cors
=
if
allowed_origins
.is_empty
()
{
let
cors
=
if
allowed_origins
.is_empty
()
{
// Allow all origins if none specified
tower_http
::
cors
::
CorsLayer
::
new
()
tower_http
::
cors
::
CorsLayer
::
new
()
.allow_origin
(
Any
)
.allow_origin
(
Any
)
.allow_methods
(
Any
)
.allow_methods
(
Any
)
.allow_headers
(
Any
)
.allow_headers
(
Any
)
.expose_headers
(
Any
)
.expose_headers
(
Any
)
}
else
{
}
else
{
// Restrict to specific origins
let
origins
:
Vec
<
http
::
HeaderValue
>
=
allowed_origins
let
origins
:
Vec
<
http
::
HeaderValue
>
=
allowed_origins
.into_iter
()
.into_iter
()
.filter_map
(|
origin
|
origin
.parse
()
.ok
())
.filter_map
(|
origin
|
origin
.parse
()
.ok
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment