Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2f173ea0
Unverified
Commit
2f173ea0
authored
Sep 12, 2025
by
Simo Lin
Committed by
GitHub
Sep 12, 2025
Browse files
[router] allow one router to support different model families and serving mode (#10244)
parent
321fecab
Changes
28
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1628 additions
and
212 deletions
+1628
-212
sgl-router/src/routers/http/router.rs
sgl-router/src/routers/http/router.rs
+269
-188
sgl-router/src/routers/mod.rs
sgl-router/src/routers/mod.rs
+16
-3
sgl-router/src/routers/router_manager.rs
sgl-router/src/routers/router_manager.rs
+766
-0
sgl-router/src/server.rs
sgl-router/src/server.rs
+262
-10
sgl-router/src/service_discovery.rs
sgl-router/src/service_discovery.rs
+8
-5
sgl-router/tests/cache_aware_backward_compat_test.rs
sgl-router/tests/cache_aware_backward_compat_test.rs
+129
-0
sgl-router/tests/policy_registry_integration.rs
sgl-router/tests/policy_registry_integration.rs
+168
-0
sgl-router/tests/test_openai_routing.rs
sgl-router/tests/test_openai_routing.rs
+10
-6
No files found.
sgl-router/src/routers/http/router.rs
View file @
2f173ea0
This diff is collapsed.
Click to expand it.
sgl-router/src/routers/mod.rs
View file @
2f173ea0
...
...
@@ -17,6 +17,7 @@ pub mod factory;
pub
mod
grpc
;
pub
mod
header_utils
;
pub
mod
http
;
pub
mod
router_manager
;
pub
use
factory
::
RouterFactory
;
// Re-export HTTP routers for convenience (keeps routers::openai_router path working)
...
...
@@ -63,14 +64,19 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
async
fn
get_model_info
(
&
self
,
req
:
Request
<
Body
>
)
->
Response
;
/// Route a generate request
async
fn
route_generate
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
GenerateRequest
)
->
Response
;
async
fn
route_generate
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
GenerateRequest
,
model_id
:
Option
<&
str
>
,
)
->
Response
;
/// Route a chat completion request
async
fn
route_chat
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
ChatCompletionRequest
,
model_id
:
Option
<&
str
>
,
)
->
Response
;
/// Route a completion request
...
...
@@ -78,6 +84,7 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
CompletionRequest
,
model_id
:
Option
<&
str
>
,
)
->
Response
;
/// Route a responses request
...
...
@@ -85,11 +92,17 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
ResponsesRequest
,
model_id
:
Option
<&
str
>
,
)
->
Response
;
async
fn
route_embeddings
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
Body
)
->
Response
;
async
fn
route_rerank
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
RerankRequest
)
->
Response
;
async
fn
route_rerank
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
RerankRequest
,
model_id
:
Option
<&
str
>
,
)
->
Response
;
/// Flush cache on all workers
async
fn
flush_cache
(
&
self
)
->
Response
;
...
...
sgl-router/src/routers/router_manager.rs
0 → 100644
View file @
2f173ea0
This diff is collapsed.
Click to expand it.
sgl-router/src/server.rs
View file @
2f173ea0
use
crate
::
config
::
RouterConfig
;
use
crate
::
core
::
WorkerRegistry
;
use
crate
::
logging
::{
self
,
LoggingConfig
};
use
crate
::
metrics
::{
self
,
PrometheusConfig
};
use
crate
::
middleware
::
TokenBucket
;
use
crate
::
policies
::
PolicyRegistry
;
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
,
RerankRequest
,
ResponsesRequest
,
V1RerankReqInput
,
};
use
crate
::
protocols
::
worker_spec
::{
WorkerApiResponse
,
WorkerConfigRequest
,
WorkerErrorResponse
};
use
crate
::
reasoning_parser
::
ParserFactory
;
use
crate
::
routers
::
router_manager
::{
RouterId
,
RouterManager
};
use
crate
::
routers
::{
RouterFactory
,
RouterTrait
};
use
crate
::
service_discovery
::{
start_service_discovery
,
ServiceDiscoveryConfig
};
use
crate
::
tokenizer
::{
factory
as
tokenizer_factory
,
traits
::
Tokenizer
};
...
...
@@ -36,6 +40,9 @@ pub struct AppContext {
pub
tokenizer
:
Option
<
Arc
<
dyn
Tokenizer
>>
,
pub
reasoning_parser_factory
:
Option
<
ParserFactory
>
,
pub
tool_parser_registry
:
Option
<&
'static
ParserRegistry
>
,
pub
worker_registry
:
Arc
<
WorkerRegistry
>
,
// Shared worker registry
pub
policy_registry
:
Arc
<
PolicyRegistry
>
,
// Shared policy registry
pub
router_manager
:
Option
<
Arc
<
RouterManager
>>
,
// Only present when enable_igw=true
}
impl
AppContext
{
...
...
@@ -75,6 +82,15 @@ impl AppContext {
(
None
,
None
,
None
)
};
// Initialize shared registries
let
worker_registry
=
Arc
::
new
(
WorkerRegistry
::
new
());
let
policy_registry
=
Arc
::
new
(
PolicyRegistry
::
new
(
router_config
.policy
.clone
(),
// Use default policy from config
));
// Initialize RouterManager only when enable_igw is true
let
router_manager
=
None
;
// Will be initialized in startup() based on config
Ok
(
Self
{
client
,
router_config
,
...
...
@@ -82,6 +98,9 @@ impl AppContext {
tokenizer
,
reasoning_parser_factory
,
tool_parser_registry
,
worker_registry
,
policy_registry
,
router_manager
,
})
}
}
...
...
@@ -134,7 +153,10 @@ async fn generate(
headers
:
http
::
HeaderMap
,
Json
(
body
):
Json
<
GenerateRequest
>
,
)
->
Response
{
state
.router
.route_generate
(
Some
(
&
headers
),
&
body
)
.await
state
.router
.route_generate
(
Some
(
&
headers
),
&
body
,
None
)
.await
}
async
fn
v1_chat_completions
(
...
...
@@ -142,7 +164,7 @@ async fn v1_chat_completions(
headers
:
http
::
HeaderMap
,
Json
(
body
):
Json
<
ChatCompletionRequest
>
,
)
->
Response
{
state
.router
.route_chat
(
Some
(
&
headers
),
&
body
)
.await
state
.router
.route_chat
(
Some
(
&
headers
),
&
body
,
None
)
.await
}
async
fn
v1_completions
(
...
...
@@ -150,7 +172,10 @@ async fn v1_completions(
headers
:
http
::
HeaderMap
,
Json
(
body
):
Json
<
CompletionRequest
>
,
)
->
Response
{
state
.router
.route_completion
(
Some
(
&
headers
),
&
body
)
.await
state
.router
.route_completion
(
Some
(
&
headers
),
&
body
,
None
)
.await
}
async
fn
rerank
(
...
...
@@ -158,7 +183,7 @@ async fn rerank(
headers
:
http
::
HeaderMap
,
Json
(
body
):
Json
<
RerankRequest
>
,
)
->
Response
{
state
.router
.route_rerank
(
Some
(
&
headers
),
&
body
)
.await
state
.router
.route_rerank
(
Some
(
&
headers
),
&
body
,
None
)
.await
}
async
fn
v1_rerank
(
...
...
@@ -168,7 +193,7 @@ async fn v1_rerank(
)
->
Response
{
state
.router
.route_rerank
(
Some
(
&
headers
),
&
body
.into
())
.route_rerank
(
Some
(
&
headers
),
&
body
.into
()
,
None
)
.await
}
...
...
@@ -177,7 +202,10 @@ async fn v1_responses(
headers
:
http
::
HeaderMap
,
Json
(
body
):
Json
<
ResponsesRequest
>
,
)
->
Response
{
state
.router
.route_responses
(
Some
(
&
headers
),
&
body
)
.await
state
.router
.route_responses
(
Some
(
&
headers
),
&
body
,
None
)
.await
}
// Worker management endpoints
...
...
@@ -232,6 +260,137 @@ async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Respons
state
.router
.get_worker_loads
()
.await
}
// New RESTful worker management endpoints (when enable_igw=true)
/// POST /workers - Add a new worker with full configuration
async
fn
create_worker
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
Json
(
config
):
Json
<
WorkerConfigRequest
>
,
)
->
Response
{
// Check if RouterManager is available (enable_igw=true)
if
let
Some
(
router_manager
)
=
&
state
.context.router_manager
{
match
router_manager
.add_worker
(
config
)
.await
{
Ok
(
response
)
=>
(
StatusCode
::
OK
,
Json
(
response
))
.into_response
(),
Err
(
error
)
=>
(
StatusCode
::
BAD_REQUEST
,
Json
(
error
))
.into_response
(),
}
}
else
{
// In single router mode, use the router's add_worker with basic config
match
state
.router
.add_worker
(
&
config
.url
)
.await
{
Ok
(
message
)
=>
{
let
response
=
WorkerApiResponse
{
success
:
true
,
message
,
worker
:
None
,
};
(
StatusCode
::
OK
,
Json
(
response
))
.into_response
()
}
Err
(
error
)
=>
{
let
error_response
=
WorkerErrorResponse
{
error
,
code
:
"ADD_WORKER_FAILED"
.to_string
(),
};
(
StatusCode
::
BAD_REQUEST
,
Json
(
error_response
))
.into_response
()
}
}
}
}
/// GET /workers - List all workers with details
async
fn
list_workers_rest
(
State
(
state
):
State
<
Arc
<
AppState
>>
)
->
Response
{
if
let
Some
(
router_manager
)
=
&
state
.context.router_manager
{
let
response
=
router_manager
.list_workers
();
Json
(
response
)
.into_response
()
}
else
{
// In single router mode, get detailed worker info from registry
let
workers
=
state
.context.worker_registry
.get_all
();
let
response
=
serde_json
::
json!
({
"workers"
:
workers
.iter
()
.map
(|
worker
|
{
let
mut
worker_info
=
serde_json
::
json!
({
"url"
:
worker
.url
(),
"model_id"
:
worker
.model_id
(),
"worker_type"
:
format!
(
"{:?}"
,
worker
.worker_type
()),
"is_healthy"
:
worker
.is_healthy
(),
"load"
:
worker
.load
(),
"connection_mode"
:
format!
(
"{:?}"
,
worker
.connection_mode
()),
"priority"
:
worker
.priority
(),
"cost"
:
worker
.cost
(),
});
// Add bootstrap_port for Prefill workers
if
let
crate
::
core
::
WorkerType
::
Prefill
{
bootstrap_port
}
=
worker
.worker_type
()
{
worker_info
[
"bootstrap_port"
]
=
serde_json
::
json!
(
bootstrap_port
);
}
worker_info
})
.collect
::
<
Vec
<
_
>>
(),
"total"
:
workers
.len
(),
"stats"
:
{
"prefill_count"
:
state
.context.worker_registry
.get_prefill_workers
()
.len
(),
"decode_count"
:
state
.context.worker_registry
.get_decode_workers
()
.len
(),
"regular_count"
:
state
.context.worker_registry
.get_by_type
(
&
crate
::
core
::
WorkerType
::
Regular
)
.len
(),
}
});
Json
(
response
)
.into_response
()
}
}
/// GET /workers/{url} - Get specific worker info
async
fn
get_worker
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
axum
::
extract
::
Path
(
url
):
axum
::
extract
::
Path
<
String
>
,
)
->
Response
{
if
let
Some
(
router_manager
)
=
&
state
.context.router_manager
{
if
let
Some
(
worker
)
=
router_manager
.get_worker
(
&
url
)
{
Json
(
worker
)
.into_response
()
}
else
{
let
error
=
WorkerErrorResponse
{
error
:
format!
(
"Worker {} not found"
,
url
),
code
:
"WORKER_NOT_FOUND"
.to_string
(),
};
(
StatusCode
::
NOT_FOUND
,
Json
(
error
))
.into_response
()
}
}
else
{
// In single router mode, check if worker exists
let
workers
=
state
.router
.get_worker_urls
();
if
workers
.contains
(
&
url
)
{
let
worker_info
=
serde_json
::
json!
({
"url"
:
url
,
"model_id"
:
"unknown"
,
"is_healthy"
:
true
});
Json
(
worker_info
)
.into_response
()
}
else
{
let
error
=
WorkerErrorResponse
{
error
:
format!
(
"Worker {} not found"
,
url
),
code
:
"WORKER_NOT_FOUND"
.to_string
(),
};
(
StatusCode
::
NOT_FOUND
,
Json
(
error
))
.into_response
()
}
}
}
/// DELETE /workers/{url} - Remove a worker
async
fn
delete_worker
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
axum
::
extract
::
Path
(
url
):
axum
::
extract
::
Path
<
String
>
,
)
->
Response
{
if
let
Some
(
router_manager
)
=
&
state
.context.router_manager
{
match
router_manager
.remove_worker_from_registry
(
&
url
)
{
Ok
(
response
)
=>
(
StatusCode
::
OK
,
Json
(
response
))
.into_response
(),
Err
(
error
)
=>
(
StatusCode
::
BAD_REQUEST
,
Json
(
error
))
.into_response
(),
}
}
else
{
// In single router mode, use router's remove_worker
state
.router
.remove_worker
(
&
url
);
let
response
=
WorkerApiResponse
{
success
:
true
,
message
:
format!
(
"Worker {} removed successfully"
,
url
),
worker
:
None
,
};
(
StatusCode
::
OK
,
Json
(
response
))
.into_response
()
}
}
pub
struct
ServerConfig
{
pub
host
:
String
,
pub
port
:
u16
,
...
...
@@ -281,11 +440,19 @@ pub fn build_app(
.route
(
"/flush_cache"
,
post
(
flush_cache
))
.route
(
"/get_loads"
,
get
(
get_loads
));
// Worker management routes
let
worker_routes
=
Router
::
new
()
.route
(
"/workers"
,
post
(
create_worker
))
.route
(
"/workers"
,
get
(
list_workers_rest
))
.route
(
"/workers/{url}"
,
get
(
get_worker
))
.route
(
"/workers/{url}"
,
axum
::
routing
::
delete
(
delete_worker
));
// Build app with all routes and middleware
Router
::
new
()
.merge
(
protected_routes
)
.merge
(
public_routes
)
.merge
(
admin_routes
)
.merge
(
worker_routes
)
// Request body size limiting
.layer
(
tower_http
::
limit
::
RequestBodyLimitLayer
::
new
(
max_payload_size
,
...
...
@@ -355,15 +522,100 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
.expect
(
"Failed to create HTTP client"
);
// Create the application context with all dependencies
let
app_context
=
Arc
::
new
(
AppContext
::
new
(
let
app_context
=
AppContext
::
new
(
config
.router_config
.clone
(),
client
.clone
(),
config
.router_config.max_concurrent_requests
,
config
.router_config.rate_limit_tokens_per_second
,
)
?
);
)
?
;
let
app_context
=
Arc
::
new
(
app_context
);
// Create the appropriate router based on enable_igw flag
let
router
:
Box
<
dyn
RouterTrait
>
=
if
config
.router_config.enable_igw
{
info!
(
"Multi-router mode enabled (enable_igw=true)"
);
// Create RouterManager with shared registries from AppContext
let
mut
router_manager
=
RouterManager
::
new
(
config
.router_config
.clone
(),
client
.clone
(),
app_context
.worker_registry
.clone
(),
app_context
.policy_registry
.clone
(),
);
// Create HTTP routers at startup (with empty worker lists)
// Workers will be added to these routers dynamically via RouterManager's worker registry
// 1. HTTP Regular Router
match
RouterFactory
::
create_regular_router
(
&
[],
// Empty worker list - workers added later
&
app_context
,
)
.await
{
Ok
(
http_regular
)
=>
{
info!
(
"Created HTTP Regular router"
);
router_manager
.register_router
(
RouterId
::
new
(
"http-regular"
.to_string
()),
Arc
::
from
(
http_regular
),
vec!
[],
// Models will be determined by workers
);
}
Err
(
e
)
=>
{
warn!
(
"Failed to create HTTP Regular router: {}"
,
e
);
}
}
// 2. HTTP PD Router
match
RouterFactory
::
create_pd_router
(
&
[],
// Empty prefill URLs
&
[],
// Empty decode URLs
None
,
// Use default prefill policy
None
,
// Use default decode policy
&
config
.router_config.policy
,
&
app_context
,
)
.await
{
Ok
(
http_pd
)
=>
{
info!
(
"Created HTTP PD router"
);
router_manager
.register_router
(
RouterId
::
new
(
"http-pd"
.to_string
()),
Arc
::
from
(
http_pd
),
vec!
[],
);
}
Err
(
e
)
=>
{
warn!
(
"Failed to create HTTP PD router: {}"
,
e
);
}
}
// TODO: Add gRPC routers once we have dynamic tokenizer loading
// Currently gRPC routers require tokenizer to be initialized first,
// but each model needs its own tokenizer. Once we implement dynamic
// tokenizer loading per model, we can enable gRPC routers here:
// - RouterType::GrpcRegular (RouterId: "grpc-regular")
// - RouterType::GrpcPd (RouterId: "grpc-pd")
// Create router with the context
let
router
=
RouterFactory
::
create_router
(
&
app_context
)
.await
?
;
info!
(
"RouterManager initialized with {} routers"
,
router_manager
.router_count
()
);
Box
::
new
(
router_manager
)
}
else
{
info!
(
"Single router mode (enable_igw=false)"
);
// Create single router with the context
RouterFactory
::
create_router
(
&
app_context
)
.await
?
};
// Start health checker for all workers in the registry
let
_
health_checker
=
app_context
.worker_registry
.start_health_checker
(
config
.router_config.health_check.check_interval_secs
);
info!
(
"Started health checker for workers with {}s interval"
,
config
.router_config.health_check.check_interval_secs
);
// Set up concurrency limiter with queue if configured
let
(
limiter
,
processor
)
=
crate
::
middleware
::
ConcurrencyLimiter
::
new
(
...
...
sgl-router/src/service_discovery.rs
View file @
2f173ea0
...
...
@@ -579,9 +579,8 @@ mod tests {
// Helper to create a Router instance for testing event handlers
async
fn
create_test_router
()
->
Arc
<
dyn
RouterTrait
>
{
use
crate
::
config
::
{
PolicyConfig
,
RouterConfig
}
;
use
crate
::
config
::
RouterConfig
;
use
crate
::
middleware
::
TokenBucket
;
use
crate
::
policies
::
PolicyFactory
;
use
crate
::
routers
::
http
::
router
::
Router
;
use
crate
::
server
::
AppContext
;
...
...
@@ -591,15 +590,19 @@ mod tests {
// Create AppContext with minimal components
let
app_context
=
Arc
::
new
(
AppContext
{
client
:
reqwest
::
Client
::
new
(),
router_config
,
router_config
:
router_config
.clone
()
,
rate_limiter
:
Arc
::
new
(
TokenBucket
::
new
(
1000
,
1000
)),
worker_registry
:
Arc
::
new
(
crate
::
core
::
WorkerRegistry
::
new
()),
policy_registry
:
Arc
::
new
(
crate
::
policies
::
PolicyRegistry
::
new
(
router_config
.policy
.clone
(),
)),
tokenizer
:
None
,
// HTTP mode doesn't need tokenizer
reasoning_parser_factory
:
None
,
// HTTP mode doesn't need reasoning parser
tool_parser_registry
:
None
,
// HTTP mode doesn't need tool parser
router_manager
:
None
,
// Test doesn't need router manager
});
let
policy
=
PolicyFactory
::
create_from_config
(
&
PolicyConfig
::
Random
);
let
router
=
Router
::
new
(
vec!
[],
policy
,
&
app_context
)
.await
.unwrap
();
let
router
=
Router
::
new
(
vec!
[],
&
app_context
)
.await
.unwrap
();
Arc
::
new
(
router
)
as
Arc
<
dyn
RouterTrait
>
}
...
...
sgl-router/tests/cache_aware_backward_compat_test.rs
0 → 100644
View file @
2f173ea0
use
sglang_router_rs
::
core
::{
BasicWorker
,
Worker
,
WorkerType
};
use
sglang_router_rs
::
policies
::{
CacheAwareConfig
,
CacheAwarePolicy
,
LoadBalancingPolicy
};
use
std
::
collections
::
HashMap
;
use
std
::
sync
::
Arc
;
#[test]
fn
test_backward_compatibility_with_empty_model_id
()
{
let
config
=
CacheAwareConfig
{
cache_threshold
:
0.5
,
balance_abs_threshold
:
2
,
balance_rel_threshold
:
1.5
,
eviction_interval_secs
:
0
,
// Disable background eviction for testing
max_tree_size
:
100
,
};
let
policy
=
CacheAwarePolicy
::
with_config
(
config
);
// Create workers with empty model_id (simulating existing routers)
let
worker1
=
BasicWorker
::
new
(
"http://worker1:8080"
.to_string
(),
WorkerType
::
Regular
);
// No model_id label - should default to "unknown"
let
mut
labels2
=
HashMap
::
new
();
labels2
.insert
(
"model_id"
.to_string
(),
"unknown"
.to_string
());
let
worker2
=
BasicWorker
::
new
(
"http://worker2:8080"
.to_string
(),
WorkerType
::
Regular
)
.with_labels
(
labels2
);
// Add workers - should both go to "default" tree
policy
.add_worker
(
&
worker1
);
policy
.add_worker
(
&
worker2
);
// Create worker list
let
workers
:
Vec
<
Arc
<
dyn
Worker
>>
=
vec!
[
Arc
::
new
(
worker1
.clone
()),
Arc
::
new
(
worker2
.clone
())];
// Select worker - should work without errors
let
selected
=
policy
.select_worker
(
&
workers
,
Some
(
"test request"
));
assert
!
(
selected
.is_some
(),
"Should select a worker"
);
// Remove workers - should work without errors
policy
.remove_worker
(
&
worker1
);
policy
.remove_worker
(
&
worker2
);
}
#[test]
fn
test_mixed_model_ids
()
{
let
config
=
CacheAwareConfig
{
cache_threshold
:
0.5
,
balance_abs_threshold
:
2
,
balance_rel_threshold
:
1.5
,
eviction_interval_secs
:
0
,
max_tree_size
:
100
,
};
let
policy
=
CacheAwarePolicy
::
with_config
(
config
);
// Create workers with different model_id scenarios
let
worker1
=
BasicWorker
::
new
(
"http://worker1:8080"
.to_string
(),
WorkerType
::
Regular
);
// No model_id label - defaults to "unknown" which goes to "default" tree
let
mut
labels2
=
HashMap
::
new
();
labels2
.insert
(
"model_id"
.to_string
(),
"llama-3"
.to_string
());
let
worker2
=
BasicWorker
::
new
(
"http://worker2:8080"
.to_string
(),
WorkerType
::
Regular
)
.with_labels
(
labels2
);
let
mut
labels3
=
HashMap
::
new
();
labels3
.insert
(
"model_id"
.to_string
(),
"unknown"
.to_string
());
let
worker3
=
BasicWorker
::
new
(
"http://worker3:8080"
.to_string
(),
WorkerType
::
Regular
)
.with_labels
(
labels3
);
let
mut
labels4
=
HashMap
::
new
();
labels4
.insert
(
"model_id"
.to_string
(),
"llama-3"
.to_string
());
let
worker4
=
BasicWorker
::
new
(
"http://worker4:8080"
.to_string
(),
WorkerType
::
Regular
)
.with_labels
(
labels4
);
// Add all workers
policy
.add_worker
(
&
worker1
);
policy
.add_worker
(
&
worker2
);
policy
.add_worker
(
&
worker3
);
policy
.add_worker
(
&
worker4
);
// Test selection with default workers only
let
default_workers
:
Vec
<
Arc
<
dyn
Worker
>>
=
vec!
[
Arc
::
new
(
worker1
.clone
()),
Arc
::
new
(
worker3
.clone
())];
let
selected
=
policy
.select_worker
(
&
default_workers
,
Some
(
"test request"
));
assert
!
(
selected
.is_some
(),
"Should select from default workers"
);
// Test selection with specific model workers only
let
llama_workers
:
Vec
<
Arc
<
dyn
Worker
>>
=
vec!
[
Arc
::
new
(
worker2
.clone
()),
Arc
::
new
(
worker4
.clone
())];
let
selected
=
policy
.select_worker
(
&
llama_workers
,
Some
(
"test request"
));
assert
!
(
selected
.is_some
(),
"Should select from llama-3 workers"
);
// Test selection with mixed workers
let
all_workers
:
Vec
<
Arc
<
dyn
Worker
>>
=
vec!
[
Arc
::
new
(
worker1
.clone
()),
Arc
::
new
(
worker2
.clone
()),
Arc
::
new
(
worker3
.clone
()),
Arc
::
new
(
worker4
.clone
()),
];
let
selected
=
policy
.select_worker
(
&
all_workers
,
Some
(
"test request"
));
assert
!
(
selected
.is_some
(),
"Should select from all workers"
);
}
#[test]
fn
test_remove_worker_by_url_backward_compat
()
{
let
config
=
CacheAwareConfig
::
default
();
let
policy
=
CacheAwarePolicy
::
with_config
(
config
);
// Create workers with different model_ids
let
mut
labels1
=
HashMap
::
new
();
labels1
.insert
(
"model_id"
.to_string
(),
"llama-3"
.to_string
());
let
worker1
=
BasicWorker
::
new
(
"http://worker1:8080"
.to_string
(),
WorkerType
::
Regular
)
.with_labels
(
labels1
);
let
worker2
=
BasicWorker
::
new
(
"http://worker2:8080"
.to_string
(),
WorkerType
::
Regular
);
// No model_id label - defaults to "unknown"
// Add workers
policy
.add_worker
(
&
worker1
);
policy
.add_worker
(
&
worker2
);
// Remove by URL (backward compatibility method)
// Should remove from all trees since we don't know the model
policy
.remove_worker_by_url
(
"http://worker1:8080"
);
// Verify removal worked
let
workers
:
Vec
<
Arc
<
dyn
Worker
>>
=
vec!
[
Arc
::
new
(
worker2
.clone
())];
let
selected
=
policy
.select_worker
(
&
workers
,
Some
(
"test"
));
assert_eq!
(
selected
,
Some
(
0
),
"Should only have worker2 left"
);
}
sgl-router/tests/policy_registry_integration.rs
0 → 100644
View file @
2f173ea0
This diff is collapsed.
Click to expand it.
sgl-router/tests/test_openai_routing.rs
View file @
2f173ea0
...
...
@@ -197,12 +197,14 @@ async fn test_unsupported_endpoints() {
rid
:
None
,
};
let
response
=
router
.route_generate
(
None
,
&
generate_request
)
.await
;
let
response
=
router
.route_generate
(
None
,
&
generate_request
,
None
)
.await
;
assert_eq!
(
response
.status
(),
StatusCode
::
NOT_IMPLEMENTED
);
// Test completion endpoint (should also not be supported)
let
completion_request
=
create_minimal_completion_request
();
let
response
=
router
.route_completion
(
None
,
&
completion_request
)
.await
;
let
response
=
router
.route_completion
(
None
,
&
completion_request
,
None
)
.await
;
assert_eq!
(
response
.status
(),
StatusCode
::
NOT_IMPLEMENTED
);
}
...
...
@@ -228,7 +230,7 @@ async fn test_openai_router_chat_completion_with_mock() {
chat_request
.temperature
=
Some
(
0.7
);
// Route the request
let
response
=
router
.route_chat
(
None
,
&
chat_request
)
.await
;
let
response
=
router
.route_chat
(
None
,
&
chat_request
,
None
)
.await
;
// Should get a successful response from mock server
assert_eq!
(
response
.status
(),
StatusCode
::
OK
);
...
...
@@ -269,7 +271,9 @@ async fn test_openai_e2e_with_server() {
let
chat_request
:
ChatCompletionRequest
=
serde_json
::
from_str
(
&
body_str
)
.unwrap
();
router
.route_chat
(
Some
(
&
parts
.headers
),
&
chat_request
)
.await
router
.route_chat
(
Some
(
&
parts
.headers
),
&
chat_request
,
None
)
.await
}
}
}),
...
...
@@ -327,7 +331,7 @@ async fn test_openai_router_chat_streaming_with_mock() {
});
let
chat_request
:
ChatCompletionRequest
=
serde_json
::
from_value
(
val
)
.unwrap
();
let
response
=
router
.route_chat
(
None
,
&
chat_request
)
.await
;
let
response
=
router
.route_chat
(
None
,
&
chat_request
,
None
)
.await
;
assert_eq!
(
response
.status
(),
StatusCode
::
OK
);
// Should be SSE
...
...
@@ -371,7 +375,7 @@ async fn test_openai_router_circuit_breaker() {
// First few requests should fail and record failures
for
_
in
0
..
3
{
let
response
=
router
.route_chat
(
None
,
&
chat_request
)
.await
;
let
response
=
router
.route_chat
(
None
,
&
chat_request
,
None
)
.await
;
// Should get either an error or circuit breaker response
assert
!
(
response
.status
()
==
StatusCode
::
INTERNAL_SERVER_ERROR
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment