Unverified Commit 61a46804 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] router circuit breaker core (#8941)

parent 9020f7fc
...@@ -41,6 +41,8 @@ pub struct RouterConfig { ...@@ -41,6 +41,8 @@ pub struct RouterConfig {
pub cors_allowed_origins: Vec<String>, pub cors_allowed_origins: Vec<String>,
/// Retry configuration /// Retry configuration
pub retry: RetryConfig, pub retry: RetryConfig,
/// Circuit breaker configuration
pub circuit_breaker: CircuitBreakerConfig,
} }
/// Routing mode configuration /// Routing mode configuration
...@@ -208,6 +210,30 @@ impl Default for RetryConfig { ...@@ -208,6 +210,30 @@ impl Default for RetryConfig {
} }
} }
/// Circuit breaker configuration for worker reliability
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CircuitBreakerConfig {
/// Number of consecutive failures before opening circuit
pub failure_threshold: u32,
/// Number of consecutive successes before closing circuit
pub success_threshold: u32,
/// Time before attempting to recover from open state (in seconds)
pub timeout_duration_secs: u64,
/// Window duration for failure tracking (in seconds)
pub window_duration_secs: u64,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
success_threshold: 2,
timeout_duration_secs: 30,
window_duration_secs: 60,
}
}
}
/// Metrics configuration /// Metrics configuration
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetricsConfig { pub struct MetricsConfig {
...@@ -249,6 +275,7 @@ impl Default for RouterConfig { ...@@ -249,6 +275,7 @@ impl Default for RouterConfig {
max_concurrent_requests: 64, max_concurrent_requests: 64,
cors_allowed_origins: vec![], cors_allowed_origins: vec![],
retry: RetryConfig::default(), retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
} }
} }
} }
...@@ -360,6 +387,7 @@ mod tests { ...@@ -360,6 +387,7 @@ mod tests {
max_concurrent_requests: 64, max_concurrent_requests: 64,
cors_allowed_origins: vec![], cors_allowed_origins: vec![],
retry: RetryConfig::default(), retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
}; };
let json = serde_json::to_string(&config).unwrap(); let json = serde_json::to_string(&config).unwrap();
...@@ -788,6 +816,7 @@ mod tests { ...@@ -788,6 +816,7 @@ mod tests {
max_concurrent_requests: 64, max_concurrent_requests: 64,
cors_allowed_origins: vec![], cors_allowed_origins: vec![],
retry: RetryConfig::default(), retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
}; };
assert!(config.mode.is_pd_mode()); assert!(config.mode.is_pd_mode());
...@@ -840,6 +869,7 @@ mod tests { ...@@ -840,6 +869,7 @@ mod tests {
max_concurrent_requests: 64, max_concurrent_requests: 64,
cors_allowed_origins: vec![], cors_allowed_origins: vec![],
retry: RetryConfig::default(), retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
}; };
assert!(!config.mode.is_pd_mode()); assert!(!config.mode.is_pd_mode());
...@@ -888,6 +918,7 @@ mod tests { ...@@ -888,6 +918,7 @@ mod tests {
max_concurrent_requests: 64, max_concurrent_requests: 64,
cors_allowed_origins: vec![], cors_allowed_origins: vec![],
retry: RetryConfig::default(), retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
}; };
assert!(config.has_service_discovery()); assert!(config.has_service_discovery());
......
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
/// Circuit breaker configuration
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
/// Number of consecutive failures to open the circuit
pub failure_threshold: u32,
/// Success threshold to close circuit from half-open
pub success_threshold: u32,
/// Duration to wait before attempting half-open
pub timeout_duration: Duration,
/// Time window for failure counting
pub window_duration: Duration,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
success_threshold: 2,
timeout_duration: Duration::from_secs(30),
window_duration: Duration::from_secs(60),
}
}
}
/// Circuit breaker state
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
/// Normal operation - requests are allowed
Closed,
/// Circuit is open - requests are rejected
Open,
/// Testing if service has recovered - limited requests allowed
HalfOpen,
}
impl std::fmt::Display for CircuitState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CircuitState::Closed => write!(f, "Closed"),
CircuitState::Open => write!(f, "Open"),
CircuitState::HalfOpen => write!(f, "HalfOpen"),
}
}
}
/// Circuit breaker implementation
#[derive(Debug)]
pub struct CircuitBreaker {
state: Arc<RwLock<CircuitState>>,
consecutive_failures: Arc<AtomicU32>,
consecutive_successes: Arc<AtomicU32>,
total_failures: Arc<AtomicU64>,
total_successes: Arc<AtomicU64>,
last_failure_time: Arc<RwLock<Option<Instant>>>,
last_state_change: Arc<RwLock<Instant>>,
config: CircuitBreakerConfig,
}
impl CircuitBreaker {
/// Create a new circuit breaker with default configuration
pub fn new() -> Self {
Self::with_config(CircuitBreakerConfig::default())
}
/// Create a new circuit breaker with custom configuration
pub fn with_config(config: CircuitBreakerConfig) -> Self {
Self {
state: Arc::new(RwLock::new(CircuitState::Closed)),
consecutive_failures: Arc::new(AtomicU32::new(0)),
consecutive_successes: Arc::new(AtomicU32::new(0)),
total_failures: Arc::new(AtomicU64::new(0)),
total_successes: Arc::new(AtomicU64::new(0)),
last_failure_time: Arc::new(RwLock::new(None)),
last_state_change: Arc::new(RwLock::new(Instant::now())),
config,
}
}
/// Check if a request can be executed
pub fn can_execute(&self) -> bool {
// First check if we need to transition from Open to HalfOpen
self.check_and_update_state();
let state = *self.state.read().unwrap();
match state {
CircuitState::Closed => true,
CircuitState::Open => false,
CircuitState::HalfOpen => true, // Allow limited requests in half-open state
}
}
/// Get the current state
pub fn state(&self) -> CircuitState {
self.check_and_update_state();
*self.state.read().unwrap()
}
/// Record the outcome of a request
pub fn record_outcome(&self, success: bool) {
if success {
self.record_success();
} else {
self.record_failure();
}
}
/// Record a successful request
pub fn record_success(&self) {
self.total_successes.fetch_add(1, Ordering::Relaxed);
self.consecutive_failures.store(0, Ordering::Release);
let successes = self.consecutive_successes.fetch_add(1, Ordering::AcqRel) + 1;
let current_state = *self.state.read().unwrap();
match current_state {
CircuitState::HalfOpen => {
// Check if we've reached the success threshold to close the circuit
if successes >= self.config.success_threshold {
self.transition_to(CircuitState::Closed);
}
}
CircuitState::Closed => {
// Already closed, nothing to do
}
CircuitState::Open => {
// Shouldn't happen, but if it does, stay open
tracing::warn!("Success recorded while circuit is open");
}
}
}
/// Record a failed request
pub fn record_failure(&self) {
self.total_failures.fetch_add(1, Ordering::Relaxed);
self.consecutive_successes.store(0, Ordering::Release);
let failures = self.consecutive_failures.fetch_add(1, Ordering::AcqRel) + 1;
// Update last failure time
{
let mut last_failure = self.last_failure_time.write().unwrap();
*last_failure = Some(Instant::now());
}
let current_state = *self.state.read().unwrap();
match current_state {
CircuitState::Closed => {
// Check if we've reached the failure threshold to open the circuit
if failures >= self.config.failure_threshold {
self.transition_to(CircuitState::Open);
}
}
CircuitState::HalfOpen => {
// Single failure in half-open state reopens the circuit
self.transition_to(CircuitState::Open);
}
CircuitState::Open => {
// Already open, nothing to do
}
}
}
/// Check and update state based on timeout
fn check_and_update_state(&self) {
let current_state = *self.state.read().unwrap();
if current_state == CircuitState::Open {
// Check if timeout has expired
let last_change = *self.last_state_change.read().unwrap();
if last_change.elapsed() >= self.config.timeout_duration {
self.transition_to(CircuitState::HalfOpen);
}
}
}
/// Transition to a new state
fn transition_to(&self, new_state: CircuitState) {
let mut state = self.state.write().unwrap();
let old_state = *state;
if old_state != new_state {
*state = new_state;
// Update last state change time
let mut last_change = self.last_state_change.write().unwrap();
*last_change = Instant::now();
// Reset counters based on transition
match new_state {
CircuitState::Closed => {
self.consecutive_failures.store(0, Ordering::Release);
self.consecutive_successes.store(0, Ordering::Release);
}
CircuitState::Open => {
self.consecutive_successes.store(0, Ordering::Release);
}
CircuitState::HalfOpen => {
self.consecutive_failures.store(0, Ordering::Release);
self.consecutive_successes.store(0, Ordering::Release);
}
}
tracing::info!(
"Circuit breaker state transition: {} -> {}",
old_state,
new_state
);
}
}
/// Get the number of consecutive failures
pub fn failure_count(&self) -> u32 {
self.consecutive_failures.load(Ordering::Acquire)
}
/// Get the number of consecutive successes
pub fn success_count(&self) -> u32 {
self.consecutive_successes.load(Ordering::Acquire)
}
/// Get total failures
pub fn total_failures(&self) -> u64 {
self.total_failures.load(Ordering::Relaxed)
}
/// Get total successes
pub fn total_successes(&self) -> u64 {
self.total_successes.load(Ordering::Relaxed)
}
/// Get time since last failure
pub fn time_since_last_failure(&self) -> Option<Duration> {
self.last_failure_time.read().unwrap().map(|t| t.elapsed())
}
/// Get time since last state change
pub fn time_since_last_state_change(&self) -> Duration {
self.last_state_change.read().unwrap().elapsed()
}
/// Check if the circuit is in a half-open state
pub fn is_half_open(&self) -> bool {
self.state() == CircuitState::HalfOpen
}
/// Record a test success (for health check probing)
pub fn record_test_success(&self) {
if self.is_half_open() {
self.record_success();
}
}
/// Record a test failure (for health check probing)
pub fn record_test_failure(&self) {
if self.is_half_open() {
self.record_failure();
}
}
/// Reset the circuit breaker to closed state
pub fn reset(&self) {
self.transition_to(CircuitState::Closed);
self.consecutive_failures.store(0, Ordering::Release);
self.consecutive_successes.store(0, Ordering::Release);
}
/// Force the circuit to open (for manual intervention)
pub fn force_open(&self) {
self.transition_to(CircuitState::Open);
}
/// Get circuit breaker statistics
pub fn stats(&self) -> CircuitBreakerStats {
CircuitBreakerStats {
state: self.state(),
consecutive_failures: self.failure_count(),
consecutive_successes: self.success_count(),
total_failures: self.total_failures(),
total_successes: self.total_successes(),
time_since_last_failure: self.time_since_last_failure(),
time_since_last_state_change: self.time_since_last_state_change(),
}
}
}
impl Clone for CircuitBreaker {
fn clone(&self) -> Self {
Self {
state: Arc::clone(&self.state),
consecutive_failures: Arc::clone(&self.consecutive_failures),
consecutive_successes: Arc::clone(&self.consecutive_successes),
total_failures: Arc::clone(&self.total_failures),
total_successes: Arc::clone(&self.total_successes),
last_failure_time: Arc::clone(&self.last_failure_time),
last_state_change: Arc::clone(&self.last_state_change),
config: self.config.clone(),
}
}
}
impl Default for CircuitBreaker {
fn default() -> Self {
Self::new()
}
}
/// Circuit breaker statistics
#[derive(Debug, Clone)]
pub struct CircuitBreakerStats {
pub state: CircuitState,
pub consecutive_failures: u32,
pub consecutive_successes: u32,
pub total_failures: u64,
pub total_successes: u64,
pub time_since_last_failure: Option<Duration>,
pub time_since_last_state_change: Duration,
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_circuit_breaker_initial_state() {
let cb = CircuitBreaker::new();
assert_eq!(cb.state(), CircuitState::Closed);
assert!(cb.can_execute());
assert_eq!(cb.failure_count(), 0);
assert_eq!(cb.success_count(), 0);
}
#[test]
fn test_circuit_opens_on_threshold() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
..Default::default()
};
let cb = CircuitBreaker::with_config(config);
// Record failures up to threshold
assert_eq!(cb.state(), CircuitState::Closed);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Closed);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Closed);
cb.record_failure();
// Circuit should now be open
assert_eq!(cb.state(), CircuitState::Open);
assert!(!cb.can_execute());
assert_eq!(cb.failure_count(), 3);
}
#[test]
fn test_circuit_half_open_after_timeout() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
timeout_duration: Duration::from_millis(100),
..Default::default()
};
let cb = CircuitBreaker::with_config(config);
// Open the circuit
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
// Wait for timeout
thread::sleep(Duration::from_millis(150));
// Circuit should be half-open
assert_eq!(cb.state(), CircuitState::HalfOpen);
assert!(cb.can_execute());
}
#[test]
fn test_circuit_closes_on_success_threshold() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
success_threshold: 2,
timeout_duration: Duration::from_millis(50),
..Default::default()
};
let cb = CircuitBreaker::with_config(config);
// Open the circuit
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
// Wait for timeout
thread::sleep(Duration::from_millis(100));
assert_eq!(cb.state(), CircuitState::HalfOpen);
// Record successes
cb.record_success();
assert_eq!(cb.state(), CircuitState::HalfOpen);
cb.record_success();
// Circuit should now be closed
assert_eq!(cb.state(), CircuitState::Closed);
assert!(cb.can_execute());
}
#[test]
fn test_circuit_reopens_on_half_open_failure() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
timeout_duration: Duration::from_millis(50),
..Default::default()
};
let cb = CircuitBreaker::with_config(config);
// Open the circuit
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
// Wait for timeout
thread::sleep(Duration::from_millis(100));
assert_eq!(cb.state(), CircuitState::HalfOpen);
// Record a failure in half-open state
cb.record_failure();
// Circuit should reopen immediately
assert_eq!(cb.state(), CircuitState::Open);
assert!(!cb.can_execute());
}
#[test]
fn test_success_resets_failure_count() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
..Default::default()
};
let cb = CircuitBreaker::with_config(config);
// Record some failures
cb.record_failure();
cb.record_failure();
assert_eq!(cb.failure_count(), 2);
// Success should reset failure count
cb.record_success();
assert_eq!(cb.failure_count(), 0);
assert_eq!(cb.success_count(), 1);
// Can now record more failures without opening
cb.record_failure();
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Closed);
}
#[test]
fn test_manual_reset() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
..Default::default()
};
let cb = CircuitBreaker::with_config(config);
// Open the circuit
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
// Manual reset
cb.reset();
assert_eq!(cb.state(), CircuitState::Closed);
assert_eq!(cb.failure_count(), 0);
assert_eq!(cb.success_count(), 0);
}
#[test]
fn test_force_open() {
let cb = CircuitBreaker::new();
assert_eq!(cb.state(), CircuitState::Closed);
cb.force_open();
assert_eq!(cb.state(), CircuitState::Open);
assert!(!cb.can_execute());
}
#[test]
fn test_stats() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
..Default::default()
};
let cb = CircuitBreaker::with_config(config);
cb.record_success();
cb.record_failure();
cb.record_failure();
let stats = cb.stats();
assert_eq!(stats.state, CircuitState::Open);
assert_eq!(stats.consecutive_failures, 2);
assert_eq!(stats.consecutive_successes, 0);
assert_eq!(stats.total_failures, 2);
assert_eq!(stats.total_successes, 1);
}
#[test]
fn test_clone() {
let cb1 = CircuitBreaker::new();
cb1.record_failure();
let cb2 = cb1.clone();
assert_eq!(cb2.failure_count(), 1);
// Changes to cb1 affect cb2 (shared state)
cb1.record_failure();
assert_eq!(cb2.failure_count(), 2);
}
#[test]
fn test_thread_safety() {
use std::sync::Arc;
let cb = Arc::new(CircuitBreaker::new());
let mut handles = vec![];
// Spawn threads that record failures
for _ in 0..10 {
let cb_clone = Arc::clone(&cb);
let handle = thread::spawn(move || {
for _ in 0..100 {
cb_clone.record_failure();
}
});
handles.push(handle);
}
// Wait for all threads
for handle in handles {
handle.join().unwrap();
}
// Should have recorded 1000 failures
assert_eq!(cb.total_failures(), 1000);
}
}
...@@ -3,12 +3,17 @@ ...@@ -3,12 +3,17 @@
//! This module contains the fundamental types and traits used throughout the router: //! This module contains the fundamental types and traits used throughout the router:
//! - Worker trait and implementations //! - Worker trait and implementations
//! - Error types //! - Error types
//! - Circuit breaker for reliability
//! - Common utilities //! - Common utilities
pub mod circuit_breaker;
pub mod error; pub mod error;
pub mod worker; pub mod worker;
// Re-export commonly used types at the module level // Re-export commonly used types at the module level
pub use circuit_breaker::{
CircuitBreaker, CircuitBreakerConfig, CircuitBreakerStats, CircuitState,
};
pub use error::{WorkerError, WorkerResult}; pub use error::{WorkerError, WorkerResult};
pub use worker::{ pub use worker::{
start_health_checker, BasicWorker, DPAwareWorker, HealthChecker, Worker, WorkerCollection, start_health_checker, BasicWorker, DPAwareWorker, HealthChecker, Worker, WorkerCollection,
......
use super::{WorkerError, WorkerResult}; use super::{CircuitBreaker, CircuitBreakerConfig, WorkerError, WorkerResult};
use async_trait::async_trait; use async_trait::async_trait;
use futures; use futures;
use serde_json; use serde_json;
...@@ -66,6 +66,19 @@ pub trait Worker: Send + Sync + fmt::Debug { ...@@ -66,6 +66,19 @@ pub trait Worker: Send + Sync + fmt::Debug {
/// Clone the worker (for trait objects) /// Clone the worker (for trait objects)
fn clone_worker(&self) -> Box<dyn Worker>; fn clone_worker(&self) -> Box<dyn Worker>;
/// Get the circuit breaker for this worker
fn circuit_breaker(&self) -> &CircuitBreaker;
/// Check if the worker is available (healthy + circuit closed/half-open)
fn is_available(&self) -> bool {
self.is_healthy() && self.circuit_breaker().can_execute()
}
/// Record the outcome of a request to this worker
fn record_outcome(&self, success: bool) {
self.circuit_breaker().record_outcome(success);
}
// === DP-aware methods === // === DP-aware methods ===
/// Check if this worker is DP-aware /// Check if this worker is DP-aware
...@@ -172,6 +185,7 @@ pub struct BasicWorker { ...@@ -172,6 +185,7 @@ pub struct BasicWorker {
load_counter: Arc<AtomicUsize>, load_counter: Arc<AtomicUsize>,
processed_counter: Arc<AtomicUsize>, processed_counter: Arc<AtomicUsize>,
healthy: Arc<AtomicBool>, healthy: Arc<AtomicBool>,
circuit_breaker: CircuitBreaker,
} }
impl BasicWorker { impl BasicWorker {
...@@ -188,6 +202,7 @@ impl BasicWorker { ...@@ -188,6 +202,7 @@ impl BasicWorker {
load_counter: Arc::new(AtomicUsize::new(0)), load_counter: Arc::new(AtomicUsize::new(0)),
processed_counter: Arc::new(AtomicUsize::new(0)), processed_counter: Arc::new(AtomicUsize::new(0)),
healthy: Arc::new(AtomicBool::new(true)), healthy: Arc::new(AtomicBool::new(true)),
circuit_breaker: CircuitBreaker::new(),
} }
} }
...@@ -201,6 +216,11 @@ impl BasicWorker { ...@@ -201,6 +216,11 @@ impl BasicWorker {
self self
} }
pub fn with_circuit_breaker_config(mut self, config: CircuitBreakerConfig) -> Self {
self.circuit_breaker = CircuitBreaker::with_config(config);
self
}
pub fn normalised_url(&self) -> WorkerResult<&str> { pub fn normalised_url(&self) -> WorkerResult<&str> {
if self.url().contains("@") { if self.url().contains("@") {
// Need to extract the URL from "http://host:port@dp_rank" // Need to extract the URL from "http://host:port@dp_rank"
...@@ -304,6 +324,10 @@ impl Worker for BasicWorker { ...@@ -304,6 +324,10 @@ impl Worker for BasicWorker {
fn clone_worker(&self) -> Box<dyn Worker> { fn clone_worker(&self) -> Box<dyn Worker> {
Box::new(self.clone()) Box::new(self.clone())
} }
fn circuit_breaker(&self) -> &CircuitBreaker {
&self.circuit_breaker
}
} }
/// A DP-aware worker that handles data-parallel routing /// A DP-aware worker that handles data-parallel routing
...@@ -421,6 +445,10 @@ impl Worker for DPAwareWorker { ...@@ -421,6 +445,10 @@ impl Worker for DPAwareWorker {
Box::new(self.clone()) Box::new(self.clone())
} }
fn circuit_breaker(&self) -> &CircuitBreaker {
self.base_worker.circuit_breaker()
}
// DP-aware specific implementations // DP-aware specific implementations
fn is_dp_aware(&self) -> bool { fn is_dp_aware(&self) -> bool {
...@@ -469,6 +497,17 @@ impl WorkerFactory { ...@@ -469,6 +497,17 @@ impl WorkerFactory {
Box::new(BasicWorker::new(url, WorkerType::Regular)) Box::new(BasicWorker::new(url, WorkerType::Regular))
} }
/// Create a regular worker with custom circuit breaker configuration
pub fn create_regular_with_config(
url: String,
circuit_breaker_config: CircuitBreakerConfig,
) -> Box<dyn Worker> {
Box::new(
BasicWorker::new(url, WorkerType::Regular)
.with_circuit_breaker_config(circuit_breaker_config),
)
}
/// Create a prefill worker with optional bootstrap port /// Create a prefill worker with optional bootstrap port
pub fn create_prefill(url: String, bootstrap_port: Option<u16>) -> Box<dyn Worker> { pub fn create_prefill(url: String, bootstrap_port: Option<u16>) -> Box<dyn Worker> {
Box::new(BasicWorker::new( Box::new(BasicWorker::new(
...@@ -477,11 +516,34 @@ impl WorkerFactory { ...@@ -477,11 +516,34 @@ impl WorkerFactory {
)) ))
} }
/// Create a prefill worker with custom circuit breaker configuration
pub fn create_prefill_with_config(
url: String,
bootstrap_port: Option<u16>,
circuit_breaker_config: CircuitBreakerConfig,
) -> Box<dyn Worker> {
Box::new(
BasicWorker::new(url, WorkerType::Prefill { bootstrap_port })
.with_circuit_breaker_config(circuit_breaker_config),
)
}
/// Create a decode worker /// Create a decode worker
pub fn create_decode(url: String) -> Box<dyn Worker> { pub fn create_decode(url: String) -> Box<dyn Worker> {
Box::new(BasicWorker::new(url, WorkerType::Decode)) Box::new(BasicWorker::new(url, WorkerType::Decode))
} }
/// Create a decode worker with custom circuit breaker configuration
pub fn create_decode_with_config(
url: String,
circuit_breaker_config: CircuitBreakerConfig,
) -> Box<dyn Worker> {
Box::new(
BasicWorker::new(url, WorkerType::Decode)
.with_circuit_breaker_config(circuit_breaker_config),
)
}
/// Create workers from URLs with automatic type detection /// Create workers from URLs with automatic type detection
pub fn create_from_urls( pub fn create_from_urls(
regular_urls: Vec<String>, regular_urls: Vec<String>,
...@@ -796,6 +858,7 @@ pub fn start_health_checker( ...@@ -796,6 +858,7 @@ pub fn start_health_checker(
mod tests { mod tests {
use super::*; use super::*;
use std::sync::RwLock; use std::sync::RwLock;
use std::thread;
use std::time::Duration; use std::time::Duration;
use tokio::time::timeout; use tokio::time::timeout;
...@@ -1574,6 +1637,94 @@ mod tests { ...@@ -1574,6 +1637,94 @@ mod tests {
assert_eq!(workers[1].url(), "http://w2:8080"); assert_eq!(workers[1].url(), "http://w2:8080");
} }
// ===== Circuit Breaker Integration Tests =====
#[test]
fn test_worker_circuit_breaker() {
let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular);
// Initial state should be available
assert!(worker.is_available());
assert_eq!(
worker.circuit_breaker().state(),
crate::core::CircuitState::Closed
);
// Record some failures
worker.record_outcome(false);
worker.record_outcome(false);
// Still available (default threshold is 5)
assert!(worker.is_available());
// Record more failures to open circuit
worker.record_outcome(false);
worker.record_outcome(false);
worker.record_outcome(false);
// Circuit should be open, worker not available
assert!(!worker.is_available());
assert!(worker.is_healthy()); // Still healthy
assert!(!worker.circuit_breaker().can_execute()); // But circuit is open
}
#[test]
fn test_worker_with_circuit_breaker_config() {
let config = crate::core::CircuitBreakerConfig {
failure_threshold: 2,
success_threshold: 1,
timeout_duration: Duration::from_millis(100),
window_duration: Duration::from_secs(60),
};
let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular)
.with_circuit_breaker_config(config);
// Should open after 2 failures
worker.record_outcome(false);
assert!(worker.is_available());
worker.record_outcome(false);
assert!(!worker.is_available());
// Wait for timeout
thread::sleep(Duration::from_millis(150));
// Should be half-open
assert!(worker.is_available());
assert_eq!(
worker.circuit_breaker().state(),
crate::core::CircuitState::HalfOpen
);
// Success should close it
worker.record_outcome(true);
assert_eq!(
worker.circuit_breaker().state(),
crate::core::CircuitState::Closed
);
}
#[test]
fn test_dp_aware_worker_circuit_breaker() {
let dp_worker =
DPAwareWorker::new("http://worker:8080".to_string(), 0, 2, WorkerType::Regular);
// Should have circuit breaker
assert!(dp_worker.is_available());
// Record failures
for _ in 0..5 {
dp_worker.record_outcome(false);
}
// Should not be available
assert!(!dp_worker.is_available());
assert_eq!(
dp_worker.circuit_breaker().state(),
crate::core::CircuitState::Open
);
}
// ===== Integration tests ===== // ===== Integration tests =====
#[tokio::test] #[tokio::test]
......
...@@ -147,6 +147,7 @@ impl Router { ...@@ -147,6 +147,7 @@ impl Router {
max_concurrent_requests: self.max_concurrent_requests, max_concurrent_requests: self.max_concurrent_requests,
cors_allowed_origins: self.cors_allowed_origins.clone(), cors_allowed_origins: self.cors_allowed_origins.clone(),
retry: config::RetryConfig::default(), retry: config::RetryConfig::default(),
circuit_breaker: config::CircuitBreakerConfig::default(),
}) })
} }
} }
......
...@@ -51,6 +51,7 @@ impl RouterFactory { ...@@ -51,6 +51,7 @@ impl RouterFactory {
ctx.router_config.dp_aware, ctx.router_config.dp_aware,
ctx.router_config.api_key.clone(), ctx.router_config.api_key.clone(),
ctx.router_config.retry.clone(), ctx.router_config.retry.clone(),
ctx.router_config.circuit_breaker.clone(),
)?; )?;
Ok(Box::new(router)) Ok(Box::new(router))
...@@ -81,6 +82,7 @@ impl RouterFactory { ...@@ -81,6 +82,7 @@ impl RouterFactory {
ctx.router_config.worker_startup_timeout_secs, ctx.router_config.worker_startup_timeout_secs,
ctx.router_config.worker_startup_check_interval_secs, ctx.router_config.worker_startup_check_interval_secs,
ctx.router_config.retry.clone(), ctx.router_config.retry.clone(),
ctx.router_config.circuit_breaker.clone(),
)?; )?;
Ok(Box::new(router)) Ok(Box::new(router))
......
// PD (Prefill-Decode) Router Implementation // PD (Prefill-Decode) Router Implementation
// This module handles routing for disaggregated prefill-decode systems // This module handles routing for disaggregated prefill-decode systems
use super::pd_types::{api_path, PDRouterError}; use super::pd_types::{api_path, PDRouterError};
use crate::config::types::RetryConfig; use crate::config::types::{CircuitBreakerConfig as ConfigCircuitBreakerConfig, RetryConfig};
use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard}; use crate::core::{CircuitBreakerConfig, HealthChecker, Worker, WorkerFactory, WorkerLoadGuard};
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::policies::LoadBalancingPolicy; use crate::policies::LoadBalancingPolicy;
...@@ -41,6 +41,7 @@ pub struct PDRouter { ...@@ -41,6 +41,7 @@ pub struct PDRouter {
// Dedicated client for prefill fire-and-forget (non-logprob) requests // Dedicated client for prefill fire-and-forget (non-logprob) requests
pub prefill_client: Client, pub prefill_client: Client,
pub retry_config: RetryConfig, pub retry_config: RetryConfig,
pub circuit_breaker_config: CircuitBreakerConfig,
_prefill_health_checker: Option<HealthChecker>, _prefill_health_checker: Option<HealthChecker>,
_decode_health_checker: Option<HealthChecker>, _decode_health_checker: Option<HealthChecker>,
} }
...@@ -68,8 +69,12 @@ impl PDRouter { ...@@ -68,8 +69,12 @@ impl PDRouter {
// Wait for the new server to be healthy // Wait for the new server to be healthy
self.wait_for_server_health(&url).await?; self.wait_for_server_health(&url).await?;
// Create Worker for the new prefill server // Create Worker for the new prefill server with circuit breaker configuration
let worker = WorkerFactory::create_prefill(url.clone(), bootstrap_port); let worker = WorkerFactory::create_prefill_with_config(
url.clone(),
bootstrap_port,
self.circuit_breaker_config.clone(),
);
// Add to prefill workers list // Add to prefill workers list
let mut workers = self let mut workers = self
...@@ -99,8 +104,11 @@ impl PDRouter { ...@@ -99,8 +104,11 @@ impl PDRouter {
// Wait for the new server to be healthy // Wait for the new server to be healthy
self.wait_for_server_health(&url).await?; self.wait_for_server_health(&url).await?;
// Create Worker for the new decode server // Create Worker for the new decode server with circuit breaker configuration
let worker = WorkerFactory::create_decode(url.clone()); let worker = WorkerFactory::create_decode_with_config(
url.clone(),
self.circuit_breaker_config.clone(),
);
// Add to decode workers list // Add to decode workers list
let mut workers = self let mut workers = self
...@@ -189,16 +197,31 @@ impl PDRouter { ...@@ -189,16 +197,31 @@ impl PDRouter {
timeout_secs: u64, timeout_secs: u64,
interval_secs: u64, interval_secs: u64,
retry_config: RetryConfig, retry_config: RetryConfig,
circuit_breaker_config: ConfigCircuitBreakerConfig,
) -> Result<Self, String> { ) -> Result<Self, String> {
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold,
timeout_duration: std::time::Duration::from_secs(
circuit_breaker_config.timeout_duration_secs,
),
window_duration: std::time::Duration::from_secs(
circuit_breaker_config.window_duration_secs,
),
};
// Convert URLs to Worker trait objects // Convert URLs to Worker trait objects
let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
.into_iter() .into_iter()
.map(|(url, port)| WorkerFactory::create_prefill(url, port)) .map(|(url, port)| {
WorkerFactory::create_prefill_with_config(url, port, core_cb_config.clone())
})
.collect(); .collect();
let decode_workers: Vec<Box<dyn Worker>> = decode_urls let decode_workers: Vec<Box<dyn Worker>> = decode_urls
.into_iter() .into_iter()
.map(WorkerFactory::create_decode) .map(|url| WorkerFactory::create_decode_with_config(url, core_cb_config.clone()))
.collect(); .collect();
// Wait for PD workers to be healthy (skip if empty - for service discovery mode) // Wait for PD workers to be healthy (skip if empty - for service discovery mode)
...@@ -280,6 +303,7 @@ impl PDRouter { ...@@ -280,6 +303,7 @@ impl PDRouter {
client, client,
prefill_client, prefill_client,
retry_config, retry_config,
circuit_breaker_config: core_cb_config,
_prefill_health_checker: Some(prefill_health_checker), _prefill_health_checker: Some(prefill_health_checker),
_decode_health_checker: Some(decode_health_checker), _decode_health_checker: Some(decode_health_checker),
}) })
...@@ -1848,6 +1872,7 @@ mod tests { ...@@ -1848,6 +1872,7 @@ mod tests {
client: Client::new(), client: Client::new(),
prefill_client: Client::new(), prefill_client: Client::new(),
retry_config: RetryConfig::default(), retry_config: RetryConfig::default(),
circuit_breaker_config: CircuitBreakerConfig::default(),
_prefill_health_checker: None, _prefill_health_checker: None,
_decode_health_checker: None, _decode_health_checker: None,
} }
......
use crate::config::types::RetryConfig; use crate::config::types::{CircuitBreakerConfig as ConfigCircuitBreakerConfig, RetryConfig};
use crate::core::{HealthChecker, Worker, WorkerFactory}; use crate::core::{CircuitBreakerConfig, HealthChecker, Worker, WorkerFactory};
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::policies::LoadBalancingPolicy; use crate::policies::LoadBalancingPolicy;
...@@ -42,6 +42,7 @@ pub struct Router { ...@@ -42,6 +42,7 @@ pub struct Router {
dp_aware: bool, dp_aware: bool,
api_key: Option<String>, api_key: Option<String>,
retry_config: RetryConfig, retry_config: RetryConfig,
circuit_breaker_config: CircuitBreakerConfig,
_worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>, _worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
_load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>, _load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
_health_checker: Option<HealthChecker>, _health_checker: Option<HealthChecker>,
...@@ -58,6 +59,7 @@ impl Router { ...@@ -58,6 +59,7 @@ impl Router {
dp_aware: bool, dp_aware: bool,
api_key: Option<String>, api_key: Option<String>,
retry_config: RetryConfig, retry_config: RetryConfig,
circuit_breaker_config: ConfigCircuitBreakerConfig,
) -> Result<Self, String> { ) -> Result<Self, String> {
// Update active workers gauge // Update active workers gauge
RouterMetrics::set_active_workers(worker_urls.len()); RouterMetrics::set_active_workers(worker_urls.len());
...@@ -75,10 +77,24 @@ impl Router { ...@@ -75,10 +77,24 @@ impl Router {
worker_urls worker_urls
}; };
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold,
timeout_duration: std::time::Duration::from_secs(
circuit_breaker_config.timeout_duration_secs,
),
window_duration: std::time::Duration::from_secs(
circuit_breaker_config.window_duration_secs,
),
};
// Create Worker trait objects from URLs // Create Worker trait objects from URLs
let workers: Vec<Box<dyn Worker>> = worker_urls let workers: Vec<Box<dyn Worker>> = worker_urls
.iter() .iter()
.map(|url| WorkerFactory::create_regular(url.clone())) .map(|url| {
WorkerFactory::create_regular_with_config(url.clone(), core_cb_config.clone())
})
.collect(); .collect();
// Initialize policy with workers if needed (e.g., for cache-aware) // Initialize policy with workers if needed (e.g., for cache-aware)
...@@ -125,6 +141,7 @@ impl Router { ...@@ -125,6 +141,7 @@ impl Router {
dp_aware, dp_aware,
api_key, api_key,
retry_config, retry_config,
circuit_breaker_config: core_cb_config,
_worker_loads: worker_loads, _worker_loads: worker_loads,
_load_monitor_handle: load_monitor_handle, _load_monitor_handle: load_monitor_handle,
_health_checker: Some(health_checker), _health_checker: Some(health_checker),
...@@ -752,7 +769,10 @@ impl Router { ...@@ -752,7 +769,10 @@ impl Router {
continue; continue;
} }
info!("Added worker: {}", dp_url); info!("Added worker: {}", dp_url);
let new_worker = WorkerFactory::create_regular(dp_url.to_string()); let new_worker = WorkerFactory::create_regular_with_config(
dp_url.to_string(),
self.circuit_breaker_config.clone(),
);
workers_guard.push(new_worker); workers_guard.push(new_worker);
worker_added = true; worker_added = true;
} }
...@@ -764,7 +784,10 @@ impl Router { ...@@ -764,7 +784,10 @@ impl Router {
return Err(format!("Worker {} already exists", worker_url)); return Err(format!("Worker {} already exists", worker_url));
} }
info!("Added worker: {}", worker_url); info!("Added worker: {}", worker_url);
let new_worker = WorkerFactory::create_regular(worker_url.to_string()); let new_worker = WorkerFactory::create_regular_with_config(
worker_url.to_string(),
self.circuit_breaker_config.clone(),
);
workers_guard.push(new_worker); workers_guard.push(new_worker);
} }
...@@ -1223,6 +1246,7 @@ mod tests { ...@@ -1223,6 +1246,7 @@ mod tests {
api_key: None, api_key: None,
client: Client::new(), client: Client::new(),
retry_config: RetryConfig::default(), retry_config: RetryConfig::default(),
circuit_breaker_config: CircuitBreakerConfig::default(),
_worker_loads: Arc::new(rx), _worker_loads: Arc::new(rx),
_load_monitor_handle: None, _load_monitor_handle: None,
_health_checker: None, _health_checker: None,
......
...@@ -589,6 +589,7 @@ mod tests { ...@@ -589,6 +589,7 @@ mod tests {
false, false,
None, None,
crate::config::types::RetryConfig::default(), crate::config::types::RetryConfig::default(),
crate::config::types::CircuitBreakerConfig::default(),
) )
.unwrap(); .unwrap();
Arc::new(router) as Arc<dyn RouterTrait> Arc::new(router) as Arc<dyn RouterTrait>
......
...@@ -8,7 +8,9 @@ use axum::{ ...@@ -8,7 +8,9 @@ use axum::{
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType}; use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
use reqwest::Client; use reqwest::Client;
use serde_json::json; use serde_json::json;
use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode}; use sglang_router_rs::config::{
CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
};
use sglang_router_rs::routers::{RouterFactory, RouterTrait}; use sglang_router_rs::routers::{RouterFactory, RouterTrait};
use std::sync::Arc; use std::sync::Arc;
use tower::ServiceExt; use tower::ServiceExt;
...@@ -45,6 +47,7 @@ impl TestContext { ...@@ -45,6 +47,7 @@ impl TestContext {
max_concurrent_requests: 64, max_concurrent_requests: 64,
cors_allowed_origins: vec![], cors_allowed_origins: vec![],
retry: RetryConfig::default(), retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
}; };
Self::new_with_config(config, worker_configs).await Self::new_with_config(config, worker_configs).await
...@@ -1087,6 +1090,7 @@ mod error_tests { ...@@ -1087,6 +1090,7 @@ mod error_tests {
max_concurrent_requests: 64, max_concurrent_requests: 64,
cors_allowed_origins: vec![], cors_allowed_origins: vec![],
retry: RetryConfig::default(), retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
}; };
let ctx = TestContext::new_with_config( let ctx = TestContext::new_with_config(
...@@ -1434,6 +1438,7 @@ mod pd_mode_tests { ...@@ -1434,6 +1438,7 @@ mod pd_mode_tests {
max_concurrent_requests: 64, max_concurrent_requests: 64,
cors_allowed_origins: vec![], cors_allowed_origins: vec![],
retry: RetryConfig::default(), retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
}; };
// Create app context // Create app context
...@@ -1588,6 +1593,7 @@ mod request_id_tests { ...@@ -1588,6 +1593,7 @@ mod request_id_tests {
max_concurrent_requests: 64, max_concurrent_requests: 64,
cors_allowed_origins: vec![], cors_allowed_origins: vec![],
retry: RetryConfig::default(), retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
}; };
let ctx = TestContext::new_with_config( let ctx = TestContext::new_with_config(
......
...@@ -3,7 +3,9 @@ mod common; ...@@ -3,7 +3,9 @@ mod common;
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType}; use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
use reqwest::Client; use reqwest::Client;
use serde_json::json; use serde_json::json;
use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode}; use sglang_router_rs::config::{
CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
};
use sglang_router_rs::routers::{RouterFactory, RouterTrait}; use sglang_router_rs::routers::{RouterFactory, RouterTrait};
use std::sync::Arc; use std::sync::Arc;
...@@ -36,6 +38,7 @@ impl TestContext { ...@@ -36,6 +38,7 @@ impl TestContext {
max_concurrent_requests: 64, max_concurrent_requests: 64,
cors_allowed_origins: vec![], cors_allowed_origins: vec![],
retry: RetryConfig::default(), retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
}; };
let mut workers = Vec::new(); let mut workers = Vec::new();
......
...@@ -4,7 +4,9 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType ...@@ -4,7 +4,9 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType
use futures_util::StreamExt; use futures_util::StreamExt;
use reqwest::Client; use reqwest::Client;
use serde_json::json; use serde_json::json;
use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode}; use sglang_router_rs::config::{
CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
};
use sglang_router_rs::routers::{RouterFactory, RouterTrait}; use sglang_router_rs::routers::{RouterFactory, RouterTrait};
use std::sync::Arc; use std::sync::Arc;
...@@ -37,6 +39,7 @@ impl TestContext { ...@@ -37,6 +39,7 @@ impl TestContext {
max_concurrent_requests: 64, max_concurrent_requests: 64,
cors_allowed_origins: vec![], cors_allowed_origins: vec![],
retry: RetryConfig::default(), retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
}; };
let mut workers = Vec::new(); let mut workers = Vec::new();
......
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
mod test_pd_routing { mod test_pd_routing {
use rand::Rng; use rand::Rng;
use serde_json::json; use serde_json::json;
use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode}; use sglang_router_rs::config::{
CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
};
use sglang_router_rs::core::{WorkerFactory, WorkerType}; use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::routers::pd_types::get_hostname; use sglang_router_rs::routers::pd_types::get_hostname;
use sglang_router_rs::routers::pd_types::PDSelectionPolicy; use sglang_router_rs::routers::pd_types::PDSelectionPolicy;
...@@ -179,6 +181,7 @@ mod test_pd_routing { ...@@ -179,6 +181,7 @@ mod test_pd_routing {
max_concurrent_requests: 64, max_concurrent_requests: 64,
cors_allowed_origins: vec![], cors_allowed_origins: vec![],
retry: RetryConfig::default(), retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
}; };
// Router creation will fail due to health checks, but config should be valid // Router creation will fail due to health checks, but config should be valid
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment