Unverified Commit c3a1d775 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] remove pd router draining channel (#10767)

parent 89971c4c
......@@ -219,6 +219,7 @@ jobs:
--decode http://127.0.0.7:30007 \
--decode http://127.0.0.8:30008 \
--host 127.0.0.9 \
--log-level warning \
--port 8000 &
ROUTER_PID=$!
......@@ -300,8 +301,8 @@ jobs:
--task text-to-text \
--num-concurrency 64 \
--traffic-scenario "D(8000,2000)" \
--max-requests-per-run 640 \
--max-time-per-run 2 \
--max-requests-per-run 1000 \
--max-time-per-run 5 \
--experiment-folder-name "benchmark_${policy}" \
--experiment-base-dir "."
......@@ -341,7 +342,7 @@ jobs:
# These can be adjusted based on your performance requirements
ttft_threshold=4.7 # Max 4.7 seconds for mean TTFT
e2e_latency_threshold=35.0 # Max 35.0 seconds for mean E2E latency
input_throughput_threshold=12000 # Min 12000 tokens/s for mean input throughput
input_throughput_threshold=10000 # Min 02000 tokens/s for mean input throughput
output_throughput_threshold=68 # Min 68 tokens/s for mean output throughput
......@@ -558,12 +559,12 @@ jobs:
# Check thresholds (using same values as in main workflow)
validation_status="✅"
if [ "$ttft" != "N/A" ] && [ "$ttft" != "null" ]; then
if (( $(echo "$ttft > 2.0" | bc -l 2>/dev/null || echo "0") )); then
if (( $(echo "$ttft > 4.7" | bc -l 2>/dev/null || echo "0") )); then
validation_status="❌"
fi
fi
if [ "$e2e_latency" != "N/A" ] && [ "$e2e_latency" != "null" ]; then
if (( $(echo "$e2e_latency > 24.0" | bc -l 2>/dev/null || echo "0") )); then
if (( $(echo "$e2e_latency > 35.0" | bc -l 2>/dev/null || echo "0") )); then
validation_status="❌"
fi
fi
......@@ -573,7 +574,7 @@ jobs:
fi
fi
if [ "$output_throughput" != "N/A" ] && [ "$output_throughput" != "null" ]; then
if (( $(echo "$output_throughput < 90" | bc -l 2>/dev/null || echo "0") )); then
if (( $(echo "$output_throughput < 68" | bc -l 2>/dev/null || echo "0") )); then
validation_status="❌"
fi
fi
......
......@@ -27,7 +27,6 @@ use serde_json::{json, Value};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, info, warn};
......@@ -38,11 +37,9 @@ pub struct PDRouter {
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
pub client: Client,
pub prefill_client: Client,
pub retry_config: RetryConfig,
pub api_key: Option<String>,
pub enable_igw: bool,
prefill_drain_tx: mpsc::Sender<reqwest::Response>,
}
#[derive(Clone)]
......@@ -241,72 +238,7 @@ impl PDRouter {
None
};
let prefill_client = Client::builder()
.pool_max_idle_per_host(0)
.http1_only()
.connect_timeout(Duration::from_millis(300))
.timeout(Duration::from_secs(ctx.router_config.request_timeout_secs))
.build()
.map_err(|e| format!("Failed to build prefill client: {}", e))?;
let (prefill_drain_tx, mut prefill_drain_rx) = mpsc::channel::<reqwest::Response>(2000);
// TODO reevaluate a simpler approach (e.g. do we really need to deal with fire and forget)
tokio::spawn(async move {
info!("Prefill drain coordinator started");
let max_concurrent_drains = 100;
let semaphore = Arc::new(tokio::sync::Semaphore::new(max_concurrent_drains));
while let Some(response) = prefill_drain_rx.recv().await {
let permit = semaphore.clone().acquire_owned().await;
match permit {
Ok(permit) => {
tokio::spawn(async move {
let url = response.url().to_string();
let status = response.status();
if !status.is_success() {
error!("Prefill drain: error status={} url={}", status, url);
RouterMetrics::record_pd_prefill_error(&url);
}
let start = Instant::now();
let mut stream = response.bytes_stream();
let mut bytes_drained = 0;
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => bytes_drained += chunk.len(),
Err(e) => {
debug!(
"Prefill drain: error streaming url={} error={}",
url, e
);
break;
}
}
}
let elapsed = start.elapsed();
if elapsed > Duration::from_millis(100) {
debug!(
"Prefill drain: slow drain {} bytes from {} in {:?}",
bytes_drained, url, elapsed
);
}
drop(permit);
});
}
Err(_) => {
break;
}
}
}
info!("Prefill drain coordinator shutting down");
});
// No longer need prefill drain channel - we'll wait for both responses
Ok(PDRouter {
worker_registry: Arc::clone(&ctx.worker_registry),
......@@ -314,8 +246,6 @@ impl PDRouter {
worker_loads,
load_monitor_handle,
client: ctx.client.clone(),
prefill_client,
prefill_drain_tx,
retry_config: ctx.router_config.effective_retry_config(),
api_key: ctx.router_config.api_key.clone(),
enable_igw: ctx.router_config.enable_igw,
......@@ -585,7 +515,15 @@ impl PDRouter {
None
};
// Build decode request with shared client
// Build both requests
let prefill_request = self.build_post_with_headers(
&self.client,
prefill.url(),
context.route,
&json_request,
headers,
false,
);
let decode_request = self.build_post_with_headers(
&self.client,
decode.url(),
......@@ -595,57 +533,46 @@ impl PDRouter {
false,
);
// Send both requests concurrently
// Send both requests concurrently and wait for both
debug!(
"Sending concurrent requests to prefill={} decode={}",
prefill.url(),
decode.url()
);
if context.return_logprob {
// Build prefill request with shared client when we need response body
let prefill_request = self.build_post_with_headers(
&self.client,
prefill.url(),
context.route,
&json_request,
headers,
false,
);
// When we need logprobs, wait for both responses
let (prefill_result, decode_result) =
tokio::join!(prefill_request.send(), decode_request.send());
debug!("Received responses from both servers");
let duration = start_time.elapsed();
RouterMetrics::record_pd_request_duration(context.route, duration);
RouterMetrics::record_pd_request(context.route);
RouterMetrics::record_pd_prefill_request(prefill.url());
RouterMetrics::record_pd_decode_request(decode.url());
// Process decode response with prefill for logprobs
debug!("Processing decode response with logprobs");
match decode_result {
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
debug!("Decode response status: {}", status);
if !status.is_success() {
RouterMetrics::record_pd_decode_error(decode.url());
error!(
"Decode server returned error status decode_url={} status={}",
decode.url(),
status
);
let (prefill_result, decode_result) =
tokio::join!(prefill_request.send(), decode_request.send());
debug!("Received responses from both servers");
return self
.handle_decode_error_response(res, &context, prefill, decode)
.await;
}
let duration = start_time.elapsed();
RouterMetrics::record_pd_request_duration(context.route, duration);
RouterMetrics::record_pd_request(context.route);
RouterMetrics::record_pd_prefill_request(prefill.url());
RouterMetrics::record_pd_decode_request(decode.url());
// Process decode response
match decode_result {
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
debug!("Decode response status: {}", status);
if !status.is_success() {
RouterMetrics::record_pd_decode_error(decode.url());
error!(
"Decode server returned error status decode_url={} status={}",
decode.url(),
status
);
return self
.handle_decode_error_response(res, &context, prefill, decode)
.await;
}
// Process prefill response for logprobs
let prefill_body = match self
// Process prefill response
let prefill_body = if context.return_logprob {
match self
.process_prefill_response(
prefill_result,
prefill.url(),
......@@ -655,32 +582,46 @@ impl PDRouter {
{
Ok((_, body)) => body,
Err(error_response) => return error_response,
};
}
} else {
// Even if we don't need logprobs, we should check prefill status
match self
.process_prefill_response(prefill_result, prefill.url(), false)
.await
{
Ok((_, body)) => body,
Err(error_response) => return error_response,
}
};
if context.is_stream {
// Streaming response with logprobs
let prefill_logprobs = prefill_body
if context.is_stream {
// Streaming response
let prefill_logprobs = if context.return_logprob {
prefill_body
.as_ref()
.and_then(|body| serde_json::from_slice::<Value>(body).ok())
.and_then(|json| {
json.pointer("/meta_info/input_token_logprobs").cloned()
});
let response_headers =
header_utils::preserve_response_headers(res.headers());
self.create_streaming_response(
res.bytes_stream(),
status,
prefill_logprobs,
context.return_logprob,
None,
Some(response_headers),
prefill,
decode,
)
})
} else {
// Non-streaming response with logprobs
None
};
let response_headers = header_utils::preserve_response_headers(res.headers());
self.create_streaming_response(
res.bytes_stream(),
status,
prefill_logprobs,
context.return_logprob,
None,
Some(response_headers),
prefill,
decode,
)
} else {
// Non-streaming response
if context.return_logprob {
self.process_non_streaming_response(
res,
status,
......@@ -688,122 +629,8 @@ impl PDRouter {
prefill_body,
)
.await
}
}
Err(e) => {
error!(
decode_url = %decode.url(),
error = %e,
"Decode request failed"
);
RouterMetrics::record_pd_decode_error(decode.url());
(
StatusCode::BAD_GATEWAY,
format!("Decode server error: {}", e),
)
.into_response()
}
}
} else {
// When we don't need logprobs, only wait for decode response
// Send both requests concurrently but don't wait for prefill
// Use dedicated prefill client with Connection: close
let prefill_future = self
.build_post_with_headers(
&self.prefill_client,
prefill.url(),
context.route,
&json_request,
headers,
true,
)
.send();
let decode_future = decode_request.send();
// Send prefill response to background worker for draining
// This ensures HTTP compliance without blocking
let drain_tx = self.prefill_drain_tx.clone();
let prefill_url = prefill.url().to_string();
tokio::spawn(async move {
if let Ok(response) = prefill_future.await {
// Try to send to drain worker
// If channel is full (under extreme load), drain inline as fallback
match drain_tx.try_send(response) {
Ok(_) => {
// Successfully queued for draining
debug!("Prefill response queued for draining");
}
Err(mpsc::error::TrySendError::Full(response)) => {
// Channel full - drain inline as fallback
warn!("Prefill drain channel full (capacity exceeded), draining inline for {}", prefill_url);
RouterMetrics::record_pd_prefill_error(&prefill_url);
// Drain inline with timeout to prevent blocking too long
let drain_future = async {
let mut stream = response.bytes_stream();
while stream.next().await.is_some() {
// Just drain
}
};
match tokio::time::timeout(Duration::from_secs(1), drain_future).await {
Ok(_) => debug!("Inline drain completed for {}", prefill_url),
Err(_) => error!("Inline drain timeout for {}", prefill_url),
}
}
Err(mpsc::error::TrySendError::Closed(_)) => {
error!("Prefill drain channel closed!");
}
}
}
});
// Wait only for decode response
let decode_result = decode_future.await;
debug!("Received decode response");
let duration = start_time.elapsed();
RouterMetrics::record_pd_request_duration(context.route, duration);
RouterMetrics::record_pd_request(context.route);
RouterMetrics::record_pd_prefill_request(prefill.url());
RouterMetrics::record_pd_decode_request(decode.url());
// Process decode response immediately
debug!("Processing decode response (no logprobs)");
match decode_result {
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
debug!("Decode response status: {}", status);
if !status.is_success() {
RouterMetrics::record_pd_decode_error(decode.url());
error!(
"Decode server returned error status decode_url={} status={}",
decode.url(),
status
);
self.handle_decode_error_response(res, &context, prefill, decode)
.await
} else if context.is_stream {
// Streaming response without logprobs - direct passthrough
let decode_url = decode.url().to_string();
let response_headers =
header_utils::preserve_response_headers(res.headers());
self.create_streaming_response(
res.bytes_stream(),
status,
None,
false,
Some(decode_url),
Some(response_headers),
prefill,
decode,
)
} else {
// Non-streaming response without logprobs - direct passthrough like fast version
// Direct passthrough when no logprobs needed
let response_headers =
header_utils::preserve_response_headers(res.headers());
......@@ -823,19 +650,19 @@ impl PDRouter {
}
}
}
Err(e) => {
error!(
decode_url = %decode.url(),
error = %e,
"Decode request failed"
);
RouterMetrics::record_pd_decode_error(decode.url());
(
StatusCode::BAD_GATEWAY,
format!("Decode server error: {}", e),
)
.into_response()
}
}
Err(e) => {
error!(
decode_url = %decode.url(),
error = %e,
"Decode request failed"
);
RouterMetrics::record_pd_decode_error(decode.url());
(
StatusCode::BAD_GATEWAY,
format!("Decode server error: {}", e),
)
.into_response()
}
}
}
......@@ -1802,8 +1629,6 @@ mod tests {
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
load_monitor_handle: None,
client: Client::new(),
prefill_client: Client::new(),
prefill_drain_tx: mpsc::channel(100).0,
retry_config: RetryConfig::default(),
api_key: Some("test_api_key".to_string()),
enable_igw: false,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment