Unverified Commit 3578eb1e authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] address worker load tracking consistency (#9523)


Co-authored-by: default avatarfzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
parent 0936c766
...@@ -55,6 +55,12 @@ pub trait Worker: Send + Sync + fmt::Debug { ...@@ -55,6 +55,12 @@ pub trait Worker: Send + Sync + fmt::Debug {
/// Decrement the load counter /// Decrement the load counter
fn decrement_load(&self); fn decrement_load(&self);
/// Reset the load counter to 0 (for sync/recovery)
fn reset_load(&self) {
// Default implementation - does nothing
// Workers that track load should override this
}
/// Get the number of processed requests /// Get the number of processed requests
fn processed_requests(&self) -> usize; fn processed_requests(&self) -> usize;
...@@ -364,6 +370,10 @@ impl Worker for BasicWorker { ...@@ -364,6 +370,10 @@ impl Worker for BasicWorker {
.ok(); .ok();
} }
fn reset_load(&self) {
self.load_counter.store(0, Ordering::Relaxed);
}
fn processed_requests(&self) -> usize { fn processed_requests(&self) -> usize {
self.processed_counter.load(Ordering::Relaxed) self.processed_counter.load(Ordering::Relaxed)
} }
...@@ -449,6 +459,10 @@ impl Worker for DPAwareWorker { ...@@ -449,6 +459,10 @@ impl Worker for DPAwareWorker {
self.base_worker.decrement_load(); self.base_worker.decrement_load();
} }
fn reset_load(&self) {
self.base_worker.reset_load();
}
fn processed_requests(&self) -> usize { fn processed_requests(&self) -> usize {
self.base_worker.processed_requests() self.base_worker.processed_requests()
} }
...@@ -825,6 +839,10 @@ pub fn start_health_checker( ...@@ -825,6 +839,10 @@ pub fn start_health_checker(
let mut interval = let mut interval =
tokio::time::interval(tokio::time::Duration::from_secs(check_interval_secs)); tokio::time::interval(tokio::time::Duration::from_secs(check_interval_secs));
// Counter for periodic load reset (every 10 health check cycles)
let mut check_count = 0u64;
const LOAD_RESET_INTERVAL: u64 = 10;
loop { loop {
interval.tick().await; interval.tick().await;
...@@ -834,6 +852,8 @@ pub fn start_health_checker( ...@@ -834,6 +852,8 @@ pub fn start_health_checker(
break; break;
} }
check_count += 1;
// Check health of all workers // Check health of all workers
let workers_to_check = match workers.read() { let workers_to_check = match workers.read() {
Ok(guard) => guard.iter().map(|w| w.clone_worker()).collect::<Vec<_>>(), Ok(guard) => guard.iter().map(|w| w.clone_worker()).collect::<Vec<_>>(),
...@@ -843,6 +863,22 @@ pub fn start_health_checker( ...@@ -843,6 +863,22 @@ pub fn start_health_checker(
} }
}; };
// Periodically reset load counters to prevent drift
// Only do this when we believe all workers should be idle
if check_count.is_multiple_of(LOAD_RESET_INTERVAL) {
let max_load = workers_to_check.iter().map(|w| w.load()).max().unwrap_or(0);
// Only reset if load appears to be very low (likely drift)
if max_load <= 2 {
tracing::debug!(
"Resetting load counters to prevent drift (max_load: {})",
max_load
);
for worker in &workers_to_check {
worker.reset_load();
}
}
}
// Perform health checks concurrently // Perform health checks concurrently
let health_checks = workers_to_check.iter().map(|worker| { let health_checks = workers_to_check.iter().map(|worker| {
let worker_url = worker.url().to_string(); let worker_url = worker.url().to_string();
......
...@@ -1243,10 +1243,19 @@ impl PDRouter { ...@@ -1243,10 +1243,19 @@ impl PDRouter {
let decode_workers = self.decode_workers.clone(); let decode_workers = self.decode_workers.clone();
tokio::spawn(async move { tokio::spawn(async move {
// Use a flag to track whether stream completed successfully
let mut stream_completed = false;
futures_util::pin_mut!(stream); futures_util::pin_mut!(stream);
while let Some(chunk_result) = stream.next().await { while let Some(chunk_result) = stream.next().await {
match chunk_result { match chunk_result {
Ok(chunk) => { Ok(chunk) => {
// Check for stream end marker to decrement load early
let is_done = chunk
.as_ref()
.windows(12)
.any(|window| window == b"data: [DONE]");
let result = if return_logprob && prefill_logprobs.is_some() { let result = if return_logprob && prefill_logprobs.is_some() {
// Try to merge logprobs // Try to merge logprobs
Self::merge_streaming_logprobs(prefill_logprobs.clone(), &chunk) Self::merge_streaming_logprobs(prefill_logprobs.clone(), &chunk)
...@@ -1258,6 +1267,12 @@ impl PDRouter { ...@@ -1258,6 +1267,12 @@ impl PDRouter {
if tx.send(Ok(result)).is_err() { if tx.send(Ok(result)).is_err() {
break; break;
} }
// If we see the done marker, decrement load immediately
if is_done {
stream_completed = true;
break;
}
} }
Err(e) => { Err(e) => {
if let Some(ref url) = decode_url { if let Some(ref url) = decode_url {
...@@ -1270,20 +1285,30 @@ impl PDRouter { ...@@ -1270,20 +1285,30 @@ impl PDRouter {
} }
} }
// Decrement load after streaming is complete // Always decrement load after streaming (either completes or errors)
// Find and decrement prefill worker
if let Ok(prefill_workers_guard) = prefill_workers.read() { if let Ok(prefill_workers_guard) = prefill_workers.read() {
for worker in prefill_workers_guard.iter() { for worker in prefill_workers_guard.iter() {
if worker.url() == prefill_url.as_str() { if worker.url() == prefill_url.as_str() {
worker.decrement_load(); worker.decrement_load();
debug!(
"Decremented load for prefill worker: {} (stream_completed: {})",
prefill_url, stream_completed
);
break; break;
} }
} }
} }
// Find and decrement decode worker
if let Ok(decode_workers_guard) = decode_workers.read() { if let Ok(decode_workers_guard) = decode_workers.read() {
for worker in decode_workers_guard.iter() { for worker in decode_workers_guard.iter() {
if worker.url() == decode_url_str.as_str() { if worker.url() == decode_url_str.as_str() {
worker.decrement_load(); worker.decrement_load();
debug!(
"Decremented load for decode worker: {} (stream_completed: {})",
decode_url_str, stream_completed
);
break; break;
} }
} }
......
...@@ -490,6 +490,13 @@ impl Router { ...@@ -490,6 +490,13 @@ impl Router {
false false
}; };
// Keep a clone for potential cleanup on retry
let worker_for_cleanup = if load_incremented {
Some(worker.clone_worker())
} else {
None
};
let response = self let response = self
.send_typed_request( .send_typed_request(
headers, headers,
...@@ -502,6 +509,19 @@ impl Router { ...@@ -502,6 +509,19 @@ impl Router {
.await; .await;
worker.record_outcome(response.status().is_success()); worker.record_outcome(response.status().is_success());
// For retryable failures, we need to decrement load since send_typed_request
// won't have done it (it only decrements on success or non-retryable failures)
if is_retryable_status(response.status()) && load_incremented {
if let Some(cleanup_worker) = worker_for_cleanup {
cleanup_worker.decrement_load();
RouterMetrics::set_running_requests(
cleanup_worker.url(),
cleanup_worker.load(),
);
}
}
response response
}, },
// should_retry predicate // should_retry predicate
...@@ -657,13 +677,25 @@ impl Router { ...@@ -657,13 +677,25 @@ impl Router {
response response
} }
Err(e) => { Err(e) => {
// IMPORTANT: Decrement load on error before returning
if load_incremented {
if let Ok(workers_guard) = self.workers.read() {
if let Some(worker) =
workers_guard.iter().find(|w| w.url() == worker_url)
{
worker.decrement_load();
RouterMetrics::set_running_requests(worker_url, worker.load());
}
}
}
let error_msg = format!("Failed to get response body: {}", e); let error_msg = format!("Failed to get response body: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, error_msg).into_response() (StatusCode::INTERNAL_SERVER_ERROR, error_msg).into_response()
} }
}; };
// Decrement load counter for non-streaming requests if it was incremented // Decrement load counter for non-streaming requests if it was incremented
if load_incremented && !is_stream { if load_incremented {
if let Ok(workers_guard) = self.workers.read() { if let Ok(workers_guard) = self.workers.read() {
if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) { if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) {
worker.decrement_load(); worker.decrement_load();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment