Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0f8cee8c
Unverified
Commit
0f8cee8c
authored
Aug 21, 2025
by
Simo Lin
Committed by
GitHub
Aug 21, 2025
Browse files
[router] fix router load guard tracking for streaming (#9491)
parent
816c4c85
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
124 additions
and
4 deletions
+124
-4
sgl-router/src/routers/pd_router.rs
sgl-router/src/routers/pd_router.rs
+124
-4
No files found.
sgl-router/src/routers/pd_router.rs
View file @
0f8cee8c
...
@@ -821,8 +821,13 @@ impl PDRouter {
...
@@ -821,8 +821,13 @@ impl PDRouter {
decode
:
&
dyn
Worker
,
decode
:
&
dyn
Worker
,
start_time
:
Instant
,
start_time
:
Instant
,
)
->
Response
{
)
->
Response
{
// Update load tracking for both workers
// For non-streaming: use guard for automatic load management
let
_
guard
=
WorkerLoadGuard
::
new_multi
(
vec!
[
prefill
,
decode
]);
// For streaming: load will be managed in create_streaming_response
let
_
guard
=
if
!
context
.is_stream
{
Some
(
WorkerLoadGuard
::
new_multi
(
vec!
[
prefill
,
decode
]))
}
else
{
None
};
// Build decode request with shared client
// Build decode request with shared client
let
decode_request
=
self
.build_post_with_headers
(
let
decode_request
=
self
.build_post_with_headers
(
...
@@ -916,13 +921,15 @@ impl PDRouter {
...
@@ -916,13 +921,15 @@ impl PDRouter {
let
response_headers
=
let
response_headers
=
header_utils
::
preserve_response_headers
(
res
.headers
());
header_utils
::
preserve_response_headers
(
res
.headers
());
S
elf
::
create_streaming_response
(
s
elf
.
create_streaming_response
(
res
.bytes_stream
(),
res
.bytes_stream
(),
status
,
status
,
prefill_logprobs
,
prefill_logprobs
,
context
.return_logprob
,
context
.return_logprob
,
None
,
None
,
Some
(
response_headers
),
Some
(
response_headers
),
prefill
,
decode
,
)
)
}
else
{
}
else
{
// Non-streaming response with logprobs
// Non-streaming response with logprobs
...
@@ -1043,13 +1050,15 @@ impl PDRouter {
...
@@ -1043,13 +1050,15 @@ impl PDRouter {
let
response_headers
=
let
response_headers
=
header_utils
::
preserve_response_headers
(
res
.headers
());
header_utils
::
preserve_response_headers
(
res
.headers
());
S
elf
::
create_streaming_response
(
s
elf
.
create_streaming_response
(
res
.bytes_stream
(),
res
.bytes_stream
(),
status
,
status
,
None
,
None
,
false
,
false
,
Some
(
decode_url
),
Some
(
decode_url
),
Some
(
response_headers
),
Some
(
response_headers
),
prefill
,
decode
,
)
)
}
else
{
}
else
{
// Non-streaming response without logprobs - direct passthrough like fast version
// Non-streaming response without logprobs - direct passthrough like fast version
...
@@ -1210,16 +1219,32 @@ impl PDRouter {
...
@@ -1210,16 +1219,32 @@ impl PDRouter {
}
}
// Helper to create a streaming response
// Helper to create a streaming response
#[allow(clippy::too_many_arguments)]
fn
create_streaming_response
(
fn
create_streaming_response
(
&
self
,
stream
:
impl
futures_util
::
Stream
<
Item
=
Result
<
bytes
::
Bytes
,
reqwest
::
Error
>>
+
Send
+
'static
,
stream
:
impl
futures_util
::
Stream
<
Item
=
Result
<
bytes
::
Bytes
,
reqwest
::
Error
>>
+
Send
+
'static
,
status
:
StatusCode
,
status
:
StatusCode
,
prefill_logprobs
:
Option
<
Value
>
,
prefill_logprobs
:
Option
<
Value
>
,
return_logprob
:
bool
,
return_logprob
:
bool
,
decode_url
:
Option
<
String
>
,
decode_url
:
Option
<
String
>
,
headers
:
Option
<
HeaderMap
>
,
headers
:
Option
<
HeaderMap
>
,
prefill
:
&
dyn
Worker
,
decode
:
&
dyn
Worker
,
)
->
Response
{
)
->
Response
{
// For streaming, increment load now - will be decremented when streaming completes
prefill
.increment_load
();
decode
.increment_load
();
// Store URLs to find workers later for decrementing
let
prefill_url
=
prefill
.url
()
.to_string
();
let
decode_url_str
=
decode
.url
()
.to_string
();
let
(
tx
,
rx
)
=
tokio
::
sync
::
mpsc
::
unbounded_channel
();
let
(
tx
,
rx
)
=
tokio
::
sync
::
mpsc
::
unbounded_channel
();
// Clone the worker collections for the spawned task
let
prefill_workers
=
self
.prefill_workers
.clone
();
let
decode_workers
=
self
.decode_workers
.clone
();
tokio
::
spawn
(
async
move
{
tokio
::
spawn
(
async
move
{
futures_util
::
pin_mut!
(
stream
);
futures_util
::
pin_mut!
(
stream
);
while
let
Some
(
chunk_result
)
=
stream
.next
()
.await
{
while
let
Some
(
chunk_result
)
=
stream
.next
()
.await
{
...
@@ -1247,6 +1272,25 @@ impl PDRouter {
...
@@ -1247,6 +1272,25 @@ impl PDRouter {
}
}
}
}
}
}
// Decrement load after streaming is complete
if
let
Ok
(
prefill_workers_guard
)
=
prefill_workers
.read
()
{
for
worker
in
prefill_workers_guard
.iter
()
{
if
worker
.url
()
==
prefill_url
.as_str
()
{
worker
.decrement_load
();
break
;
}
}
}
if
let
Ok
(
decode_workers_guard
)
=
decode_workers
.read
()
{
for
worker
in
decode_workers_guard
.iter
()
{
if
worker
.url
()
==
decode_url_str
.as_str
()
{
worker
.decrement_load
();
break
;
}
}
}
});
});
let
stream
=
UnboundedReceiverStream
::
new
(
rx
);
let
stream
=
UnboundedReceiverStream
::
new
(
rx
);
...
@@ -2279,6 +2323,82 @@ mod tests {
...
@@ -2279,6 +2323,82 @@ mod tests {
assert_eq!
(
decode_worker
.load
(),
0
);
assert_eq!
(
decode_worker
.load
(),
0
);
}
}
#[tokio::test]
async
fn
test_streaming_load_tracking
()
{
use
futures_util
::
StreamExt
;
use
tokio
::
time
::{
sleep
,
Duration
};
let
router
=
create_test_pd_router
();
// Add workers
let
prefill_worker
=
create_test_worker
(
"http://prefill"
.to_string
(),
WorkerType
::
Prefill
{
bootstrap_port
:
None
,
},
true
,
);
let
decode_worker
=
create_test_worker
(
"http://decode"
.to_string
(),
WorkerType
::
Decode
,
true
);
router
.prefill_workers
.write
()
.unwrap
()
.push
(
prefill_worker
);
router
.decode_workers
.write
()
.unwrap
()
.push
(
decode_worker
);
// Get references to the workers - clone to avoid holding lock across await
let
(
prefill_ref
,
decode_ref
)
=
{
let
workers
=
router
.prefill_workers
.read
()
.unwrap
();
let
prefill
=
workers
[
0
]
.clone_worker
();
drop
(
workers
);
let
workers
=
router
.decode_workers
.read
()
.unwrap
();
let
decode
=
workers
[
0
]
.clone_worker
();
(
prefill
,
decode
)
};
// Initially load should be 0
assert_eq!
(
prefill_ref
.load
(),
0
);
assert_eq!
(
decode_ref
.load
(),
0
);
// Create a mock streaming response
let
(
tx
,
rx
)
=
tokio
::
sync
::
mpsc
::
unbounded_channel
();
let
stream
=
tokio_stream
::
wrappers
::
UnboundedReceiverStream
::
new
(
rx
);
// Call create_streaming_response which should increment load
let
_
response
=
router
.create_streaming_response
(
stream
.map
(
Ok
),
StatusCode
::
OK
,
None
,
false
,
None
,
None
,
prefill_ref
.as_ref
(),
decode_ref
.as_ref
(),
);
// Load should be incremented immediately
assert_eq!
(
prefill_ref
.load
(),
1
);
assert_eq!
(
decode_ref
.load
(),
1
);
// Send some data through the stream
tx
.send
(
bytes
::
Bytes
::
from
(
"test data"
))
.unwrap
();
// Give time for the spawned task to process
sleep
(
Duration
::
from_millis
(
10
))
.await
;
// Load should still be 1 (streaming in progress)
assert_eq!
(
prefill_ref
.load
(),
1
);
assert_eq!
(
decode_ref
.load
(),
1
);
// Close the stream
drop
(
tx
);
// Give time for cleanup
sleep
(
Duration
::
from_millis
(
100
))
.await
;
// Load should be decremented after streaming completes
assert_eq!
(
prefill_ref
.load
(),
0
);
assert_eq!
(
decode_ref
.load
(),
0
);
}
// ============= Concurrent Operations Tests =============
// ============= Concurrent Operations Tests =============
#[tokio::test]
#[tokio::test]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment