Unverified Commit 354ac435 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[pd-router] Add Configurable Retry Logic for reduce backend pressure (#8744)

parent d98a4913
......@@ -39,6 +39,8 @@ pub struct RouterConfig {
pub max_concurrent_requests: usize,
/// CORS allowed origins
pub cors_allowed_origins: Vec<String>,
/// Retry configuration
pub retry: RetryConfig,
}
/// Routing mode configuration
......@@ -182,6 +184,30 @@ impl Default for DiscoveryConfig {
}
}
/// Retry configuration for request handling
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryConfig {
/// Maximum number of retry attempts
pub max_retries: u32,
/// Initial backoff delay in milliseconds
pub initial_backoff_ms: u64,
/// Maximum backoff delay in milliseconds
pub max_backoff_ms: u64,
/// Backoff multiplier for exponential backoff
pub backoff_multiplier: f32,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
initial_backoff_ms: 100,
max_backoff_ms: 10000,
backoff_multiplier: 2.0,
}
}
}
/// Metrics configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetricsConfig {
......@@ -210,7 +236,7 @@ impl Default for RouterConfig {
host: "127.0.0.1".to_string(),
port: 3001,
max_payload_size: 268_435_456, // 256MB
request_timeout_secs: 600,
request_timeout_secs: 3600, // 1 hour to match Python mini LB
worker_startup_timeout_secs: 300,
worker_startup_check_interval_secs: 10,
dp_aware: false,
......@@ -222,6 +248,7 @@ impl Default for RouterConfig {
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
}
}
}
......@@ -277,7 +304,7 @@ mod tests {
assert_eq!(config.host, "127.0.0.1");
assert_eq!(config.port, 3001);
assert_eq!(config.max_payload_size, 268_435_456);
assert_eq!(config.request_timeout_secs, 600);
assert_eq!(config.request_timeout_secs, 3600);
assert_eq!(config.worker_startup_timeout_secs, 300);
assert_eq!(config.worker_startup_check_interval_secs, 10);
assert!(config.discovery.is_none());
......@@ -332,6 +359,7 @@ mod tests {
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
};
let json = serde_json::to_string(&config).unwrap();
......@@ -759,6 +787,7 @@ mod tests {
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
};
assert!(config.mode.is_pd_mode());
......@@ -810,6 +839,7 @@ mod tests {
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
};
assert!(!config.mode.is_pd_mode());
......@@ -857,6 +887,7 @@ mod tests {
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
};
assert!(config.has_service_discovery());
......
......@@ -19,7 +19,7 @@ pub enum PolicyType {
Random,
RoundRobin,
CacheAware,
PowerOfTwo, // Moved from PD-specific, now shared
PowerOfTwo,
}
#[pyclass]
......@@ -45,7 +45,6 @@ struct Router {
selector: HashMap<String, String>,
service_discovery_port: u16,
service_discovery_namespace: Option<String>,
// PD service discovery fields
prefill_selector: HashMap<String, String>,
decode_selector: HashMap<String, String>,
bootstrap_port_annotation: String,
......@@ -53,14 +52,11 @@ struct Router {
prometheus_host: Option<String>,
request_timeout_secs: u64,
request_id_headers: Option<Vec<String>>,
// PD mode flag
pd_disaggregation: bool,
// PD-specific fields (only used when pd_disaggregation is true)
prefill_urls: Option<Vec<(String, Option<u16>)>>,
decode_urls: Option<Vec<String>>,
prefill_policy: Option<PolicyType>,
decode_policy: Option<PolicyType>,
// Additional server config fields
max_concurrent_requests: usize,
cors_allowed_origins: Vec<String>,
}
......@@ -150,6 +146,7 @@ impl Router {
request_id_headers: self.request_id_headers.clone(),
max_concurrent_requests: self.max_concurrent_requests,
cors_allowed_origins: self.cors_allowed_origins.clone(),
retry: config::RetryConfig::default(),
})
}
}
......@@ -289,7 +286,6 @@ impl Router {
check_interval: std::time::Duration::from_secs(60),
port: self.service_discovery_port,
namespace: self.service_discovery_namespace.clone(),
// PD mode configuration
pd_mode: self.pd_disaggregation,
prefill_selector: self.prefill_selector.clone(),
decode_selector: self.decode_selector.clone(),
......
......@@ -50,6 +50,7 @@ impl RouterFactory {
ctx.router_config.worker_startup_check_interval_secs,
ctx.router_config.dp_aware,
ctx.router_config.api_key.clone(),
ctx.router_config.retry.clone(),
)?;
Ok(Box::new(router))
......@@ -79,6 +80,7 @@ impl RouterFactory {
ctx.client.clone(),
ctx.router_config.worker_startup_timeout_secs,
ctx.router_config.worker_startup_check_interval_secs,
ctx.router_config.retry.clone(),
)?;
Ok(Box::new(router))
......
......@@ -3,6 +3,7 @@
use super::pd_types::{api_path, Bootstrap, ChatReqInput, GenerateReqInput, PDRouterError};
use super::request_adapter::ToPdRequest;
use crate::config::types::RetryConfig;
use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard};
use crate::metrics::RouterMetrics;
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
......@@ -16,6 +17,8 @@ use axum::{
Json,
};
use futures_util::StreamExt;
use rand::Rng;
use reqwest::Client;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, RwLock};
......@@ -36,6 +39,7 @@ pub struct PDRouter {
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
pub client: Client,
pub retry_config: RetryConfig,
_prefill_health_checker: Option<HealthChecker>,
_decode_health_checker: Option<HealthChecker>,
}
......@@ -180,6 +184,7 @@ impl PDRouter {
client: Client,
timeout_secs: u64,
interval_secs: u64,
retry_config: RetryConfig,
) -> Result<Self, String> {
// Convert URLs to Worker trait objects
let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
......@@ -260,6 +265,7 @@ impl PDRouter {
worker_loads,
load_monitor_handle,
client,
retry_config,
_prefill_health_checker: Some(prefill_health_checker),
_decode_health_checker: Some(decode_health_checker),
})
......@@ -294,6 +300,38 @@ impl PDRouter {
}
}
// Helper to handle server selection errors
fn handle_server_selection_error(error: String) -> Response {
error!("Failed to select PD pair error={}", error);
RouterMetrics::record_pd_error("server_selection");
(
StatusCode::SERVICE_UNAVAILABLE,
format!("No available servers: {}", error),
)
.into_response()
}
// Helper to handle bootstrap injection errors
fn handle_bootstrap_error(error: impl std::fmt::Display) -> Response {
error!("Failed to add bootstrap info error={}", error);
RouterMetrics::record_pd_error("bootstrap_injection");
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Bootstrap injection failed: {}", error),
)
.into_response()
}
// Helper to handle serialization errors
fn handle_serialization_error(error: impl std::fmt::Display) -> Response {
error!("Failed to serialize request error={}", error);
(
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to serialize request",
)
.into_response()
}
// Route a typed generate request
pub async fn route_generate(
&self,
......@@ -320,15 +358,7 @@ impl PDRouter {
// Select servers
let (prefill, decode) = match self.select_pd_pair(request_text).await {
Ok(pair) => pair,
Err(e) => {
error!("Failed to select PD pair error={}", e);
RouterMetrics::record_pd_error("server_selection");
return (
StatusCode::SERVICE_UNAVAILABLE,
format!("No available servers: {}", e),
)
.into_response();
}
Err(e) => return Self::handle_server_selection_error(e),
};
// Log routing decision
......@@ -341,26 +371,13 @@ impl PDRouter {
// Add bootstrap info using the trait method
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
error!("Failed to add bootstrap info error={}", e);
RouterMetrics::record_pd_error("bootstrap_injection");
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Bootstrap injection failed: {}", e),
)
.into_response();
return Self::handle_bootstrap_error(e);
}
// Convert to JSON after bootstrap injection
let json_with_bootstrap = match serde_json::to_value(&typed_req) {
Ok(json) => json,
Err(e) => {
error!("Failed to serialize request error={}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to serialize request",
)
.into_response();
}
Err(e) => return Self::handle_serialization_error(e),
};
// Execute dual dispatch
......@@ -406,15 +423,7 @@ impl PDRouter {
// Select servers
let (prefill, decode) = match self.select_pd_pair(request_text).await {
Ok(pair) => pair,
Err(e) => {
error!("Failed to select PD pair error={}", e);
RouterMetrics::record_pd_error("server_selection");
return (
StatusCode::SERVICE_UNAVAILABLE,
format!("No available servers: {}", e),
)
.into_response();
}
Err(e) => return Self::handle_server_selection_error(e),
};
// Log routing decision
......@@ -425,28 +434,14 @@ impl PDRouter {
decode.url()
);
// Add bootstrap info using the trait method
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
error!("Failed to add bootstrap info error={}", e);
RouterMetrics::record_pd_error("bootstrap_injection");
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Bootstrap injection failed: {}", e),
)
.into_response();
return Self::handle_bootstrap_error(e);
}
// Convert to JSON after bootstrap injection
let json_with_bootstrap = match serde_json::to_value(&typed_req) {
Ok(json) => json,
Err(e) => {
error!("Failed to serialize request error={}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to serialize request",
)
.into_response();
}
Err(e) => return Self::handle_serialization_error(e),
};
// Execute dual dispatch
......@@ -485,15 +480,7 @@ impl PDRouter {
// Select servers
let (prefill, decode) = match self.select_pd_pair(request_text).await {
Ok(pair) => pair,
Err(e) => {
error!("Failed to select PD pair error={}", e);
RouterMetrics::record_pd_error("server_selection");
return (
StatusCode::SERVICE_UNAVAILABLE,
format!("No available servers: {}", e),
)
.into_response();
}
Err(e) => return Self::handle_server_selection_error(e),
};
// Log routing decision
......@@ -504,28 +491,14 @@ impl PDRouter {
decode.url()
);
// Add bootstrap info using the trait method
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
error!("Failed to add bootstrap info error={}", e);
RouterMetrics::record_pd_error("bootstrap_injection");
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Bootstrap injection failed: {}", e),
)
.into_response();
return Self::handle_bootstrap_error(e);
}
// Convert to JSON after bootstrap injection
let json_with_bootstrap = match serde_json::to_value(&typed_req) {
Ok(json) => json,
Err(e) => {
error!("Failed to serialize request error={}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to serialize request",
)
.into_response();
}
Err(e) => return Self::handle_serialization_error(e),
};
// Execute dual dispatch
......@@ -542,7 +515,7 @@ impl PDRouter {
.await
}
// Execute the dual dispatch to prefill and decode servers
// Execute the dual dispatch to prefill and decode servers with retry logic
async fn execute_dual_dispatch(
&self,
headers: Option<&HeaderMap>,
......@@ -554,37 +527,127 @@ impl PDRouter {
return_logprob: bool,
start_time: Instant,
) -> Response {
// Update load tracking for both workers
let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]);
for attempt in 0..self.retry_config.max_retries {
if attempt > 0 {
// Calculate backoff with exponential growth and jitter
let base_backoff = self.retry_config.initial_backoff_ms as f64
* self
.retry_config
.backoff_multiplier
.powf((attempt - 1) as f32) as f64;
let backoff_ms = base_backoff.min(self.retry_config.max_backoff_ms as f64) as u64;
// Add jitter to prevent thundering herd
let jitter = {
let mut rng = rand::thread_rng();
rng.gen_range(0..backoff_ms / 2)
};
let total_backoff = Duration::from_millis(backoff_ms + jitter);
info!(
"Retrying request (attempt {}/{}) after {:?} backoff",
attempt + 1,
self.retry_config.max_retries,
total_backoff
);
// Build requests using .json() method
let mut prefill_request = self
.client
.post(api_path(prefill.url(), route))
.json(&json_request);
tokio::time::sleep(total_backoff).await;
}
let mut decode_request = self
.client
.post(api_path(decode.url(), route))
.json(&json_request);
debug!(
"Executing request attempt {}/{}",
attempt + 1,
self.retry_config.max_retries
);
let result = self
.execute_dual_dispatch_inner(
headers,
json_request.clone(),
route,
prefill,
decode,
is_stream,
return_logprob,
start_time,
)
.await;
// Copy headers from original request (excluding content-type and content-length which are set by .json())
if let Some(headers) = headers {
for (name, value) in headers.iter() {
let name_str = name.as_str();
if name_str != "content-type" && name_str != "content-length" {
// Skip headers with non-ASCII values
if value.to_str().is_ok() {
prefill_request = prefill_request.header(name, value);
decode_request = decode_request.header(name, value);
}
}
// Check if we should retry based on the response status
let status = result.status();
debug!(
"Request attempt {} returned status: {}",
attempt + 1,
status
);
// Don't retry client errors (4xx) or successful responses
if status.is_client_error() || status.is_success() {
debug!(
"Returning response with status {} (no retry needed)",
status
);
return result;
}
// Check if this is the last attempt
if attempt == self.retry_config.max_retries - 1 {
warn!("Final attempt failed with status {}", status);
return result;
}
// Log retry decision for retryable errors
if status.is_server_error()
|| status == StatusCode::BAD_GATEWAY
|| status == StatusCode::GATEWAY_TIMEOUT
{
warn!(
"Retryable error status: {} on attempt {}/{}. Will retry.",
status,
attempt + 1,
self.retry_config.max_retries
);
} else {
// Don't retry other statuses
debug!("Status {} is not retryable, returning response", status);
return result;
}
}
// This should never be reached due to the loop logic, but just in case
unreachable!("Retry loop completed without returning")
}
// Inner implementation of dual dispatch (extracted for retry logic)
async fn execute_dual_dispatch_inner(
&self,
headers: Option<&HeaderMap>,
json_request: Value,
route: &str,
prefill: &dyn Worker,
decode: &dyn Worker,
is_stream: bool,
return_logprob: bool,
start_time: Instant,
) -> Response {
// Update load tracking for both workers
let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]);
// Build requests with headers
let prefill_request =
self.build_request_with_headers(prefill.url(), route, &json_request, headers);
let decode_request =
self.build_request_with_headers(decode.url(), route, &json_request, headers);
// Send both requests concurrently
debug!(
"Sending concurrent requests to prefill={} decode={}",
prefill.url(),
decode.url()
);
let (prefill_result, decode_result) =
tokio::join!(prefill_request.send(), decode_request.send());
debug!("Received responses from both servers");
// Update metrics
let duration = start_time.elapsed();
......@@ -593,11 +656,22 @@ impl PDRouter {
RouterMetrics::record_pd_prefill_request(prefill.url());
RouterMetrics::record_pd_decode_request(decode.url());
// Process prefill response
let (_prefill_status, prefill_body) = match self
.process_prefill_response(prefill_result, prefill.url(), return_logprob)
.await
{
Ok(result) => result,
Err(error_response) => return error_response,
};
// Process decode response
debug!("Processing decode response");
match decode_result {
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
debug!("Decode response status: {}", status);
if !status.is_success() {
RouterMetrics::record_pd_decode_error(decode.url());
......@@ -618,128 +692,36 @@ impl PDRouter {
}
}
// Log prefill errors for debugging
if let Err(e) = &prefill_result {
error!(
"Prefill server failed (non-critical) prefill_url={} error={}",
prefill.url(),
e
);
RouterMetrics::record_pd_prefill_error(prefill.url());
}
if is_stream {
// Streaming response
if return_logprob {
// Get prefill logprobs for merging
let prefill_logprobs =
match prefill_result {
Ok(prefill_res) => match prefill_res.bytes().await {
Ok(body) => serde_json::from_slice::<Value>(&body)
.ok()
.and_then(|json| {
json.pointer("/meta_info/input_token_logprobs").cloned()
}),
Err(_) => None,
},
Err(_) => None,
};
// Stream with logprob merging
let stream = res.bytes_stream();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
tokio::spawn(async move {
let mut stream = stream;
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
// Try to merge logprobs
if let Ok(merged) = Self::merge_streaming_logprobs(
prefill_logprobs.clone(),
&chunk,
) {
if tx.send(Ok(merged)).is_err() {
break;
}
} else {
if tx.send(Ok(chunk)).is_err() {
break;
}
}
}
Err(e) => {
let _ = tx.send(Err(format!("Stream error: {}", e)));
break;
}
}
}
});
let stream = UnboundedReceiverStream::new(rx);
let body = Body::from_stream(stream);
let mut response = Response::new(body);
*response.status_mut() = status;
response
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
response
let prefill_logprobs = if return_logprob {
prefill_body
.as_ref()
.and_then(|body| serde_json::from_slice::<Value>(body).ok())
.and_then(|json| {
json.pointer("/meta_info/input_token_logprobs").cloned()
})
} else {
// No logprob merging needed
let stream = res.bytes_stream();
let decode_url = decode.url().to_string();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
tokio::spawn(async move {
let mut stream = stream;
while let Some(chunk) = stream.next().await {
match chunk {
Ok(bytes) => {
if tx.send(Ok(bytes)).is_err() {
break;
}
}
Err(e) => {
error!(
"Stream error from decode server {}: {}",
decode_url, e
);
RouterMetrics::record_pd_stream_error(&decode_url);
let _ = tx.send(Err(format!("Stream error: {}", e)));
break;
}
}
}
});
None
};
let stream = UnboundedReceiverStream::new(rx);
let body = Body::from_stream(stream);
let decode_url = if !return_logprob {
Some(decode.url().to_string())
} else {
None
};
let mut response = Response::new(body);
*response.status_mut() = status;
response
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
response
}
Self::create_streaming_response(
res.bytes_stream(),
status,
prefill_logprobs,
return_logprob,
decode_url,
)
} else {
// Non-streaming response
match res.bytes().await {
Ok(decode_body) => {
if return_logprob {
self.merge_logprobs(prefill_result, decode_body, status)
.await
} else {
(status, decode_body).into_response()
}
}
Err(e) => {
error!("Failed to read decode response: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response")
.into_response()
}
}
// Non-streaming response - use helper
self.process_non_streaming_response(res, status, return_logprob, prefill_body)
.await
}
}
Err(e) => {
......@@ -758,62 +740,6 @@ impl PDRouter {
}
}
// Merge logprobs from prefill and decode responses
async fn merge_logprobs(
&self,
prefill_result: Result<reqwest::Response, reqwest::Error>,
decode_body: bytes::Bytes,
status: StatusCode,
) -> Response {
match prefill_result {
Ok(prefill_res) => {
match prefill_res.bytes().await {
Ok(prefill_body) => {
match (
serde_json::from_slice::<Value>(&prefill_body),
serde_json::from_slice::<Value>(&decode_body),
) {
(Ok(prefill_json), Ok(mut decode_json)) => {
// Merge input_token_logprobs
if let (Some(prefill_meta), Some(decode_meta)) = (
prefill_json.get("meta_info"),
decode_json.get_mut("meta_info"),
) {
if let (Some(prefill_logprobs), Some(decode_logprobs)) = (
prefill_meta.get("input_token_logprobs"),
decode_meta.get_mut("input_token_logprobs"),
) {
if let (Some(p_arr), Some(d_arr)) = (
prefill_logprobs.as_array(),
decode_logprobs.as_array(),
) {
let mut merged = p_arr.clone();
merged.extend(d_arr.clone());
decode_meta["input_token_logprobs"] =
Value::Array(merged);
}
}
}
let mut response = Json(decode_json).into_response();
*response.status_mut() = status;
response
}
_ => {
warn!("Failed to parse responses for logprob merging");
(status, decode_body).into_response()
}
}
}
Err(e) => {
warn!("Failed to read prefill response: {}", e);
(status, decode_body).into_response()
}
}
}
Err(_) => (status, decode_body).into_response(),
}
}
// Select a pair of prefill and decode servers
async fn select_pd_pair(
&self,
......@@ -900,6 +826,229 @@ impl PDRouter {
}
}
// Helper to create a streaming response
fn create_streaming_response(
stream: impl futures_util::Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
status: StatusCode,
prefill_logprobs: Option<Value>,
return_logprob: bool,
decode_url: Option<String>,
) -> Response {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
tokio::spawn(async move {
futures_util::pin_mut!(stream);
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
let result = if return_logprob && prefill_logprobs.is_some() {
// Try to merge logprobs
Self::merge_streaming_logprobs(prefill_logprobs.clone(), &chunk)
.unwrap_or(chunk)
} else {
chunk
};
if tx.send(Ok(result)).is_err() {
break;
}
}
Err(e) => {
if let Some(ref url) = decode_url {
error!("Stream error from decode server {}: {}", url, e);
RouterMetrics::record_pd_stream_error(url);
}
let _ = tx.send(Err(format!("Stream error: {}", e)));
break;
}
}
}
});
let stream = UnboundedReceiverStream::new(rx);
let body = Body::from_stream(stream);
let mut response = Response::new(body);
*response.status_mut() = status;
response
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
response
}
// Helper to process non-streaming decode response with logprob merging
async fn process_non_streaming_response(
&self,
res: reqwest::Response,
status: StatusCode,
return_logprob: bool,
prefill_body: Option<bytes::Bytes>,
) -> Response {
match res.bytes().await {
Ok(decode_body) => {
if return_logprob && prefill_body.is_some() {
// Merge logprobs from prefill and decode
let prefill_body = prefill_body.as_ref().unwrap();
match (
serde_json::from_slice::<Value>(prefill_body),
serde_json::from_slice::<Value>(&decode_body),
) {
(Ok(prefill_json), Ok(mut decode_json)) => {
// Use helper to merge logprobs
Self::merge_logprobs_in_json(&prefill_json, &mut decode_json);
// Return merged response
match serde_json::to_vec(&decode_json) {
Ok(body) => (status, body).into_response(),
Err(e) => {
error!("Failed to serialize merged response: {}", e);
(status, decode_body).into_response()
}
}
}
_ => {
// If parsing fails, just return decode response
warn!("Failed to parse responses for logprob merging");
(status, decode_body).into_response()
}
}
} else {
(status, decode_body).into_response()
}
}
Err(e) => {
error!("Failed to read decode response: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response").into_response()
}
}
}
// Helper to process prefill response and extract body if needed for logprobs
async fn process_prefill_response(
&self,
prefill_result: Result<reqwest::Response, reqwest::Error>,
prefill_url: &str,
return_logprob: bool,
) -> Result<(StatusCode, Option<bytes::Bytes>), Response> {
// Check prefill result first - it's critical for disaggregated mode
let prefill_response = match prefill_result {
Ok(response) => response,
Err(e) => {
RouterMetrics::record_pd_prefill_error(prefill_url);
error!(
"Prefill server failed (CRITICAL) prefill_url={} error={}. Decode will timeout without prefill KV cache.",
prefill_url,
e
);
// Return error immediately - don't wait for decode to timeout
return Err((
StatusCode::BAD_GATEWAY,
format!(
"Prefill server error: {}. This will cause decode timeout.",
e
),
)
.into_response());
}
};
let prefill_status = StatusCode::from_u16(prefill_response.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
// Check if prefill succeeded
if !prefill_status.is_success() {
RouterMetrics::record_pd_prefill_error(prefill_url);
// Get error body from prefill
let error_msg = prefill_response
.text()
.await
.unwrap_or_else(|_| "Unknown prefill error".to_string());
error!(
"Prefill server returned error status prefill_url={} status={} body={}",
prefill_url, prefill_status, error_msg
);
return Err((
prefill_status,
format!("Prefill server error ({}): {}", prefill_status, error_msg),
)
.into_response());
}
// Read prefill body if needed for logprob merging
let prefill_body = if return_logprob {
match prefill_response.bytes().await {
Ok(body) => Some(body),
Err(e) => {
warn!("Failed to read prefill response body for logprobs: {}", e);
None
}
}
} else {
// For non-logprob requests, just consume the response without storing
debug!("Consuming prefill response body (non-logprob request)");
match prefill_response.bytes().await {
Ok(_) => debug!("Prefill response consumed successfully"),
Err(e) => warn!("Error consuming prefill response: {}", e),
}
None
};
Ok((prefill_status, prefill_body))
}
// Helper to build a request with headers copied from the original request
fn build_request_with_headers(
&self,
url: &str,
route: &str,
json_request: &Value,
headers: Option<&HeaderMap>,
) -> reqwest::RequestBuilder {
let mut request = self.client.post(api_path(url, route)).json(json_request);
// Copy headers from original request (excluding content-type and content-length which are set by .json())
if let Some(headers) = headers {
for (name, value) in headers.iter() {
let name_str = name.as_str();
if name_str != "content-type" && name_str != "content-length" {
// Skip headers with non-ASCII values
if value.to_str().is_ok() {
request = request.header(name, value);
}
}
}
}
request
}
// Helper to merge logprobs from prefill and decode responses
fn merge_logprobs_in_json(prefill_json: &Value, decode_json: &mut Value) -> bool {
if let (Some(prefill_meta), Some(decode_meta)) = (
prefill_json.get("meta_info"),
decode_json.get_mut("meta_info"),
) {
if let (Some(prefill_logprobs), Some(decode_logprobs)) = (
prefill_meta.get("input_token_logprobs"),
decode_meta.get_mut("input_token_logprobs"),
) {
if let (Some(prefill_arr), Some(decode_arr)) =
(prefill_logprobs.as_array(), decode_logprobs.as_array_mut())
{
let mut merged = prefill_arr.clone();
merged.extend(decode_arr.clone());
decode_meta["input_token_logprobs"] = Value::Array(merged);
return true;
}
}
}
false
}
// Simple helper to merge logprobs in streaming responses
fn merge_streaming_logprobs(
prefill_logprobs: Option<Value>,
......@@ -1316,7 +1465,6 @@ impl PDRouter {
use crate::routers::{RouterTrait, WorkerManagement};
use async_trait::async_trait;
use reqwest::Client;
#[async_trait]
impl WorkerManagement for PDRouter {
......@@ -1558,6 +1706,7 @@ mod tests {
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
load_monitor_handle: None,
client: Client::new(),
retry_config: RetryConfig::default(),
_prefill_health_checker: None,
_decode_health_checker: None,
}
......
use crate::config::types::RetryConfig;
use crate::core::{HealthChecker, Worker, WorkerFactory};
use crate::metrics::RouterMetrics;
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
......@@ -11,6 +12,7 @@ use axum::{
Json,
};
use futures_util::StreamExt;
use reqwest::Client;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::thread;
......@@ -39,6 +41,7 @@ pub struct Router {
interval_secs: u64,
dp_aware: bool,
api_key: Option<String>,
retry_config: RetryConfig,
_worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
_load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
_health_checker: Option<HealthChecker>,
......@@ -54,6 +57,7 @@ impl Router {
interval_secs: u64,
dp_aware: bool,
api_key: Option<String>,
retry_config: RetryConfig,
) -> Result<Self, String> {
// Update active workers gauge
RouterMetrics::set_active_workers(worker_urls.len());
......@@ -120,6 +124,7 @@ impl Router {
interval_secs,
dp_aware,
api_key,
retry_config,
_worker_loads: worker_loads,
_load_monitor_handle: load_monitor_handle,
_health_checker: Some(health_checker),
......@@ -141,6 +146,12 @@ impl Router {
timeout_secs: u64,
interval_secs: u64,
) -> Result<(), String> {
if worker_urls.is_empty() {
return Err(
"Timeout waiting for workers to become healthy: no workers provided".to_string(),
);
}
let start_time = std::time::Instant::now();
let sync_client = reqwest::blocking::Client::builder()
.timeout(Duration::from_secs(timeout_secs))
......@@ -365,11 +376,13 @@ impl Router {
) -> Response {
// Handle retries like the original implementation
let start = Instant::now();
const MAX_REQUEST_RETRIES: u32 = 3;
const MAX_TOTAL_RETRIES: u32 = 6;
// Use retry config for per-worker retries
let max_request_retries = self.retry_config.max_retries;
// Total retries across all workers (2x to allow trying multiple workers)
let max_total_retries = self.retry_config.max_retries * 2;
let mut total_retries = 0;
while total_retries < MAX_TOTAL_RETRIES {
while total_retries < max_total_retries {
// Extract routing text directly from typed request
let text = typed_req.extract_text_for_routing();
let is_stream = typed_req.is_stream();
......@@ -379,7 +392,7 @@ impl Router {
let mut request_retries = 0;
// Try the same worker multiple times
while request_retries < MAX_REQUEST_RETRIES {
while request_retries < max_request_retries {
if total_retries >= 1 {
info!("Retrying request after {} failed attempts", total_retries);
RouterMetrics::record_retry(route);
......@@ -429,13 +442,13 @@ impl Router {
route,
worker_url,
request_retries + 1,
MAX_REQUEST_RETRIES
max_request_retries
);
request_retries += 1;
total_retries += 1;
if request_retries == MAX_REQUEST_RETRIES {
if request_retries == max_request_retries {
warn!(
"Removing failed worker after typed request failures worker_url={}",
worker_url
......@@ -1003,7 +1016,6 @@ impl Router {
}
use async_trait::async_trait;
use reqwest::Client;
#[async_trait]
impl WorkerManagement for Router {
......@@ -1210,6 +1222,7 @@ mod tests {
dp_aware: false,
api_key: None,
client: Client::new(),
retry_config: RetryConfig::default(),
_worker_loads: Arc::new(rx),
_load_monitor_handle: None,
_health_checker: None,
......@@ -1237,8 +1250,10 @@ mod tests {
#[test]
fn test_wait_for_healthy_workers_empty_list() {
// Empty list will timeout as there are no workers to check
let result = Router::wait_for_healthy_workers(&[], 1, 1);
assert!(result.is_ok());
assert!(result.is_err());
assert!(result.unwrap_err().contains("Timeout"));
}
#[test]
......
......@@ -580,8 +580,17 @@ mod tests {
use crate::routers::router::Router;
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
let router =
Router::new(vec![], policy, reqwest::Client::new(), 5, 1, false, None).unwrap();
let router = Router::new(
vec![],
policy,
reqwest::Client::new(),
5,
1,
false,
None,
crate::config::types::RetryConfig::default(),
)
.unwrap();
Arc::new(router) as Arc<dyn RouterTrait>
}
......
......@@ -8,7 +8,7 @@ use axum::{
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
use reqwest::Client;
use serde_json::json;
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode};
use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode};
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
use std::sync::Arc;
use tower::ServiceExt;
......@@ -44,6 +44,7 @@ impl TestContext {
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
};
Self::new_with_config(config, worker_configs).await
......@@ -1085,6 +1086,7 @@ mod error_tests {
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
};
let ctx = TestContext::new_with_config(
......@@ -1431,6 +1433,7 @@ mod pd_mode_tests {
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
};
// Create app context
......@@ -1584,6 +1587,7 @@ mod request_id_tests {
request_id_headers: Some(vec!["custom-id".to_string(), "trace-id".to_string()]),
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
};
let ctx = TestContext::new_with_config(
......
......@@ -3,7 +3,7 @@ mod common;
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
use reqwest::Client;
use serde_json::json;
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode};
use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode};
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
use std::sync::Arc;
......@@ -35,6 +35,7 @@ impl TestContext {
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
};
let mut workers = Vec::new();
......
......@@ -4,7 +4,7 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType
use futures_util::StreamExt;
use reqwest::Client;
use serde_json::json;
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode};
use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode};
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
use std::sync::Arc;
......@@ -36,6 +36,7 @@ impl TestContext {
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
};
let mut workers = Vec::new();
......
......@@ -2,7 +2,7 @@
mod test_pd_routing {
use rand::Rng;
use serde_json::json;
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode};
use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode};
use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::routers::pd_types::get_hostname;
use sglang_router_rs::routers::pd_types::PDSelectionPolicy;
......@@ -178,6 +178,7 @@ mod test_pd_routing {
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
};
// Router creation will fail due to health checks, but config should be valid
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment