Unverified Commit 354ac435 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[pd-router] Add Configurable Retry Logic for reduce backend pressure (#8744)

parent d98a4913
......@@ -39,6 +39,8 @@ pub struct RouterConfig {
pub max_concurrent_requests: usize,
/// CORS allowed origins
pub cors_allowed_origins: Vec<String>,
/// Retry configuration
pub retry: RetryConfig,
}
/// Routing mode configuration
......@@ -182,6 +184,30 @@ impl Default for DiscoveryConfig {
}
}
/// Retry configuration for request handling
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryConfig {
/// Maximum number of retry attempts
pub max_retries: u32,
/// Initial backoff delay in milliseconds
pub initial_backoff_ms: u64,
/// Maximum backoff delay in milliseconds
pub max_backoff_ms: u64,
/// Backoff multiplier for exponential backoff
pub backoff_multiplier: f32,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
initial_backoff_ms: 100,
max_backoff_ms: 10000,
backoff_multiplier: 2.0,
}
}
}
/// Metrics configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetricsConfig {
......@@ -210,7 +236,7 @@ impl Default for RouterConfig {
host: "127.0.0.1".to_string(),
port: 3001,
max_payload_size: 268_435_456, // 256MB
request_timeout_secs: 600,
request_timeout_secs: 3600, // 1 hour to match Python mini LB
worker_startup_timeout_secs: 300,
worker_startup_check_interval_secs: 10,
dp_aware: false,
......@@ -222,6 +248,7 @@ impl Default for RouterConfig {
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
}
}
}
......@@ -277,7 +304,7 @@ mod tests {
assert_eq!(config.host, "127.0.0.1");
assert_eq!(config.port, 3001);
assert_eq!(config.max_payload_size, 268_435_456);
assert_eq!(config.request_timeout_secs, 600);
assert_eq!(config.request_timeout_secs, 3600);
assert_eq!(config.worker_startup_timeout_secs, 300);
assert_eq!(config.worker_startup_check_interval_secs, 10);
assert!(config.discovery.is_none());
......@@ -332,6 +359,7 @@ mod tests {
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
};
let json = serde_json::to_string(&config).unwrap();
......@@ -759,6 +787,7 @@ mod tests {
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
};
assert!(config.mode.is_pd_mode());
......@@ -810,6 +839,7 @@ mod tests {
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
};
assert!(!config.mode.is_pd_mode());
......@@ -857,6 +887,7 @@ mod tests {
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
};
assert!(config.has_service_discovery());
......
......@@ -19,7 +19,7 @@ pub enum PolicyType {
Random,
RoundRobin,
CacheAware,
PowerOfTwo, // Moved from PD-specific, now shared
PowerOfTwo,
}
#[pyclass]
......@@ -45,7 +45,6 @@ struct Router {
selector: HashMap<String, String>,
service_discovery_port: u16,
service_discovery_namespace: Option<String>,
// PD service discovery fields
prefill_selector: HashMap<String, String>,
decode_selector: HashMap<String, String>,
bootstrap_port_annotation: String,
......@@ -53,14 +52,11 @@ struct Router {
prometheus_host: Option<String>,
request_timeout_secs: u64,
request_id_headers: Option<Vec<String>>,
// PD mode flag
pd_disaggregation: bool,
// PD-specific fields (only used when pd_disaggregation is true)
prefill_urls: Option<Vec<(String, Option<u16>)>>,
decode_urls: Option<Vec<String>>,
prefill_policy: Option<PolicyType>,
decode_policy: Option<PolicyType>,
// Additional server config fields
max_concurrent_requests: usize,
cors_allowed_origins: Vec<String>,
}
......@@ -150,6 +146,7 @@ impl Router {
request_id_headers: self.request_id_headers.clone(),
max_concurrent_requests: self.max_concurrent_requests,
cors_allowed_origins: self.cors_allowed_origins.clone(),
retry: config::RetryConfig::default(),
})
}
}
......@@ -289,7 +286,6 @@ impl Router {
check_interval: std::time::Duration::from_secs(60),
port: self.service_discovery_port,
namespace: self.service_discovery_namespace.clone(),
// PD mode configuration
pd_mode: self.pd_disaggregation,
prefill_selector: self.prefill_selector.clone(),
decode_selector: self.decode_selector.clone(),
......
......@@ -50,6 +50,7 @@ impl RouterFactory {
ctx.router_config.worker_startup_check_interval_secs,
ctx.router_config.dp_aware,
ctx.router_config.api_key.clone(),
ctx.router_config.retry.clone(),
)?;
Ok(Box::new(router))
......@@ -79,6 +80,7 @@ impl RouterFactory {
ctx.client.clone(),
ctx.router_config.worker_startup_timeout_secs,
ctx.router_config.worker_startup_check_interval_secs,
ctx.router_config.retry.clone(),
)?;
Ok(Box::new(router))
......
This diff is collapsed.
use crate::config::types::RetryConfig;
use crate::core::{HealthChecker, Worker, WorkerFactory};
use crate::metrics::RouterMetrics;
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
......@@ -11,6 +12,7 @@ use axum::{
Json,
};
use futures_util::StreamExt;
use reqwest::Client;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::thread;
......@@ -39,6 +41,7 @@ pub struct Router {
interval_secs: u64,
dp_aware: bool,
api_key: Option<String>,
retry_config: RetryConfig,
_worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
_load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
_health_checker: Option<HealthChecker>,
......@@ -54,6 +57,7 @@ impl Router {
interval_secs: u64,
dp_aware: bool,
api_key: Option<String>,
retry_config: RetryConfig,
) -> Result<Self, String> {
// Update active workers gauge
RouterMetrics::set_active_workers(worker_urls.len());
......@@ -120,6 +124,7 @@ impl Router {
interval_secs,
dp_aware,
api_key,
retry_config,
_worker_loads: worker_loads,
_load_monitor_handle: load_monitor_handle,
_health_checker: Some(health_checker),
......@@ -141,6 +146,12 @@ impl Router {
timeout_secs: u64,
interval_secs: u64,
) -> Result<(), String> {
if worker_urls.is_empty() {
return Err(
"Timeout waiting for workers to become healthy: no workers provided".to_string(),
);
}
let start_time = std::time::Instant::now();
let sync_client = reqwest::blocking::Client::builder()
.timeout(Duration::from_secs(timeout_secs))
......@@ -365,11 +376,13 @@ impl Router {
) -> Response {
// Handle retries like the original implementation
let start = Instant::now();
const MAX_REQUEST_RETRIES: u32 = 3;
const MAX_TOTAL_RETRIES: u32 = 6;
// Use retry config for per-worker retries
let max_request_retries = self.retry_config.max_retries;
// Total retries across all workers (2x to allow trying multiple workers)
let max_total_retries = self.retry_config.max_retries * 2;
let mut total_retries = 0;
while total_retries < MAX_TOTAL_RETRIES {
while total_retries < max_total_retries {
// Extract routing text directly from typed request
let text = typed_req.extract_text_for_routing();
let is_stream = typed_req.is_stream();
......@@ -379,7 +392,7 @@ impl Router {
let mut request_retries = 0;
// Try the same worker multiple times
while request_retries < MAX_REQUEST_RETRIES {
while request_retries < max_request_retries {
if total_retries >= 1 {
info!("Retrying request after {} failed attempts", total_retries);
RouterMetrics::record_retry(route);
......@@ -429,13 +442,13 @@ impl Router {
route,
worker_url,
request_retries + 1,
MAX_REQUEST_RETRIES
max_request_retries
);
request_retries += 1;
total_retries += 1;
if request_retries == MAX_REQUEST_RETRIES {
if request_retries == max_request_retries {
warn!(
"Removing failed worker after typed request failures worker_url={}",
worker_url
......@@ -1003,7 +1016,6 @@ impl Router {
}
use async_trait::async_trait;
use reqwest::Client;
#[async_trait]
impl WorkerManagement for Router {
......@@ -1210,6 +1222,7 @@ mod tests {
dp_aware: false,
api_key: None,
client: Client::new(),
retry_config: RetryConfig::default(),
_worker_loads: Arc::new(rx),
_load_monitor_handle: None,
_health_checker: None,
......@@ -1237,8 +1250,10 @@ mod tests {
#[test]
fn test_wait_for_healthy_workers_empty_list() {
// Empty list will timeout as there are no workers to check
let result = Router::wait_for_healthy_workers(&[], 1, 1);
assert!(result.is_ok());
assert!(result.is_err());
assert!(result.unwrap_err().contains("Timeout"));
}
#[test]
......
......@@ -580,8 +580,17 @@ mod tests {
use crate::routers::router::Router;
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
let router =
Router::new(vec![], policy, reqwest::Client::new(), 5, 1, false, None).unwrap();
let router = Router::new(
vec![],
policy,
reqwest::Client::new(),
5,
1,
false,
None,
crate::config::types::RetryConfig::default(),
)
.unwrap();
Arc::new(router) as Arc<dyn RouterTrait>
}
......
......@@ -8,7 +8,7 @@ use axum::{
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
use reqwest::Client;
use serde_json::json;
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode};
use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode};
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
use std::sync::Arc;
use tower::ServiceExt;
......@@ -44,6 +44,7 @@ impl TestContext {
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
};
Self::new_with_config(config, worker_configs).await
......@@ -1085,6 +1086,7 @@ mod error_tests {
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
};
let ctx = TestContext::new_with_config(
......@@ -1431,6 +1433,7 @@ mod pd_mode_tests {
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
};
// Create app context
......@@ -1584,6 +1587,7 @@ mod request_id_tests {
request_id_headers: Some(vec!["custom-id".to_string(), "trace-id".to_string()]),
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
};
let ctx = TestContext::new_with_config(
......
......@@ -3,7 +3,7 @@ mod common;
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
use reqwest::Client;
use serde_json::json;
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode};
use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode};
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
use std::sync::Arc;
......@@ -35,6 +35,7 @@ impl TestContext {
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
};
let mut workers = Vec::new();
......
......@@ -4,7 +4,7 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType
use futures_util::StreamExt;
use reqwest::Client;
use serde_json::json;
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode};
use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode};
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
use std::sync::Arc;
......@@ -36,6 +36,7 @@ impl TestContext {
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
};
let mut workers = Vec::new();
......
......@@ -2,7 +2,7 @@
mod test_pd_routing {
use rand::Rng;
use serde_json::json;
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode};
use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode};
use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::routers::pd_types::get_hostname;
use sglang_router_rs::routers::pd_types::PDSelectionPolicy;
......@@ -178,6 +178,7 @@ mod test_pd_routing {
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
};
// Router creation will fail due to health checks, but config should be valid
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment