Unverified Commit 354ac435 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[pd-router] Add Configurable Retry Logic for reduce backend pressure (#8744)

parent d98a4913
...@@ -39,6 +39,8 @@ pub struct RouterConfig { ...@@ -39,6 +39,8 @@ pub struct RouterConfig {
pub max_concurrent_requests: usize, pub max_concurrent_requests: usize,
/// CORS allowed origins /// CORS allowed origins
pub cors_allowed_origins: Vec<String>, pub cors_allowed_origins: Vec<String>,
/// Retry configuration
pub retry: RetryConfig,
} }
/// Routing mode configuration /// Routing mode configuration
...@@ -182,6 +184,30 @@ impl Default for DiscoveryConfig { ...@@ -182,6 +184,30 @@ impl Default for DiscoveryConfig {
} }
} }
/// Retry configuration for request handling
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryConfig {
/// Maximum number of retry attempts
pub max_retries: u32,
/// Initial backoff delay in milliseconds
pub initial_backoff_ms: u64,
/// Maximum backoff delay in milliseconds
pub max_backoff_ms: u64,
/// Backoff multiplier for exponential backoff
pub backoff_multiplier: f32,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
initial_backoff_ms: 100,
max_backoff_ms: 10000,
backoff_multiplier: 2.0,
}
}
}
/// Metrics configuration /// Metrics configuration
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetricsConfig { pub struct MetricsConfig {
...@@ -210,7 +236,7 @@ impl Default for RouterConfig { ...@@ -210,7 +236,7 @@ impl Default for RouterConfig {
host: "127.0.0.1".to_string(), host: "127.0.0.1".to_string(),
port: 3001, port: 3001,
max_payload_size: 268_435_456, // 256MB max_payload_size: 268_435_456, // 256MB
request_timeout_secs: 600, request_timeout_secs: 3600, // 1 hour to match Python mini LB
worker_startup_timeout_secs: 300, worker_startup_timeout_secs: 300,
worker_startup_check_interval_secs: 10, worker_startup_check_interval_secs: 10,
dp_aware: false, dp_aware: false,
...@@ -222,6 +248,7 @@ impl Default for RouterConfig { ...@@ -222,6 +248,7 @@ impl Default for RouterConfig {
request_id_headers: None, request_id_headers: None,
max_concurrent_requests: 64, max_concurrent_requests: 64,
cors_allowed_origins: vec![], cors_allowed_origins: vec![],
retry: RetryConfig::default(),
} }
} }
} }
...@@ -277,7 +304,7 @@ mod tests { ...@@ -277,7 +304,7 @@ mod tests {
assert_eq!(config.host, "127.0.0.1"); assert_eq!(config.host, "127.0.0.1");
assert_eq!(config.port, 3001); assert_eq!(config.port, 3001);
assert_eq!(config.max_payload_size, 268_435_456); assert_eq!(config.max_payload_size, 268_435_456);
assert_eq!(config.request_timeout_secs, 600); assert_eq!(config.request_timeout_secs, 3600);
assert_eq!(config.worker_startup_timeout_secs, 300); assert_eq!(config.worker_startup_timeout_secs, 300);
assert_eq!(config.worker_startup_check_interval_secs, 10); assert_eq!(config.worker_startup_check_interval_secs, 10);
assert!(config.discovery.is_none()); assert!(config.discovery.is_none());
...@@ -332,6 +359,7 @@ mod tests { ...@@ -332,6 +359,7 @@ mod tests {
request_id_headers: None, request_id_headers: None,
max_concurrent_requests: 64, max_concurrent_requests: 64,
cors_allowed_origins: vec![], cors_allowed_origins: vec![],
retry: RetryConfig::default(),
}; };
let json = serde_json::to_string(&config).unwrap(); let json = serde_json::to_string(&config).unwrap();
...@@ -759,6 +787,7 @@ mod tests { ...@@ -759,6 +787,7 @@ mod tests {
request_id_headers: None, request_id_headers: None,
max_concurrent_requests: 64, max_concurrent_requests: 64,
cors_allowed_origins: vec![], cors_allowed_origins: vec![],
retry: RetryConfig::default(),
}; };
assert!(config.mode.is_pd_mode()); assert!(config.mode.is_pd_mode());
...@@ -810,6 +839,7 @@ mod tests { ...@@ -810,6 +839,7 @@ mod tests {
request_id_headers: None, request_id_headers: None,
max_concurrent_requests: 64, max_concurrent_requests: 64,
cors_allowed_origins: vec![], cors_allowed_origins: vec![],
retry: RetryConfig::default(),
}; };
assert!(!config.mode.is_pd_mode()); assert!(!config.mode.is_pd_mode());
...@@ -857,6 +887,7 @@ mod tests { ...@@ -857,6 +887,7 @@ mod tests {
request_id_headers: None, request_id_headers: None,
max_concurrent_requests: 64, max_concurrent_requests: 64,
cors_allowed_origins: vec![], cors_allowed_origins: vec![],
retry: RetryConfig::default(),
}; };
assert!(config.has_service_discovery()); assert!(config.has_service_discovery());
......
...@@ -19,7 +19,7 @@ pub enum PolicyType { ...@@ -19,7 +19,7 @@ pub enum PolicyType {
Random, Random,
RoundRobin, RoundRobin,
CacheAware, CacheAware,
PowerOfTwo, // Moved from PD-specific, now shared PowerOfTwo,
} }
#[pyclass] #[pyclass]
...@@ -45,7 +45,6 @@ struct Router { ...@@ -45,7 +45,6 @@ struct Router {
selector: HashMap<String, String>, selector: HashMap<String, String>,
service_discovery_port: u16, service_discovery_port: u16,
service_discovery_namespace: Option<String>, service_discovery_namespace: Option<String>,
// PD service discovery fields
prefill_selector: HashMap<String, String>, prefill_selector: HashMap<String, String>,
decode_selector: HashMap<String, String>, decode_selector: HashMap<String, String>,
bootstrap_port_annotation: String, bootstrap_port_annotation: String,
...@@ -53,14 +52,11 @@ struct Router { ...@@ -53,14 +52,11 @@ struct Router {
prometheus_host: Option<String>, prometheus_host: Option<String>,
request_timeout_secs: u64, request_timeout_secs: u64,
request_id_headers: Option<Vec<String>>, request_id_headers: Option<Vec<String>>,
// PD mode flag
pd_disaggregation: bool, pd_disaggregation: bool,
// PD-specific fields (only used when pd_disaggregation is true)
prefill_urls: Option<Vec<(String, Option<u16>)>>, prefill_urls: Option<Vec<(String, Option<u16>)>>,
decode_urls: Option<Vec<String>>, decode_urls: Option<Vec<String>>,
prefill_policy: Option<PolicyType>, prefill_policy: Option<PolicyType>,
decode_policy: Option<PolicyType>, decode_policy: Option<PolicyType>,
// Additional server config fields
max_concurrent_requests: usize, max_concurrent_requests: usize,
cors_allowed_origins: Vec<String>, cors_allowed_origins: Vec<String>,
} }
...@@ -150,6 +146,7 @@ impl Router { ...@@ -150,6 +146,7 @@ impl Router {
request_id_headers: self.request_id_headers.clone(), request_id_headers: self.request_id_headers.clone(),
max_concurrent_requests: self.max_concurrent_requests, max_concurrent_requests: self.max_concurrent_requests,
cors_allowed_origins: self.cors_allowed_origins.clone(), cors_allowed_origins: self.cors_allowed_origins.clone(),
retry: config::RetryConfig::default(),
}) })
} }
} }
...@@ -289,7 +286,6 @@ impl Router { ...@@ -289,7 +286,6 @@ impl Router {
check_interval: std::time::Duration::from_secs(60), check_interval: std::time::Duration::from_secs(60),
port: self.service_discovery_port, port: self.service_discovery_port,
namespace: self.service_discovery_namespace.clone(), namespace: self.service_discovery_namespace.clone(),
// PD mode configuration
pd_mode: self.pd_disaggregation, pd_mode: self.pd_disaggregation,
prefill_selector: self.prefill_selector.clone(), prefill_selector: self.prefill_selector.clone(),
decode_selector: self.decode_selector.clone(), decode_selector: self.decode_selector.clone(),
......
...@@ -50,6 +50,7 @@ impl RouterFactory { ...@@ -50,6 +50,7 @@ impl RouterFactory {
ctx.router_config.worker_startup_check_interval_secs, ctx.router_config.worker_startup_check_interval_secs,
ctx.router_config.dp_aware, ctx.router_config.dp_aware,
ctx.router_config.api_key.clone(), ctx.router_config.api_key.clone(),
ctx.router_config.retry.clone(),
)?; )?;
Ok(Box::new(router)) Ok(Box::new(router))
...@@ -79,6 +80,7 @@ impl RouterFactory { ...@@ -79,6 +80,7 @@ impl RouterFactory {
ctx.client.clone(), ctx.client.clone(),
ctx.router_config.worker_startup_timeout_secs, ctx.router_config.worker_startup_timeout_secs,
ctx.router_config.worker_startup_check_interval_secs, ctx.router_config.worker_startup_check_interval_secs,
ctx.router_config.retry.clone(),
)?; )?;
Ok(Box::new(router)) Ok(Box::new(router))
......
This diff is collapsed.
use crate::config::types::RetryConfig;
use crate::core::{HealthChecker, Worker, WorkerFactory}; use crate::core::{HealthChecker, Worker, WorkerFactory};
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
...@@ -11,6 +12,7 @@ use axum::{ ...@@ -11,6 +12,7 @@ use axum::{
Json, Json,
}; };
use futures_util::StreamExt; use futures_util::StreamExt;
use reqwest::Client;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use std::thread; use std::thread;
...@@ -39,6 +41,7 @@ pub struct Router { ...@@ -39,6 +41,7 @@ pub struct Router {
interval_secs: u64, interval_secs: u64,
dp_aware: bool, dp_aware: bool,
api_key: Option<String>, api_key: Option<String>,
retry_config: RetryConfig,
_worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>, _worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
_load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>, _load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
_health_checker: Option<HealthChecker>, _health_checker: Option<HealthChecker>,
...@@ -54,6 +57,7 @@ impl Router { ...@@ -54,6 +57,7 @@ impl Router {
interval_secs: u64, interval_secs: u64,
dp_aware: bool, dp_aware: bool,
api_key: Option<String>, api_key: Option<String>,
retry_config: RetryConfig,
) -> Result<Self, String> { ) -> Result<Self, String> {
// Update active workers gauge // Update active workers gauge
RouterMetrics::set_active_workers(worker_urls.len()); RouterMetrics::set_active_workers(worker_urls.len());
...@@ -120,6 +124,7 @@ impl Router { ...@@ -120,6 +124,7 @@ impl Router {
interval_secs, interval_secs,
dp_aware, dp_aware,
api_key, api_key,
retry_config,
_worker_loads: worker_loads, _worker_loads: worker_loads,
_load_monitor_handle: load_monitor_handle, _load_monitor_handle: load_monitor_handle,
_health_checker: Some(health_checker), _health_checker: Some(health_checker),
...@@ -141,6 +146,12 @@ impl Router { ...@@ -141,6 +146,12 @@ impl Router {
timeout_secs: u64, timeout_secs: u64,
interval_secs: u64, interval_secs: u64,
) -> Result<(), String> { ) -> Result<(), String> {
if worker_urls.is_empty() {
return Err(
"Timeout waiting for workers to become healthy: no workers provided".to_string(),
);
}
let start_time = std::time::Instant::now(); let start_time = std::time::Instant::now();
let sync_client = reqwest::blocking::Client::builder() let sync_client = reqwest::blocking::Client::builder()
.timeout(Duration::from_secs(timeout_secs)) .timeout(Duration::from_secs(timeout_secs))
...@@ -365,11 +376,13 @@ impl Router { ...@@ -365,11 +376,13 @@ impl Router {
) -> Response { ) -> Response {
// Handle retries like the original implementation // Handle retries like the original implementation
let start = Instant::now(); let start = Instant::now();
const MAX_REQUEST_RETRIES: u32 = 3; // Use retry config for per-worker retries
const MAX_TOTAL_RETRIES: u32 = 6; let max_request_retries = self.retry_config.max_retries;
// Total retries across all workers (2x to allow trying multiple workers)
let max_total_retries = self.retry_config.max_retries * 2;
let mut total_retries = 0; let mut total_retries = 0;
while total_retries < MAX_TOTAL_RETRIES { while total_retries < max_total_retries {
// Extract routing text directly from typed request // Extract routing text directly from typed request
let text = typed_req.extract_text_for_routing(); let text = typed_req.extract_text_for_routing();
let is_stream = typed_req.is_stream(); let is_stream = typed_req.is_stream();
...@@ -379,7 +392,7 @@ impl Router { ...@@ -379,7 +392,7 @@ impl Router {
let mut request_retries = 0; let mut request_retries = 0;
// Try the same worker multiple times // Try the same worker multiple times
while request_retries < MAX_REQUEST_RETRIES { while request_retries < max_request_retries {
if total_retries >= 1 { if total_retries >= 1 {
info!("Retrying request after {} failed attempts", total_retries); info!("Retrying request after {} failed attempts", total_retries);
RouterMetrics::record_retry(route); RouterMetrics::record_retry(route);
...@@ -429,13 +442,13 @@ impl Router { ...@@ -429,13 +442,13 @@ impl Router {
route, route,
worker_url, worker_url,
request_retries + 1, request_retries + 1,
MAX_REQUEST_RETRIES max_request_retries
); );
request_retries += 1; request_retries += 1;
total_retries += 1; total_retries += 1;
if request_retries == MAX_REQUEST_RETRIES { if request_retries == max_request_retries {
warn!( warn!(
"Removing failed worker after typed request failures worker_url={}", "Removing failed worker after typed request failures worker_url={}",
worker_url worker_url
...@@ -1003,7 +1016,6 @@ impl Router { ...@@ -1003,7 +1016,6 @@ impl Router {
} }
use async_trait::async_trait; use async_trait::async_trait;
use reqwest::Client;
#[async_trait] #[async_trait]
impl WorkerManagement for Router { impl WorkerManagement for Router {
...@@ -1210,6 +1222,7 @@ mod tests { ...@@ -1210,6 +1222,7 @@ mod tests {
dp_aware: false, dp_aware: false,
api_key: None, api_key: None,
client: Client::new(), client: Client::new(),
retry_config: RetryConfig::default(),
_worker_loads: Arc::new(rx), _worker_loads: Arc::new(rx),
_load_monitor_handle: None, _load_monitor_handle: None,
_health_checker: None, _health_checker: None,
...@@ -1237,8 +1250,10 @@ mod tests { ...@@ -1237,8 +1250,10 @@ mod tests {
#[test] #[test]
fn test_wait_for_healthy_workers_empty_list() { fn test_wait_for_healthy_workers_empty_list() {
// Empty list will timeout as there are no workers to check
let result = Router::wait_for_healthy_workers(&[], 1, 1); let result = Router::wait_for_healthy_workers(&[], 1, 1);
assert!(result.is_ok()); assert!(result.is_err());
assert!(result.unwrap_err().contains("Timeout"));
} }
#[test] #[test]
......
...@@ -580,8 +580,17 @@ mod tests { ...@@ -580,8 +580,17 @@ mod tests {
use crate::routers::router::Router; use crate::routers::router::Router;
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random); let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
let router = let router = Router::new(
Router::new(vec![], policy, reqwest::Client::new(), 5, 1, false, None).unwrap(); vec![],
policy,
reqwest::Client::new(),
5,
1,
false,
None,
crate::config::types::RetryConfig::default(),
)
.unwrap();
Arc::new(router) as Arc<dyn RouterTrait> Arc::new(router) as Arc<dyn RouterTrait>
} }
......
...@@ -8,7 +8,7 @@ use axum::{ ...@@ -8,7 +8,7 @@ use axum::{
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType}; use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
use reqwest::Client; use reqwest::Client;
use serde_json::json; use serde_json::json;
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode}; use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode};
use sglang_router_rs::routers::{RouterFactory, RouterTrait}; use sglang_router_rs::routers::{RouterFactory, RouterTrait};
use std::sync::Arc; use std::sync::Arc;
use tower::ServiceExt; use tower::ServiceExt;
...@@ -44,6 +44,7 @@ impl TestContext { ...@@ -44,6 +44,7 @@ impl TestContext {
request_id_headers: None, request_id_headers: None,
max_concurrent_requests: 64, max_concurrent_requests: 64,
cors_allowed_origins: vec![], cors_allowed_origins: vec![],
retry: RetryConfig::default(),
}; };
Self::new_with_config(config, worker_configs).await Self::new_with_config(config, worker_configs).await
...@@ -1085,6 +1086,7 @@ mod error_tests { ...@@ -1085,6 +1086,7 @@ mod error_tests {
request_id_headers: None, request_id_headers: None,
max_concurrent_requests: 64, max_concurrent_requests: 64,
cors_allowed_origins: vec![], cors_allowed_origins: vec![],
retry: RetryConfig::default(),
}; };
let ctx = TestContext::new_with_config( let ctx = TestContext::new_with_config(
...@@ -1431,6 +1433,7 @@ mod pd_mode_tests { ...@@ -1431,6 +1433,7 @@ mod pd_mode_tests {
request_id_headers: None, request_id_headers: None,
max_concurrent_requests: 64, max_concurrent_requests: 64,
cors_allowed_origins: vec![], cors_allowed_origins: vec![],
retry: RetryConfig::default(),
}; };
// Create app context // Create app context
...@@ -1584,6 +1587,7 @@ mod request_id_tests { ...@@ -1584,6 +1587,7 @@ mod request_id_tests {
request_id_headers: Some(vec!["custom-id".to_string(), "trace-id".to_string()]), request_id_headers: Some(vec!["custom-id".to_string(), "trace-id".to_string()]),
max_concurrent_requests: 64, max_concurrent_requests: 64,
cors_allowed_origins: vec![], cors_allowed_origins: vec![],
retry: RetryConfig::default(),
}; };
let ctx = TestContext::new_with_config( let ctx = TestContext::new_with_config(
......
...@@ -3,7 +3,7 @@ mod common; ...@@ -3,7 +3,7 @@ mod common;
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType}; use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
use reqwest::Client; use reqwest::Client;
use serde_json::json; use serde_json::json;
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode}; use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode};
use sglang_router_rs::routers::{RouterFactory, RouterTrait}; use sglang_router_rs::routers::{RouterFactory, RouterTrait};
use std::sync::Arc; use std::sync::Arc;
...@@ -35,6 +35,7 @@ impl TestContext { ...@@ -35,6 +35,7 @@ impl TestContext {
request_id_headers: None, request_id_headers: None,
max_concurrent_requests: 64, max_concurrent_requests: 64,
cors_allowed_origins: vec![], cors_allowed_origins: vec![],
retry: RetryConfig::default(),
}; };
let mut workers = Vec::new(); let mut workers = Vec::new();
......
...@@ -4,7 +4,7 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType ...@@ -4,7 +4,7 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType
use futures_util::StreamExt; use futures_util::StreamExt;
use reqwest::Client; use reqwest::Client;
use serde_json::json; use serde_json::json;
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode}; use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode};
use sglang_router_rs::routers::{RouterFactory, RouterTrait}; use sglang_router_rs::routers::{RouterFactory, RouterTrait};
use std::sync::Arc; use std::sync::Arc;
...@@ -36,6 +36,7 @@ impl TestContext { ...@@ -36,6 +36,7 @@ impl TestContext {
request_id_headers: None, request_id_headers: None,
max_concurrent_requests: 64, max_concurrent_requests: 64,
cors_allowed_origins: vec![], cors_allowed_origins: vec![],
retry: RetryConfig::default(),
}; };
let mut workers = Vec::new(); let mut workers = Vec::new();
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
mod test_pd_routing { mod test_pd_routing {
use rand::Rng; use rand::Rng;
use serde_json::json; use serde_json::json;
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode}; use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode};
use sglang_router_rs::core::{WorkerFactory, WorkerType}; use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::routers::pd_types::get_hostname; use sglang_router_rs::routers::pd_types::get_hostname;
use sglang_router_rs::routers::pd_types::PDSelectionPolicy; use sglang_router_rs::routers::pd_types::PDSelectionPolicy;
...@@ -178,6 +178,7 @@ mod test_pd_routing { ...@@ -178,6 +178,7 @@ mod test_pd_routing {
request_id_headers: None, request_id_headers: None,
max_concurrent_requests: 64, max_concurrent_requests: 64,
cors_allowed_origins: vec![], cors_allowed_origins: vec![],
retry: RetryConfig::default(),
}; };
// Router creation will fail due to health checks, but config should be valid // Router creation will fail due to health checks, but config should be valid
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment