Unverified Commit 067068f2 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] regular router circuit breaker (#8997)

parent 6beeff41
{
"name": "sglang",
"lockfileVersion": 3,
"requires": true,
"packages": {}
}
......@@ -116,6 +116,39 @@ python -m sglang_router.launch_router \
--prometheus-port 9000
```
### Retries and Circuit Breakers
- Retries (regular router) are enabled by default with exponential backoff and jitter. You can tune them via CLI:
```bash
python -m sglang_router.launch_router \
--worker-urls http://localhost:8080 http://localhost:8081 \
--retry-max-retries 3 \
--retry-initial-backoff-ms 100 \
--retry-max-backoff-ms 10000 \
--retry-backoff-multiplier 2.0 \
--retry-jitter-factor 0.1
```
- Circuit Breaker defaults protect workers and auto-recover. Tune thresholds/timeouts:
```bash
python -m sglang_router.launch_router \
--worker-urls http://localhost:8080 http://localhost:8081 \
--cb-failure-threshold 5 \
--cb-success-threshold 2 \
--cb-timeout-duration-secs 30 \
--cb-window-duration-secs 60
```
Behavior summary:
- Closed → Open after N consecutive failures (failure-threshold)
- Open → HalfOpen after timeout (timeout-duration-secs)
- HalfOpen → Closed after M consecutive successes (success-threshold)
- Any failure in HalfOpen reopens immediately
Retry predicate (regular router): retry on 408/429/500/502/503/504, otherwise return immediately. Backoff/jitter observed between attempts.
### Request ID Tracking
Track requests across distributed systems with configurable headers:
......
......@@ -74,6 +74,19 @@ class RouterArgs:
max_concurrent_requests: int = 64
# CORS allowed origins
cors_allowed_origins: List[str] = dataclasses.field(default_factory=list)
# Retry configuration
retry_max_retries: int = 3
retry_initial_backoff_ms: int = 100
retry_max_backoff_ms: int = 10_000
retry_backoff_multiplier: float = 2.0
retry_jitter_factor: float = 0.1
disable_retries: bool = False
# Circuit breaker configuration
cb_failure_threshold: int = 5
cb_success_threshold: int = 2
cb_timeout_duration_secs: int = 30
cb_window_duration_secs: int = 60
disable_circuit_breaker: bool = False
@staticmethod
def add_cli_args(
......@@ -289,6 +302,63 @@ class RouterArgs:
default=RouterArgs.request_timeout_secs,
help="Request timeout in seconds",
)
# Retry configuration
parser.add_argument(
f"--{prefix}retry-max-retries",
type=int,
default=RouterArgs.retry_max_retries,
)
parser.add_argument(
f"--{prefix}retry-initial-backoff-ms",
type=int,
default=RouterArgs.retry_initial_backoff_ms,
)
parser.add_argument(
f"--{prefix}retry-max-backoff-ms",
type=int,
default=RouterArgs.retry_max_backoff_ms,
)
parser.add_argument(
f"--{prefix}retry-backoff-multiplier",
type=float,
default=RouterArgs.retry_backoff_multiplier,
)
parser.add_argument(
f"--{prefix}retry-jitter-factor",
type=float,
default=RouterArgs.retry_jitter_factor,
)
parser.add_argument(
f"--{prefix}disable-retries",
action="store_true",
help="Disable retries (equivalent to setting retry_max_retries=1)",
)
# Circuit breaker configuration
parser.add_argument(
f"--{prefix}cb-failure-threshold",
type=int,
default=RouterArgs.cb_failure_threshold,
)
parser.add_argument(
f"--{prefix}cb-success-threshold",
type=int,
default=RouterArgs.cb_success_threshold,
)
parser.add_argument(
f"--{prefix}cb-timeout-duration-secs",
type=int,
default=RouterArgs.cb_timeout_duration_secs,
)
parser.add_argument(
f"--{prefix}cb-window-duration-secs",
type=int,
default=RouterArgs.cb_window_duration_secs,
)
parser.add_argument(
f"--{prefix}disable-circuit-breaker",
action="store_true",
help="Disable circuit breaker (equivalent to setting cb_failure_threshold to u32::MAX)",
)
parser.add_argument(
f"--{prefix}max-concurrent-requests",
type=int,
......@@ -372,6 +442,19 @@ class RouterArgs:
RouterArgs.max_concurrent_requests,
),
cors_allowed_origins=getattr(args, f"{prefix}cors_allowed_origins", []),
retry_max_retries=getattr(args, f"{prefix}retry_max_retries"),
retry_initial_backoff_ms=getattr(args, f"{prefix}retry_initial_backoff_ms"),
retry_max_backoff_ms=getattr(args, f"{prefix}retry_max_backoff_ms"),
retry_backoff_multiplier=getattr(args, f"{prefix}retry_backoff_multiplier"),
retry_jitter_factor=getattr(args, f"{prefix}retry_jitter_factor"),
cb_failure_threshold=getattr(args, f"{prefix}cb_failure_threshold"),
cb_success_threshold=getattr(args, f"{prefix}cb_success_threshold"),
cb_timeout_duration_secs=getattr(args, f"{prefix}cb_timeout_duration_secs"),
cb_window_duration_secs=getattr(args, f"{prefix}cb_window_duration_secs"),
disable_retries=getattr(args, f"{prefix}disable_retries", False),
disable_circuit_breaker=getattr(
args, f"{prefix}disable_circuit_breaker", False
),
)
@staticmethod
......@@ -558,6 +641,17 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
request_id_headers=router_args.request_id_headers,
max_concurrent_requests=router_args.max_concurrent_requests,
cors_allowed_origins=router_args.cors_allowed_origins,
retry_max_retries=router_args.retry_max_retries,
retry_initial_backoff_ms=router_args.retry_initial_backoff_ms,
retry_max_backoff_ms=router_args.retry_max_backoff_ms,
retry_backoff_multiplier=router_args.retry_backoff_multiplier,
retry_jitter_factor=router_args.retry_jitter_factor,
cb_failure_threshold=router_args.cb_failure_threshold,
cb_success_threshold=router_args.cb_success_threshold,
cb_timeout_duration_secs=router_args.cb_timeout_duration_secs,
cb_window_duration_secs=router_args.cb_window_duration_secs,
disable_retries=router_args.disable_retries,
disable_circuit_breaker=router_args.disable_circuit_breaker,
)
router.start()
......
......@@ -158,6 +158,7 @@ def main():
default=31000,
help="Base port number for data parallel workers",
)
# No extra retry/CB flags here; RouterArgs.add_cli_args already defines them with router- prefix
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
......
......@@ -104,6 +104,17 @@ class Router:
decode_policy: Optional[PolicyType] = None,
max_concurrent_requests: int = 64,
cors_allowed_origins: List[str] = None,
retry_max_retries: int = 3,
retry_initial_backoff_ms: int = 100,
retry_max_backoff_ms: int = 10_000,
retry_backoff_multiplier: float = 2.0,
retry_jitter_factor: float = 0.1,
cb_failure_threshold: int = 5,
cb_success_threshold: int = 2,
cb_timeout_duration_secs: int = 30,
cb_window_duration_secs: int = 60,
disable_retries: bool = False,
disable_circuit_breaker: bool = False,
):
if selector is None:
selector = {}
......@@ -149,6 +160,17 @@ class Router:
decode_policy=decode_policy,
max_concurrent_requests=max_concurrent_requests,
cors_allowed_origins=cors_allowed_origins,
retry_max_retries=retry_max_retries,
retry_initial_backoff_ms=retry_initial_backoff_ms,
retry_max_backoff_ms=retry_max_backoff_ms,
retry_backoff_multiplier=retry_backoff_multiplier,
retry_jitter_factor=retry_jitter_factor,
cb_failure_threshold=cb_failure_threshold,
cb_success_threshold=cb_success_threshold,
cb_timeout_duration_secs=cb_timeout_duration_secs,
cb_window_duration_secs=cb_window_duration_secs,
disable_retries=disable_retries,
disable_circuit_breaker=disable_circuit_breaker,
)
def start(self) -> None:
......
......@@ -53,6 +53,17 @@ class TestLaunchRouter(unittest.TestCase):
prefill=None,
decode=None,
worker_urls=[],
retry_max_retries=3,
retry_initial_backoff_ms=100,
retry_max_backoff_ms=10_000,
retry_backoff_multiplier=2.0,
retry_jitter_factor=0.1,
cb_failure_threshold=5,
cb_success_threshold=2,
cb_timeout_duration_secs=30,
cb_window_duration_secs=60,
disable_retries=False,
disable_circuit_breaker=False,
)
def create_router_args(self, **kwargs):
......
......@@ -31,6 +31,16 @@ def popen_launch_router(
prometheus_port: int = None,
prometheus_host: str = None,
dp_aware: bool = False,
# Router retry/CB tuning (optional)
router_retry_max_retries: int = None,
router_retry_initial_backoff_ms: int = None,
router_retry_max_backoff_ms: int = None,
router_retry_backoff_multiplier: float = None,
router_retry_jitter_factor: float = None,
router_cb_failure_threshold: int = None,
router_cb_success_threshold: int = None,
router_cb_timeout_duration_secs: int = None,
router_cb_window_duration_secs: int = None,
):
"""
Launch the router server process.
......@@ -107,6 +117,21 @@ def popen_launch_router(
if dp_aware:
command.append("--router-dp-aware")
# Append router retry/CB tuning flags if provided
def _add(flag: str, val):
if val is not None:
command.extend([flag, str(val)])
_add("--router-retry-max-retries", router_retry_max_retries)
_add("--router-retry-initial-backoff-ms", router_retry_initial_backoff_ms)
_add("--router-retry-max-backoff-ms", router_retry_max_backoff_ms)
_add("--router-retry-backoff-multiplier", router_retry_backoff_multiplier)
_add("--router-retry-jitter-factor", router_retry_jitter_factor)
_add("--router-cb-failure-threshold", router_cb_failure_threshold)
_add("--router-cb-success-threshold", router_cb_success_threshold)
_add("--router-cb-timeout-duration-secs", router_cb_timeout_duration_secs)
_add("--router-cb-window-duration-secs", router_cb_window_duration_secs)
process = subprocess.Popen(command, stdout=None, stderr=None)
start_time = time.perf_counter()
......
......@@ -43,6 +43,12 @@ pub struct RouterConfig {
pub retry: RetryConfig,
/// Circuit breaker configuration
pub circuit_breaker: CircuitBreakerConfig,
/// Disable retries (overrides retry.max_retries to 1 when true)
#[serde(default)]
pub disable_retries: bool,
/// Disable circuit breaker (overrides circuit_breaker.failure_threshold to u32::MAX when true)
#[serde(default)]
pub disable_circuit_breaker: bool,
}
/// Routing mode configuration
......@@ -197,6 +203,10 @@ pub struct RetryConfig {
pub max_backoff_ms: u64,
/// Backoff multiplier for exponential backoff
pub backoff_multiplier: f32,
/// Jitter factor applied to backoff (0.0 - 1.0)
/// Effective delay D' = D * (1 + U[-j, +j])
#[serde(default = "default_retry_jitter_factor")]
pub jitter_factor: f32,
}
impl Default for RetryConfig {
......@@ -206,10 +216,15 @@ impl Default for RetryConfig {
initial_backoff_ms: 100,
max_backoff_ms: 10000,
backoff_multiplier: 2.0,
jitter_factor: 0.1,
}
}
}
fn default_retry_jitter_factor() -> f32 {
0.1
}
/// Circuit breaker configuration for worker reliability
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CircuitBreakerConfig {
......@@ -276,6 +291,8 @@ impl Default for RouterConfig {
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
}
}
}
......@@ -312,6 +329,24 @@ impl RouterConfig {
pub fn has_metrics(&self) -> bool {
self.metrics.is_some()
}
/// Compute the effective retry config considering disable flag
pub fn effective_retry_config(&self) -> RetryConfig {
let mut cfg = self.retry.clone();
if self.disable_retries {
cfg.max_retries = 1;
}
cfg
}
/// Compute the effective circuit breaker config considering disable flag
pub fn effective_circuit_breaker_config(&self) -> CircuitBreakerConfig {
let mut cfg = self.circuit_breaker.clone();
if self.disable_circuit_breaker {
cfg.failure_threshold = u32::MAX;
}
cfg
}
}
#[cfg(test)]
......@@ -388,6 +423,8 @@ mod tests {
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
};
let json = serde_json::to_string(&config).unwrap();
......@@ -817,6 +854,8 @@ mod tests {
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
};
assert!(config.mode.is_pd_mode());
......@@ -870,6 +909,8 @@ mod tests {
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
};
assert!(!config.mode.is_pd_mode());
......@@ -919,6 +960,8 @@ mod tests {
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
};
assert!(config.has_service_discovery());
......
......@@ -23,6 +23,12 @@ impl ConfigValidator {
Self::validate_compatibility(config)?;
// Validate effective retry/CB configs (respect disable flags)
let retry_cfg = config.effective_retry_config();
let cb_cfg = config.effective_circuit_breaker_config();
Self::validate_retry(&retry_cfg)?;
Self::validate_circuit_breaker(&cb_cfg)?;
Ok(())
}
......@@ -263,6 +269,79 @@ impl ConfigValidator {
Ok(())
}
/// Validate retry configuration
fn validate_retry(retry: &RetryConfig) -> ConfigResult<()> {
if retry.max_retries < 1 {
return Err(ConfigError::InvalidValue {
field: "retry.max_retries".to_string(),
value: retry.max_retries.to_string(),
reason: "Must be >= 1 (set to 1 to effectively disable retries)".to_string(),
});
}
if retry.initial_backoff_ms == 0 {
return Err(ConfigError::InvalidValue {
field: "retry.initial_backoff_ms".to_string(),
value: retry.initial_backoff_ms.to_string(),
reason: "Must be > 0".to_string(),
});
}
if retry.max_backoff_ms < retry.initial_backoff_ms {
return Err(ConfigError::InvalidValue {
field: "retry.max_backoff_ms".to_string(),
value: retry.max_backoff_ms.to_string(),
reason: "Must be >= initial_backoff_ms".to_string(),
});
}
if retry.backoff_multiplier < 1.0 {
return Err(ConfigError::InvalidValue {
field: "retry.backoff_multiplier".to_string(),
value: retry.backoff_multiplier.to_string(),
reason: "Must be >= 1.0".to_string(),
});
}
if !(0.0..=1.0).contains(&retry.jitter_factor) {
return Err(ConfigError::InvalidValue {
field: "retry.jitter_factor".to_string(),
value: retry.jitter_factor.to_string(),
reason: "Must be between 0.0 and 1.0".to_string(),
});
}
Ok(())
}
/// Validate circuit breaker configuration
fn validate_circuit_breaker(cb: &CircuitBreakerConfig) -> ConfigResult<()> {
if cb.failure_threshold < 1 {
return Err(ConfigError::InvalidValue {
field: "circuit_breaker.failure_threshold".to_string(),
value: cb.failure_threshold.to_string(),
reason: "Must be >= 1 (set to u32::MAX to effectively disable CB)".to_string(),
});
}
if cb.success_threshold < 1 {
return Err(ConfigError::InvalidValue {
field: "circuit_breaker.success_threshold".to_string(),
value: cb.success_threshold.to_string(),
reason: "Must be >= 1".to_string(),
});
}
if cb.timeout_duration_secs == 0 {
return Err(ConfigError::InvalidValue {
field: "circuit_breaker.timeout_duration_secs".to_string(),
value: cb.timeout_duration_secs.to_string(),
reason: "Must be > 0".to_string(),
});
}
if cb.window_duration_secs == 0 {
return Err(ConfigError::InvalidValue {
field: "circuit_breaker.window_duration_secs".to_string(),
value: cb.window_duration_secs.to_string(),
reason: "Must be > 0".to_string(),
});
}
Ok(())
}
/// Validate compatibility between different configuration sections
fn validate_compatibility(config: &RouterConfig) -> ConfigResult<()> {
// All policies are now supported for both router types thanks to the unified trait design
......
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use tracing::info;
/// Circuit breaker configuration
#[derive(Debug, Clone)]
......@@ -113,6 +114,7 @@ impl CircuitBreaker {
self.total_successes.fetch_add(1, Ordering::Relaxed);
self.consecutive_failures.store(0, Ordering::Release);
let successes = self.consecutive_successes.fetch_add(1, Ordering::AcqRel) + 1;
// Outcome-level metrics are recorded at the worker level where the worker label is known
let current_state = *self.state.read().unwrap();
......@@ -138,6 +140,7 @@ impl CircuitBreaker {
self.total_failures.fetch_add(1, Ordering::Relaxed);
self.consecutive_successes.store(0, Ordering::Release);
let failures = self.consecutive_failures.fetch_add(1, Ordering::AcqRel) + 1;
// Outcome-level metrics are recorded at the worker level where the worker label is known
// Update last failure time
{
......@@ -204,11 +207,18 @@ impl CircuitBreaker {
}
}
tracing::info!(
"Circuit breaker state transition: {} -> {}",
old_state,
new_state
);
let from = match old_state {
CircuitState::Closed => "closed",
CircuitState::Open => "open",
CircuitState::HalfOpen => "half_open",
};
let to = match new_state {
CircuitState::Closed => "closed",
CircuitState::Open => "open",
CircuitState::HalfOpen => "half_open",
};
info!("Circuit breaker state transition: {} -> {}", from, to);
// Transition metrics are recorded at the worker level where the worker label is known
}
}
......
......@@ -8,6 +8,7 @@
pub mod circuit_breaker;
pub mod error;
pub mod retry;
pub mod worker;
// Re-export commonly used types at the module level
......@@ -15,6 +16,7 @@ pub use circuit_breaker::{
CircuitBreaker, CircuitBreakerConfig, CircuitBreakerStats, CircuitState,
};
pub use error::{WorkerError, WorkerResult};
pub use retry::{BackoffCalculator, RetryError, RetryExecutor};
pub use worker::{
start_health_checker, BasicWorker, DPAwareWorker, HealthChecker, Worker, WorkerCollection,
WorkerFactory, WorkerLoadGuard, WorkerType,
......
use crate::config::types::RetryConfig;
use axum::response::Response;
use rand::Rng;
use std::time::Duration;
use tracing::debug;
/// Computes exponential backoff with optional jitter.
#[derive(Debug, Clone)]
pub struct BackoffCalculator;
impl BackoffCalculator {
/// Calculate backoff delay for a given attempt index (0-based).
pub fn calculate_delay(config: &RetryConfig, attempt: u32) -> Duration {
// Base exponential backoff
let pow = config.backoff_multiplier.powi(attempt as i32);
let mut delay_ms = (config.initial_backoff_ms as f32 * pow) as u64;
if delay_ms > config.max_backoff_ms {
delay_ms = config.max_backoff_ms;
}
// Apply jitter in range [-j, +j]
let jitter = config.jitter_factor.max(0.0).min(1.0);
if jitter > 0.0 {
let mut rng = rand::thread_rng();
let jitter_scale: f32 = rng.gen_range(-jitter..=jitter);
let jitter_ms = (delay_ms as f32 * jitter_scale)
.round()
.max(-(delay_ms as f32));
let adjusted = (delay_ms as i64 + jitter_ms as i64).max(0) as u64;
return Duration::from_millis(adjusted);
}
Duration::from_millis(delay_ms)
}
}
#[derive(Debug, thiserror::Error)]
pub enum RetryError {
#[error("no available workers")]
NoAvailableWorkers,
#[error("maximum retry attempts exceeded")]
MaxRetriesExceeded,
}
/// A thin async retry executor for generic operations.
#[derive(Debug, Clone, Default)]
pub struct RetryExecutor;
impl RetryExecutor {
/// Execute an async operation with retries and backoff.
/// The `operation` closure is invoked each attempt with the attempt index.
pub async fn execute_with_retry<F, Fut, T>(
config: &RetryConfig,
mut operation: F,
) -> Result<T, RetryError>
where
F: FnMut(u32) -> Fut,
Fut: std::future::Future<Output = Result<T, ()>>,
{
let max = config.max_retries.max(1);
let mut attempt: u32 = 0;
loop {
match operation(attempt).await {
Ok(val) => return Ok(val),
Err(_) => {
// Use the number of failures so far (0-indexed) to compute delay,
// so the first retry uses `initial_backoff_ms`.
let is_last = attempt + 1 >= max;
if is_last {
return Err(RetryError::MaxRetriesExceeded);
}
let delay = BackoffCalculator::calculate_delay(config, attempt);
attempt += 1; // advance to the next attempt after computing delay
tokio::time::sleep(delay).await;
}
}
}
}
/// Execute an operation that returns an HTTP Response with retries and backoff.
///
/// Usage pattern:
/// - `operation(attempt)`: perform one attempt (0-based). Construct and send the request,
/// then return the `Response`. Do any per-attempt bookkeeping (e.g., load tracking,
/// circuit-breaker outcome recording) inside this closure.
/// - `should_retry(&response, attempt)`: decide if the given response should be retried
/// (e.g., based on HTTP status). Returning false short-circuits and returns the response.
/// - `on_backoff(delay, next_attempt)`: called before sleeping between attempts.
/// Use this to record metrics.
/// - `on_exhausted()`: called when the executor has exhausted all retry attempts.
///
/// Example:
/// ```ignore
/// let resp = RetryExecutor::execute_response_with_retry(
/// &retry_cfg,
/// |attempt| async move {
/// let worker = select_cb_aware_worker()?;
/// let resp = send_request(worker).await;
/// worker.record_outcome(resp.status().is_success());
/// resp
/// },
/// |res, _| matches!(res.status(), StatusCode::REQUEST_TIMEOUT | StatusCode::TOO_MANY_REQUESTS | StatusCode::INTERNAL_SERVER_ERROR | StatusCode::BAD_GATEWAY | StatusCode::SERVICE_UNAVAILABLE | StatusCode::GATEWAY_TIMEOUT),
/// |delay, attempt| RouterMetrics::record_retry_backoff_duration(delay, attempt),
/// || RouterMetrics::record_retries_exhausted("/route"),
/// ).await;
/// ```
pub async fn execute_response_with_retry<Op, Fut, ShouldRetry, OnBackoff, OnExhausted>(
config: &RetryConfig,
mut operation: Op,
should_retry: ShouldRetry,
on_backoff: OnBackoff,
mut on_exhausted: OnExhausted,
) -> Response
where
Op: FnMut(u32) -> Fut,
Fut: std::future::Future<Output = Response>,
ShouldRetry: Fn(&Response, u32) -> bool,
OnBackoff: Fn(Duration, u32),
OnExhausted: FnMut(),
{
let max = config.max_retries.max(1);
let mut attempt: u32 = 0;
loop {
let response = operation(attempt).await;
let is_last = attempt + 1 >= max;
if !should_retry(&response, attempt) {
return response;
}
if is_last {
// Exhausted retries
on_exhausted();
return response;
}
// Backoff before next attempt
let next_attempt = attempt + 1;
// Compute delay based on the number of failures so far (0-indexed)
let delay = BackoffCalculator::calculate_delay(config, attempt);
debug!(
attempt = attempt,
next_attempt = next_attempt,
delay_ms = delay.as_millis() as u64,
"Retry backoff"
);
on_backoff(delay, next_attempt);
tokio::time::sleep(delay).await;
attempt = next_attempt;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
fn base_retry_config() -> RetryConfig {
RetryConfig {
max_retries: 3,
initial_backoff_ms: 1,
max_backoff_ms: 4,
backoff_multiplier: 2.0,
jitter_factor: 0.0,
}
}
#[test]
fn test_backoff_no_jitter_progression_and_cap() {
let cfg = RetryConfig {
max_retries: 10,
initial_backoff_ms: 100,
max_backoff_ms: 250,
backoff_multiplier: 2.0,
jitter_factor: 0.0,
};
// attempt=0 => 100ms
assert_eq!(
BackoffCalculator::calculate_delay(&cfg, 0),
Duration::from_millis(100)
);
// attempt=1 => 200ms
assert_eq!(
BackoffCalculator::calculate_delay(&cfg, 1),
Duration::from_millis(200)
);
// attempt=2 => 400ms -> capped to 250ms
assert_eq!(
BackoffCalculator::calculate_delay(&cfg, 2),
Duration::from_millis(250)
);
// large attempt still capped
assert_eq!(
BackoffCalculator::calculate_delay(&cfg, 10),
Duration::from_millis(250)
);
}
#[test]
fn test_backoff_with_jitter_within_bounds() {
let cfg = RetryConfig {
max_retries: 5,
initial_backoff_ms: 100,
max_backoff_ms: 10_000,
backoff_multiplier: 2.0,
jitter_factor: 0.5,
};
// attempt=2 => base 400ms, jitter in [0.5x, 1.5x]
let base = 400.0;
for _ in 0..50 {
let d = BackoffCalculator::calculate_delay(&cfg, 2).as_millis() as f32;
assert!(d >= base * 0.5 - 1.0 && d <= base * 1.5 + 1.0);
}
}
#[tokio::test]
async fn test_execute_with_retry_success_after_failures() {
let cfg = base_retry_config();
let remaining = Arc::new(AtomicU32::new(2));
let calls = Arc::new(AtomicU32::new(0));
let res: Result<u32, RetryError> = RetryExecutor::execute_with_retry(&cfg, {
let remaining = remaining.clone();
let calls = calls.clone();
move |_attempt| {
calls.fetch_add(1, Ordering::Relaxed);
let remaining = remaining.clone();
async move {
if remaining
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |v| v.checked_sub(1))
.is_ok()
{
Err(())
} else {
Ok(42u32)
}
}
}
})
.await;
assert!(res.is_ok());
assert_eq!(res.unwrap(), 42);
assert_eq!(calls.load(Ordering::Relaxed), 3); // 2 fails + 1 success
}
#[tokio::test]
async fn test_execute_with_retry_exhausted() {
let cfg = base_retry_config();
let calls = Arc::new(AtomicU32::new(0));
let res: Result<u32, RetryError> = RetryExecutor::execute_with_retry(&cfg, {
let calls = calls.clone();
move |_attempt| {
calls.fetch_add(1, Ordering::Relaxed);
async move { Err(()) }
}
})
.await;
assert!(matches!(res, Err(RetryError::MaxRetriesExceeded)));
assert_eq!(calls.load(Ordering::Relaxed), cfg.max_retries);
}
#[tokio::test]
async fn test_execute_response_with_retry_success_path_and_hooks() {
let cfg = base_retry_config();
let remaining = Arc::new(AtomicU32::new(2));
let calls = Arc::new(AtomicU32::new(0));
let backoffs = Arc::new(AtomicU32::new(0));
let exhausted = Arc::new(AtomicU32::new(0));
let response = RetryExecutor::execute_response_with_retry(
&cfg,
{
let remaining = remaining.clone();
let calls = calls.clone();
move |_attempt| {
calls.fetch_add(1, Ordering::Relaxed);
let remaining = remaining.clone();
async move {
if remaining
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |v| v.checked_sub(1))
.is_ok()
{
(StatusCode::SERVICE_UNAVAILABLE, "fail").into_response()
} else {
(StatusCode::OK, "ok").into_response()
}
}
}
},
|res, _attempt| !res.status().is_success(), // retry until success
{
let backoffs = backoffs.clone();
move |_delay, _next_attempt| {
backoffs.fetch_add(1, Ordering::Relaxed);
}
},
{
let exhausted = exhausted.clone();
move || {
exhausted.fetch_add(1, Ordering::Relaxed);
}
},
)
.await;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(calls.load(Ordering::Relaxed), 3); // 2 fails + 1 success
assert_eq!(backoffs.load(Ordering::Relaxed), 2);
assert_eq!(exhausted.load(Ordering::Relaxed), 0);
}
#[tokio::test]
async fn test_execute_response_with_retry_non_retryable_short_circuit() {
let cfg = base_retry_config();
let calls = Arc::new(AtomicU32::new(0));
let backoffs = Arc::new(AtomicU32::new(0));
let exhausted = Arc::new(AtomicU32::new(0));
let response = RetryExecutor::execute_response_with_retry(
&cfg,
{
let calls = calls.clone();
move |_attempt| {
calls.fetch_add(1, Ordering::Relaxed);
async move { (StatusCode::BAD_REQUEST, "bad").into_response() }
}
},
|_res, _attempt| false, // never retry
{
let backoffs = backoffs.clone();
move |_delay, _next_attempt| {
backoffs.fetch_add(1, Ordering::Relaxed);
}
},
{
let exhausted = exhausted.clone();
move || {
exhausted.fetch_add(1, Ordering::Relaxed);
}
},
)
.await;
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
assert_eq!(calls.load(Ordering::Relaxed), 1);
assert_eq!(backoffs.load(Ordering::Relaxed), 0);
assert_eq!(exhausted.load(Ordering::Relaxed), 0);
}
#[tokio::test]
async fn test_execute_response_with_retry_exhausted_hooks() {
let cfg = base_retry_config();
let calls = Arc::new(AtomicU32::new(0));
let backoffs = Arc::new(AtomicU32::new(0));
let exhausted = Arc::new(AtomicU32::new(0));
let response = RetryExecutor::execute_response_with_retry(
&cfg,
{
let calls = calls.clone();
move |_attempt| {
calls.fetch_add(1, Ordering::Relaxed);
async move { (StatusCode::SERVICE_UNAVAILABLE, "fail").into_response() }
}
},
|_res, _attempt| true, // keep retrying
{
let backoffs = backoffs.clone();
move |_delay, _next_attempt| {
backoffs.fetch_add(1, Ordering::Relaxed);
}
},
{
let exhausted = exhausted.clone();
move || {
exhausted.fetch_add(1, Ordering::Relaxed);
}
},
)
.await;
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
assert_eq!(calls.load(Ordering::Relaxed), cfg.max_retries);
assert_eq!(backoffs.load(Ordering::Relaxed), cfg.max_retries - 1);
assert_eq!(exhausted.load(Ordering::Relaxed), 1);
}
}
......@@ -77,7 +77,35 @@ pub trait Worker: Send + Sync + fmt::Debug {
/// Record the outcome of a request to this worker
fn record_outcome(&self, success: bool) {
// Record outcome-level metric with worker label
let outcome_str = if success { "success" } else { "failure" };
RouterMetrics::record_cb_outcome(self.url(), outcome_str);
// Record into circuit breaker and infer state change for metrics
let before = self.circuit_breaker().state();
self.circuit_breaker().record_outcome(success);
let after = self.circuit_breaker().state();
if before != after {
let from = match before {
crate::core::CircuitState::Closed => "closed",
crate::core::CircuitState::Open => "open",
crate::core::CircuitState::HalfOpen => "half_open",
};
let to = match after {
crate::core::CircuitState::Closed => "closed",
crate::core::CircuitState::Open => "open",
crate::core::CircuitState::HalfOpen => "half_open",
};
RouterMetrics::record_cb_state_transition(self.url(), from, to);
}
let state_code = match self.circuit_breaker().state() {
crate::core::CircuitState::Closed => 0u8,
crate::core::CircuitState::Open => 1u8,
crate::core::CircuitState::HalfOpen => 2u8,
};
RouterMetrics::set_cb_state(self.url(), state_code);
}
// === DP-aware methods ===
......
......@@ -59,6 +59,19 @@ struct Router {
decode_policy: Option<PolicyType>,
max_concurrent_requests: usize,
cors_allowed_origins: Vec<String>,
// Retry configuration
retry_max_retries: u32,
retry_initial_backoff_ms: u64,
retry_max_backoff_ms: u64,
retry_backoff_multiplier: f32,
retry_jitter_factor: f32,
disable_retries: bool,
// Circuit breaker configuration
cb_failure_threshold: u32,
cb_success_threshold: u32,
cb_timeout_duration_secs: u64,
cb_window_duration_secs: u64,
disable_circuit_breaker: bool,
}
impl Router {
......@@ -146,8 +159,21 @@ impl Router {
request_id_headers: self.request_id_headers.clone(),
max_concurrent_requests: self.max_concurrent_requests,
cors_allowed_origins: self.cors_allowed_origins.clone(),
retry: config::RetryConfig::default(),
circuit_breaker: config::CircuitBreakerConfig::default(),
retry: config::RetryConfig {
max_retries: self.retry_max_retries,
initial_backoff_ms: self.retry_initial_backoff_ms,
max_backoff_ms: self.retry_max_backoff_ms,
backoff_multiplier: self.retry_backoff_multiplier,
jitter_factor: self.retry_jitter_factor,
},
circuit_breaker: config::CircuitBreakerConfig {
failure_threshold: self.cb_failure_threshold,
success_threshold: self.cb_success_threshold,
timeout_duration_secs: self.cb_timeout_duration_secs,
window_duration_secs: self.cb_window_duration_secs,
},
disable_retries: false,
disable_circuit_breaker: false,
})
}
}
......@@ -189,7 +215,20 @@ impl Router {
prefill_policy = None,
decode_policy = None,
max_concurrent_requests = 64,
cors_allowed_origins = vec![]
cors_allowed_origins = vec![],
// Retry defaults
retry_max_retries = 3,
retry_initial_backoff_ms = 100,
retry_max_backoff_ms = 10_000,
retry_backoff_multiplier = 2.0,
retry_jitter_factor = 0.1,
disable_retries = false,
// Circuit breaker defaults
cb_failure_threshold = 5,
cb_success_threshold = 2,
cb_timeout_duration_secs = 30,
cb_window_duration_secs = 60,
disable_circuit_breaker = false,
))]
fn new(
worker_urls: Vec<String>,
......@@ -226,6 +265,17 @@ impl Router {
decode_policy: Option<PolicyType>,
max_concurrent_requests: usize,
cors_allowed_origins: Vec<String>,
retry_max_retries: u32,
retry_initial_backoff_ms: u64,
retry_max_backoff_ms: u64,
retry_backoff_multiplier: f32,
retry_jitter_factor: f32,
disable_retries: bool,
cb_failure_threshold: u32,
cb_success_threshold: u32,
cb_timeout_duration_secs: u64,
cb_window_duration_secs: u64,
disable_circuit_breaker: bool,
) -> PyResult<Self> {
Ok(Router {
host,
......@@ -262,6 +312,17 @@ impl Router {
decode_policy,
max_concurrent_requests,
cors_allowed_origins,
retry_max_retries,
retry_initial_backoff_ms,
retry_max_backoff_ms,
retry_backoff_multiplier,
retry_jitter_factor,
disable_retries,
cb_failure_threshold,
cb_success_threshold,
cb_timeout_duration_secs,
cb_window_duration_secs,
disable_circuit_breaker,
})
}
......
......@@ -36,6 +36,28 @@ pub fn init_metrics() {
"sgl_router_retries_total",
"Total number of request retries by route"
);
describe_histogram!(
"sgl_router_retry_backoff_duration_seconds",
"Backoff duration in seconds by attempt index"
);
describe_counter!(
"sgl_router_retries_exhausted_total",
"Total number of requests that exhausted retries by route"
);
// Circuit breaker metrics
describe_gauge!(
"sgl_router_cb_state",
"Circuit breaker state per worker (0=closed, 1=open, 2=half_open)"
);
describe_counter!(
"sgl_router_cb_state_transitions_total",
"Total number of circuit breaker state transitions by worker"
);
describe_counter!(
"sgl_router_cb_outcomes_total",
"Total number of circuit breaker outcomes by worker and outcome type (success/failure)"
);
// Worker metrics
describe_gauge!(
......@@ -186,6 +208,20 @@ impl RouterMetrics {
.increment(1);
}
pub fn record_retry_backoff_duration(duration: Duration, attempt: u32) {
histogram!("sgl_router_retry_backoff_duration_seconds",
"attempt" => attempt.to_string()
)
.record(duration.as_secs_f64());
}
pub fn record_retries_exhausted(route: &str) {
counter!("sgl_router_retries_exhausted_total",
"route" => route.to_string()
)
.increment(1);
}
// Worker metrics
pub fn set_active_workers(count: usize) {
gauge!("sgl_router_active_workers").set(count as f64);
......@@ -321,6 +357,31 @@ impl RouterMetrics {
)
.set(count as f64);
}
// Circuit breaker metrics
pub fn set_cb_state(worker: &str, state_code: u8) {
gauge!("sgl_router_cb_state",
"worker" => worker.to_string()
)
.set(state_code as f64);
}
pub fn record_cb_state_transition(worker: &str, from: &str, to: &str) {
counter!("sgl_router_cb_state_transitions_total",
"worker" => worker.to_string(),
"from" => from.to_string(),
"to" => to.to_string()
)
.increment(1);
}
pub fn record_cb_outcome(worker: &str, outcome: &str) {
counter!("sgl_router_cb_outcomes_total",
"worker" => worker.to_string(),
"outcome" => outcome.to_string()
)
.increment(1);
}
}
#[cfg(test)]
......
......@@ -109,7 +109,7 @@ pub(crate) fn get_healthy_worker_indices(workers: &[Box<dyn Worker>]) -> Vec<usi
workers
.iter()
.enumerate()
.filter(|(_, w)| w.is_healthy())
.filter(|(_, w)| w.is_healthy() && w.circuit_breaker().can_execute())
.map(|(idx, _)| idx)
.collect()
}
......
......@@ -1845,7 +1845,7 @@ impl RouterTrait for PDRouter {
mod tests {
use super::*;
use crate::core::{BasicWorker, WorkerType};
use crate::policies::{CacheAwarePolicy, RandomPolicy};
use crate::policies::RandomPolicy;
fn create_test_pd_router() -> PDRouter {
let prefill_policy = Arc::new(RandomPolicy::new());
......
use crate::config::types::{CircuitBreakerConfig as ConfigCircuitBreakerConfig, RetryConfig};
use crate::core::{CircuitBreakerConfig, HealthChecker, Worker, WorkerFactory};
use crate::core::{CircuitBreakerConfig, HealthChecker, RetryExecutor, Worker, WorkerFactory};
use crate::metrics::RouterMetrics;
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::policies::LoadBalancingPolicy;
......@@ -382,6 +382,33 @@ impl Router {
}
// New method to route typed requests directly
/// Select worker considering circuit breaker state
fn select_worker_with_circuit_breaker(&self, text: Option<&str>) -> Option<Box<dyn Worker>> {
let workers = self.workers.read().ok()?;
let available: Vec<Box<dyn Worker>> = workers
.iter()
.filter(|w| w.is_available())
.map(|w| w.clone_worker())
.collect();
if available.is_empty() {
return None;
}
let idx = self.policy.select_worker(&available, text)?;
Some(available[idx].clone_worker())
}
fn is_retryable_status(status: StatusCode) -> bool {
matches!(
status,
StatusCode::REQUEST_TIMEOUT
| StatusCode::TOO_MANY_REQUESTS
| StatusCode::INTERNAL_SERVER_ERROR
| StatusCode::BAD_GATEWAY
| StatusCode::SERVICE_UNAVAILABLE
| StatusCode::GATEWAY_TIMEOUT
)
}
pub async fn route_typed_request<
T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone,
>(
......@@ -390,127 +417,70 @@ impl Router {
typed_req: &T,
route: &str,
) -> Response {
// Handle retries like the original implementation
let start = Instant::now();
// Use retry config for per-worker retries
let max_request_retries = self.retry_config.max_retries;
// Total retries across all workers (2x to allow trying multiple workers)
let max_total_retries = self.retry_config.max_retries * 2;
let mut total_retries = 0;
while total_retries < max_total_retries {
// Extract routing text directly from typed request
let text = typed_req.extract_text_for_routing();
let is_stream = typed_req.is_stream();
// Select worker based on text
let worker_url = self.select_generate_worker_from_text(&text);
if worker_url.is_empty() {
RouterMetrics::record_request_error(route, "no_healthy_workers");
return (
StatusCode::SERVICE_UNAVAILABLE,
"No healthy workers available",
)
.into_response();
}
let mut request_retries = 0;
// Try the same worker multiple times
while request_retries < max_request_retries {
if total_retries >= 1 {
info!("Retrying request after {} failed attempts", total_retries);
RouterMetrics::record_retry(route);
}
let is_stream = typed_req.is_stream();
let text = typed_req.extract_text_for_routing();
let response = RetryExecutor::execute_response_with_retry(
&self.retry_config,
// operation per attempt
|_: u32| async {
let worker = match self.select_worker_with_circuit_breaker(Some(&text)) {
Some(w) => w,
None => {
RouterMetrics::record_request_error(route, "no_available_workers");
return (
StatusCode::SERVICE_UNAVAILABLE,
"No available workers (all circuits open or unhealthy)",
)
.into_response();
}
};
// Increment load before request if using RAII load tracking
// Optional load tracking for cache-aware policy
let load_incremented = if self.policy.name() == "cache_aware" {
let workers_guard = self.workers.read().unwrap();
if let Some(worker) = workers_guard.iter().find(|w| w.url() == &worker_url) {
worker.increment_load();
RouterMetrics::set_running_requests(&worker_url, worker.load());
true
} else {
false
}
worker.increment_load();
RouterMetrics::set_running_requests(worker.url(), worker.load());
true
} else {
false
};
// Send typed request directly
let response = self
.send_typed_request(
headers,
typed_req,
route,
&worker_url,
worker.url(),
is_stream,
load_incremented,
)
.await;
if response.status().is_success() {
let duration = start.elapsed();
RouterMetrics::record_request(route);
RouterMetrics::record_generate_duration(duration);
return response;
} else {
let status = response.status();
if status.is_client_error() && status != StatusCode::TOO_MANY_REQUESTS {
RouterMetrics::record_request_error(route, "client_error");
return response;
}
// if the worker is healthy, it means the request is bad, so return the error response
let health_response = self.send_health_check(&worker_url).await;
if health_response.status().is_success() {
RouterMetrics::record_request_error(route, "request_failed");
return response;
}
}
warn!(
"Generate request failed route={} worker_url={} attempt={} max_attempts={}",
route,
worker_url,
request_retries + 1,
max_request_retries
);
request_retries += 1;
total_retries += 1;
if request_retries == max_request_retries {
warn!(
"Removing failed worker after typed request failures worker_url={}",
worker_url
);
self.remove_worker(&worker_url);
break;
}
let backoff_ms = (100u64 * (request_retries as u64)).min(1000);
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
}
}
RouterMetrics::record_request_error(route, "request_failed");
(
StatusCode::INTERNAL_SERVER_ERROR,
"All retry attempts failed",
worker.record_outcome(response.status().is_success());
response
},
// should_retry predicate
|res, _attempt| Self::is_retryable_status(res.status()),
// on_backoff hook
|delay, attempt| {
RouterMetrics::record_retry(route);
RouterMetrics::record_retry_backoff_duration(delay, attempt);
},
// on_exhausted hook
|| RouterMetrics::record_retries_exhausted(route),
)
.into_response()
}
// Helper method to select worker from text using the policy
fn select_generate_worker_from_text(&self, text: &str) -> String {
let workers = self.workers.read().unwrap();
match self.policy.select_worker(&workers, Some(text)) {
Some(idx) => workers[idx].url().to_string(),
None => {
warn!("No healthy workers available");
String::new()
}
.await;
if response.status().is_success() {
let duration = start.elapsed();
RouterMetrics::record_request(route);
RouterMetrics::record_generate_duration(duration);
} else if !Self::is_retryable_status(response.status()) {
RouterMetrics::record_request_error(route, "non_retryable_error");
}
response
}
// TODO (rui): Better accommodate to the Worker abstraction
......
......@@ -48,6 +48,8 @@ impl TestContext {
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
};
Self::new_with_config(config, worker_configs).await
......@@ -1091,6 +1093,8 @@ mod error_tests {
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
};
let ctx = TestContext::new_with_config(
......@@ -1439,6 +1443,8 @@ mod pd_mode_tests {
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
};
// Create app context
......@@ -1594,6 +1600,8 @@ mod request_id_tests {
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
};
let ctx = TestContext::new_with_config(
......
......@@ -39,6 +39,8 @@ impl TestContext {
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
};
let mut workers = Vec::new();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment