Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7b7e5615
Unverified
Commit
7b7e5615
authored
Aug 08, 2025
by
Simo Lin
Committed by
GitHub
Aug 08, 2025
Browse files
[router] fix radix tree integration issues in PD router (#8982)
parent
1a8706c8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
76 additions
and
165 deletions
+76
-165
sgl-router/src/policies/cache_aware.rs
sgl-router/src/policies/cache_aware.rs
+32
-13
sgl-router/src/routers/pd_router.rs
sgl-router/src/routers/pd_router.rs
+44
-152
No files found.
sgl-router/src/policies/cache_aware.rs
View file @
7b7e5615
...
...
@@ -112,7 +112,7 @@ impl CacheAwarePolicy {
}
}
/// Initialize the tree with worker URLs
/// Initialize the tree with worker URLs
(used only during initial setup)
pub
fn
init_workers
(
&
self
,
workers
:
&
[
Box
<
dyn
Worker
>
])
{
if
let
Ok
(
tree
)
=
self
.tree
.lock
()
{
for
worker
in
workers
{
...
...
@@ -121,6 +121,13 @@ impl CacheAwarePolicy {
}
}
/// Add a single worker to the tree (incremental update)
pub
fn
add_worker
(
&
self
,
url
:
&
str
)
{
if
let
Ok
(
tree
)
=
self
.tree
.lock
()
{
tree
.insert
(
""
,
url
);
}
}
/// Remove a worker from the tree
pub
fn
remove_worker
(
&
self
,
url
:
&
str
)
{
if
let
Ok
(
tree
)
=
self
.tree
.lock
()
{
...
...
@@ -178,6 +185,13 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
.min_by_key
(|
&&
idx
|
workers
[
idx
]
.load
())
.copied
()
?
;
// Even in imbalanced mode, update the tree to maintain cache state
if
let
Some
(
text
)
=
request_text
{
if
let
Ok
(
tree
)
=
self
.tree
.lock
()
{
tree
.insert
(
text
,
workers
[
min_load_idx
]
.url
());
}
}
// Increment processed counter
workers
[
min_load_idx
]
.increment_processed
();
RouterMetrics
::
record_processed_request
(
workers
[
min_load_idx
]
.url
());
...
...
@@ -206,21 +220,26 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
};
// Find the index of the selected worker
let
selected_idx
=
workers
.iter
()
.position
(|
w
|
w
.url
()
==
selected_url
)
?
;
if
let
Some
(
selected_idx
)
=
workers
.iter
()
.position
(|
w
|
w
.url
()
==
selected_url
)
{
// Only proceed if the worker is healthy
if
workers
[
selected_idx
]
.is_healthy
()
{
// Update the tree with this request
tree
.insert
(
text
,
&
selected_url
);
// Only proceed if the worker is healthy
if
!
workers
[
selected_idx
]
.is_healthy
()
{
return
healthy_indices
.first
()
.copied
();
}
// Update the tree with this request
tree
.insert
(
text
,
&
selected_url
);
// Increment processed counter
workers
[
selected_idx
]
.increment_processed
();
RouterMetrics
::
record_processed_request
(
&
selected_url
);
// Increment processed counter
workers
[
selected_idx
]
.increment_processed
();
RouterMetrics
::
record_processed_request
(
&
selected_url
);
return
Some
(
selected_idx
);
}
}
else
{
// Selected worker no longer exists, remove it from tree
tree
.remove_tenant
(
&
selected_url
);
debug!
(
"Removed stale worker {} from cache tree"
,
selected_url
);
}
return
Some
(
selected_idx
);
// Fallback to first healthy worker
return
healthy_indices
.first
()
.copied
();
}
// Fallback to first healthy worker if tree operations fail
...
...
sgl-router/src/routers/pd_router.rs
View file @
7b7e5615
...
...
@@ -7,7 +7,6 @@ use crate::metrics::RouterMetrics;
use
crate
::
openai_api_types
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
};
use
crate
::
policies
::
LoadBalancingPolicy
;
use
crate
::
routers
::{
RouterTrait
,
WorkerManagement
};
use
crate
::
tree
::
Tree
;
use
async_trait
::
async_trait
;
use
axum
::{
body
::
Body
,
...
...
@@ -20,7 +19,7 @@ use futures_util::StreamExt;
use
reqwest
::
Client
;
use
serde_json
::
Value
;
use
std
::
collections
::
HashMap
;
use
std
::
sync
::{
Arc
,
Mutex
,
RwLock
};
use
std
::
sync
::{
Arc
,
RwLock
};
use
std
::
time
::{
Duration
,
Instant
};
use
tokio_stream
::
wrappers
::
UnboundedReceiverStream
;
use
tracing
::{
debug
,
error
,
info
,
warn
};
...
...
@@ -31,8 +30,6 @@ pub struct PDRouter {
pub
decode_workers
:
Arc
<
RwLock
<
Vec
<
Box
<
dyn
Worker
>>>>
,
pub
prefill_policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
pub
decode_policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
pub
prefill_tree
:
Option
<
Arc
<
Mutex
<
Tree
>>>
,
pub
decode_tree
:
Option
<
Arc
<
Mutex
<
Tree
>>>
,
pub
timeout_secs
:
u64
,
pub
interval_secs
:
u64
,
pub
worker_loads
:
Arc
<
tokio
::
sync
::
watch
::
Receiver
<
HashMap
<
String
,
isize
>>>
,
...
...
@@ -91,9 +88,14 @@ impl PDRouter {
workers
.push
(
worker
);
// Add to cache tree if using cache-aware policy for prefill
if
let
Some
(
ref
tree
)
=
self
.prefill_tree
{
tree
.lock
()
.unwrap
()
.insert
(
""
,
&
url
);
// Update cache-aware policy if applicable
drop
(
workers
);
// Release write lock
if
let
Some
(
cache_policy
)
=
self
.prefill_policy
.as_any
()
.downcast_ref
::
<
crate
::
policies
::
CacheAwarePolicy
>
()
{
cache_policy
.add_worker
(
&
url
);
}
info!
(
"Added prefill server: {}"
,
url
);
...
...
@@ -125,9 +127,14 @@ impl PDRouter {
workers
.push
(
worker
);
// Add to cache tree if using cache-aware policy for decode
if
let
Some
(
ref
tree
)
=
self
.decode_tree
{
tree
.lock
()
.unwrap
()
.insert
(
""
,
&
url
);
// Update cache-aware policy if applicable
drop
(
workers
);
// Release write lock
if
let
Some
(
cache_policy
)
=
self
.decode_policy
.as_any
()
.downcast_ref
::
<
crate
::
policies
::
CacheAwarePolicy
>
()
{
cache_policy
.add_worker
(
&
url
);
}
info!
(
"Added decode server: {}"
,
url
);
...
...
@@ -152,9 +159,13 @@ impl PDRouter {
});
}
// Remove from cache tree if using cache-aware policy
if
let
Some
(
ref
tree
)
=
self
.prefill_tree
{
tree
.lock
()
.unwrap
()
.remove_tenant
(
url
);
// Remove from cache-aware policy if applicable
if
let
Some
(
cache_policy
)
=
self
.prefill_policy
.as_any
()
.downcast_ref
::
<
crate
::
policies
::
CacheAwarePolicy
>
()
{
cache_policy
.remove_worker
(
url
);
}
info!
(
"Removed prefill server: {}"
,
url
);
...
...
@@ -179,9 +190,13 @@ impl PDRouter {
});
}
// Remove from the cache tree if using cache-aware policy for decode
if
let
Some
(
ref
tree
)
=
self
.decode_tree
{
tree
.lock
()
.unwrap
()
.remove_tenant
(
url
);
// Remove from cache-aware policy if applicable
if
let
Some
(
cache_policy
)
=
self
.decode_policy
.as_any
()
.downcast_ref
::
<
crate
::
policies
::
CacheAwarePolicy
>
()
{
cache_policy
.remove_worker
(
url
);
}
info!
(
"Removed decode server: {}"
,
url
);
...
...
@@ -238,11 +253,20 @@ impl PDRouter {
)
?
;
}
// Initialize cache-aware components if needed for prefill policy
let
prefill_tree
=
Self
::
initialize_radix_tree
(
&
prefill_policy
,
&
prefill_workers
)
?
;
// Initialize cache-aware policies with workers
if
let
Some
(
cache_policy
)
=
prefill_policy
.as_any
()
.downcast_ref
::
<
crate
::
policies
::
CacheAwarePolicy
>
()
{
cache_policy
.init_workers
(
&
prefill_workers
);
}
// Initialize cache-aware components if needed for decode policy
let
decode_tree
=
Self
::
initialize_radix_tree
(
&
decode_policy
,
&
decode_workers
)
?
;
if
let
Some
(
cache_policy
)
=
decode_policy
.as_any
()
.downcast_ref
::
<
crate
::
policies
::
CacheAwarePolicy
>
()
{
cache_policy
.init_workers
(
&
decode_workers
);
}
// Set up background load monitoring for power-of-two selection
let
(
tx
,
rx
)
=
tokio
::
sync
::
watch
::
channel
(
HashMap
::
new
());
...
...
@@ -294,8 +318,6 @@ impl PDRouter {
decode_workers
,
prefill_policy
,
decode_policy
,
prefill_tree
,
decode_tree
,
timeout_secs
,
interval_secs
,
worker_loads
,
...
...
@@ -309,35 +331,6 @@ impl PDRouter {
})
}
// Helper function to initialize radix tree for cache-aware policies
fn
initialize_radix_tree
(
policy
:
&
Arc
<
dyn
LoadBalancingPolicy
>
,
workers
:
&
[
Box
<
dyn
Worker
>
],
)
->
Result
<
Option
<
Arc
<
Mutex
<
Tree
>>>
,
String
>
{
if
let
Some
(
cache_policy
)
=
policy
.as_any
()
.downcast_ref
::
<
crate
::
policies
::
CacheAwarePolicy
>
()
{
// Initialize the policy's internal tree with workers
cache_policy
.init_workers
(
workers
);
let
tree
=
Arc
::
new
(
Mutex
::
new
(
Tree
::
new
()));
{
let
tree_guard
=
tree
.lock
()
.map_err
(|
e
|
format!
(
"Failed to lock tree: {}"
,
e
))
?
;
for
worker
in
workers
{
tree_guard
.insert
(
""
,
worker
.url
());
}
}
Ok
(
Some
(
tree
))
}
else
{
Ok
(
None
)
}
}
// Helper to handle server selection errors
fn
handle_server_selection_error
(
error
:
String
)
->
Response
{
error!
(
"Failed to select PD pair error={}"
,
error
);
...
...
@@ -1863,8 +1856,6 @@ mod tests {
decode_workers
:
Arc
::
new
(
RwLock
::
new
(
vec!
[])),
prefill_policy
,
decode_policy
,
prefill_tree
:
None
,
decode_tree
:
None
,
timeout_secs
:
5
,
interval_secs
:
1
,
worker_loads
:
Arc
::
new
(
tokio
::
sync
::
watch
::
channel
(
HashMap
::
new
())
.1
),
...
...
@@ -2002,105 +1993,6 @@ mod tests {
}
}
// ============= Cache Tree Integration Tests =============
#[tokio::test]
async
fn
test_cache_tree_operations
()
{
let
cache_policy
=
Arc
::
new
(
CacheAwarePolicy
::
new
());
let
mut
router
=
create_test_pd_router
();
router
.prefill_policy
=
cache_policy
;
// Initialize cache tree
let
tree
=
Arc
::
new
(
Mutex
::
new
(
Tree
::
new
()));
router
.prefill_tree
=
Some
(
Arc
::
clone
(
&
tree
));
// Manually add worker and update tree
let
worker
=
create_test_worker
(
"http://worker1"
.to_string
(),
WorkerType
::
Prefill
{
bootstrap_port
:
None
,
},
true
,
);
router
.prefill_workers
.write
()
.unwrap
()
.push
(
worker
);
// Update tree
tree
.lock
()
.unwrap
()
.insert
(
""
,
"http://worker1"
);
// Verify tree contains the worker
let
tree_guard
=
tree
.lock
()
.unwrap
();
let
(
_
matched_text
,
tenant
)
=
tree_guard
.prefix_match
(
""
);
// Since we inserted with empty prefix, we should get a match
assert_eq!
(
tenant
,
"http://worker1"
);
}
#[tokio::test]
async
fn
test_cache_tree_rebuild_on_remove
()
{
let
cache_policy
=
Arc
::
new
(
CacheAwarePolicy
::
new
());
let
mut
router
=
create_test_pd_router
();
router
.prefill_policy
=
cache_policy
;
// Initialize cache tree
let
tree
=
Arc
::
new
(
Mutex
::
new
(
Tree
::
new
()));
router
.prefill_tree
=
Some
(
Arc
::
clone
(
&
tree
));
// Add multiple workers
let
worker1
=
create_test_worker
(
"http://worker1"
.to_string
(),
WorkerType
::
Prefill
{
bootstrap_port
:
None
,
},
true
,
);
let
worker2
=
create_test_worker
(
"http://worker2"
.to_string
(),
WorkerType
::
Prefill
{
bootstrap_port
:
None
,
},
true
,
);
router
.prefill_workers
.write
()
.unwrap
()
.push
(
worker1
);
router
.prefill_workers
.write
()
.unwrap
()
.push
(
worker2
);
// Initialize tree with both workers
{
let
tree_guard
=
tree
.lock
()
.unwrap
();
tree_guard
.insert
(
""
,
"http://worker1"
);
tree_guard
.insert
(
""
,
"http://worker2"
);
}
// Remove one worker
let
result
=
router
.remove_prefill_server
(
"http://worker1"
)
.await
;
assert
!
(
result
.is_ok
());
// Verify tree only contains remaining worker
let
tree_guard
=
tree
.lock
()
.unwrap
();
let
(
_
matched_text
,
tenant
)
=
tree_guard
.prefix_match
(
""
);
// After rebuild, tree should only have worker2
assert_eq!
(
tenant
,
"http://worker2"
);
}
#[tokio::test]
async
fn
test_no_cache_tree_operations
()
{
let
router
=
create_test_pd_router
();
assert
!
(
router
.prefill_tree
.is_none
());
// Add a worker without cache tree
let
worker
=
create_test_worker
(
"http://worker1"
.to_string
(),
WorkerType
::
Prefill
{
bootstrap_port
:
None
,
},
true
,
);
router
.prefill_workers
.write
()
.unwrap
()
.push
(
worker
);
// Remove should work without tree
let
result
=
router
.remove_prefill_server
(
"http://worker1"
)
.await
;
assert
!
(
result
.is_ok
());
}
// ============= Bootstrap Injection Tests =============
// Note: These tests are commented out as we've moved to the optimized bootstrap injection
// approach that doesn't use the Bootstrap trait on GenerateReqInput anymore.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment