Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2f173ea0
Unverified
Commit
2f173ea0
authored
Sep 12, 2025
by
Simo Lin
Committed by
GitHub
Sep 12, 2025
Browse files
[router] allow one router to support different model families and serving mode (#10244)
parent
321fecab
Changes
28
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1895 additions
and
620 deletions
+1895
-620
sgl-router/py_src/sglang_router/router.py
sgl-router/py_src/sglang_router/router.py
+3
-0
sgl-router/py_src/sglang_router/router_args.py
sgl-router/py_src/sglang_router/router_args.py
+6
-0
sgl-router/py_test/e2e/conftest.py
sgl-router/py_test/e2e/conftest.py
+3
-0
sgl-router/py_test/integration/test_retries.py
sgl-router/py_test/integration/test_retries.py
+1
-1
sgl-router/src/core/mod.rs
sgl-router/src/core/mod.rs
+2
-0
sgl-router/src/core/worker.rs
sgl-router/src/core/worker.rs
+105
-5
sgl-router/src/core/worker_registry.rs
sgl-router/src/core/worker_registry.rs
+526
-0
sgl-router/src/policies/cache_aware.rs
sgl-router/src/policies/cache_aware.rs
+185
-57
sgl-router/src/policies/mod.rs
sgl-router/src/policies/mod.rs
+12
-8
sgl-router/src/policies/power_of_two.rs
sgl-router/src/policies/power_of_two.rs
+8
-8
sgl-router/src/policies/random.rs
sgl-router/src/policies/random.rs
+10
-9
sgl-router/src/policies/registry.rs
sgl-router/src/policies/registry.rs
+333
-0
sgl-router/src/policies/round_robin.rs
sgl-router/src/policies/round_robin.rs
+13
-12
sgl-router/src/protocols/mod.rs
sgl-router/src/protocols/mod.rs
+1
-0
sgl-router/src/protocols/worker_spec.rs
sgl-router/src/protocols/worker_spec.rs
+198
-0
sgl-router/src/routers/factory.rs
sgl-router/src/routers/factory.rs
+13
-32
sgl-router/src/routers/grpc/pd_router.rs
sgl-router/src/routers/grpc/pd_router.rs
+11
-6
sgl-router/src/routers/grpc/router.rs
sgl-router/src/routers/grpc/router.rs
+8
-3
sgl-router/src/routers/http/openai_router.rs
sgl-router/src/routers/http/openai_router.rs
+10
-1
sgl-router/src/routers/http/pd_router.rs
sgl-router/src/routers/http/pd_router.rs
+447
-478
No files found.
sgl-router/py_src/sglang_router/router.py
View file @
2f173ea0
...
...
@@ -46,6 +46,9 @@ class Router:
max_payload_size: Maximum payload size in bytes. Default: 256MB
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
dp_aware: Enable data parallelism aware schedule. Default: False
enable_igw: Enable IGW (Inference-Gateway) mode for multi-model support. When enabled,
the router can manage multiple models simultaneously with per-model load balancing
policies. Default: False
api_key: The api key used for the authorization with the worker.
Useful when the dp aware scheduling strategy is enabled.
Default: None
...
...
sgl-router/py_src/sglang_router/router_args.py
View file @
2f173ea0
...
...
@@ -34,6 +34,7 @@ class RouterArgs:
max_tree_size
:
int
=
2
**
26
max_payload_size
:
int
=
512
*
1024
*
1024
# 512MB default for large batches
dp_aware
:
bool
=
False
enable_igw
:
bool
=
False
# Enable IGW (Inter-Gateway) mode for multi-model support
api_key
:
Optional
[
str
]
=
None
log_dir
:
Optional
[
str
]
=
None
log_level
:
Optional
[
str
]
=
None
...
...
@@ -227,6 +228,11 @@ class RouterArgs:
action
=
"store_true"
,
help
=
"Enable data parallelism aware schedule"
,
)
parser
.
add_argument
(
f
"--
{
prefix
}
enable-igw"
,
action
=
"store_true"
,
help
=
"Enable IGW (Inference-Gateway) mode for multi-model support"
,
)
parser
.
add_argument
(
f
"--
{
prefix
}
api-key"
,
type
=
str
,
...
...
sgl-router/py_test/e2e/conftest.py
View file @
2f173ea0
...
...
@@ -128,6 +128,7 @@ def _popen_launch_router_only(
timeout
:
float
=
120.0
,
*
,
dp_aware
:
bool
=
False
,
enable_igw
:
bool
=
False
,
api_key
:
str
|
None
=
None
,
)
->
subprocess
.
Popen
:
host
,
port
=
_parse_url
(
base_url
)
...
...
@@ -146,6 +147,8 @@ def _popen_launch_router_only(
]
if
dp_aware
:
cmd
+=
[
"--dp-aware"
]
if
enable_igw
:
cmd
+=
[
"--enable-igw"
]
if
api_key
is
not
None
:
cmd
+=
[
"--api-key"
,
api_key
]
cmd
+=
[
...
...
sgl-router/py_test/integration/test_retries.py
View file @
2f173ea0
...
...
@@ -35,7 +35,7 @@ def test_retry_reroutes_to_healthy_worker(router_manager, mock_workers):
)
assert
r
.
status_code
==
200
wid
=
r
.
headers
.
get
(
"X-Worker-Id"
)
or
r
.
json
().
get
(
"worker_id"
)
assert
wid
==
id_b
# should have retried onto healthy worker
assert
wid
in
[
id_b
,
id_c
]
# should have retried onto
a
healthy worker
(B or C)
# mock_workers fixture handles cleanup
...
...
sgl-router/src/core/mod.rs
View file @
2f173ea0
...
...
@@ -11,6 +11,7 @@ pub mod error;
pub
mod
retry
;
pub
mod
token_bucket
;
pub
mod
worker
;
pub
mod
worker_registry
;
// Re-export commonly used types at the module level
pub
use
circuit_breaker
::{
...
...
@@ -22,3 +23,4 @@ pub use worker::{
start_health_checker
,
BasicWorker
,
ConnectionMode
,
DPAwareWorker
,
HealthChecker
,
HealthConfig
,
Worker
,
WorkerCollection
,
WorkerFactory
,
WorkerLoadGuard
,
WorkerType
,
};
pub
use
worker_registry
::{
WorkerId
,
WorkerRegistry
,
WorkerRegistryStats
};
sgl-router/src/core/worker.rs
View file @
2f173ea0
...
...
@@ -155,6 +155,82 @@ pub trait Worker: Send + Sync + fmt::Debug {
fn
can_handle
(
&
self
,
_
req
:
&
serde_json
::
Value
)
->
bool
{
true
}
// === Multi-router support ===
// TODO: - Enhanced Worker Discovery
// The Worker trait should handle async discovery of metadata from the worker itself
// rather than having service discovery or other components query /get_server_info.
// This keeps service discovery decoupled from worker-specific APIs.
//
// Proposed additions:
// - async fn discover_metadata(&mut self) -> Result<(), Error>
// Query /get_server_info and populate metadata labels with model_id, priority, cost, etc.
// - async fn validate_configuration(&self) -> Result<(), Error>
// Ensure worker has required configuration for its mode (e.g., tokenizer for gRPC)
// - Make worker creation async to allow metadata discovery during initialization
//
// This way service discovery just calls router.add_worker() and the worker
// handles its own metadata discovery internally.
/// Get the model ID this worker serves
fn
model_id
(
&
self
)
->
&
str
{
self
.metadata
()
.labels
.get
(
"model_id"
)
.map
(|
s
|
s
.as_str
())
.unwrap_or
(
"unknown"
)
}
/// Get the priority of this worker (higher value = higher priority)
fn
priority
(
&
self
)
->
u32
{
self
.metadata
()
.labels
.get
(
"priority"
)
.and_then
(|
s
|
s
.parse
()
.ok
())
.unwrap_or
(
50
)
// Default priority is 50 (mid-range)
}
/// Get the cost factor of this worker (1.0 = baseline)
fn
cost
(
&
self
)
->
f32
{
self
.metadata
()
.labels
.get
(
"cost"
)
.and_then
(|
s
|
s
.parse
()
.ok
())
.unwrap_or
(
1.0
)
}
/// Get the tokenizer path for this worker (gRPC mode only)
fn
tokenizer_path
(
&
self
)
->
Option
<&
str
>
{
self
.metadata
()
.labels
.get
(
"tokenizer_path"
)
.map
(|
s
|
s
.as_str
())
}
/// Get the reasoning parser type for this worker (gRPC mode only)
fn
reasoning_parser
(
&
self
)
->
Option
<&
str
>
{
self
.metadata
()
.labels
.get
(
"reasoning_parser"
)
.map
(|
s
|
s
.as_str
())
}
/// Get the tool parser type for this worker (gRPC mode only)
fn
tool_parser
(
&
self
)
->
Option
<&
str
>
{
self
.metadata
()
.labels
.get
(
"tool_parser"
)
.map
(|
s
|
s
.as_str
())
}
/// Get the chat template for this worker (gRPC mode only)
fn
chat_template
(
&
self
)
->
Option
<&
str
>
{
self
.metadata
()
.labels
.get
(
"chat_template"
)
.map
(|
s
|
s
.as_str
())
}
}
/// Connection mode for worker communication
...
...
@@ -724,6 +800,21 @@ impl WorkerFactory {
)
}
/// Create a regular worker with custom labels (for multi-router support)
pub
fn
create_regular_with_labels
(
url
:
String
,
labels
:
std
::
collections
::
HashMap
<
String
,
String
>
,
circuit_breaker_config
:
CircuitBreakerConfig
,
)
->
Box
<
dyn
Worker
>
{
let
mut
worker
=
BasicWorker
::
new
(
url
.clone
(),
WorkerType
::
Regular
)
.with_circuit_breaker_config
(
circuit_breaker_config
);
// Add labels to metadata
worker
.metadata.labels
=
labels
;
Box
::
new
(
worker
)
}
/// Create a DP-aware worker of specified type
pub
fn
create_dp_aware
(
base_url
:
String
,
...
...
@@ -941,6 +1032,11 @@ impl fmt::Debug for HealthChecker {
}
impl
HealthChecker
{
/// Create a new HealthChecker
pub
fn
new
(
handle
:
tokio
::
task
::
JoinHandle
<
()
>
,
shutdown
:
Arc
<
AtomicBool
>
)
->
Self
{
Self
{
handle
,
shutdown
}
}
/// Shutdown the health checker gracefully
pub
async
fn
shutdown
(
self
)
{
self
.shutdown
.store
(
true
,
Ordering
::
Release
);
...
...
@@ -950,7 +1046,7 @@ impl HealthChecker {
/// Start an async background health checker for a collection of workers
pub
fn
start_health_checker
(
workers
:
std
::
sync
::
Arc
<
std
::
sync
::
RwLock
<
Vec
<
Box
<
dyn
Worker
>>>>
,
workers
:
std
::
sync
::
Arc
<
std
::
sync
::
RwLock
<
Vec
<
std
::
sync
::
Arc
<
dyn
Worker
>>>>
,
check_interval_secs
:
u64
,
)
->
HealthChecker
{
let
shutdown
=
Arc
::
new
(
AtomicBool
::
new
(
false
));
...
...
@@ -1602,9 +1698,11 @@ mod tests {
// Test HealthChecker background task
#[tokio::test]
async
fn
test_health_checker_startup
()
{
let
worker
s
=
Arc
::
new
(
RwLock
::
new
(
vec!
[
WorkerFactory
::
create_regular
(
let
worker
=
Arc
::
new
(
BasicWorker
::
new
(
"http://w1:8080"
.to_string
(),
)]));
WorkerType
::
Regular
,
))
as
Arc
<
dyn
Worker
>
;
let
workers
=
Arc
::
new
(
RwLock
::
new
(
vec!
[
worker
]));
let
checker
=
start_health_checker
(
workers
.clone
(),
60
);
...
...
@@ -1617,9 +1715,11 @@ mod tests {
#[tokio::test]
async
fn
test_health_checker_shutdown
()
{
let
worker
s
=
Arc
::
new
(
RwLock
::
new
(
vec!
[
WorkerFactory
::
create_regular
(
let
worker
=
Arc
::
new
(
BasicWorker
::
new
(
"http://w1:8080"
.to_string
(),
)]));
WorkerType
::
Regular
,
))
as
Arc
<
dyn
Worker
>
;
let
workers
=
Arc
::
new
(
RwLock
::
new
(
vec!
[
worker
]));
let
checker
=
start_health_checker
(
workers
.clone
(),
60
);
...
...
sgl-router/src/core/worker_registry.rs
0 → 100644
View file @
2f173ea0
//! Worker Registry for multi-router support
//!
//! Provides centralized registry for workers with model-based indexing
use
crate
::
core
::{
ConnectionMode
,
Worker
,
WorkerType
};
use
dashmap
::
DashMap
;
use
std
::
sync
::{
Arc
,
RwLock
};
use
uuid
::
Uuid
;
/// Unique identifier for a worker
#[derive(Debug,
Clone,
Hash,
Eq,
PartialEq)]
pub
struct
WorkerId
(
String
);
impl
WorkerId
{
/// Create a new worker ID
pub
fn
new
()
->
Self
{
Self
(
Uuid
::
new_v4
()
.to_string
())
}
/// Create a worker ID from a string
pub
fn
from_string
(
s
:
String
)
->
Self
{
Self
(
s
)
}
/// Get the ID as a string
pub
fn
as_str
(
&
self
)
->
&
str
{
&
self
.0
}
}
impl
Default
for
WorkerId
{
fn
default
()
->
Self
{
Self
::
new
()
}
}
/// Type alias for the model index to reduce complexity
type
ModelIndex
=
Arc
<
DashMap
<
String
,
Arc
<
RwLock
<
Vec
<
Arc
<
dyn
Worker
>>>>>>
;
/// Worker registry with model-based indexing
#[derive(Debug)]
pub
struct
WorkerRegistry
{
/// All workers indexed by ID
workers
:
Arc
<
DashMap
<
WorkerId
,
Arc
<
dyn
Worker
>>>
,
/// Workers indexed by model ID (stores WorkerId for reference)
model_workers
:
Arc
<
DashMap
<
String
,
Vec
<
WorkerId
>>>
,
/// Optimized model index for O(1) lookups (stores Arc<dyn Worker> directly)
model_index
:
ModelIndex
,
/// Workers indexed by worker type
type_workers
:
Arc
<
DashMap
<
WorkerType
,
Vec
<
WorkerId
>>>
,
/// Workers indexed by connection mode
connection_workers
:
Arc
<
DashMap
<
ConnectionMode
,
Vec
<
WorkerId
>>>
,
/// URL to worker ID mapping (for backward compatibility)
url_to_id
:
Arc
<
DashMap
<
String
,
WorkerId
>>
,
}
impl
WorkerRegistry
{
/// Create a new worker registry
pub
fn
new
()
->
Self
{
Self
{
workers
:
Arc
::
new
(
DashMap
::
new
()),
model_workers
:
Arc
::
new
(
DashMap
::
new
()),
model_index
:
Arc
::
new
(
DashMap
::
new
()),
type_workers
:
Arc
::
new
(
DashMap
::
new
()),
connection_workers
:
Arc
::
new
(
DashMap
::
new
()),
url_to_id
:
Arc
::
new
(
DashMap
::
new
()),
}
}
/// Register a new worker
pub
fn
register
(
&
self
,
worker
:
Arc
<
dyn
Worker
>
)
->
WorkerId
{
let
worker_id
=
if
let
Some
(
existing_id
)
=
self
.url_to_id
.get
(
worker
.url
())
{
// Worker with this URL already exists, update it
existing_id
.clone
()
}
else
{
WorkerId
::
new
()
};
// Store worker
self
.workers
.insert
(
worker_id
.clone
(),
worker
.clone
());
// Update URL mapping
self
.url_to_id
.insert
(
worker
.url
()
.to_string
(),
worker_id
.clone
());
// Update model index (both ID-based and optimized)
let
model_id
=
worker
.model_id
()
.to_string
();
self
.model_workers
.entry
(
model_id
.clone
())
.or_default
()
.push
(
worker_id
.clone
());
// Update optimized model index for O(1) lookups
self
.model_index
.entry
(
model_id
)
.or_insert_with
(||
Arc
::
new
(
RwLock
::
new
(
Vec
::
new
())))
.write
()
.expect
(
"RwLock for model_index is poisoned"
)
.push
(
worker
.clone
());
// Update type index
self
.type_workers
.entry
(
worker
.worker_type
())
.or_default
()
.push
(
worker_id
.clone
());
// Update connection mode index
self
.connection_workers
.entry
(
worker
.connection_mode
())
.or_default
()
.push
(
worker_id
.clone
());
worker_id
}
/// Remove a worker by ID
pub
fn
remove
(
&
self
,
worker_id
:
&
WorkerId
)
->
Option
<
Arc
<
dyn
Worker
>>
{
if
let
Some
((
_
,
worker
))
=
self
.workers
.remove
(
worker_id
)
{
// Remove from URL mapping
self
.url_to_id
.remove
(
worker
.url
());
// Remove from model index (both ID-based and optimized)
if
let
Some
(
mut
model_workers
)
=
self
.model_workers
.get_mut
(
worker
.model_id
())
{
model_workers
.retain
(|
id
|
id
!=
worker_id
);
}
// Remove from optimized model index
if
let
Some
(
model_index_entry
)
=
self
.model_index
.get
(
worker
.model_id
())
{
let
worker_url
=
worker
.url
();
model_index_entry
.write
()
.expect
(
"RwLock for model_index is poisoned"
)
.retain
(|
w
|
w
.url
()
!=
worker_url
);
}
// Remove from type index
if
let
Some
(
mut
type_workers
)
=
self
.type_workers
.get_mut
(
&
worker
.worker_type
())
{
type_workers
.retain
(|
id
|
id
!=
worker_id
);
}
// Remove from connection mode index
if
let
Some
(
mut
conn_workers
)
=
self
.connection_workers
.get_mut
(
&
worker
.connection_mode
())
{
conn_workers
.retain
(|
id
|
id
!=
worker_id
);
}
Some
(
worker
)
}
else
{
None
}
}
/// Remove a worker by URL
pub
fn
remove_by_url
(
&
self
,
url
:
&
str
)
->
Option
<
Arc
<
dyn
Worker
>>
{
if
let
Some
((
_
,
worker_id
))
=
self
.url_to_id
.remove
(
url
)
{
self
.remove
(
&
worker_id
)
}
else
{
None
}
}
/// Get a worker by ID
pub
fn
get
(
&
self
,
worker_id
:
&
WorkerId
)
->
Option
<
Arc
<
dyn
Worker
>>
{
self
.workers
.get
(
worker_id
)
.map
(|
entry
|
entry
.clone
())
}
/// Get a worker by URL
pub
fn
get_by_url
(
&
self
,
url
:
&
str
)
->
Option
<
Arc
<
dyn
Worker
>>
{
self
.url_to_id
.get
(
url
)
.and_then
(|
id
|
self
.get
(
&
id
))
}
/// Get all workers for a model
pub
fn
get_by_model
(
&
self
,
model_id
:
&
str
)
->
Vec
<
Arc
<
dyn
Worker
>>
{
self
.model_workers
.get
(
model_id
)
.map
(|
ids
|
ids
.iter
()
.filter_map
(|
id
|
self
.get
(
id
))
.collect
())
.unwrap_or_default
()
}
/// Get all workers for a model (O(1) optimized version)
/// This method uses the pre-indexed model_index for fast lookups
pub
fn
get_by_model_fast
(
&
self
,
model_id
:
&
str
)
->
Vec
<
Arc
<
dyn
Worker
>>
{
self
.model_index
.get
(
model_id
)
.map
(|
workers
|
{
workers
.read
()
.expect
(
"RwLock for model_index is poisoned"
)
.clone
()
})
.unwrap_or_default
()
}
/// Get all workers by worker type
pub
fn
get_by_type
(
&
self
,
worker_type
:
&
WorkerType
)
->
Vec
<
Arc
<
dyn
Worker
>>
{
self
.type_workers
.get
(
worker_type
)
.map
(|
ids
|
ids
.iter
()
.filter_map
(|
id
|
self
.get
(
id
))
.collect
())
.unwrap_or_default
()
}
/// Get all prefill workers (regardless of bootstrap_port)
pub
fn
get_prefill_workers
(
&
self
)
->
Vec
<
Arc
<
dyn
Worker
>>
{
self
.workers
.iter
()
.filter_map
(|
entry
|
{
let
worker
=
entry
.value
();
match
worker
.worker_type
()
{
WorkerType
::
Prefill
{
..
}
=>
Some
(
worker
.clone
()),
_
=>
None
,
}
})
.collect
()
}
/// Get all decode workers
pub
fn
get_decode_workers
(
&
self
)
->
Vec
<
Arc
<
dyn
Worker
>>
{
self
.get_by_type
(
&
WorkerType
::
Decode
)
}
/// Get all workers by connection mode
pub
fn
get_by_connection
(
&
self
,
connection_mode
:
&
ConnectionMode
)
->
Vec
<
Arc
<
dyn
Worker
>>
{
self
.connection_workers
.get
(
connection_mode
)
.map
(|
ids
|
ids
.iter
()
.filter_map
(|
id
|
self
.get
(
id
))
.collect
())
.unwrap_or_default
()
}
/// Get all workers
pub
fn
get_all
(
&
self
)
->
Vec
<
Arc
<
dyn
Worker
>>
{
self
.workers
.iter
()
.map
(|
entry
|
entry
.value
()
.clone
())
.collect
()
}
/// Get all workers with their IDs
pub
fn
get_all_with_ids
(
&
self
)
->
Vec
<
(
WorkerId
,
Arc
<
dyn
Worker
>
)
>
{
self
.workers
.iter
()
.map
(|
entry
|
(
entry
.key
()
.clone
(),
entry
.value
()
.clone
()))
.collect
()
}
/// Get all worker URLs
pub
fn
get_all_urls
(
&
self
)
->
Vec
<
String
>
{
self
.workers
.iter
()
.map
(|
entry
|
entry
.value
()
.url
()
.to_string
())
.collect
()
}
/// Get all model IDs with workers
pub
fn
get_models
(
&
self
)
->
Vec
<
String
>
{
self
.model_workers
.iter
()
.filter
(|
entry
|
!
entry
.value
()
.is_empty
())
.map
(|
entry
|
entry
.key
()
.clone
())
.collect
()
}
/// Get workers filtered by multiple criteria
///
/// This method allows flexible filtering of workers based on:
/// - model_id: Filter by specific model
/// - worker_type: Filter by worker type (Regular, Prefill, Decode)
/// - connection_mode: Filter by connection mode (Http, Grpc)
/// - healthy_only: Only return healthy workers
pub
fn
get_workers_filtered
(
&
self
,
model_id
:
Option
<&
str
>
,
worker_type
:
Option
<
WorkerType
>
,
connection_mode
:
Option
<
ConnectionMode
>
,
healthy_only
:
bool
,
)
->
Vec
<
Arc
<
dyn
Worker
>>
{
// Start with the most efficient collection based on filters
// Use model index when possible as it's O(1) lookup
let
workers
=
if
let
Some
(
model
)
=
model_id
{
self
.get_by_model_fast
(
model
)
}
else
{
self
.get_all
()
};
// Apply remaining filters
workers
.into_iter
()
.filter
(|
w
|
{
// Check worker_type if specified
if
let
Some
(
ref
wtype
)
=
worker_type
{
if
w
.worker_type
()
!=
*
wtype
{
return
false
;
}
}
// Check connection_mode if specified
if
let
Some
(
ref
conn
)
=
connection_mode
{
if
w
.connection_mode
()
!=
*
conn
{
return
false
;
}
}
// Check health if required
if
healthy_only
&&
!
w
.is_healthy
()
{
return
false
;
}
true
})
.collect
()
}
/// Get worker statistics
pub
fn
stats
(
&
self
)
->
WorkerRegistryStats
{
let
total_workers
=
self
.workers
.len
();
let
total_models
=
self
.get_models
()
.len
();
let
mut
healthy_count
=
0
;
let
mut
total_load
=
0
;
let
mut
regular_count
=
0
;
let
mut
prefill_count
=
0
;
let
mut
decode_count
=
0
;
for
worker
in
self
.get_all
()
{
if
worker
.is_healthy
()
{
healthy_count
+=
1
;
}
total_load
+=
worker
.load
();
match
worker
.worker_type
()
{
WorkerType
::
Regular
=>
regular_count
+=
1
,
WorkerType
::
Prefill
{
..
}
=>
prefill_count
+=
1
,
WorkerType
::
Decode
=>
decode_count
+=
1
,
}
}
WorkerRegistryStats
{
total_workers
,
total_models
,
healthy_workers
:
healthy_count
,
total_load
,
regular_workers
:
regular_count
,
prefill_workers
:
prefill_count
,
decode_workers
:
decode_count
,
}
}
/// Start a health checker for all workers in the registry
/// This should be called once after the registry is populated with workers
pub
fn
start_health_checker
(
&
self
,
check_interval_secs
:
u64
)
->
crate
::
core
::
HealthChecker
{
use
std
::
sync
::
atomic
::{
AtomicBool
,
Ordering
};
use
std
::
sync
::
Arc
;
let
shutdown
=
Arc
::
new
(
AtomicBool
::
new
(
false
));
let
shutdown_clone
=
shutdown
.clone
();
let
workers_ref
=
self
.workers
.clone
();
let
handle
=
tokio
::
spawn
(
async
move
{
let
mut
interval
=
tokio
::
time
::
interval
(
tokio
::
time
::
Duration
::
from_secs
(
check_interval_secs
));
// Counter for periodic load reset (every 10 health check cycles)
let
mut
check_count
=
0u64
;
const
LOAD_RESET_INTERVAL
:
u64
=
10
;
loop
{
interval
.tick
()
.await
;
// Check for shutdown signal
if
shutdown_clone
.load
(
Ordering
::
Acquire
)
{
tracing
::
debug!
(
"Registry health checker shutting down"
);
break
;
}
// Get all workers from registry
let
workers
:
Vec
<
Arc
<
dyn
crate
::
core
::
Worker
>>
=
workers_ref
.iter
()
.map
(|
entry
|
entry
.value
()
.clone
())
.collect
();
// Perform health checks
for
worker
in
&
workers
{
let
_
=
worker
.check_health_async
()
.await
;
// Use async version directly
}
// Reset loads periodically
check_count
+=
1
;
if
check_count
%
LOAD_RESET_INTERVAL
==
0
{
tracing
::
debug!
(
"Resetting worker loads (cycle {})"
,
check_count
);
for
worker
in
&
workers
{
worker
.reset_load
();
}
}
}
});
crate
::
core
::
HealthChecker
::
new
(
handle
,
shutdown
)
}
}
impl
Default
for
WorkerRegistry
{
fn
default
()
->
Self
{
Self
::
new
()
}
}
/// Statistics for the worker registry
#[derive(Debug,
Clone)]
pub
struct
WorkerRegistryStats
{
pub
total_workers
:
usize
,
pub
total_models
:
usize
,
pub
healthy_workers
:
usize
,
pub
total_load
:
usize
,
pub
regular_workers
:
usize
,
pub
prefill_workers
:
usize
,
pub
decode_workers
:
usize
,
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
crate
::
core
::{
CircuitBreakerConfig
,
WorkerFactory
};
use
std
::
collections
::
HashMap
;
#[test]
fn
test_worker_registry
()
{
let
registry
=
WorkerRegistry
::
new
();
// Create a worker with labels
let
mut
labels
=
HashMap
::
new
();
labels
.insert
(
"model_id"
.to_string
(),
"llama-3-8b"
.to_string
());
labels
.insert
(
"priority"
.to_string
(),
"50"
.to_string
());
labels
.insert
(
"cost"
.to_string
(),
"0.8"
.to_string
());
let
worker
=
WorkerFactory
::
create_regular_with_labels
(
"http://worker1:8080"
.to_string
(),
labels
,
CircuitBreakerConfig
::
default
(),
);
// Register worker (WorkerFactory returns Box<dyn Worker>, convert to Arc)
let
worker_id
=
registry
.register
(
Arc
::
from
(
worker
));
// Verify registration
assert
!
(
registry
.get
(
&
worker_id
)
.is_some
());
assert
!
(
registry
.get_by_url
(
"http://worker1:8080"
)
.is_some
());
assert_eq!
(
registry
.get_by_model
(
"llama-3-8b"
)
.len
(),
1
);
assert_eq!
(
registry
.get_by_type
(
&
WorkerType
::
Regular
)
.len
(),
1
);
assert_eq!
(
registry
.get_by_connection
(
&
ConnectionMode
::
Http
)
.len
(),
1
);
// Test stats
let
stats
=
registry
.stats
();
assert_eq!
(
stats
.total_workers
,
1
);
assert_eq!
(
stats
.total_models
,
1
);
// Remove worker
registry
.remove
(
&
worker_id
);
assert
!
(
registry
.get
(
&
worker_id
)
.is_none
());
}
#[test]
fn
test_model_index_fast_lookup
()
{
let
registry
=
WorkerRegistry
::
new
();
// Create workers for different models
let
mut
labels1
=
HashMap
::
new
();
labels1
.insert
(
"model_id"
.to_string
(),
"llama-3"
.to_string
());
let
worker1
=
WorkerFactory
::
create_regular_with_labels
(
"http://worker1:8080"
.to_string
(),
labels1
,
CircuitBreakerConfig
::
default
(),
);
let
mut
labels2
=
HashMap
::
new
();
labels2
.insert
(
"model_id"
.to_string
(),
"llama-3"
.to_string
());
let
worker2
=
WorkerFactory
::
create_regular_with_labels
(
"http://worker2:8080"
.to_string
(),
labels2
,
CircuitBreakerConfig
::
default
(),
);
let
mut
labels3
=
HashMap
::
new
();
labels3
.insert
(
"model_id"
.to_string
(),
"gpt-4"
.to_string
());
let
worker3
=
WorkerFactory
::
create_regular_with_labels
(
"http://worker3:8080"
.to_string
(),
labels3
,
CircuitBreakerConfig
::
default
(),
);
// Register workers
registry
.register
(
Arc
::
from
(
worker1
));
registry
.register
(
Arc
::
from
(
worker2
));
registry
.register
(
Arc
::
from
(
worker3
));
// Test get_by_model_fast for llama-3
let
llama_workers
=
registry
.get_by_model_fast
(
"llama-3"
);
assert_eq!
(
llama_workers
.len
(),
2
);
let
urls
:
Vec
<
String
>
=
llama_workers
.iter
()
.map
(|
w
|
w
.url
()
.to_string
())
.collect
();
assert
!
(
urls
.contains
(
&
"http://worker1:8080"
.to_string
()));
assert
!
(
urls
.contains
(
&
"http://worker2:8080"
.to_string
()));
// Test get_by_model_fast for gpt-4
let
gpt_workers
=
registry
.get_by_model_fast
(
"gpt-4"
);
assert_eq!
(
gpt_workers
.len
(),
1
);
assert_eq!
(
gpt_workers
[
0
]
.url
(),
"http://worker3:8080"
);
// Test get_by_model_fast for non-existent model
let
unknown_workers
=
registry
.get_by_model_fast
(
"unknown-model"
);
assert_eq!
(
unknown_workers
.len
(),
0
);
// Test that both get_by_model and get_by_model_fast return same results
let
llama_workers_slow
=
registry
.get_by_model
(
"llama-3"
);
assert_eq!
(
llama_workers
.len
(),
llama_workers_slow
.len
());
// Test removal updates the model index
registry
.remove_by_url
(
"http://worker1:8080"
);
let
llama_workers_after
=
registry
.get_by_model_fast
(
"llama-3"
);
assert_eq!
(
llama_workers_after
.len
(),
1
);
assert_eq!
(
llama_workers_after
[
0
]
.url
(),
"http://worker2:8080"
);
}
}
sgl-router/src/policies/cache_aware.rs
View file @
2f173ea0
...
...
@@ -63,6 +63,7 @@ use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy};
use
crate
::
core
::
Worker
;
use
crate
::
metrics
::
RouterMetrics
;
use
crate
::
tree
::
Tree
;
use
std
::
collections
::
HashMap
;
use
std
::
sync
::{
Arc
,
Mutex
};
use
std
::
thread
;
use
std
::
time
::
Duration
;
...
...
@@ -72,10 +73,11 @@ use tracing::debug;
///
/// Routes requests based on cache affinity when load is balanced,
/// switches to shortest-queue routing when load is imbalanced.
/// Maintains separate trees per model for multi-model support.
#[derive(Debug)]
pub
struct
CacheAwarePolicy
{
config
:
CacheAwareConfig
,
tree
:
Arc
<
Mutex
<
Tree
>>
,
tree
s
:
Arc
<
Mutex
<
HashMap
<
String
,
Tree
>>>
,
// model_id ->
Tree
eviction_handle
:
Option
<
thread
::
JoinHandle
<
()
>>
,
}
...
...
@@ -85,20 +87,26 @@ impl CacheAwarePolicy {
}
pub
fn
with_config
(
config
:
CacheAwareConfig
)
->
Self
{
let
tree
=
Arc
::
new
(
Mutex
::
new
(
Tree
::
new
()));
let
tree
s
=
Arc
::
new
(
Mutex
::
new
(
HashMap
::
<
String
,
Tree
>
::
new
()));
// Start background eviction thread if configured
let
eviction_handle
=
if
config
.eviction_interval_secs
>
0
{
let
tree_clone
=
Arc
::
clone
(
&
tree
);
let
tree
s
_clone
=
Arc
::
clone
(
&
tree
s
);
let
max_tree_size
=
config
.max_tree_size
;
let
interval
=
config
.eviction_interval_secs
;
Some
(
thread
::
spawn
(
move
||
loop
{
thread
::
sleep
(
Duration
::
from_secs
(
interval
));
if
let
Ok
(
tree_guard
)
=
tree_clone
.lock
()
{
tree_guard
.evict_tenant_by_size
(
max_tree_size
);
debug!
(
"Cache eviction completed, max_size: {}"
,
max_tree_size
);
if
let
Ok
(
mut
trees_guard
)
=
trees_clone
.lock
()
{
// Evict for all model trees
for
(
model_id
,
tree
)
in
trees_guard
.iter_mut
()
{
tree
.evict_tenant_by_size
(
max_tree_size
);
debug!
(
"Cache eviction completed for model {}, max_size: {}"
,
model_id
,
max_tree_size
);
}
}
}))
}
else
{
...
...
@@ -107,38 +115,97 @@ impl CacheAwarePolicy {
Self
{
config
,
tree
,
tree
s
,
eviction_handle
,
}
}
/// Initialize the tree with worker URLs (used only during initial setup)
pub
fn
init_workers
(
&
self
,
workers
:
&
[
Box
<
dyn
Worker
>
])
{
if
let
Ok
(
tree
)
=
self
.tree
.lock
()
{
pub
fn
init_workers
(
&
self
,
workers
:
&
[
Arc
<
dyn
Worker
>
])
{
if
let
Ok
(
mut
trees
)
=
self
.trees
.lock
()
{
// Group workers by model
let
mut
model_workers
:
HashMap
<
String
,
Vec
<&
Arc
<
dyn
Worker
>>>
=
HashMap
::
new
();
for
worker
in
workers
{
tree
.insert
(
""
,
worker
.url
());
// Use "default" for unknown/empty model_ids for backward compatibility
let
model_id
=
worker
.model_id
();
let
tree_key
=
if
model_id
.is_empty
()
||
model_id
==
"unknown"
{
"default"
.to_string
()
}
else
{
model_id
.to_string
()
};
model_workers
.entry
(
tree_key
)
.or_default
()
.push
(
worker
);
}
// Initialize tree for each model
for
(
tree_key
,
model_workers
)
in
model_workers
{
let
tree
=
trees
.entry
(
tree_key
)
.or_insert_with
(
Tree
::
new
);
for
worker
in
model_workers
{
tree
.insert
(
""
,
worker
.url
());
}
}
}
}
/// Add a single worker to the tree (incremental update)
pub
fn
add_worker
(
&
self
,
url
:
&
str
)
{
if
let
Ok
(
tree
)
=
self
.tree
.lock
()
{
pub
fn
add_worker
(
&
self
,
worker
:
&
dyn
Worker
)
{
if
let
Ok
(
mut
trees
)
=
self
.trees
.lock
()
{
// For backward compatibility: if model_id is "unknown" or empty,
// use a default tree. This preserves existing behavior for single-model routers.
let
model_id
=
worker
.model_id
();
let
tree_key
=
if
model_id
.is_empty
()
||
model_id
==
"unknown"
{
"default"
.to_string
()
}
else
{
model_id
.to_string
()
};
let
tree
=
trees
.entry
(
tree_key
)
.or_insert_with
(
Tree
::
new
);
tree
.insert
(
""
,
worker
.url
());
}
}
/// Add a worker by URL and model (for backward compatibility)
pub
fn
add_worker_by_url
(
&
self
,
url
:
&
str
,
model_id
:
&
str
)
{
if
let
Ok
(
mut
trees
)
=
self
.trees
.lock
()
{
let
tree
=
trees
.entry
(
model_id
.to_string
())
.or_insert_with
(
Tree
::
new
);
tree
.insert
(
""
,
url
);
}
}
/// Remove a worker from the tree
pub
fn
remove_worker
(
&
self
,
url
:
&
str
)
{
if
let
Ok
(
tree
)
=
self
.tree
.lock
()
{
tree
.remove_tenant
(
url
);
pub
fn
remove_worker
(
&
self
,
worker
:
&
dyn
Worker
)
{
if
let
Ok
(
mut
trees
)
=
self
.trees
.lock
()
{
// Use same logic as add_worker for consistency
let
model_id
=
worker
.model_id
();
let
tree_key
=
if
model_id
.is_empty
()
||
model_id
==
"unknown"
{
"default"
.to_string
()
}
else
{
model_id
.to_string
()
};
if
let
Some
(
tree
)
=
trees
.get_mut
(
&
tree_key
)
{
tree
.remove_tenant
(
worker
.url
());
}
}
}
/// Remove a worker by URL (removes from all model trees for backward compatibility)
pub
fn
remove_worker_by_url
(
&
self
,
url
:
&
str
)
{
if
let
Ok
(
mut
trees
)
=
self
.trees
.lock
()
{
// Remove from all trees since we don't know which model it belongs to
for
(
_
model_id
,
tree
)
in
trees
.iter_mut
()
{
tree
.remove_tenant
(
url
);
}
}
}
/// Run cache eviction to prevent unbounded growth
pub
fn
evict_cache
(
&
self
,
max_size
:
usize
)
{
if
let
Ok
(
tree
)
=
self
.tree
.lock
()
{
tree
.evict_tenant_by_size
(
max_size
);
if
let
Ok
(
mut
trees
)
=
self
.trees
.lock
()
{
for
(
model_id
,
tree
)
in
trees
.iter_mut
()
{
tree
.evict_tenant_by_size
(
max_size
);
debug!
(
"Cache eviction for model {}, max_size: {}"
,
model_id
,
max_size
);
}
}
}
}
...
...
@@ -146,7 +213,7 @@ impl CacheAwarePolicy {
impl
LoadBalancingPolicy
for
CacheAwarePolicy
{
fn
select_worker
(
&
self
,
workers
:
&
[
Box
<
dyn
Worker
>
],
workers
:
&
[
Arc
<
dyn
Worker
>
],
request_text
:
Option
<&
str
>
,
)
->
Option
<
usize
>
{
let
healthy_indices
=
get_healthy_worker_indices
(
workers
);
...
...
@@ -155,6 +222,18 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
return
None
;
}
// Group workers by model (using "default" for unknown/empty model_ids)
let
mut
model_workers
:
HashMap
<
String
,
Vec
<
usize
>>
=
HashMap
::
new
();
for
idx
in
&
healthy_indices
{
let
model_id
=
workers
[
*
idx
]
.model_id
();
let
tree_key
=
if
model_id
.is_empty
()
||
model_id
==
"unknown"
{
"default"
.to_string
()
}
else
{
model_id
.to_string
()
};
model_workers
.entry
(
tree_key
)
.or_default
()
.push
(
*
idx
);
}
// Get current load statistics
let
loads
:
Vec
<
usize
>
=
workers
.iter
()
.map
(|
w
|
w
.load
())
.collect
();
let
max_load
=
*
loads
.iter
()
.max
()
.unwrap_or
(
&
0
);
...
...
@@ -187,7 +266,14 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
// Even in imbalanced mode, update the tree to maintain cache state
if
let
Some
(
text
)
=
request_text
{
if
let
Ok
(
tree
)
=
self
.tree
.lock
()
{
if
let
Ok
(
mut
trees
)
=
self
.trees
.lock
()
{
let
model_id
=
workers
[
min_load_idx
]
.model_id
();
let
tree_key
=
if
model_id
.is_empty
()
||
model_id
==
"unknown"
{
"default"
.to_string
()
}
else
{
model_id
.to_string
()
};
let
tree
=
trees
.entry
(
tree_key
)
.or_insert_with
(
Tree
::
new
);
tree
.insert
(
text
,
workers
[
min_load_idx
]
.url
());
}
}
...
...
@@ -203,43 +289,85 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
// Use cache-aware routing when balanced
let
text
=
request_text
.unwrap_or
(
""
);
if
let
Ok
(
tree
)
=
self
.tree
.lock
()
{
let
(
matched_text
,
matched_worker
)
=
tree
.prefix_match
(
text
);
let
match_rate
=
if
text
.is_empty
()
{
0.0
}
else
{
matched_text
.chars
()
.count
()
as
f32
/
text
.chars
()
.count
()
as
f32
};
if
let
Ok
(
mut
trees
)
=
self
.trees
.lock
()
{
let
mut
best_match_idx
:
Option
<
usize
>
=
None
;
let
mut
best_match_rate
:
f32
=
0.0
;
// Find best match across all models
for
(
model_id
,
worker_indices
)
in
&
model_workers
{
let
tree
=
trees
.entry
(
model_id
.clone
())
.or_insert_with
(
Tree
::
new
);
let
(
matched_text
,
matched_worker
)
=
tree
.prefix_match
(
text
);
let
match_rate
=
if
text
.is_empty
()
{
0.0
}
else
{
matched_text
.chars
()
.count
()
as
f32
/
text
.chars
()
.count
()
as
f32
};
// Check if this model has the best match
if
match_rate
>
best_match_rate
{
// Find the worker index for this URL
if
let
Some
(
idx
)
=
worker_indices
.iter
()
.find
(|
&&
idx
|
workers
[
idx
]
.url
()
==
matched_worker
)
{
best_match_idx
=
Some
(
*
idx
);
best_match_rate
=
match_rate
;
}
}
}
let
selected_url
=
if
match_rate
>
self
.config.cache_threshold
{
// Select worker based on cache threshold
let
selected_idx
=
if
let
(
Some
(
idx
),
true
)
=
(
best_match_idx
,
best_match_rate
>
self
.config.cache_threshold
,
)
{
RouterMetrics
::
record_cache_hit
();
matched_worker
.to_string
()
idx
}
else
{
RouterMetrics
::
record_cache_miss
();
tree
.get_smallest_tenant
()
};
// Find the index of the selected worker
if
let
Some
(
selected_idx
)
=
workers
.iter
()
.position
(|
w
|
w
.url
()
==
selected_url
)
{
// Only proceed if the worker is healthy
if
workers
[
selected_idx
]
.is_healthy
()
{
// Update the tree with this request
tree
.insert
(
text
,
&
selected_url
);
// Increment processed counter
workers
[
selected_idx
]
.increment_processed
();
RouterMetrics
::
record_processed_request
(
&
selected_url
);
// Find model with smallest tree (most cache capacity)
let
mut
smallest_tree_model
=
String
::
new
();
let
mut
smallest_tree_size
=
usize
::
MAX
;
for
model_id
in
model_workers
.keys
()
{
let
tree
=
trees
.entry
(
model_id
.clone
())
.or_insert_with
(
Tree
::
new
);
let
size
=
tree
.get_used_size_per_tenant
()
.values
()
.sum
::
<
usize
>
();
if
size
<
smallest_tree_size
{
smallest_tree_size
=
size
;
smallest_tree_model
=
model_id
.clone
();
}
}
return
Some
(
selected_idx
);
// Select least loaded worker from model with most cache capacity
if
let
Some
(
worker_indices
)
=
model_workers
.get
(
&
smallest_tree_model
)
{
worker_indices
.iter
()
.min_by_key
(|
&&
idx
|
workers
[
idx
]
.load
())
.copied
()
.unwrap_or
(
healthy_indices
[
0
])
}
else
{
healthy_indices
[
0
]
}
};
// Update the tree with this request
let
model_id
=
workers
[
selected_idx
]
.model_id
();
let
tree_key
=
if
model_id
.is_empty
()
||
model_id
==
"unknown"
{
"default"
.to_string
()
}
else
{
// Selected worker no longer exists, remove it from tree
tree
.remove_tenant
(
&
selected_url
);
debug!
(
"Removed stale worker {} from cache tree"
,
selected_url
);
}
model_id
.to_string
()
};
let
tree
=
trees
.entry
(
tree_key
)
.or_insert_with
(
Tree
::
new
);
tree
.insert
(
text
,
workers
[
selected_idx
]
.url
());
// Increment processed counter
workers
[
selected_idx
]
.increment_processed
();
RouterMetrics
::
record_processed_request
(
workers
[
selected_idx
]
.url
());
RouterMetrics
::
record_policy_decision
(
self
.name
(),
workers
[
selected_idx
]
.url
());
// Fallback to first healthy worker
return
healthy_indices
.first
()
.copied
();
return
Some
(
selected_idx
);
}
// Fallback to first healthy worker if tree operations fail
...
...
@@ -272,8 +400,8 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
fn
select_worker_pair
(
&
self
,
prefill_workers
:
&
[
Box
<
dyn
Worker
>
],
decode_workers
:
&
[
Box
<
dyn
Worker
>
],
prefill_workers
:
&
[
Arc
<
dyn
Worker
>
],
decode_workers
:
&
[
Arc
<
dyn
Worker
>
],
request_text
:
Option
<&
str
>
,
)
->
Option
<
(
usize
,
usize
)
>
{
// DEPRECATED: This method is no longer used when separate policies are configured.
...
...
@@ -333,12 +461,12 @@ mod tests {
..
Default
::
default
()
};
let
policy
=
CacheAwarePolicy
::
with_config
(
config
);
let
workers
:
Vec
<
Box
<
dyn
Worker
>>
=
vec!
[
Box
::
new
(
BasicWorker
::
new
(
let
workers
:
Vec
<
Arc
<
dyn
Worker
>>
=
vec!
[
Arc
::
new
(
BasicWorker
::
new
(
"http://w1:8000"
.to_string
(),
WorkerType
::
Regular
,
)),
Box
::
new
(
BasicWorker
::
new
(
Arc
::
new
(
BasicWorker
::
new
(
"http://w2:8000"
.to_string
(),
WorkerType
::
Regular
,
)),
...
...
@@ -378,7 +506,7 @@ mod tests {
}
// worker2 has load 0
let
workers
:
Vec
<
Box
<
dyn
Worker
>>
=
vec!
[
Box
::
new
(
worker1
),
Box
::
new
(
worker2
)];
let
workers
:
Vec
<
Arc
<
dyn
Worker
>>
=
vec!
[
Arc
::
new
(
worker1
),
Arc
::
new
(
worker2
)];
policy
.init_workers
(
&
workers
);
// Should select worker2 (lower load) despite cache affinity
...
...
@@ -395,12 +523,12 @@ mod tests {
..
Default
::
default
()
};
let
policy
=
CacheAwarePolicy
::
with_config
(
config
);
let
workers
:
Vec
<
Box
<
dyn
Worker
>>
=
vec!
[
Box
::
new
(
BasicWorker
::
new
(
let
workers
:
Vec
<
Arc
<
dyn
Worker
>>
=
vec!
[
Arc
::
new
(
BasicWorker
::
new
(
"http://w1:8000"
.to_string
(),
WorkerType
::
Regular
,
)),
Box
::
new
(
BasicWorker
::
new
(
Arc
::
new
(
BasicWorker
::
new
(
"http://w2:8000"
.to_string
(),
WorkerType
::
Regular
,
)),
...
...
@@ -413,7 +541,7 @@ mod tests {
policy
.select_worker
(
&
workers
,
Some
(
"test2"
));
// Remove a worker
policy
.remove_worker
(
"http://w1:8000"
);
policy
.remove_worker
_by_url
(
"http://w1:8000"
);
workers
[
0
]
.set_healthy
(
false
);
// All requests should now go to worker2
...
...
sgl-router/src/policies/mod.rs
View file @
2f173ea0
...
...
@@ -5,17 +5,20 @@
use
crate
::
core
::
Worker
;
use
std
::
fmt
::
Debug
;
use
std
::
sync
::
Arc
;
mod
cache_aware
;
mod
factory
;
mod
power_of_two
;
mod
random
;
mod
registry
;
mod
round_robin
;
pub
use
cache_aware
::
CacheAwarePolicy
;
pub
use
factory
::
PolicyFactory
;
pub
use
power_of_two
::
PowerOfTwoPolicy
;
pub
use
random
::
RandomPolicy
;
pub
use
registry
::
PolicyRegistry
;
pub
use
round_robin
::
RoundRobinPolicy
;
/// Core trait for load balancing policies
...
...
@@ -26,9 +29,10 @@ pub trait LoadBalancingPolicy: Send + Sync + Debug {
/// Select a single worker from the available workers
///
/// This is used for regular routing mode where requests go to a single worker.
/// Now uses Arc<dyn Worker> for better performance and to avoid unnecessary cloning.
fn
select_worker
(
&
self
,
workers
:
&
[
Box
<
dyn
Worker
>
],
workers
:
&
[
Arc
<
dyn
Worker
>
],
request_text
:
Option
<&
str
>
,
)
->
Option
<
usize
>
;
...
...
@@ -38,8 +42,8 @@ pub trait LoadBalancingPolicy: Send + Sync + Debug {
/// Default implementation uses select_worker for each array independently.
fn
select_worker_pair
(
&
self
,
prefill_workers
:
&
[
Box
<
dyn
Worker
>
],
decode_workers
:
&
[
Box
<
dyn
Worker
>
],
prefill_workers
:
&
[
Arc
<
dyn
Worker
>
],
decode_workers
:
&
[
Arc
<
dyn
Worker
>
],
request_text
:
Option
<&
str
>
,
)
->
Option
<
(
usize
,
usize
)
>
{
// Default implementation: independently select from each pool
...
...
@@ -105,7 +109,7 @@ impl Default for CacheAwareConfig {
}
/// Helper function to filter healthy workers and return their indices
pub
(
crate
)
fn
get_healthy_worker_indices
(
workers
:
&
[
Box
<
dyn
Worker
>
])
->
Vec
<
usize
>
{
pub
(
crate
)
fn
get_healthy_worker_indices
(
workers
:
&
[
Arc
<
dyn
Worker
>
])
->
Vec
<
usize
>
{
workers
.iter
()
.enumerate
()
...
...
@@ -121,16 +125,16 @@ mod tests {
#[test]
fn
test_get_healthy_worker_indices
()
{
let
workers
:
Vec
<
Box
<
dyn
Worker
>>
=
vec!
[
Box
::
new
(
BasicWorker
::
new
(
let
workers
:
Vec
<
Arc
<
dyn
Worker
>>
=
vec!
[
Arc
::
new
(
BasicWorker
::
new
(
"http://w1:8000"
.to_string
(),
WorkerType
::
Regular
,
)),
Box
::
new
(
BasicWorker
::
new
(
Arc
::
new
(
BasicWorker
::
new
(
"http://w2:8000"
.to_string
(),
WorkerType
::
Regular
,
)),
Box
::
new
(
BasicWorker
::
new
(
Arc
::
new
(
BasicWorker
::
new
(
"http://w3:8000"
.to_string
(),
WorkerType
::
Regular
,
)),
...
...
sgl-router/src/policies/power_of_two.rs
View file @
2f173ea0
...
...
@@ -5,7 +5,7 @@ use crate::core::Worker;
use
crate
::
metrics
::
RouterMetrics
;
use
rand
::
Rng
;
use
std
::
collections
::
HashMap
;
use
std
::
sync
::
RwLock
;
use
std
::
sync
::
{
Arc
,
RwLock
}
;
use
tracing
::
info
;
/// Power-of-two choices policy
...
...
@@ -41,7 +41,7 @@ impl PowerOfTwoPolicy {
impl
LoadBalancingPolicy
for
PowerOfTwoPolicy
{
fn
select_worker
(
&
self
,
workers
:
&
[
Box
<
dyn
Worker
>
],
workers
:
&
[
Arc
<
dyn
Worker
>
],
_
request_text
:
Option
<&
str
>
,
)
->
Option
<
usize
>
{
let
healthy_indices
=
get_healthy_worker_indices
(
workers
);
...
...
@@ -137,8 +137,8 @@ mod tests {
}
// worker3 has load 0
let
workers
:
Vec
<
Box
<
dyn
Worker
>>
=
vec!
[
Box
::
new
(
worker1
),
Box
::
new
(
worker2
),
Box
::
new
(
worker3
)];
let
workers
:
Vec
<
Arc
<
dyn
Worker
>>
=
vec!
[
Arc
::
new
(
worker1
),
Arc
::
new
(
worker2
),
Arc
::
new
(
worker3
)];
// Run multiple selections
let
mut
selected_counts
=
[
0
;
3
];
...
...
@@ -156,12 +156,12 @@ mod tests {
#[test]
fn
test_power_of_two_with_cached_loads
()
{
let
policy
=
PowerOfTwoPolicy
::
new
();
let
workers
:
Vec
<
Box
<
dyn
Worker
>>
=
vec!
[
Box
::
new
(
BasicWorker
::
new
(
let
workers
:
Vec
<
Arc
<
dyn
Worker
>>
=
vec!
[
Arc
::
new
(
BasicWorker
::
new
(
"http://w1:8000"
.to_string
(),
WorkerType
::
Regular
,
)),
Box
::
new
(
BasicWorker
::
new
(
Arc
::
new
(
BasicWorker
::
new
(
"http://w2:8000"
.to_string
(),
WorkerType
::
Regular
,
)),
...
...
@@ -190,7 +190,7 @@ mod tests {
#[test]
fn
test_power_of_two_single_worker
()
{
let
policy
=
PowerOfTwoPolicy
::
new
();
let
workers
:
Vec
<
Box
<
dyn
Worker
>>
=
vec!
[
Box
::
new
(
BasicWorker
::
new
(
let
workers
:
Vec
<
Arc
<
dyn
Worker
>>
=
vec!
[
Arc
::
new
(
BasicWorker
::
new
(
"http://w1:8000"
.to_string
(),
WorkerType
::
Regular
,
))];
...
...
sgl-router/src/policies/random.rs
View file @
2f173ea0
...
...
@@ -4,6 +4,7 @@ use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use
crate
::
core
::
Worker
;
use
crate
::
metrics
::
RouterMetrics
;
use
rand
::
Rng
;
use
std
::
sync
::
Arc
;
/// Random selection policy
///
...
...
@@ -20,7 +21,7 @@ impl RandomPolicy {
impl
LoadBalancingPolicy
for
RandomPolicy
{
fn
select_worker
(
&
self
,
workers
:
&
[
Box
<
dyn
Worker
>
],
workers
:
&
[
Arc
<
dyn
Worker
>
],
_
request_text
:
Option
<&
str
>
,
)
->
Option
<
usize
>
{
let
healthy_indices
=
get_healthy_worker_indices
(
workers
);
...
...
@@ -56,16 +57,16 @@ mod tests {
#[test]
fn
test_random_selection
()
{
let
policy
=
RandomPolicy
::
new
();
let
workers
:
Vec
<
Box
<
dyn
Worker
>>
=
vec!
[
Box
::
new
(
BasicWorker
::
new
(
let
workers
:
Vec
<
Arc
<
dyn
Worker
>>
=
vec!
[
Arc
::
new
(
BasicWorker
::
new
(
"http://w1:8000"
.to_string
(),
WorkerType
::
Regular
,
)),
Box
::
new
(
BasicWorker
::
new
(
Arc
::
new
(
BasicWorker
::
new
(
"http://w2:8000"
.to_string
(),
WorkerType
::
Regular
,
)),
Box
::
new
(
BasicWorker
::
new
(
Arc
::
new
(
BasicWorker
::
new
(
"http://w3:8000"
.to_string
(),
WorkerType
::
Regular
,
)),
...
...
@@ -87,12 +88,12 @@ mod tests {
#[test]
fn
test_random_with_unhealthy_workers
()
{
let
policy
=
RandomPolicy
::
new
();
let
workers
:
Vec
<
Box
<
dyn
Worker
>>
=
vec!
[
Box
::
new
(
BasicWorker
::
new
(
let
workers
:
Vec
<
Arc
<
dyn
Worker
>>
=
vec!
[
Arc
::
new
(
BasicWorker
::
new
(
"http://w1:8000"
.to_string
(),
WorkerType
::
Regular
,
)),
Box
::
new
(
BasicWorker
::
new
(
Arc
::
new
(
BasicWorker
::
new
(
"http://w2:8000"
.to_string
(),
WorkerType
::
Regular
,
)),
...
...
@@ -110,7 +111,7 @@ mod tests {
#[test]
fn
test_random_no_healthy_workers
()
{
let
policy
=
RandomPolicy
::
new
();
let
workers
:
Vec
<
Box
<
dyn
Worker
>>
=
vec!
[
Box
::
new
(
BasicWorker
::
new
(
let
workers
:
Vec
<
Arc
<
dyn
Worker
>>
=
vec!
[
Arc
::
new
(
BasicWorker
::
new
(
"http://w1:8000"
.to_string
(),
WorkerType
::
Regular
,
))];
...
...
sgl-router/src/policies/registry.rs
0 → 100644
View file @
2f173ea0
/// Policy Registry for managing model-to-policy mappings
///
/// This registry manages the dynamic assignment of load balancing policies to models.
/// When the first worker of a new model is added, it determines the policy for that model.
/// All subsequent workers of the same model use the established policy.
/// When the last worker of a model is removed, the policy mapping is cleaned up.
use
super
::{
CacheAwareConfig
,
CacheAwarePolicy
,
LoadBalancingPolicy
,
PowerOfTwoPolicy
,
RandomPolicy
,
RoundRobinPolicy
,
};
use
crate
::
config
::
types
::
PolicyConfig
;
use
std
::
collections
::
HashMap
;
use
std
::
sync
::{
Arc
,
RwLock
};
use
tracing
::{
debug
,
info
,
warn
};
/// Registry for managing model-to-policy mappings
#[derive(Clone)]
pub
struct
PolicyRegistry
{
/// Model ID -> Policy instance mapping
model_policies
:
Arc
<
RwLock
<
HashMap
<
String
,
Arc
<
dyn
LoadBalancingPolicy
>>>>
,
/// Model ID -> Worker count for cleanup tracking
model_worker_counts
:
Arc
<
RwLock
<
HashMap
<
String
,
usize
>>>
,
/// Default policy instance (cached)
default_policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
/// Prefill policy for PD mode
prefill_policy
:
Arc
<
RwLock
<
Option
<
Arc
<
dyn
LoadBalancingPolicy
>>>>
,
/// Decode policy for PD mode
decode_policy
:
Arc
<
RwLock
<
Option
<
Arc
<
dyn
LoadBalancingPolicy
>>>>
,
}
impl
PolicyRegistry
{
/// Create a new PolicyRegistry with a default policy
pub
fn
new
(
default_policy_config
:
PolicyConfig
)
->
Self
{
let
default_policy
=
Self
::
create_policy_from_config
(
&
default_policy_config
);
Self
{
model_policies
:
Arc
::
new
(
RwLock
::
new
(
HashMap
::
new
())),
model_worker_counts
:
Arc
::
new
(
RwLock
::
new
(
HashMap
::
new
())),
default_policy
,
prefill_policy
:
Arc
::
new
(
RwLock
::
new
(
None
)),
decode_policy
:
Arc
::
new
(
RwLock
::
new
(
None
)),
}
}
/// Called when a worker is added
/// Returns the policy that should be used for this worker's model
pub
fn
on_worker_added
(
&
self
,
model_id
:
&
str
,
policy_hint
:
Option
<&
str
>
,
)
->
Arc
<
dyn
LoadBalancingPolicy
>
{
// Increment worker count
{
let
mut
counts
=
self
.model_worker_counts
.write
()
.unwrap
();
*
counts
.entry
(
model_id
.to_string
())
.or_insert
(
0
)
+=
1
;
debug!
(
"Worker added for model {}, count: {}"
,
model_id
,
counts
.get
(
model_id
)
.unwrap
()
);
}
// Check if model already has a policy
{
let
policies
=
self
.model_policies
.read
()
.unwrap
();
if
let
Some
(
existing_policy
)
=
policies
.get
(
model_id
)
{
debug!
(
"Model {} already has policy: {}"
,
model_id
,
existing_policy
.name
()
);
return
Arc
::
clone
(
existing_policy
);
}
}
// New model - determine policy
let
policy
=
self
.determine_policy_for_model
(
model_id
,
policy_hint
);
info!
(
"Assigning policy {} to new model {}"
,
policy
.name
(),
model_id
);
// Store policy for this model
{
let
mut
policies
=
self
.model_policies
.write
()
.unwrap
();
policies
.insert
(
model_id
.to_string
(),
Arc
::
clone
(
&
policy
));
}
policy
}
/// Called when a worker is removed
pub
fn
on_worker_removed
(
&
self
,
model_id
:
&
str
)
{
let
should_cleanup
=
{
let
mut
counts
=
self
.model_worker_counts
.write
()
.unwrap
();
if
let
Some
(
count
)
=
counts
.get_mut
(
model_id
)
{
*
count
=
count
.saturating_sub
(
1
);
debug!
(
"Worker removed for model {}, count: {}"
,
model_id
,
*
count
);
if
*
count
==
0
{
counts
.remove
(
model_id
);
true
}
else
{
false
}
}
else
{
warn!
(
"Attempted to remove worker for model {} with no registered workers"
,
model_id
);
false
}
};
// Clean up policy if this was the last worker
if
should_cleanup
{
let
mut
policies
=
self
.model_policies
.write
()
.unwrap
();
if
let
Some
(
policy
)
=
policies
.remove
(
model_id
)
{
info!
(
"Removed policy {} for model {} (last worker removed)"
,
policy
.name
(),
model_id
);
// Policy will be dropped here, cleaning up any resources
drop
(
policy
);
}
}
}
/// Get the policy for a model
pub
fn
get_policy
(
&
self
,
model_id
:
&
str
)
->
Option
<
Arc
<
dyn
LoadBalancingPolicy
>>
{
self
.model_policies
.read
()
.unwrap
()
.get
(
model_id
)
.cloned
()
}
/// Get the default policy
pub
fn
get_default_policy
(
&
self
)
->
Arc
<
dyn
LoadBalancingPolicy
>
{
Arc
::
clone
(
&
self
.default_policy
)
}
/// Get policy for a model, or default if not found
pub
fn
get_policy_or_default
(
&
self
,
model_id
:
&
str
)
->
Arc
<
dyn
LoadBalancingPolicy
>
{
self
.get_policy
(
model_id
)
.unwrap_or_else
(||
self
.get_default_policy
())
}
/// Determine policy for a new model
fn
determine_policy_for_model
(
&
self
,
model_id
:
&
str
,
policy_hint
:
Option
<&
str
>
,
)
->
Arc
<
dyn
LoadBalancingPolicy
>
{
// 1. Check policy hint from worker
if
let
Some
(
policy_type
)
=
policy_hint
{
debug!
(
"Using policy hint '{}' for model {}"
,
policy_type
,
model_id
);
return
self
.create_policy_from_type
(
policy_type
);
}
// 2. Use default policy
debug!
(
"Using default policy for model {}"
,
model_id
);
Arc
::
clone
(
&
self
.default_policy
)
}
/// Create a policy from a type string
fn
create_policy_from_type
(
&
self
,
policy_type
:
&
str
)
->
Arc
<
dyn
LoadBalancingPolicy
>
{
match
policy_type
{
"round_robin"
=>
Arc
::
new
(
RoundRobinPolicy
::
new
()),
"random"
=>
Arc
::
new
(
RandomPolicy
::
new
()),
"cache_aware"
=>
Arc
::
new
(
CacheAwarePolicy
::
new
()),
"power_of_two"
=>
Arc
::
new
(
PowerOfTwoPolicy
::
new
()),
_
=>
{
warn!
(
"Unknown policy type '{}', using default"
,
policy_type
);
Arc
::
clone
(
&
self
.default_policy
)
}
}
}
/// Create a policy from a PolicyConfig
fn
create_policy_from_config
(
config
:
&
PolicyConfig
)
->
Arc
<
dyn
LoadBalancingPolicy
>
{
match
config
{
PolicyConfig
::
RoundRobin
=>
Arc
::
new
(
RoundRobinPolicy
::
new
()),
PolicyConfig
::
Random
=>
Arc
::
new
(
RandomPolicy
::
new
()),
PolicyConfig
::
CacheAware
{
cache_threshold
,
balance_abs_threshold
,
balance_rel_threshold
,
eviction_interval_secs
,
max_tree_size
,
}
=>
{
let
cache_config
=
CacheAwareConfig
{
cache_threshold
:
*
cache_threshold
,
balance_abs_threshold
:
*
balance_abs_threshold
,
balance_rel_threshold
:
*
balance_rel_threshold
,
eviction_interval_secs
:
*
eviction_interval_secs
,
max_tree_size
:
*
max_tree_size
,
};
Arc
::
new
(
CacheAwarePolicy
::
with_config
(
cache_config
))
}
PolicyConfig
::
PowerOfTwo
{
..
}
=>
Arc
::
new
(
PowerOfTwoPolicy
::
new
()),
}
}
/// Get current model->policy mappings (for debugging/monitoring)
pub
fn
get_all_mappings
(
&
self
)
->
HashMap
<
String
,
String
>
{
let
policies
=
self
.model_policies
.read
()
.unwrap
();
policies
.iter
()
.map
(|(
model
,
policy
)|
(
model
.clone
(),
policy
.name
()
.to_string
()))
.collect
()
}
/// Get worker counts per model
pub
fn
get_worker_counts
(
&
self
)
->
HashMap
<
String
,
usize
>
{
self
.model_worker_counts
.read
()
.unwrap
()
.clone
()
}
/// Clear all policies (useful for testing)
pub
fn
clear
(
&
self
)
{
let
mut
policies
=
self
.model_policies
.write
()
.unwrap
();
policies
.clear
();
let
mut
counts
=
self
.model_worker_counts
.write
()
.unwrap
();
counts
.clear
();
}
/// Set the prefill policy for PD mode
pub
fn
set_prefill_policy
(
&
self
,
policy
:
Arc
<
dyn
LoadBalancingPolicy
>
)
{
let
mut
prefill_policy
=
self
.prefill_policy
.write
()
.unwrap
();
*
prefill_policy
=
Some
(
policy
);
}
/// Set the decode policy for PD mode
pub
fn
set_decode_policy
(
&
self
,
policy
:
Arc
<
dyn
LoadBalancingPolicy
>
)
{
let
mut
decode_policy
=
self
.decode_policy
.write
()
.unwrap
();
*
decode_policy
=
Some
(
policy
);
}
/// Get the prefill policy for PD mode, or default if not set
pub
fn
get_prefill_policy
(
&
self
)
->
Arc
<
dyn
LoadBalancingPolicy
>
{
let
prefill_policy
=
self
.prefill_policy
.read
()
.unwrap
();
prefill_policy
.as_ref
()
.map
(
Arc
::
clone
)
.unwrap_or_else
(||
self
.get_default_policy
())
}
/// Get the decode policy for PD mode, or default if not set
pub
fn
get_decode_policy
(
&
self
)
->
Arc
<
dyn
LoadBalancingPolicy
>
{
let
decode_policy
=
self
.decode_policy
.read
()
.unwrap
();
decode_policy
.as_ref
()
.map
(
Arc
::
clone
)
.unwrap_or_else
(||
self
.get_default_policy
())
}
}
impl
std
::
fmt
::
Debug
for
PolicyRegistry
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
f
.debug_struct
(
"PolicyRegistry"
)
.field
(
"model_policies"
,
&
self
.model_policies
)
.field
(
"model_worker_counts"
,
&
self
.model_worker_counts
)
.field
(
"default_policy"
,
&
self
.default_policy
.name
())
.finish
()
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
#[test]
fn
test_policy_registry_basic
()
{
let
registry
=
PolicyRegistry
::
new
(
PolicyConfig
::
RoundRobin
);
// First worker of a model sets the policy
let
policy1
=
registry
.on_worker_added
(
"llama-3"
,
Some
(
"cache_aware"
));
assert_eq!
(
policy1
.name
(),
"cache_aware"
);
// Second worker of same model uses existing policy
let
policy2
=
registry
.on_worker_added
(
"llama-3"
,
Some
(
"round_robin"
));
assert_eq!
(
policy2
.name
(),
"cache_aware"
);
// Ignores hint, uses existing
// Different model can have different policy
let
policy3
=
registry
.on_worker_added
(
"gpt-4"
,
Some
(
"random"
));
assert_eq!
(
policy3
.name
(),
"random"
);
// Check mappings
let
mappings
=
registry
.get_all_mappings
();
assert_eq!
(
mappings
.get
(
"llama-3"
)
.unwrap
(),
"cache_aware"
);
assert_eq!
(
mappings
.get
(
"gpt-4"
)
.unwrap
(),
"random"
);
// Check worker counts
let
counts
=
registry
.get_worker_counts
();
assert_eq!
(
*
counts
.get
(
"llama-3"
)
.unwrap
(),
2
);
assert_eq!
(
*
counts
.get
(
"gpt-4"
)
.unwrap
(),
1
);
}
#[test]
fn
test_policy_registry_cleanup
()
{
let
registry
=
PolicyRegistry
::
new
(
PolicyConfig
::
RoundRobin
);
// Add workers
registry
.on_worker_added
(
"llama-3"
,
Some
(
"cache_aware"
));
registry
.on_worker_added
(
"llama-3"
,
None
);
assert_eq!
(
registry
.get_worker_counts
()
.get
(
"llama-3"
),
Some
(
&
2
));
// Remove one worker - policy should remain
registry
.on_worker_removed
(
"llama-3"
);
assert
!
(
registry
.get_policy
(
"llama-3"
)
.is_some
());
assert_eq!
(
registry
.get_worker_counts
()
.get
(
"llama-3"
),
Some
(
&
1
));
// Remove last worker - policy should be cleaned up
registry
.on_worker_removed
(
"llama-3"
);
assert
!
(
registry
.get_policy
(
"llama-3"
)
.is_none
());
assert_eq!
(
registry
.get_worker_counts
()
.get
(
"llama-3"
),
None
);
}
#[test]
fn
test_default_policy
()
{
let
registry
=
PolicyRegistry
::
new
(
PolicyConfig
::
RoundRobin
);
// No hint, no template - uses default
let
policy
=
registry
.on_worker_added
(
"unknown-model"
,
None
);
assert_eq!
(
policy
.name
(),
"round_robin"
);
// Get default directly
let
default
=
registry
.get_default_policy
();
assert_eq!
(
default
.name
(),
"round_robin"
);
}
}
sgl-router/src/policies/round_robin.rs
View file @
2f173ea0
...
...
@@ -4,6 +4,7 @@ use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use
crate
::
core
::
Worker
;
use
crate
::
metrics
::
RouterMetrics
;
use
std
::
sync
::
atomic
::{
AtomicUsize
,
Ordering
};
use
std
::
sync
::
Arc
;
/// Round-robin selection policy
///
...
...
@@ -24,7 +25,7 @@ impl RoundRobinPolicy {
impl
LoadBalancingPolicy
for
RoundRobinPolicy
{
fn
select_worker
(
&
self
,
workers
:
&
[
Box
<
dyn
Worker
>
],
workers
:
&
[
Arc
<
dyn
Worker
>
],
_
request_text
:
Option
<&
str
>
,
)
->
Option
<
usize
>
{
let
healthy_indices
=
get_healthy_worker_indices
(
workers
);
...
...
@@ -64,16 +65,16 @@ mod tests {
#[test]
fn
test_round_robin_selection
()
{
let
policy
=
RoundRobinPolicy
::
new
();
let
workers
:
Vec
<
Box
<
dyn
Worker
>>
=
vec!
[
Box
::
new
(
BasicWorker
::
new
(
let
workers
:
Vec
<
Arc
<
dyn
Worker
>>
=
vec!
[
Arc
::
new
(
BasicWorker
::
new
(
"http://w1:8000"
.to_string
(),
WorkerType
::
Regular
,
)),
Box
::
new
(
BasicWorker
::
new
(
Arc
::
new
(
BasicWorker
::
new
(
"http://w2:8000"
.to_string
(),
WorkerType
::
Regular
,
)),
Box
::
new
(
BasicWorker
::
new
(
Arc
::
new
(
BasicWorker
::
new
(
"http://w3:8000"
.to_string
(),
WorkerType
::
Regular
,
)),
...
...
@@ -90,16 +91,16 @@ mod tests {
#[test]
fn
test_round_robin_with_unhealthy_workers
()
{
let
policy
=
RoundRobinPolicy
::
new
();
let
workers
:
Vec
<
Box
<
dyn
Worker
>>
=
vec!
[
Box
::
new
(
BasicWorker
::
new
(
let
workers
:
Vec
<
Arc
<
dyn
Worker
>>
=
vec!
[
Arc
::
new
(
BasicWorker
::
new
(
"http://w1:8000"
.to_string
(),
WorkerType
::
Regular
,
)),
Box
::
new
(
BasicWorker
::
new
(
Arc
::
new
(
BasicWorker
::
new
(
"http://w2:8000"
.to_string
(),
WorkerType
::
Regular
,
)),
Box
::
new
(
BasicWorker
::
new
(
Arc
::
new
(
BasicWorker
::
new
(
"http://w3:8000"
.to_string
(),
WorkerType
::
Regular
,
)),
...
...
@@ -118,12 +119,12 @@ mod tests {
#[test]
fn
test_round_robin_reset
()
{
let
policy
=
RoundRobinPolicy
::
new
();
let
workers
:
Vec
<
Box
<
dyn
Worker
>>
=
vec!
[
Box
::
new
(
BasicWorker
::
new
(
let
workers
:
Vec
<
Arc
<
dyn
Worker
>>
=
vec!
[
Arc
::
new
(
BasicWorker
::
new
(
"http://w1:8000"
.to_string
(),
WorkerType
::
Regular
,
)),
Box
::
new
(
BasicWorker
::
new
(
Arc
::
new
(
BasicWorker
::
new
(
"http://w2:8000"
.to_string
(),
WorkerType
::
Regular
,
)),
...
...
sgl-router/src/protocols/mod.rs
View file @
2f173ea0
...
...
@@ -3,3 +3,4 @@
pub
mod
spec
;
pub
mod
validation
;
pub
mod
worker_spec
;
sgl-router/src/protocols/worker_spec.rs
0 → 100644
View file @
2f173ea0
//! Worker management API specifications
//!
//! Defines the request/response structures for worker management endpoints
use
serde
::{
Deserialize
,
Serialize
};
use
std
::
collections
::
HashMap
;
/// Worker configuration for API requests
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
WorkerConfigRequest
{
/// Worker URL (required)
pub
url
:
String
,
/// Model ID (optional, will query from server if not provided)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
model_id
:
Option
<
String
>
,
/// Worker priority (optional, default: 50, higher = preferred)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
priority
:
Option
<
u32
>
,
/// Worker cost factor (optional, default: 1.0)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
cost
:
Option
<
f32
>
,
/// Worker type (optional: "regular", "prefill", "decode")
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
worker_type
:
Option
<
String
>
,
/// Bootstrap port for prefill workers (optional)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
bootstrap_port
:
Option
<
u16
>
,
// gRPC-specific configuration (optional, ignored in HTTP mode)
/// Tokenizer path for gRPC mode
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
tokenizer_path
:
Option
<
String
>
,
/// Reasoning parser type for gRPC mode
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
reasoning_parser
:
Option
<
String
>
,
/// Tool parser type for gRPC mode
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
tool_parser
:
Option
<
String
>
,
/// Chat template for gRPC mode
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
chat_template
:
Option
<
String
>
,
/// Additional labels (optional)
#[serde(default,
skip_serializing_if
=
"HashMap::is_empty"
)]
pub
labels
:
HashMap
<
String
,
String
>
,
}
/// Worker information for API responses
#[derive(Debug,
Clone,
Serialize)]
pub
struct
WorkerInfo
{
/// Worker unique identifier
pub
id
:
String
,
/// Worker URL
pub
url
:
String
,
/// Model ID this worker serves
pub
model_id
:
String
,
/// Worker priority
pub
priority
:
u32
,
/// Worker cost factor
pub
cost
:
f32
,
/// Worker type
pub
worker_type
:
String
,
/// Whether the worker is healthy
pub
is_healthy
:
bool
,
/// Current load on the worker
pub
load
:
usize
,
/// Connection mode (http or grpc)
pub
connection_mode
:
String
,
// gRPC-specific fields (None for HTTP workers)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
tokenizer_path
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
reasoning_parser
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
tool_parser
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
chat_template
:
Option
<
String
>
,
/// Additional metadata
#[serde(skip_serializing_if
=
"HashMap::is_empty"
)]
pub
metadata
:
HashMap
<
String
,
String
>
,
}
/// Worker list response
#[derive(Debug,
Clone,
Serialize)]
pub
struct
WorkerListResponse
{
/// List of workers
pub
workers
:
Vec
<
WorkerInfo
>
,
/// Total count
pub
total
:
usize
,
/// Statistics
pub
stats
:
WorkerStats
,
}
/// Worker statistics
#[derive(Debug,
Clone,
Serialize)]
pub
struct
WorkerStats
{
pub
total_workers
:
usize
,
pub
healthy_workers
:
usize
,
pub
total_models
:
usize
,
pub
total_load
:
usize
,
pub
by_type
:
WorkerTypeStats
,
}
/// Worker statistics by type
#[derive(Debug,
Clone,
Serialize)]
pub
struct
WorkerTypeStats
{
pub
regular
:
usize
,
pub
prefill
:
usize
,
pub
decode
:
usize
,
}
/// Worker update request
#[derive(Debug,
Clone,
Deserialize)]
pub
struct
WorkerUpdateRequest
{
/// Update priority
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
priority
:
Option
<
u32
>
,
/// Update cost
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
cost
:
Option
<
f32
>
,
/// Update labels
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
labels
:
Option
<
HashMap
<
String
,
String
>>
,
}
/// Generic API response
#[derive(Debug,
Clone,
Serialize)]
pub
struct
WorkerApiResponse
{
pub
success
:
bool
,
pub
message
:
String
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
worker
:
Option
<
WorkerInfo
>
,
}
/// Error response
#[derive(Debug,
Clone,
Serialize)]
pub
struct
WorkerErrorResponse
{
pub
error
:
String
,
pub
code
:
String
,
}
/// Server info response from /get_server_info endpoint
#[derive(Debug,
Clone,
Deserialize)]
pub
struct
ServerInfo
{
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
model_id
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
model_path
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
priority
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
cost
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
worker_type
:
Option
<
String
>
,
// gRPC-specific
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
tokenizer_path
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
reasoning_parser
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
tool_parser
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
chat_template
:
Option
<
String
>
,
}
sgl-router/src/routers/factory.rs
View file @
2f173ea0
...
...
@@ -15,11 +15,6 @@ pub struct RouterFactory;
impl
RouterFactory
{
/// Create a router instance from application context
pub
async
fn
create_router
(
ctx
:
&
Arc
<
AppContext
>
)
->
Result
<
Box
<
dyn
RouterTrait
>
,
String
>
{
// Check if IGW mode is enabled
if
ctx
.router_config.enable_igw
{
return
Self
::
create_igw_router
(
ctx
)
.await
;
}
// Check connection mode and route to appropriate implementation
match
ctx
.router_config.connection_mode
{
ConnectionMode
::
Grpc
=>
{
...
...
@@ -53,8 +48,7 @@ impl RouterFactory {
// Route to HTTP implementation based on routing mode
match
&
ctx
.router_config.mode
{
RoutingMode
::
Regular
{
worker_urls
}
=>
{
Self
::
create_regular_router
(
worker_urls
,
&
ctx
.router_config.policy
,
ctx
)
.await
Self
::
create_regular_router
(
worker_urls
,
ctx
)
.await
}
RoutingMode
::
PrefillDecode
{
prefill_urls
,
...
...
@@ -80,23 +74,19 @@ impl RouterFactory {
}
}
/// Create a regular router
with injected policy
async
fn
create_regular_router
(
/// Create a regular router
pub
async
fn
create_regular_router
(
worker_urls
:
&
[
String
],
policy_config
:
&
PolicyConfig
,
ctx
:
&
Arc
<
AppContext
>
,
)
->
Result
<
Box
<
dyn
RouterTrait
>
,
String
>
{
// Create policy
let
policy
=
PolicyFactory
::
create_from_config
(
policy_config
);
// Create regular router with injected policy and context
let
router
=
Router
::
new
(
worker_urls
.to_vec
(),
policy
,
ctx
)
.await
?
;
// Create regular router with context
let
router
=
Router
::
new
(
worker_urls
.to_vec
(),
ctx
)
.await
?
;
Ok
(
Box
::
new
(
router
))
}
/// Create a PD router with injected policy
async
fn
create_pd_router
(
pub
async
fn
create_pd_router
(
prefill_urls
:
&
[(
String
,
Option
<
u16
>
)],
decode_urls
:
&
[
String
],
prefill_policy_config
:
Option
<&
PolicyConfig
>
,
...
...
@@ -104,21 +94,18 @@ impl RouterFactory {
main_policy_config
:
&
PolicyConfig
,
ctx
:
&
Arc
<
AppContext
>
,
)
->
Result
<
Box
<
dyn
RouterTrait
>
,
String
>
{
//
Creat
e policies - use specific policies if provided, otherwise fall back to main policy
//
Initializ
e policies
in PolicyRegistry
- use specific policies if provided, otherwise fall back to main policy
let
prefill_policy
=
PolicyFactory
::
create_from_config
(
prefill_policy_config
.unwrap_or
(
main_policy_config
));
let
decode_policy
=
PolicyFactory
::
create_from_config
(
decode_policy_config
.unwrap_or
(
main_policy_config
));
// Create PD router with separate policies and context
let
router
=
PDRouter
::
new
(
prefill_urls
.to_vec
(),
decode_urls
.to_vec
(),
prefill_policy
,
decode_policy
,
ctx
,
)
.await
?
;
// Set the prefill and decode policies in the registry
ctx
.policy_registry
.set_prefill_policy
(
prefill_policy
);
ctx
.policy_registry
.set_decode_policy
(
decode_policy
);
// Create PD router with context (policies are in PolicyRegistry)
let
router
=
PDRouter
::
new
(
prefill_urls
.to_vec
(),
decode_urls
.to_vec
(),
ctx
)
.await
?
;
Ok
(
Box
::
new
(
router
))
}
...
...
@@ -186,10 +173,4 @@ impl RouterFactory {
Ok
(
Box
::
new
(
router
))
}
/// Create an IGW router (placeholder for future implementation)
async
fn
create_igw_router
(
_
ctx
:
&
Arc
<
AppContext
>
)
->
Result
<
Box
<
dyn
RouterTrait
>
,
String
>
{
// For now, return an error indicating IGW is not yet implemented
Err
(
"IGW mode is not yet implemented"
.to_string
())
}
}
sgl-router/src/routers/grpc/pd_router.rs
View file @
2f173ea0
...
...
@@ -27,9 +27,9 @@ use tracing::{info, warn};
#[allow(dead_code)]
// Fields will be used once implementation is complete
pub
struct
GrpcPDRouter
{
/// Prefill worker connections
prefill_workers
:
Arc
<
RwLock
<
Vec
<
Box
<
dyn
Worker
>>>>
,
prefill_workers
:
Arc
<
RwLock
<
Vec
<
Arc
<
dyn
Worker
>>>>
,
/// Decode worker connections
decode_workers
:
Arc
<
RwLock
<
Vec
<
Box
<
dyn
Worker
>>>>
,
decode_workers
:
Arc
<
RwLock
<
Vec
<
Arc
<
dyn
Worker
>>>>
,
/// gRPC clients for prefill workers
prefill_grpc_clients
:
Arc
<
RwLock
<
HashMap
<
String
,
SglangSchedulerClient
>>>
,
/// gRPC clients for decode workers
...
...
@@ -127,7 +127,7 @@ impl GrpcPDRouter {
}
// Create Prefill Worker trait objects with gRPC connection mode
let
prefill_workers
:
Vec
<
Box
<
dyn
Worker
>>
=
prefill_urls
let
prefill_workers
:
Vec
<
Arc
<
dyn
Worker
>>
=
prefill_urls
.iter
()
.map
(|(
url
,
bootstrap_port
)|
{
let
worker
=
BasicWorker
::
with_connection_mode
(
...
...
@@ -147,12 +147,12 @@ impl GrpcPDRouter {
failure_threshold
:
ctx
.router_config.health_check.failure_threshold
,
success_threshold
:
ctx
.router_config.health_check.success_threshold
,
});
Box
::
new
(
worker
)
as
Box
<
dyn
Worker
>
Arc
::
new
(
worker
)
as
Arc
<
dyn
Worker
>
})
.collect
();
// Create Decode Worker trait objects with gRPC connection mode
let
decode_workers
:
Vec
<
Box
<
dyn
Worker
>>
=
decode_urls
let
decode_workers
:
Vec
<
Arc
<
dyn
Worker
>>
=
decode_urls
.iter
()
.map
(|
url
|
{
let
worker
=
BasicWorker
::
with_connection_mode
(
...
...
@@ -168,7 +168,7 @@ impl GrpcPDRouter {
failure_threshold
:
ctx
.router_config.health_check.failure_threshold
,
success_threshold
:
ctx
.router_config.health_check.success_threshold
,
});
Box
::
new
(
worker
)
as
Box
<
dyn
Worker
>
Arc
::
new
(
worker
)
as
Arc
<
dyn
Worker
>
})
.collect
();
...
...
@@ -269,6 +269,7 @@ impl RouterTrait for GrpcPDRouter {
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
&
crate
::
protocols
::
spec
::
GenerateRequest
,
_
model_id
:
Option
<&
str
>
,
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
}
...
...
@@ -277,6 +278,7 @@ impl RouterTrait for GrpcPDRouter {
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
&
crate
::
protocols
::
spec
::
ChatCompletionRequest
,
_
model_id
:
Option
<&
str
>
,
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
}
...
...
@@ -285,6 +287,7 @@ impl RouterTrait for GrpcPDRouter {
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
&
crate
::
protocols
::
spec
::
CompletionRequest
,
_
model_id
:
Option
<&
str
>
,
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
}
...
...
@@ -293,6 +296,7 @@ impl RouterTrait for GrpcPDRouter {
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
&
crate
::
protocols
::
spec
::
ResponsesRequest
,
_
model_id
:
Option
<&
str
>
,
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
}
...
...
@@ -305,6 +309,7 @@ impl RouterTrait for GrpcPDRouter {
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
&
crate
::
protocols
::
spec
::
RerankRequest
,
_
model_id
:
Option
<&
str
>
,
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
}
...
...
sgl-router/src/routers/grpc/router.rs
View file @
2f173ea0
...
...
@@ -27,7 +27,7 @@ use tracing::{info, warn};
#[allow(dead_code)]
// Fields will be used once implementation is complete
pub
struct
GrpcRouter
{
/// Worker connections
workers
:
Arc
<
RwLock
<
Vec
<
Box
<
dyn
Worker
>>>>
,
workers
:
Arc
<
RwLock
<
Vec
<
Arc
<
dyn
Worker
>>>>
,
/// gRPC clients for each worker
grpc_clients
:
Arc
<
RwLock
<
HashMap
<
String
,
SglangSchedulerClient
>>>
,
/// Load balancing policy
...
...
@@ -103,7 +103,7 @@ impl GrpcRouter {
}
// Create Worker trait objects with gRPC connection mode
let
mut
workers
:
Vec
<
Box
<
dyn
Worker
>>
=
Vec
::
new
();
let
mut
workers
:
Vec
<
Arc
<
dyn
Worker
>>
=
Vec
::
new
();
// Move clients from the HashMap to the workers
for
url
in
&
worker_urls
{
...
...
@@ -123,7 +123,7 @@ impl GrpcRouter {
})
.with_grpc_client
(
client
);
workers
.push
(
Box
::
new
(
worker
)
as
Box
<
dyn
Worker
>
);
workers
.push
(
Arc
::
new
(
worker
)
as
Arc
<
dyn
Worker
>
);
}
else
{
warn!
(
"No gRPC client for worker {}, skipping"
,
url
);
}
...
...
@@ -202,6 +202,7 @@ impl RouterTrait for GrpcRouter {
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
&
crate
::
protocols
::
spec
::
GenerateRequest
,
_
model_id
:
Option
<&
str
>
,
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
}
...
...
@@ -210,6 +211,7 @@ impl RouterTrait for GrpcRouter {
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
&
crate
::
protocols
::
spec
::
ChatCompletionRequest
,
_
model_id
:
Option
<&
str
>
,
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
}
...
...
@@ -218,6 +220,7 @@ impl RouterTrait for GrpcRouter {
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
&
crate
::
protocols
::
spec
::
CompletionRequest
,
_
model_id
:
Option
<&
str
>
,
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
}
...
...
@@ -226,6 +229,7 @@ impl RouterTrait for GrpcRouter {
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
&
crate
::
protocols
::
spec
::
ResponsesRequest
,
_
model_id
:
Option
<&
str
>
,
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
}
...
...
@@ -238,6 +242,7 @@ impl RouterTrait for GrpcRouter {
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
&
crate
::
protocols
::
spec
::
RerankRequest
,
_
model_id
:
Option
<&
str
>
,
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
)
.into_response
()
}
...
...
sgl-router/src/routers/http/openai_router.rs
View file @
2f173ea0
...
...
@@ -186,6 +186,7 @@ impl super::super::RouterTrait for OpenAIRouter {
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
&
GenerateRequest
,
_
model_id
:
Option
<&
str
>
,
)
->
Response
{
// Generate endpoint is SGLang-specific, not supported for OpenAI backend
(
...
...
@@ -199,6 +200,7 @@ impl super::super::RouterTrait for OpenAIRouter {
&
self
,
headers
:
Option
<&
HeaderMap
>
,
body
:
&
ChatCompletionRequest
,
_
model_id
:
Option
<&
str
>
,
)
->
Response
{
if
!
self
.circuit_breaker
.can_execute
()
{
return
(
StatusCode
::
SERVICE_UNAVAILABLE
,
"Circuit breaker open"
)
.into_response
();
...
...
@@ -326,6 +328,7 @@ impl super::super::RouterTrait for OpenAIRouter {
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
&
CompletionRequest
,
_
model_id
:
Option
<&
str
>
,
)
->
Response
{
// Completion endpoint not implemented for OpenAI backend
(
...
...
@@ -339,6 +342,7 @@ impl super::super::RouterTrait for OpenAIRouter {
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
&
crate
::
protocols
::
spec
::
ResponsesRequest
,
_
model_id
:
Option
<&
str
>
,
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
,
...
...
@@ -383,7 +387,12 @@ impl super::super::RouterTrait for OpenAIRouter {
.into_response
()
}
async
fn
route_rerank
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
&
RerankRequest
)
->
Response
{
async
fn
route_rerank
(
&
self
,
_
headers
:
Option
<&
HeaderMap
>
,
_
body
:
&
RerankRequest
,
_
model_id
:
Option
<&
str
>
,
)
->
Response
{
(
StatusCode
::
NOT_IMPLEMENTED
,
"Rerank endpoint not implemented for OpenAI backend"
,
...
...
sgl-router/src/routers/http/pd_router.rs
View file @
2f173ea0
This diff is collapsed.
Click to expand it.
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment