Unverified Commit b341b7db authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] introduce prefill response draining for http compliance (#9281)

parent b498cd21
......@@ -29,6 +29,7 @@ use serde_json::Value;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, info, warn};
......@@ -49,6 +50,8 @@ pub struct PDRouter {
pub circuit_breaker_config: CircuitBreakerConfig,
_prefill_health_checker: Option<HealthChecker>,
_decode_health_checker: Option<HealthChecker>,
// Channel for sending prefill responses to background workers for draining
prefill_drain_tx: mpsc::Sender<reqwest::Response>,
}
// Request context for PD router operations
......@@ -501,6 +504,75 @@ impl PDRouter {
.build()
.map_err(|e| format!("Failed to build prefill client: {}", e))?;
// Create bounded channel for prefill response draining
// Larger buffer for high concurrency scenarios
let (prefill_drain_tx, mut prefill_drain_rx) = mpsc::channel::<reqwest::Response>(2000);
// Spawn a coordinator with limited concurrent drain tasks
// This prevents unbounded task spawning under extreme load
tokio::spawn(async move {
info!("Prefill drain coordinator started");
// Use a semaphore to limit concurrent drain operations
let max_concurrent_drains = 100;
let semaphore = Arc::new(tokio::sync::Semaphore::new(max_concurrent_drains));
while let Some(response) = prefill_drain_rx.recv().await {
let permit = semaphore.clone().acquire_owned().await;
match permit {
Ok(permit) => {
// Spawn a task to drain this response
tokio::spawn(async move {
let url = response.url().to_string();
let status = response.status();
if !status.is_success() {
error!("Prefill drain: error status={} url={}", status, url);
RouterMetrics::record_pd_prefill_error(&url);
}
// Drain the response body efficiently
// Use streaming to avoid loading entire body into memory
let start = std::time::Instant::now();
let mut stream = response.bytes_stream();
let mut bytes_drained = 0;
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => bytes_drained += chunk.len(),
Err(e) => {
debug!(
"Prefill drain: error streaming url={} error={}",
url, e
);
break;
}
}
}
let elapsed = start.elapsed();
if elapsed > Duration::from_millis(100) {
// Only log slow drains
debug!(
"Prefill drain: slow drain {} bytes from {} in {:?}",
bytes_drained, url, elapsed
);
}
// Permit is automatically released when dropped
drop(permit);
});
}
Err(_) => {
// Semaphore closed, shutting down
break;
}
}
}
info!("Prefill drain coordinator shutting down");
});
Ok(PDRouter {
prefill_workers,
decode_workers,
......@@ -512,6 +584,7 @@ impl PDRouter {
load_monitor_handle,
client,
prefill_client,
prefill_drain_tx,
retry_config,
circuit_breaker_config: core_cb_config,
_prefill_health_checker: Some(prefill_health_checker),
......@@ -702,11 +775,9 @@ impl PDRouter {
.execute_dual_dispatch_internal(
headers,
json_request,
context.route,
context,
prefill.as_ref(),
decode.as_ref(),
context.is_stream,
context.return_logprob,
start_time,
)
.await;
......@@ -734,16 +805,13 @@ impl PDRouter {
}
// Internal method that performs the actual dual dispatch (without retry logic)
#[allow(clippy::too_many_arguments)]
async fn execute_dual_dispatch_internal(
&self,
headers: Option<&HeaderMap>,
json_request: Value,
route: &str,
context: PDRequestContext,
prefill: &dyn Worker,
decode: &dyn Worker,
is_stream: bool,
return_logprob: bool,
start_time: Instant,
) -> Response {
// Update load tracking for both workers
......@@ -753,7 +821,7 @@ impl PDRouter {
let decode_request = self.build_post_with_headers(
&self.client,
decode.url(),
route,
context.route,
&json_request,
headers,
false,
......@@ -766,12 +834,12 @@ impl PDRouter {
decode.url()
);
if return_logprob {
if context.return_logprob {
// Build prefill request with shared client when we need response body
let prefill_request = self.build_post_with_headers(
&self.client,
prefill.url(),
route,
context.route,
&json_request,
headers,
false,
......@@ -783,8 +851,8 @@ impl PDRouter {
// Update metrics
let duration = start_time.elapsed();
RouterMetrics::record_pd_request_duration(route, duration);
RouterMetrics::record_pd_request(route);
RouterMetrics::record_pd_request_duration(context.route, duration);
RouterMetrics::record_pd_request(context.route);
RouterMetrics::record_pd_prefill_request(prefill.url());
RouterMetrics::record_pd_decode_request(decode.url());
......@@ -818,14 +886,18 @@ impl PDRouter {
// Process prefill response for logprobs
let prefill_body = match self
.process_prefill_response(prefill_result, prefill.url(), return_logprob)
.process_prefill_response(
prefill_result,
prefill.url(),
context.return_logprob,
)
.await
{
Ok((_, body)) => body,
Err(error_response) => return error_response,
};
if is_stream {
if context.is_stream {
// Streaming response with logprobs
let prefill_logprobs = prefill_body
.as_ref()
......@@ -841,7 +913,7 @@ impl PDRouter {
res.bytes_stream(),
status,
prefill_logprobs,
return_logprob,
context.return_logprob,
None,
Some(response_headers),
)
......@@ -850,7 +922,7 @@ impl PDRouter {
self.process_non_streaming_response(
res,
status,
return_logprob,
context.return_logprob,
prefill_body,
)
.await
......@@ -878,7 +950,7 @@ impl PDRouter {
.build_post_with_headers(
&self.prefill_client,
prefill.url(),
route,
context.route,
&json_request,
headers,
true,
......@@ -886,11 +958,41 @@ impl PDRouter {
.send();
let decode_future = decode_request.send();
// Send prefill response to background worker for draining
// This ensures HTTP compliance without blocking
let drain_tx = self.prefill_drain_tx.clone();
let prefill_url = prefill.url().to_string();
tokio::spawn(async move {
if let Ok(response) = prefill_future.await {
// Consume the entire response body to maintain HTTP compliance
// This runs in the background and won't block the decode response
let _ = response.bytes().await;
// Try to send to drain worker
// If channel is full (under extreme load), drain inline as fallback
match drain_tx.try_send(response) {
Ok(_) => {
// Successfully queued for draining
debug!("Prefill response queued for draining");
}
Err(mpsc::error::TrySendError::Full(response)) => {
// Channel full - drain inline as fallback
warn!("Prefill drain channel full (capacity exceeded), draining inline for {}", prefill_url);
RouterMetrics::record_pd_prefill_error(&prefill_url);
// Drain inline with timeout to prevent blocking too long
let drain_future = async {
let mut stream = response.bytes_stream();
while stream.next().await.is_some() {
// Just drain
}
};
match tokio::time::timeout(Duration::from_secs(1), drain_future).await {
Ok(_) => debug!("Inline drain completed for {}", prefill_url),
Err(_) => error!("Inline drain timeout for {}", prefill_url),
}
}
Err(mpsc::error::TrySendError::Closed(_)) => {
error!("Prefill drain channel closed!");
}
}
}
});
......@@ -900,8 +1002,8 @@ impl PDRouter {
// Update metrics
let duration = start_time.elapsed();
RouterMetrics::record_pd_request_duration(route, duration);
RouterMetrics::record_pd_request(route);
RouterMetrics::record_pd_request_duration(context.route, duration);
RouterMetrics::record_pd_request(context.route);
RouterMetrics::record_pd_prefill_request(prefill.url());
RouterMetrics::record_pd_decode_request(decode.url());
......@@ -928,7 +1030,7 @@ impl PDRouter {
(status, format!("Decode server error: {}", e)).into_response()
}
}
} else if is_stream {
} else if context.is_stream {
// Streaming response without logprobs - direct passthrough
let decode_url = decode.url().to_string();
let response_headers =
......@@ -1280,10 +1382,10 @@ impl PDRouter {
fn build_post_with_headers(
&self,
client: &reqwest::Client,
client: &Client,
url: &str,
route: &str,
json_request: &serde_json::Value,
json_request: &Value,
headers: Option<&HeaderMap>,
connection_close: bool,
) -> reqwest::RequestBuilder {
......@@ -1894,6 +1996,7 @@ mod tests {
load_monitor_handle: None,
client: Client::new(),
prefill_client: Client::new(),
prefill_drain_tx: mpsc::channel(100).0,
retry_config: RetryConfig::default(),
circuit_breaker_config: CircuitBreakerConfig::default(),
_prefill_health_checker: None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment