Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6f81a710
Unverified
Commit
6f81a710
authored
Aug 11, 2025
by
Simo Lin
Committed by
GitHub
Aug 11, 2025
Browse files
[pd-router] add retry and circuit breakfor for pd router (#9051)
parent
a6452b71
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
235 additions
and
176 deletions
+235
-176
sgl-router/src/core/mod.rs
sgl-router/src/core/mod.rs
+1
-1
sgl-router/src/core/retry.rs
sgl-router/src/core/retry.rs
+16
-2
sgl-router/src/routers/pd_router.rs
sgl-router/src/routers/pd_router.rs
+211
-152
sgl-router/src/routers/router.rs
sgl-router/src/routers/router.rs
+7
-21
No files found.
sgl-router/src/core/mod.rs
View file @
6f81a710
...
...
@@ -16,7 +16,7 @@ pub use circuit_breaker::{
CircuitBreaker
,
CircuitBreakerConfig
,
CircuitBreakerStats
,
CircuitState
,
};
pub
use
error
::{
WorkerError
,
WorkerResult
};
pub
use
retry
::{
BackoffCalculator
,
RetryError
,
RetryExecutor
};
pub
use
retry
::{
is_retryable_status
,
BackoffCalculator
,
RetryError
,
RetryExecutor
};
pub
use
worker
::{
start_health_checker
,
BasicWorker
,
DPAwareWorker
,
HealthChecker
,
Worker
,
WorkerCollection
,
WorkerFactory
,
WorkerLoadGuard
,
WorkerType
,
...
...
sgl-router/src/core/retry.rs
View file @
6f81a710
use
crate
::
config
::
types
::
RetryConfig
;
use
axum
::
http
::
StatusCode
;
use
axum
::
response
::
Response
;
use
rand
::
Rng
;
use
std
::
time
::
Duration
;
use
tracing
::
debug
;
/// Check if an HTTP status code indicates a retryable error
pub
fn
is_retryable_status
(
status
:
StatusCode
)
->
bool
{
matches!
(
status
,
StatusCode
::
REQUEST_TIMEOUT
|
StatusCode
::
TOO_MANY_REQUESTS
|
StatusCode
::
INTERNAL_SERVER_ERROR
|
StatusCode
::
BAD_GATEWAY
|
StatusCode
::
SERVICE_UNAVAILABLE
|
StatusCode
::
GATEWAY_TIMEOUT
)
}
/// Computes exponential backoff with optional jitter.
#[derive(Debug,
Clone)]
pub
struct
BackoffCalculator
;
...
...
@@ -21,8 +35,8 @@ impl BackoffCalculator {
// Apply jitter in range [-j, +j]
let
jitter
=
config
.jitter_factor
.max
(
0.0
)
.min
(
1.0
);
if
jitter
>
0.0
{
let
mut
rng
=
rand
::
thread_
rng
();
let
jitter_scale
:
f32
=
rng
.
gen
_range
(
-
jitter
..=
jitter
);
let
mut
rng
=
rand
::
rng
();
let
jitter_scale
:
f32
=
rng
.
random
_range
(
-
jitter
..=
jitter
);
let
jitter_ms
=
(
delay_ms
as
f32
*
jitter_scale
)
.round
()
.max
(
-
(
delay_ms
as
f32
));
...
...
sgl-router/src/routers/pd_router.rs
View file @
6f81a710
...
...
@@ -2,7 +2,10 @@
// This module handles routing for disaggregated prefill-decode systems
use
super
::
pd_types
::{
api_path
,
PDRouterError
};
use
crate
::
config
::
types
::{
CircuitBreakerConfig
as
ConfigCircuitBreakerConfig
,
RetryConfig
};
use
crate
::
core
::{
CircuitBreakerConfig
,
HealthChecker
,
Worker
,
WorkerFactory
,
WorkerLoadGuard
};
use
crate
::
core
::{
is_retryable_status
,
CircuitBreakerConfig
,
HealthChecker
,
RetryExecutor
,
Worker
,
WorkerFactory
,
WorkerLoadGuard
,
};
use
crate
::
metrics
::
RouterMetrics
;
use
crate
::
openai_api_types
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
};
use
crate
::
policies
::
LoadBalancingPolicy
;
...
...
@@ -17,6 +20,7 @@ use axum::{
};
use
futures_util
::
StreamExt
;
use
reqwest
::
Client
;
use
serde
::
Serialize
;
use
serde_json
::
Value
;
use
std
::
collections
::
HashMap
;
use
std
::
sync
::{
Arc
,
RwLock
};
...
...
@@ -43,6 +47,16 @@ pub struct PDRouter {
_
decode_health_checker
:
Option
<
HealthChecker
>
,
}
// Request context for PD router operations
#[derive(Clone)]
struct
PDRequestContext
{
route
:
&
'static
str
,
batch_size
:
Option
<
usize
>
,
is_stream
:
bool
,
return_logprob
:
bool
,
request_text
:
Option
<
String
>
,
}
impl
PDRouter
{
// Dynamic worker management methods for service discovery
...
...
@@ -218,12 +232,8 @@ impl PDRouter {
let
core_cb_config
=
CircuitBreakerConfig
{
failure_threshold
:
circuit_breaker_config
.failure_threshold
,
success_threshold
:
circuit_breaker_config
.success_threshold
,
timeout_duration
:
std
::
time
::
Duration
::
from_secs
(
circuit_breaker_config
.timeout_duration_secs
,
),
window_duration
:
std
::
time
::
Duration
::
from_secs
(
circuit_breaker_config
.window_duration_secs
,
),
timeout_duration
:
Duration
::
from_secs
(
circuit_breaker_config
.timeout_duration_secs
),
window_duration
:
Duration
::
from_secs
(
circuit_breaker_config
.window_duration_secs
),
};
// Convert URLs to Worker trait objects
...
...
@@ -459,8 +469,96 @@ impl PDRouter {
Ok
(
original
)
}
// Execute the dual dispatch to prefill and decode servers
async
fn
execute_dual_dispatch
(
// Execute the dual dispatch to prefill and decode servers with retries and bootstrap injection
async
fn
execute_dual_dispatch
<
T
:
Serialize
+
Clone
>
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
original_request
:
&
T
,
context
:
PDRequestContext
,
)
->
Response
{
let
start_time
=
Instant
::
now
();
let
route
=
context
.route
;
RetryExecutor
::
execute_response_with_retry
(
&
self
.retry_config
,
// Operation per attempt
{
let
original_request
=
original_request
.clone
();
move
|
attempt
:
u32
|
{
let
original_request
=
original_request
.clone
();
let
context
=
context
.clone
();
async
move
{
// Select workers fresh for each attempt
let
(
prefill
,
decode
)
=
match
self
.select_pd_pair
(
context
.request_text
.as_deref
())
.await
{
Ok
(
pair
)
=>
pair
,
Err
(
e
)
=>
{
RouterMetrics
::
record_pd_error
(
"server_selection"
);
return
Self
::
handle_server_selection_error
(
e
);
}
};
debug!
(
"PD retry attempt {} using prefill={} decode={}"
,
attempt
,
prefill
.url
(),
decode
.url
()
);
// Serialize the original request
let
mut
json_request
=
match
serde_json
::
to_value
(
&
original_request
)
{
Ok
(
v
)
=>
v
,
Err
(
e
)
=>
return
Self
::
handle_serialization_error
(
e
),
};
// Inject bootstrap based on current prefill worker
json_request
=
match
Self
::
inject_bootstrap_into_value
(
json_request
,
prefill
.as_ref
(),
context
.batch_size
,
)
{
Ok
(
v
)
=>
v
,
Err
(
e
)
=>
return
Self
::
handle_serialization_error
(
e
),
};
// Execute the actual dual dispatch
let
response
=
self
.execute_dual_dispatch_internal
(
headers
,
json_request
,
context
.route
,
prefill
.as_ref
(),
decode
.as_ref
(),
context
.is_stream
,
context
.return_logprob
,
start_time
,
)
.await
;
// Record outcomes for circuit breakers
let
is_success
=
response
.status
()
.is_success
();
prefill
.record_outcome
(
is_success
);
decode
.record_outcome
(
is_success
);
response
}
}
},
// Should retry predicate
|
res
,
_
attempt
|
is_retryable_status
(
res
.status
()),
// On backoff hook
|
delay
,
attempt
|
{
RouterMetrics
::
record_retry
(
route
);
RouterMetrics
::
record_retry_backoff_duration
(
delay
,
attempt
);
},
// On exhausted hook
||
RouterMetrics
::
record_retries_exhausted
(
route
),
)
.await
}
// Internal method that performs the actual dual dispatch (without retry logic)
async
fn
execute_dual_dispatch_internal
(
&
self
,
headers
:
Option
<&
HeaderMap
>
,
json_request
:
Value
,
...
...
@@ -696,7 +794,7 @@ impl PDRouter {
self
.prefill_policy
.needs_request_text
()
||
self
.decode_policy
.needs_request_text
()
}
// Select a pair of prefill and decode servers
// Select a pair of prefill and decode servers
considering circuit breaker state
async
fn
select_pd_pair
(
&
self
,
request_text
:
Option
<&
str
>
,
...
...
@@ -711,29 +809,58 @@ impl PDRouter {
.read
()
.map_err
(|
e
|
format!
(
"Failed to acquire decode workers lock: {}"
,
e
))
?
;
// Check we have workers
if
prefill_workers
.is_empty
()
{
return
Err
(
"No prefill workers available. Please check if prefill servers are configured and healthy."
.to_string
());
}
if
decode_workers
.is_empty
()
{
return
Err
(
"No decode workers available. Please check if decode servers are configured and healthy."
.to_string
());
// Select workers using helper function
let
prefill
=
Self
::
pick_worker_by_policy
(
&*
prefill_workers
,
&*
self
.prefill_policy
,
request_text
,
"prefill"
,
)
?
;
let
decode
=
Self
::
pick_worker_by_policy
(
&*
decode_workers
,
&*
self
.decode_policy
,
request_text
,
"decode"
,
)
?
;
Ok
((
prefill
,
decode
))
}
// Helper function to select a worker using the policy
fn
pick_worker_by_policy
(
workers
:
&
[
Box
<
dyn
Worker
>
],
policy
:
&
dyn
LoadBalancingPolicy
,
request_text
:
Option
<&
str
>
,
worker_type
:
&
str
,
)
->
Result
<
Box
<
dyn
Worker
>
,
String
>
{
// Check if we have any workers
if
workers
.is_empty
()
{
return
Err
(
format!
(
"No {} workers available. Please check if {} servers are configured and healthy."
,
worker_type
,
worker_type
));
}
// Select prefill worker using prefill policy
let
prefill_idx
=
self
.prefill_policy
.select_worker
(
&
prefill_workers
,
request_text
)
.ok_or
(
"Failed to select prefill worker"
)
?
;
// Filter available workers (healthy + circuit breaker not open)
let
available_workers
:
Vec
<
Box
<
dyn
Worker
>>
=
workers
.iter
()
.filter
(|
w
|
w
.is_available
())
.map
(|
w
|
w
.clone_worker
())
.collect
();
// Select decode worker using decode policy
let
decode_idx
=
self
.decode_policy
.select_worker
(
&
decode_workers
,
request_text
)
.ok_or
(
"Failed to select decode worker"
)
?
;
if
available_workers
.is_empty
()
{
return
Err
(
format!
(
"No available {} workers (all circuits open or unhealthy)"
,
worker_type
));
}
let
prefill
=
prefill_workers
[
prefill_idx
]
.clone_worker
();
let
decode
=
decode_workers
[
decode_idx
]
.clone_worker
();
Ok
((
prefill
,
decode
))
// Let policy select from available workers only
match
policy
.select_worker
(
&
available_workers
,
request_text
)
{
Some
(
idx
)
=>
Ok
(
available_workers
[
idx
]
.clone_worker
()),
None
=>
Err
(
format!
(
"Policy could not select a {} worker"
,
worker_type
)),
}
}
// Background task to monitor worker loads with shared client
...
...
@@ -1449,61 +1576,41 @@ impl RouterTrait for PDRouter {
headers
:
Option
<&
HeaderMap
>
,
body
:
&
GenerateRequest
,
)
->
Response
{
let
start
=
Instant
::
now
();
// Extract flags for routing logic
// Extract parameters
let
is_stream
=
body
.stream
;
let
return_logprob
=
body
.return_logprob
;
// Extract text for cache-aware routing
only if needed
// Extract text for cache-aware routing
let
request_text
=
if
self
.policies_need_request_text
()
{
body
.text
.as_deref
()
.or_else
(||
{
body
.prompt
.as_ref
()
.and_then
(|
p
|
match
p
{
crate
::
openai_api_types
::
StringOrArray
::
String
(
s
)
=>
Some
(
s
.as_str
()),
crate
::
openai_api_types
::
StringOrArray
::
Array
(
v
)
=>
{
v
.first
()
.map
(|
s
|
s
.as_str
())
}
body
.text
.as_deref
()
.or_else
(||
{
body
.prompt
.as_ref
()
.and_then
(|
p
|
match
p
{
crate
::
openai_api_types
::
StringOrArray
::
String
(
s
)
=>
Some
(
s
.as_str
()),
crate
::
openai_api_types
::
StringOrArray
::
Array
(
v
)
=>
{
v
.first
()
.map
(|
s
|
s
.as_str
())
}
})
})
}
)
.map
(|
s
|
s
.to_string
()
)
}
else
{
None
};
// Select servers
let
(
prefill
,
decode
)
=
match
self
.select_pd_pair
(
request_text
)
.await
{
Ok
(
pair
)
=>
pair
,
Err
(
e
)
=>
return
Self
::
handle_server_selection_error
(
e
),
};
// Log routing decision
info!
(
"PD routing decision route=/generate prefill_url={} decode_url={}"
,
prefill
.url
(),
decode
.url
()
);
// Calculate batch size
let
batch_size
=
Self
::
get_generate_batch_size
(
body
);
let
original
=
match
serde_json
::
to_value
(
body
)
{
Ok
(
v
)
=>
v
,
Err
(
e
)
=>
return
Self
::
handle_serialization_error
(
e
),
};
let
json
=
match
Self
::
inject_bootstrap_into_value
(
original
,
prefill
.as_ref
(),
batch_size
)
{
Ok
(
v
)
=>
v
,
Err
(
e
)
=>
return
Self
::
handle_serialization_error
(
e
),
};
// Execute dual dispatch
self
.execute_dual_dispatch
(
headers
,
json
,
"/generate"
,
prefill
.as_ref
(),
decode
.as_ref
(),
// Create context
let
context
=
PDRequestContext
{
route
:
"/generate"
,
batch_size
,
is_stream
,
return_logprob
,
start
,
)
.await
request_text
,
};
// Execute with retry and bootstrap injection
self
.execute_dual_dispatch
(
headers
,
body
,
context
)
.await
}
async
fn
route_chat
(
...
...
@@ -1511,25 +1618,19 @@ impl RouterTrait for PDRouter {
headers
:
Option
<&
HeaderMap
>
,
body
:
&
ChatCompletionRequest
,
)
->
Response
{
let
start
=
Instant
::
now
();
// Extract flags for routing logic
// Extract parameters
let
is_stream
=
body
.stream
;
let
return_logprob
=
body
.logprobs
;
// Extract text for cache-aware routing
from chat messages only if needed
// Extract text for cache-aware routing
let
request_text
=
if
self
.policies_need_request_text
()
{
body
.messages
.first
()
.and_then
(|
msg
|
match
msg
{
crate
::
openai_api_types
::
ChatMessage
::
User
{
content
,
..
}
=>
{
match
content
{
crate
::
openai_api_types
::
UserMessageContent
::
Text
(
text
)
=>
{
Some
(
text
.as_str
())
}
crate
::
openai_api_types
::
UserMessageContent
::
Parts
(
_
)
=>
None
,
// Skip complex content
}
}
crate
::
openai_api_types
::
ChatMessage
::
User
{
content
,
..
}
=>
match
content
{
crate
::
openai_api_types
::
UserMessageContent
::
Text
(
text
)
=>
Some
(
text
.clone
()),
crate
::
openai_api_types
::
UserMessageContent
::
Parts
(
_
)
=>
None
,
},
crate
::
openai_api_types
::
ChatMessage
::
System
{
content
,
..
}
=>
{
Some
(
content
.
as_str
())
Some
(
content
.
clone
())
}
_
=>
None
,
})
...
...
@@ -1537,41 +1638,20 @@ impl RouterTrait for PDRouter {
None
};
// Select servers
let
(
prefill
,
decode
)
=
match
self
.select_pd_pair
(
request_text
)
.await
{
Ok
(
pair
)
=>
pair
,
Err
(
e
)
=>
return
Self
::
handle_server_selection_error
(
e
),
};
// Log routing decision
info!
(
"PD routing decision route=/v1/chat/completions prefill_url={} decode_url={}"
,
prefill
.url
(),
decode
.url
()
);
// Calculate batch size
let
batch_size
=
Self
::
get_chat_batch_size
(
body
);
let
original
=
match
serde_json
::
to_value
(
body
)
{
Ok
(
v
)
=>
v
,
Err
(
e
)
=>
return
Self
::
handle_serialization_error
(
e
),
};
let
json
=
match
Self
::
inject_bootstrap_into_value
(
original
,
prefill
.as_ref
(),
batch_size
)
{
Ok
(
v
)
=>
v
,
Err
(
e
)
=>
return
Self
::
handle_serialization_error
(
e
),
};
// Execute dual dispatch
self
.execute_dual_dispatch
(
headers
,
json
,
"/v1/chat/completions"
,
prefill
.as_ref
(),
decode
.as_ref
(),
// Create context
let
context
=
PDRequestContext
{
route
:
"/v1/chat/completions"
,
batch_size
,
is_stream
,
return_logprob
,
start
,
)
.await
request_text
,
};
// Execute with retry and bootstrap injection
self
.execute_dual_dispatch
(
headers
,
body
,
context
)
.await
}
async
fn
route_completion
(
...
...
@@ -1579,57 +1659,36 @@ impl RouterTrait for PDRouter {
headers
:
Option
<&
HeaderMap
>
,
body
:
&
CompletionRequest
,
)
->
Response
{
let
start
=
Instant
::
now
();
// Extract flags for routing logic
// Extract parameters
let
is_stream
=
body
.stream
;
let
return_logprob
=
body
.logprobs
.is_some
();
// Extract text for cache-aware routing
only if needed
// Extract text for cache-aware routing
let
request_text
=
if
self
.policies_need_request_text
()
{
match
&
body
.prompt
{
crate
::
openai_api_types
::
StringOrArray
::
String
(
s
)
=>
Some
(
s
.as_str
()),
crate
::
openai_api_types
::
StringOrArray
::
Array
(
v
)
=>
v
.first
()
.map
(|
s
|
s
.as_str
()),
crate
::
openai_api_types
::
StringOrArray
::
String
(
s
)
=>
Some
(
s
.clone
()),
crate
::
openai_api_types
::
StringOrArray
::
Array
(
v
)
=>
{
v
.first
()
.map
(|
s
|
s
.to_string
())
}
}
}
else
{
None
};
// Select servers
let
(
prefill
,
decode
)
=
match
self
.select_pd_pair
(
request_text
)
.await
{
Ok
(
pair
)
=>
pair
,
Err
(
e
)
=>
return
Self
::
handle_server_selection_error
(
e
),
};
// Log routing decision
info!
(
"PD routing decision route=/v1/completions prefill_url={} decode_url={}"
,
prefill
.url
(),
decode
.url
()
);
// Calculate batch size
let
batch_size
=
Self
::
get_completion_batch_size
(
body
);
let
original
=
match
serde_json
::
to_value
(
body
)
{
Ok
(
v
)
=>
v
,
Err
(
e
)
=>
return
Self
::
handle_serialization_error
(
e
),
};
let
json
=
match
Self
::
inject_bootstrap_into_value
(
original
,
prefill
.as_ref
(),
batch_size
)
{
Ok
(
v
)
=>
v
,
Err
(
e
)
=>
return
Self
::
handle_serialization_error
(
e
),
};
// Execute dual dispatch
self
.execute_dual_dispatch
(
headers
,
json
,
"/v1/completions"
,
prefill
.as_ref
(),
decode
.as_ref
(),
// Create context
let
context
=
PDRequestContext
{
route
:
"/v1/completions"
,
batch_size
,
is_stream
,
return_logprob
,
start
,
)
.await
request_text
,
};
// Execute with retry and bootstrap injection
self
.execute_dual_dispatch
(
headers
,
body
,
context
)
.await
}
async
fn
flush_cache
(
&
self
)
->
Response
{
...
...
sgl-router/src/routers/router.rs
View file @
6f81a710
use
crate
::
config
::
types
::{
CircuitBreakerConfig
as
ConfigCircuitBreakerConfig
,
RetryConfig
};
use
crate
::
core
::{
CircuitBreakerConfig
,
HealthChecker
,
RetryExecutor
,
Worker
,
WorkerFactory
};
use
crate
::
core
::{
is_retryable_status
,
CircuitBreakerConfig
,
HealthChecker
,
RetryExecutor
,
Worker
,
WorkerFactory
,
};
use
crate
::
metrics
::
RouterMetrics
;
use
crate
::
openai_api_types
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
};
use
crate
::
policies
::
LoadBalancingPolicy
;
...
...
@@ -81,12 +83,8 @@ impl Router {
let
core_cb_config
=
CircuitBreakerConfig
{
failure_threshold
:
circuit_breaker_config
.failure_threshold
,
success_threshold
:
circuit_breaker_config
.success_threshold
,
timeout_duration
:
std
::
time
::
Duration
::
from_secs
(
circuit_breaker_config
.timeout_duration_secs
,
),
window_duration
:
std
::
time
::
Duration
::
from_secs
(
circuit_breaker_config
.window_duration_secs
,
),
timeout_duration
:
Duration
::
from_secs
(
circuit_breaker_config
.timeout_duration_secs
),
window_duration
:
Duration
::
from_secs
(
circuit_breaker_config
.window_duration_secs
),
};
// Create Worker trait objects from URLs
...
...
@@ -397,18 +395,6 @@ impl Router {
Some
(
available
[
idx
]
.clone_worker
())
}
fn
is_retryable_status
(
status
:
StatusCode
)
->
bool
{
matches!
(
status
,
StatusCode
::
REQUEST_TIMEOUT
|
StatusCode
::
TOO_MANY_REQUESTS
|
StatusCode
::
INTERNAL_SERVER_ERROR
|
StatusCode
::
BAD_GATEWAY
|
StatusCode
::
SERVICE_UNAVAILABLE
|
StatusCode
::
GATEWAY_TIMEOUT
)
}
pub
async
fn
route_typed_request
<
T
:
crate
::
openai_api_types
::
GenerationRequest
+
serde
::
Serialize
+
Clone
,
>
(
...
...
@@ -461,7 +447,7 @@ impl Router {
response
},
// should_retry predicate
|
res
,
_
attempt
|
Self
::
is_retryable_status
(
res
.status
()),
|
res
,
_
attempt
|
is_retryable_status
(
res
.status
()),
// on_backoff hook
|
delay
,
attempt
|
{
RouterMetrics
::
record_retry
(
route
);
...
...
@@ -476,7 +462,7 @@ impl Router {
let
duration
=
start
.elapsed
();
RouterMetrics
::
record_request
(
route
);
RouterMetrics
::
record_generate_duration
(
duration
);
}
else
if
!
Self
::
is_retryable_status
(
response
.status
())
{
}
else
if
!
is_retryable_status
(
response
.status
())
{
RouterMetrics
::
record_request_error
(
route
,
"non_retryable_error"
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment