Unverified Commit a59cbea9 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] harden retries + metrics; fix streaming load; header filtering (#8972)

parent 53f7874a
......@@ -7,7 +7,7 @@ use crate::routers::{RouterTrait, WorkerManagement};
use axum::{
body::Body,
extract::Request,
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
http::{header::CONTENT_LENGTH, header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response},
Json,
};
......@@ -351,9 +351,8 @@ impl Router {
Ok(worker_url) => {
let mut request_builder = self.client.get(format!("{}/{}", worker_url, endpoint));
for (name, value) in headers {
if name.to_lowercase() != "content-type"
&& name.to_lowercase() != "content-length"
{
let name_lc = name.to_lowercase();
if name_lc != "content-type" && name_lc != "content-length" {
request_builder = request_builder.header(name, value);
}
}
......@@ -406,6 +405,14 @@ impl Router {
// Select worker based on text
let worker_url = self.select_generate_worker_from_text(&text);
if worker_url.is_empty() {
RouterMetrics::record_request_error(route, "no_healthy_workers");
return (
StatusCode::SERVICE_UNAVAILABLE,
"No healthy workers available",
)
.into_response();
}
let mut request_retries = 0;
// Try the same worker multiple times
......@@ -443,9 +450,15 @@ impl Router {
if response.status().is_success() {
let duration = start.elapsed();
RouterMetrics::record_request(route);
RouterMetrics::record_generate_duration(duration);
return response;
} else {
let status = response.status();
if status.is_client_error() && status != StatusCode::TOO_MANY_REQUESTS {
RouterMetrics::record_request_error(route, "client_error");
return response;
}
// if the worker is healthy, it means the request is bad, so return the error response
let health_response = self.send_health_check(&worker_url).await;
if health_response.status().is_success() {
......@@ -473,6 +486,9 @@ impl Router {
self.remove_worker(&worker_url);
break;
}
let backoff_ms = (100u64 * (request_retries as u64)).min(1000);
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
}
}
......@@ -524,8 +540,6 @@ impl Router {
is_stream: bool,
load_incremented: bool, // Whether load was incremented for this request
) -> Response {
let start = Instant::now();
let mut request_builder = if self.dp_aware {
let (worker_url_prefix, dp_rank) = match Self::extract_dp_rank(worker_url) {
Ok(tup) => tup,
......@@ -582,9 +596,7 @@ impl Router {
if let Some(headers) = headers {
for (name, value) in headers {
// Skip Content-Type and Content-Length as .json() sets them
if name.to_string().to_lowercase() != "content-type"
&& name.to_string().to_lowercase() != "content-length"
{
if *name != CONTENT_TYPE && *name != CONTENT_LENGTH {
request_builder = request_builder.header(name, value);
}
}
......@@ -639,11 +651,6 @@ impl Router {
}
}
// Record metrics
let duration = start.elapsed();
RouterMetrics::record_generate_duration(duration);
RouterMetrics::record_request(route);
response
} else if load_incremented {
// For streaming with load tracking, we need to manually decrement when done
......@@ -656,6 +663,7 @@ impl Router {
// Spawn task to forward stream and detect completion
tokio::spawn(async move {
let mut stream = stream;
let mut decremented = false;
while let Some(chunk) = stream.next().await {
match chunk {
Ok(bytes) => {
......@@ -674,6 +682,7 @@ impl Router {
&worker_url,
worker.load(),
);
decremented = true;
}
}
}
......@@ -687,6 +696,15 @@ impl Router {
}
}
}
if !decremented {
if let Ok(workers_guard) = workers.read() {
if let Some(worker) = workers_guard.iter().find(|w| w.url() == &worker_url)
{
worker.decrement_load();
RouterMetrics::set_running_requests(&worker_url, worker.load());
}
}
}
});
let stream = UnboundedReceiverStream::new(rx);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment