"vscode:/vscode.git/clone" did not exist on "682d36b5c7c88ef3eaa4322b0af9727d40e41d55"
Unverified Commit 6f81a710 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[pd-router] add retry and circuit breakfor for pd router (#9051)

parent a6452b71
...@@ -16,7 +16,7 @@ pub use circuit_breaker::{ ...@@ -16,7 +16,7 @@ pub use circuit_breaker::{
CircuitBreaker, CircuitBreakerConfig, CircuitBreakerStats, CircuitState, CircuitBreaker, CircuitBreakerConfig, CircuitBreakerStats, CircuitState,
}; };
pub use error::{WorkerError, WorkerResult}; pub use error::{WorkerError, WorkerResult};
pub use retry::{BackoffCalculator, RetryError, RetryExecutor}; pub use retry::{is_retryable_status, BackoffCalculator, RetryError, RetryExecutor};
pub use worker::{ pub use worker::{
start_health_checker, BasicWorker, DPAwareWorker, HealthChecker, Worker, WorkerCollection, start_health_checker, BasicWorker, DPAwareWorker, HealthChecker, Worker, WorkerCollection,
WorkerFactory, WorkerLoadGuard, WorkerType, WorkerFactory, WorkerLoadGuard, WorkerType,
......
use crate::config::types::RetryConfig; use crate::config::types::RetryConfig;
use axum::http::StatusCode;
use axum::response::Response; use axum::response::Response;
use rand::Rng; use rand::Rng;
use std::time::Duration; use std::time::Duration;
use tracing::debug; use tracing::debug;
/// Check if an HTTP status code indicates a retryable error
pub fn is_retryable_status(status: StatusCode) -> bool {
matches!(
status,
StatusCode::REQUEST_TIMEOUT
| StatusCode::TOO_MANY_REQUESTS
| StatusCode::INTERNAL_SERVER_ERROR
| StatusCode::BAD_GATEWAY
| StatusCode::SERVICE_UNAVAILABLE
| StatusCode::GATEWAY_TIMEOUT
)
}
/// Computes exponential backoff with optional jitter. /// Computes exponential backoff with optional jitter.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct BackoffCalculator; pub struct BackoffCalculator;
...@@ -21,8 +35,8 @@ impl BackoffCalculator { ...@@ -21,8 +35,8 @@ impl BackoffCalculator {
// Apply jitter in range [-j, +j] // Apply jitter in range [-j, +j]
let jitter = config.jitter_factor.max(0.0).min(1.0); let jitter = config.jitter_factor.max(0.0).min(1.0);
if jitter > 0.0 { if jitter > 0.0 {
let mut rng = rand::thread_rng(); let mut rng = rand::rng();
let jitter_scale: f32 = rng.gen_range(-jitter..=jitter); let jitter_scale: f32 = rng.random_range(-jitter..=jitter);
let jitter_ms = (delay_ms as f32 * jitter_scale) let jitter_ms = (delay_ms as f32 * jitter_scale)
.round() .round()
.max(-(delay_ms as f32)); .max(-(delay_ms as f32));
......
...@@ -2,7 +2,10 @@ ...@@ -2,7 +2,10 @@
// This module handles routing for disaggregated prefill-decode systems // This module handles routing for disaggregated prefill-decode systems
use super::pd_types::{api_path, PDRouterError}; use super::pd_types::{api_path, PDRouterError};
use crate::config::types::{CircuitBreakerConfig as ConfigCircuitBreakerConfig, RetryConfig}; use crate::config::types::{CircuitBreakerConfig as ConfigCircuitBreakerConfig, RetryConfig};
use crate::core::{CircuitBreakerConfig, HealthChecker, Worker, WorkerFactory, WorkerLoadGuard}; use crate::core::{
is_retryable_status, CircuitBreakerConfig, HealthChecker, RetryExecutor, Worker, WorkerFactory,
WorkerLoadGuard,
};
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::policies::LoadBalancingPolicy; use crate::policies::LoadBalancingPolicy;
...@@ -17,6 +20,7 @@ use axum::{ ...@@ -17,6 +20,7 @@ use axum::{
}; };
use futures_util::StreamExt; use futures_util::StreamExt;
use reqwest::Client; use reqwest::Client;
use serde::Serialize;
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
...@@ -43,6 +47,16 @@ pub struct PDRouter { ...@@ -43,6 +47,16 @@ pub struct PDRouter {
_decode_health_checker: Option<HealthChecker>, _decode_health_checker: Option<HealthChecker>,
} }
// Request context for PD router operations
#[derive(Clone)]
struct PDRequestContext {
route: &'static str,
batch_size: Option<usize>,
is_stream: bool,
return_logprob: bool,
request_text: Option<String>,
}
impl PDRouter { impl PDRouter {
// Dynamic worker management methods for service discovery // Dynamic worker management methods for service discovery
...@@ -218,12 +232,8 @@ impl PDRouter { ...@@ -218,12 +232,8 @@ impl PDRouter {
let core_cb_config = CircuitBreakerConfig { let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold, failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold, success_threshold: circuit_breaker_config.success_threshold,
timeout_duration: std::time::Duration::from_secs( timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
circuit_breaker_config.timeout_duration_secs, window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
),
window_duration: std::time::Duration::from_secs(
circuit_breaker_config.window_duration_secs,
),
}; };
// Convert URLs to Worker trait objects // Convert URLs to Worker trait objects
...@@ -459,8 +469,96 @@ impl PDRouter { ...@@ -459,8 +469,96 @@ impl PDRouter {
Ok(original) Ok(original)
} }
// Execute the dual dispatch to prefill and decode servers // Execute the dual dispatch to prefill and decode servers with retries and bootstrap injection
async fn execute_dual_dispatch( async fn execute_dual_dispatch<T: Serialize + Clone>(
&self,
headers: Option<&HeaderMap>,
original_request: &T,
context: PDRequestContext,
) -> Response {
let start_time = Instant::now();
let route = context.route;
RetryExecutor::execute_response_with_retry(
&self.retry_config,
// Operation per attempt
{
let original_request = original_request.clone();
move |attempt: u32| {
let original_request = original_request.clone();
let context = context.clone();
async move {
// Select workers fresh for each attempt
let (prefill, decode) =
match self.select_pd_pair(context.request_text.as_deref()).await {
Ok(pair) => pair,
Err(e) => {
RouterMetrics::record_pd_error("server_selection");
return Self::handle_server_selection_error(e);
}
};
debug!(
"PD retry attempt {} using prefill={} decode={}",
attempt,
prefill.url(),
decode.url()
);
// Serialize the original request
let mut json_request = match serde_json::to_value(&original_request) {
Ok(v) => v,
Err(e) => return Self::handle_serialization_error(e),
};
// Inject bootstrap based on current prefill worker
json_request = match Self::inject_bootstrap_into_value(
json_request,
prefill.as_ref(),
context.batch_size,
) {
Ok(v) => v,
Err(e) => return Self::handle_serialization_error(e),
};
// Execute the actual dual dispatch
let response = self
.execute_dual_dispatch_internal(
headers,
json_request,
context.route,
prefill.as_ref(),
decode.as_ref(),
context.is_stream,
context.return_logprob,
start_time,
)
.await;
// Record outcomes for circuit breakers
let is_success = response.status().is_success();
prefill.record_outcome(is_success);
decode.record_outcome(is_success);
response
}
}
},
// Should retry predicate
|res, _attempt| is_retryable_status(res.status()),
// On backoff hook
|delay, attempt| {
RouterMetrics::record_retry(route);
RouterMetrics::record_retry_backoff_duration(delay, attempt);
},
// On exhausted hook
|| RouterMetrics::record_retries_exhausted(route),
)
.await
}
// Internal method that performs the actual dual dispatch (without retry logic)
async fn execute_dual_dispatch_internal(
&self, &self,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
json_request: Value, json_request: Value,
...@@ -696,7 +794,7 @@ impl PDRouter { ...@@ -696,7 +794,7 @@ impl PDRouter {
self.prefill_policy.needs_request_text() || self.decode_policy.needs_request_text() self.prefill_policy.needs_request_text() || self.decode_policy.needs_request_text()
} }
// Select a pair of prefill and decode servers // Select a pair of prefill and decode servers considering circuit breaker state
async fn select_pd_pair( async fn select_pd_pair(
&self, &self,
request_text: Option<&str>, request_text: Option<&str>,
...@@ -711,29 +809,58 @@ impl PDRouter { ...@@ -711,29 +809,58 @@ impl PDRouter {
.read() .read()
.map_err(|e| format!("Failed to acquire decode workers lock: {}", e))?; .map_err(|e| format!("Failed to acquire decode workers lock: {}", e))?;
// Check we have workers // Select workers using helper function
if prefill_workers.is_empty() { let prefill = Self::pick_worker_by_policy(
return Err("No prefill workers available. Please check if prefill servers are configured and healthy.".to_string()); &*prefill_workers,
&*self.prefill_policy,
request_text,
"prefill",
)?;
let decode = Self::pick_worker_by_policy(
&*decode_workers,
&*self.decode_policy,
request_text,
"decode",
)?;
Ok((prefill, decode))
} }
if decode_workers.is_empty() {
return Err("No decode workers available. Please check if decode servers are configured and healthy.".to_string()); // Helper function to select a worker using the policy
fn pick_worker_by_policy(
workers: &[Box<dyn Worker>],
policy: &dyn LoadBalancingPolicy,
request_text: Option<&str>,
worker_type: &str,
) -> Result<Box<dyn Worker>, String> {
// Check if we have any workers
if workers.is_empty() {
return Err(format!(
"No {} workers available. Please check if {} servers are configured and healthy.",
worker_type, worker_type
));
} }
// Select prefill worker using prefill policy // Filter available workers (healthy + circuit breaker not open)
let prefill_idx = self let available_workers: Vec<Box<dyn Worker>> = workers
.prefill_policy .iter()
.select_worker(&prefill_workers, request_text) .filter(|w| w.is_available())
.ok_or("Failed to select prefill worker")?; .map(|w| w.clone_worker())
.collect();
// Select decode worker using decode policy if available_workers.is_empty() {
let decode_idx = self return Err(format!(
.decode_policy "No available {} workers (all circuits open or unhealthy)",
.select_worker(&decode_workers, request_text) worker_type
.ok_or("Failed to select decode worker")?; ));
}
let prefill = prefill_workers[prefill_idx].clone_worker(); // Let policy select from available workers only
let decode = decode_workers[decode_idx].clone_worker(); match policy.select_worker(&available_workers, request_text) {
Ok((prefill, decode)) Some(idx) => Ok(available_workers[idx].clone_worker()),
None => Err(format!("Policy could not select a {} worker", worker_type)),
}
} }
// Background task to monitor worker loads with shared client // Background task to monitor worker loads with shared client
...@@ -1449,15 +1576,15 @@ impl RouterTrait for PDRouter { ...@@ -1449,15 +1576,15 @@ impl RouterTrait for PDRouter {
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
body: &GenerateRequest, body: &GenerateRequest,
) -> Response { ) -> Response {
let start = Instant::now(); // Extract parameters
// Extract flags for routing logic
let is_stream = body.stream; let is_stream = body.stream;
let return_logprob = body.return_logprob; let return_logprob = body.return_logprob;
// Extract text for cache-aware routing only if needed // Extract text for cache-aware routing
let request_text = if self.policies_need_request_text() { let request_text = if self.policies_need_request_text() {
body.text.as_deref().or_else(|| { body.text
.as_deref()
.or_else(|| {
body.prompt.as_ref().and_then(|p| match p { body.prompt.as_ref().and_then(|p| match p {
crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()), crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()),
crate::openai_api_types::StringOrArray::Array(v) => { crate::openai_api_types::StringOrArray::Array(v) => {
...@@ -1465,45 +1592,25 @@ impl RouterTrait for PDRouter { ...@@ -1465,45 +1592,25 @@ impl RouterTrait for PDRouter {
} }
}) })
}) })
.map(|s| s.to_string())
} else { } else {
None None
}; };
// Select servers // Calculate batch size
let (prefill, decode) = match self.select_pd_pair(request_text).await {
Ok(pair) => pair,
Err(e) => return Self::handle_server_selection_error(e),
};
// Log routing decision
info!(
"PD routing decision route=/generate prefill_url={} decode_url={}",
prefill.url(),
decode.url()
);
let batch_size = Self::get_generate_batch_size(body); let batch_size = Self::get_generate_batch_size(body);
let original = match serde_json::to_value(body) {
Ok(v) => v,
Err(e) => return Self::handle_serialization_error(e),
};
let json = match Self::inject_bootstrap_into_value(original, prefill.as_ref(), batch_size) {
Ok(v) => v,
Err(e) => return Self::handle_serialization_error(e),
};
// Execute dual dispatch // Create context
self.execute_dual_dispatch( let context = PDRequestContext {
headers, route: "/generate",
json, batch_size,
"/generate",
prefill.as_ref(),
decode.as_ref(),
is_stream, is_stream,
return_logprob, return_logprob,
start, request_text,
) };
.await
// Execute with retry and bootstrap injection
self.execute_dual_dispatch(headers, body, context).await
} }
async fn route_chat( async fn route_chat(
...@@ -1511,25 +1618,19 @@ impl RouterTrait for PDRouter { ...@@ -1511,25 +1618,19 @@ impl RouterTrait for PDRouter {
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
body: &ChatCompletionRequest, body: &ChatCompletionRequest,
) -> Response { ) -> Response {
let start = Instant::now(); // Extract parameters
// Extract flags for routing logic
let is_stream = body.stream; let is_stream = body.stream;
let return_logprob = body.logprobs; let return_logprob = body.logprobs;
// Extract text for cache-aware routing from chat messages only if needed // Extract text for cache-aware routing
let request_text = if self.policies_need_request_text() { let request_text = if self.policies_need_request_text() {
body.messages.first().and_then(|msg| match msg { body.messages.first().and_then(|msg| match msg {
crate::openai_api_types::ChatMessage::User { content, .. } => { crate::openai_api_types::ChatMessage::User { content, .. } => match content {
match content { crate::openai_api_types::UserMessageContent::Text(text) => Some(text.clone()),
crate::openai_api_types::UserMessageContent::Text(text) => { crate::openai_api_types::UserMessageContent::Parts(_) => None,
Some(text.as_str()) },
}
crate::openai_api_types::UserMessageContent::Parts(_) => None, // Skip complex content
}
}
crate::openai_api_types::ChatMessage::System { content, .. } => { crate::openai_api_types::ChatMessage::System { content, .. } => {
Some(content.as_str()) Some(content.clone())
} }
_ => None, _ => None,
}) })
...@@ -1537,41 +1638,20 @@ impl RouterTrait for PDRouter { ...@@ -1537,41 +1638,20 @@ impl RouterTrait for PDRouter {
None None
}; };
// Select servers // Calculate batch size
let (prefill, decode) = match self.select_pd_pair(request_text).await {
Ok(pair) => pair,
Err(e) => return Self::handle_server_selection_error(e),
};
// Log routing decision
info!(
"PD routing decision route=/v1/chat/completions prefill_url={} decode_url={}",
prefill.url(),
decode.url()
);
let batch_size = Self::get_chat_batch_size(body); let batch_size = Self::get_chat_batch_size(body);
let original = match serde_json::to_value(body) {
Ok(v) => v,
Err(e) => return Self::handle_serialization_error(e),
};
let json = match Self::inject_bootstrap_into_value(original, prefill.as_ref(), batch_size) {
Ok(v) => v,
Err(e) => return Self::handle_serialization_error(e),
};
// Execute dual dispatch // Create context
self.execute_dual_dispatch( let context = PDRequestContext {
headers, route: "/v1/chat/completions",
json, batch_size,
"/v1/chat/completions",
prefill.as_ref(),
decode.as_ref(),
is_stream, is_stream,
return_logprob, return_logprob,
start, request_text,
) };
.await
// Execute with retry and bootstrap injection
self.execute_dual_dispatch(headers, body, context).await
} }
async fn route_completion( async fn route_completion(
...@@ -1579,57 +1659,36 @@ impl RouterTrait for PDRouter { ...@@ -1579,57 +1659,36 @@ impl RouterTrait for PDRouter {
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
body: &CompletionRequest, body: &CompletionRequest,
) -> Response { ) -> Response {
let start = Instant::now(); // Extract parameters
// Extract flags for routing logic
let is_stream = body.stream; let is_stream = body.stream;
let return_logprob = body.logprobs.is_some(); let return_logprob = body.logprobs.is_some();
// Extract text for cache-aware routing only if needed // Extract text for cache-aware routing
let request_text = if self.policies_need_request_text() { let request_text = if self.policies_need_request_text() {
match &body.prompt { match &body.prompt {
crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()), crate::openai_api_types::StringOrArray::String(s) => Some(s.clone()),
crate::openai_api_types::StringOrArray::Array(v) => v.first().map(|s| s.as_str()), crate::openai_api_types::StringOrArray::Array(v) => {
v.first().map(|s| s.to_string())
}
} }
} else { } else {
None None
}; };
// Select servers // Calculate batch size
let (prefill, decode) = match self.select_pd_pair(request_text).await {
Ok(pair) => pair,
Err(e) => return Self::handle_server_selection_error(e),
};
// Log routing decision
info!(
"PD routing decision route=/v1/completions prefill_url={} decode_url={}",
prefill.url(),
decode.url()
);
let batch_size = Self::get_completion_batch_size(body); let batch_size = Self::get_completion_batch_size(body);
let original = match serde_json::to_value(body) {
Ok(v) => v,
Err(e) => return Self::handle_serialization_error(e),
};
let json = match Self::inject_bootstrap_into_value(original, prefill.as_ref(), batch_size) {
Ok(v) => v,
Err(e) => return Self::handle_serialization_error(e),
};
// Execute dual dispatch // Create context
self.execute_dual_dispatch( let context = PDRequestContext {
headers, route: "/v1/completions",
json, batch_size,
"/v1/completions",
prefill.as_ref(),
decode.as_ref(),
is_stream, is_stream,
return_logprob, return_logprob,
start, request_text,
) };
.await
// Execute with retry and bootstrap injection
self.execute_dual_dispatch(headers, body, context).await
} }
async fn flush_cache(&self) -> Response { async fn flush_cache(&self) -> Response {
......
use crate::config::types::{CircuitBreakerConfig as ConfigCircuitBreakerConfig, RetryConfig}; use crate::config::types::{CircuitBreakerConfig as ConfigCircuitBreakerConfig, RetryConfig};
use crate::core::{CircuitBreakerConfig, HealthChecker, RetryExecutor, Worker, WorkerFactory}; use crate::core::{
is_retryable_status, CircuitBreakerConfig, HealthChecker, RetryExecutor, Worker, WorkerFactory,
};
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::policies::LoadBalancingPolicy; use crate::policies::LoadBalancingPolicy;
...@@ -81,12 +83,8 @@ impl Router { ...@@ -81,12 +83,8 @@ impl Router {
let core_cb_config = CircuitBreakerConfig { let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold, failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold, success_threshold: circuit_breaker_config.success_threshold,
timeout_duration: std::time::Duration::from_secs( timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
circuit_breaker_config.timeout_duration_secs, window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
),
window_duration: std::time::Duration::from_secs(
circuit_breaker_config.window_duration_secs,
),
}; };
// Create Worker trait objects from URLs // Create Worker trait objects from URLs
...@@ -397,18 +395,6 @@ impl Router { ...@@ -397,18 +395,6 @@ impl Router {
Some(available[idx].clone_worker()) Some(available[idx].clone_worker())
} }
fn is_retryable_status(status: StatusCode) -> bool {
matches!(
status,
StatusCode::REQUEST_TIMEOUT
| StatusCode::TOO_MANY_REQUESTS
| StatusCode::INTERNAL_SERVER_ERROR
| StatusCode::BAD_GATEWAY
| StatusCode::SERVICE_UNAVAILABLE
| StatusCode::GATEWAY_TIMEOUT
)
}
pub async fn route_typed_request< pub async fn route_typed_request<
T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone, T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone,
>( >(
...@@ -461,7 +447,7 @@ impl Router { ...@@ -461,7 +447,7 @@ impl Router {
response response
}, },
// should_retry predicate // should_retry predicate
|res, _attempt| Self::is_retryable_status(res.status()), |res, _attempt| is_retryable_status(res.status()),
// on_backoff hook // on_backoff hook
|delay, attempt| { |delay, attempt| {
RouterMetrics::record_retry(route); RouterMetrics::record_retry(route);
...@@ -476,7 +462,7 @@ impl Router { ...@@ -476,7 +462,7 @@ impl Router {
let duration = start.elapsed(); let duration = start.elapsed();
RouterMetrics::record_request(route); RouterMetrics::record_request(route);
RouterMetrics::record_generate_duration(duration); RouterMetrics::record_generate_duration(duration);
} else if !Self::is_retryable_status(response.status()) { } else if !is_retryable_status(response.status()) {
RouterMetrics::record_request_error(route, "non_retryable_error"); RouterMetrics::record_request_error(route, "non_retryable_error");
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment