Unverified Commit b341b7db authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] introduce prefill response draining for http compliance (#9281)

parent b498cd21
...@@ -29,6 +29,7 @@ use serde_json::Value; ...@@ -29,6 +29,7 @@ use serde_json::Value;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
...@@ -49,6 +50,8 @@ pub struct PDRouter { ...@@ -49,6 +50,8 @@ pub struct PDRouter {
pub circuit_breaker_config: CircuitBreakerConfig, pub circuit_breaker_config: CircuitBreakerConfig,
_prefill_health_checker: Option<HealthChecker>, _prefill_health_checker: Option<HealthChecker>,
_decode_health_checker: Option<HealthChecker>, _decode_health_checker: Option<HealthChecker>,
// Channel for sending prefill responses to background workers for draining
prefill_drain_tx: mpsc::Sender<reqwest::Response>,
} }
// Request context for PD router operations // Request context for PD router operations
...@@ -501,6 +504,75 @@ impl PDRouter { ...@@ -501,6 +504,75 @@ impl PDRouter {
.build() .build()
.map_err(|e| format!("Failed to build prefill client: {}", e))?; .map_err(|e| format!("Failed to build prefill client: {}", e))?;
// Create bounded channel for prefill response draining
// Larger buffer for high concurrency scenarios
let (prefill_drain_tx, mut prefill_drain_rx) = mpsc::channel::<reqwest::Response>(2000);
// Spawn a coordinator with limited concurrent drain tasks
// This prevents unbounded task spawning under extreme load
tokio::spawn(async move {
info!("Prefill drain coordinator started");
// Use a semaphore to limit concurrent drain operations
let max_concurrent_drains = 100;
let semaphore = Arc::new(tokio::sync::Semaphore::new(max_concurrent_drains));
while let Some(response) = prefill_drain_rx.recv().await {
let permit = semaphore.clone().acquire_owned().await;
match permit {
Ok(permit) => {
// Spawn a task to drain this response
tokio::spawn(async move {
let url = response.url().to_string();
let status = response.status();
if !status.is_success() {
error!("Prefill drain: error status={} url={}", status, url);
RouterMetrics::record_pd_prefill_error(&url);
}
// Drain the response body efficiently
// Use streaming to avoid loading entire body into memory
let start = std::time::Instant::now();
let mut stream = response.bytes_stream();
let mut bytes_drained = 0;
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => bytes_drained += chunk.len(),
Err(e) => {
debug!(
"Prefill drain: error streaming url={} error={}",
url, e
);
break;
}
}
}
let elapsed = start.elapsed();
if elapsed > Duration::from_millis(100) {
// Only log slow drains
debug!(
"Prefill drain: slow drain {} bytes from {} in {:?}",
bytes_drained, url, elapsed
);
}
// Permit is automatically released when dropped
drop(permit);
});
}
Err(_) => {
// Semaphore closed, shutting down
break;
}
}
}
info!("Prefill drain coordinator shutting down");
});
Ok(PDRouter { Ok(PDRouter {
prefill_workers, prefill_workers,
decode_workers, decode_workers,
...@@ -512,6 +584,7 @@ impl PDRouter { ...@@ -512,6 +584,7 @@ impl PDRouter {
load_monitor_handle, load_monitor_handle,
client, client,
prefill_client, prefill_client,
prefill_drain_tx,
retry_config, retry_config,
circuit_breaker_config: core_cb_config, circuit_breaker_config: core_cb_config,
_prefill_health_checker: Some(prefill_health_checker), _prefill_health_checker: Some(prefill_health_checker),
...@@ -702,11 +775,9 @@ impl PDRouter { ...@@ -702,11 +775,9 @@ impl PDRouter {
.execute_dual_dispatch_internal( .execute_dual_dispatch_internal(
headers, headers,
json_request, json_request,
context.route, context,
prefill.as_ref(), prefill.as_ref(),
decode.as_ref(), decode.as_ref(),
context.is_stream,
context.return_logprob,
start_time, start_time,
) )
.await; .await;
...@@ -734,16 +805,13 @@ impl PDRouter { ...@@ -734,16 +805,13 @@ impl PDRouter {
} }
// Internal method that performs the actual dual dispatch (without retry logic) // Internal method that performs the actual dual dispatch (without retry logic)
#[allow(clippy::too_many_arguments)]
async fn execute_dual_dispatch_internal( async fn execute_dual_dispatch_internal(
&self, &self,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
json_request: Value, json_request: Value,
route: &str, context: PDRequestContext,
prefill: &dyn Worker, prefill: &dyn Worker,
decode: &dyn Worker, decode: &dyn Worker,
is_stream: bool,
return_logprob: bool,
start_time: Instant, start_time: Instant,
) -> Response { ) -> Response {
// Update load tracking for both workers // Update load tracking for both workers
...@@ -753,7 +821,7 @@ impl PDRouter { ...@@ -753,7 +821,7 @@ impl PDRouter {
let decode_request = self.build_post_with_headers( let decode_request = self.build_post_with_headers(
&self.client, &self.client,
decode.url(), decode.url(),
route, context.route,
&json_request, &json_request,
headers, headers,
false, false,
...@@ -766,12 +834,12 @@ impl PDRouter { ...@@ -766,12 +834,12 @@ impl PDRouter {
decode.url() decode.url()
); );
if return_logprob { if context.return_logprob {
// Build prefill request with shared client when we need response body // Build prefill request with shared client when we need response body
let prefill_request = self.build_post_with_headers( let prefill_request = self.build_post_with_headers(
&self.client, &self.client,
prefill.url(), prefill.url(),
route, context.route,
&json_request, &json_request,
headers, headers,
false, false,
...@@ -783,8 +851,8 @@ impl PDRouter { ...@@ -783,8 +851,8 @@ impl PDRouter {
// Update metrics // Update metrics
let duration = start_time.elapsed(); let duration = start_time.elapsed();
RouterMetrics::record_pd_request_duration(route, duration); RouterMetrics::record_pd_request_duration(context.route, duration);
RouterMetrics::record_pd_request(route); RouterMetrics::record_pd_request(context.route);
RouterMetrics::record_pd_prefill_request(prefill.url()); RouterMetrics::record_pd_prefill_request(prefill.url());
RouterMetrics::record_pd_decode_request(decode.url()); RouterMetrics::record_pd_decode_request(decode.url());
...@@ -818,14 +886,18 @@ impl PDRouter { ...@@ -818,14 +886,18 @@ impl PDRouter {
// Process prefill response for logprobs // Process prefill response for logprobs
let prefill_body = match self let prefill_body = match self
.process_prefill_response(prefill_result, prefill.url(), return_logprob) .process_prefill_response(
prefill_result,
prefill.url(),
context.return_logprob,
)
.await .await
{ {
Ok((_, body)) => body, Ok((_, body)) => body,
Err(error_response) => return error_response, Err(error_response) => return error_response,
}; };
if is_stream { if context.is_stream {
// Streaming response with logprobs // Streaming response with logprobs
let prefill_logprobs = prefill_body let prefill_logprobs = prefill_body
.as_ref() .as_ref()
...@@ -841,7 +913,7 @@ impl PDRouter { ...@@ -841,7 +913,7 @@ impl PDRouter {
res.bytes_stream(), res.bytes_stream(),
status, status,
prefill_logprobs, prefill_logprobs,
return_logprob, context.return_logprob,
None, None,
Some(response_headers), Some(response_headers),
) )
...@@ -850,7 +922,7 @@ impl PDRouter { ...@@ -850,7 +922,7 @@ impl PDRouter {
self.process_non_streaming_response( self.process_non_streaming_response(
res, res,
status, status,
return_logprob, context.return_logprob,
prefill_body, prefill_body,
) )
.await .await
...@@ -878,7 +950,7 @@ impl PDRouter { ...@@ -878,7 +950,7 @@ impl PDRouter {
.build_post_with_headers( .build_post_with_headers(
&self.prefill_client, &self.prefill_client,
prefill.url(), prefill.url(),
route, context.route,
&json_request, &json_request,
headers, headers,
true, true,
...@@ -886,11 +958,41 @@ impl PDRouter { ...@@ -886,11 +958,41 @@ impl PDRouter {
.send(); .send();
let decode_future = decode_request.send(); let decode_future = decode_request.send();
// Send prefill response to background worker for draining
// This ensures HTTP compliance without blocking
let drain_tx = self.prefill_drain_tx.clone();
let prefill_url = prefill.url().to_string();
tokio::spawn(async move { tokio::spawn(async move {
if let Ok(response) = prefill_future.await { if let Ok(response) = prefill_future.await {
// Consume the entire response body to maintain HTTP compliance // Try to send to drain worker
// This runs in the background and won't block the decode response // If channel is full (under extreme load), drain inline as fallback
let _ = response.bytes().await; match drain_tx.try_send(response) {
Ok(_) => {
// Successfully queued for draining
debug!("Prefill response queued for draining");
}
Err(mpsc::error::TrySendError::Full(response)) => {
// Channel full - drain inline as fallback
warn!("Prefill drain channel full (capacity exceeded), draining inline for {}", prefill_url);
RouterMetrics::record_pd_prefill_error(&prefill_url);
// Drain inline with timeout to prevent blocking too long
let drain_future = async {
let mut stream = response.bytes_stream();
while stream.next().await.is_some() {
// Just drain
}
};
match tokio::time::timeout(Duration::from_secs(1), drain_future).await {
Ok(_) => debug!("Inline drain completed for {}", prefill_url),
Err(_) => error!("Inline drain timeout for {}", prefill_url),
}
}
Err(mpsc::error::TrySendError::Closed(_)) => {
error!("Prefill drain channel closed!");
}
}
} }
}); });
...@@ -900,8 +1002,8 @@ impl PDRouter { ...@@ -900,8 +1002,8 @@ impl PDRouter {
// Update metrics // Update metrics
let duration = start_time.elapsed(); let duration = start_time.elapsed();
RouterMetrics::record_pd_request_duration(route, duration); RouterMetrics::record_pd_request_duration(context.route, duration);
RouterMetrics::record_pd_request(route); RouterMetrics::record_pd_request(context.route);
RouterMetrics::record_pd_prefill_request(prefill.url()); RouterMetrics::record_pd_prefill_request(prefill.url());
RouterMetrics::record_pd_decode_request(decode.url()); RouterMetrics::record_pd_decode_request(decode.url());
...@@ -928,7 +1030,7 @@ impl PDRouter { ...@@ -928,7 +1030,7 @@ impl PDRouter {
(status, format!("Decode server error: {}", e)).into_response() (status, format!("Decode server error: {}", e)).into_response()
} }
} }
} else if is_stream { } else if context.is_stream {
// Streaming response without logprobs - direct passthrough // Streaming response without logprobs - direct passthrough
let decode_url = decode.url().to_string(); let decode_url = decode.url().to_string();
let response_headers = let response_headers =
...@@ -1280,10 +1382,10 @@ impl PDRouter { ...@@ -1280,10 +1382,10 @@ impl PDRouter {
fn build_post_with_headers( fn build_post_with_headers(
&self, &self,
client: &reqwest::Client, client: &Client,
url: &str, url: &str,
route: &str, route: &str,
json_request: &serde_json::Value, json_request: &Value,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
connection_close: bool, connection_close: bool,
) -> reqwest::RequestBuilder { ) -> reqwest::RequestBuilder {
...@@ -1894,6 +1996,7 @@ mod tests { ...@@ -1894,6 +1996,7 @@ mod tests {
load_monitor_handle: None, load_monitor_handle: None,
client: Client::new(), client: Client::new(),
prefill_client: Client::new(), prefill_client: Client::new(),
prefill_drain_tx: mpsc::channel(100).0,
retry_config: RetryConfig::default(), retry_config: RetryConfig::default(),
circuit_breaker_config: CircuitBreakerConfig::default(), circuit_breaker_config: CircuitBreakerConfig::default(),
_prefill_health_checker: None, _prefill_health_checker: None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment