Unverified Commit a59cbea9 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] harden retries + metrics; fix streaming load; header filtering (#8972)

parent 53f7874a
...@@ -7,7 +7,7 @@ use crate::routers::{RouterTrait, WorkerManagement}; ...@@ -7,7 +7,7 @@ use crate::routers::{RouterTrait, WorkerManagement};
use axum::{ use axum::{
body::Body, body::Body,
extract::Request, extract::Request,
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode}, http::{header::CONTENT_LENGTH, header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
Json, Json,
}; };
...@@ -351,9 +351,8 @@ impl Router { ...@@ -351,9 +351,8 @@ impl Router {
Ok(worker_url) => { Ok(worker_url) => {
let mut request_builder = self.client.get(format!("{}/{}", worker_url, endpoint)); let mut request_builder = self.client.get(format!("{}/{}", worker_url, endpoint));
for (name, value) in headers { for (name, value) in headers {
if name.to_lowercase() != "content-type" let name_lc = name.to_lowercase();
&& name.to_lowercase() != "content-length" if name_lc != "content-type" && name_lc != "content-length" {
{
request_builder = request_builder.header(name, value); request_builder = request_builder.header(name, value);
} }
} }
...@@ -406,6 +405,14 @@ impl Router { ...@@ -406,6 +405,14 @@ impl Router {
// Select worker based on text // Select worker based on text
let worker_url = self.select_generate_worker_from_text(&text); let worker_url = self.select_generate_worker_from_text(&text);
if worker_url.is_empty() {
RouterMetrics::record_request_error(route, "no_healthy_workers");
return (
StatusCode::SERVICE_UNAVAILABLE,
"No healthy workers available",
)
.into_response();
}
let mut request_retries = 0; let mut request_retries = 0;
// Try the same worker multiple times // Try the same worker multiple times
...@@ -443,9 +450,15 @@ impl Router { ...@@ -443,9 +450,15 @@ impl Router {
if response.status().is_success() { if response.status().is_success() {
let duration = start.elapsed(); let duration = start.elapsed();
RouterMetrics::record_request(route);
RouterMetrics::record_generate_duration(duration); RouterMetrics::record_generate_duration(duration);
return response; return response;
} else { } else {
let status = response.status();
if status.is_client_error() && status != StatusCode::TOO_MANY_REQUESTS {
RouterMetrics::record_request_error(route, "client_error");
return response;
}
// if the worker is healthy, it means the request is bad, so return the error response // if the worker is healthy, it means the request is bad, so return the error response
let health_response = self.send_health_check(&worker_url).await; let health_response = self.send_health_check(&worker_url).await;
if health_response.status().is_success() { if health_response.status().is_success() {
...@@ -473,6 +486,9 @@ impl Router { ...@@ -473,6 +486,9 @@ impl Router {
self.remove_worker(&worker_url); self.remove_worker(&worker_url);
break; break;
} }
let backoff_ms = (100u64 * (request_retries as u64)).min(1000);
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
} }
} }
...@@ -524,8 +540,6 @@ impl Router { ...@@ -524,8 +540,6 @@ impl Router {
is_stream: bool, is_stream: bool,
load_incremented: bool, // Whether load was incremented for this request load_incremented: bool, // Whether load was incremented for this request
) -> Response { ) -> Response {
let start = Instant::now();
let mut request_builder = if self.dp_aware { let mut request_builder = if self.dp_aware {
let (worker_url_prefix, dp_rank) = match Self::extract_dp_rank(worker_url) { let (worker_url_prefix, dp_rank) = match Self::extract_dp_rank(worker_url) {
Ok(tup) => tup, Ok(tup) => tup,
...@@ -582,9 +596,7 @@ impl Router { ...@@ -582,9 +596,7 @@ impl Router {
if let Some(headers) = headers { if let Some(headers) = headers {
for (name, value) in headers { for (name, value) in headers {
// Skip Content-Type and Content-Length as .json() sets them // Skip Content-Type and Content-Length as .json() sets them
if name.to_string().to_lowercase() != "content-type" if *name != CONTENT_TYPE && *name != CONTENT_LENGTH {
&& name.to_string().to_lowercase() != "content-length"
{
request_builder = request_builder.header(name, value); request_builder = request_builder.header(name, value);
} }
} }
...@@ -639,11 +651,6 @@ impl Router { ...@@ -639,11 +651,6 @@ impl Router {
} }
} }
// Record metrics
let duration = start.elapsed();
RouterMetrics::record_generate_duration(duration);
RouterMetrics::record_request(route);
response response
} else if load_incremented { } else if load_incremented {
// For streaming with load tracking, we need to manually decrement when done // For streaming with load tracking, we need to manually decrement when done
...@@ -656,6 +663,7 @@ impl Router { ...@@ -656,6 +663,7 @@ impl Router {
// Spawn task to forward stream and detect completion // Spawn task to forward stream and detect completion
tokio::spawn(async move { tokio::spawn(async move {
let mut stream = stream; let mut stream = stream;
let mut decremented = false;
while let Some(chunk) = stream.next().await { while let Some(chunk) = stream.next().await {
match chunk { match chunk {
Ok(bytes) => { Ok(bytes) => {
...@@ -674,6 +682,7 @@ impl Router { ...@@ -674,6 +682,7 @@ impl Router {
&worker_url, &worker_url,
worker.load(), worker.load(),
); );
decremented = true;
} }
} }
} }
...@@ -687,6 +696,15 @@ impl Router { ...@@ -687,6 +696,15 @@ impl Router {
} }
} }
} }
if !decremented {
if let Ok(workers_guard) = workers.read() {
if let Some(worker) = workers_guard.iter().find(|w| w.url() == &worker_url)
{
worker.decrement_load();
RouterMetrics::set_running_requests(&worker_url, worker.load());
}
}
}
}); });
let stream = UnboundedReceiverStream::new(rx); let stream = UnboundedReceiverStream::new(rx);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment