Unverified Commit 8bdf18e5 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix: should route based on waiting requests, not active (#989)

parent 5c98f8d1
......@@ -215,8 +215,10 @@ pub fn process_worker_selection(
.get_mut(&selection.worker_id)
.expect("worker not found");
// Update worker state
worker.data.request_active_slots += 1;
// Update worker state predictively
// Will be overwritten on next polling of metrics
worker.data.num_requests_waiting += 1;
// Assumes radix attention so KV load is only incremented by uncached blocks
worker.data.kv_active_blocks += selection.required_blocks - selection.overlap_blocks as u64;
// Emit event
......@@ -245,7 +247,7 @@ impl WorkerSelector for DefaultWorkerSelector {
assert!(request.isl_tokens > 0);
let mut worker_scores = HashMap::new();
let mut max_active = 0.0;
let mut max_waiting = 0.0;
// Calculate worker scores and find max waiting requests
for (worker_id, ep) in workers.endpoints.iter() {
......@@ -256,16 +258,16 @@ impl WorkerSelector for DefaultWorkerSelector {
}
// Track max waiting requests
max_active = f64::max(max_active, ep.data.request_active_slots as f64);
max_waiting = f64::max(max_waiting, ep.data.num_requests_waiting as f64);
}
if max_active == 0.0 {
if max_waiting == 0.0 {
return Err(KvSchedulerError::NoEndpoints);
}
// make immutable
let worker_scores = worker_scores;
let max_active = max_active;
let max_waiting = max_waiting;
// Calculate logits for each worker
let mut best_logit = f64::NEG_INFINITY;
......@@ -280,14 +282,14 @@ impl WorkerSelector for DefaultWorkerSelector {
// Calculate normalized metrics
assert!(ep.data.kv_total_blocks > 0);
let gpu_cache_usage = ep.data.kv_active_blocks as f64 / ep.data.kv_total_blocks as f64;
let normalized_active = if max_active > 0.0 {
ep.data.request_active_slots as f64 / max_active
let normalized_waiting = if max_waiting > 0.0 {
ep.data.num_requests_waiting as f64 / max_waiting
} else {
0.0
};
// Calculate logit using same formula as Python
let logit = 2.0 * score - gpu_cache_usage - normalized_active;
let logit = 2.0 * score - gpu_cache_usage - normalized_waiting;
tracing::info!(
"Formula for {}: {:.3} = 2.0 * {:.3} - {:.3} - {:.3}",
......@@ -295,7 +297,7 @@ impl WorkerSelector for DefaultWorkerSelector {
logit,
score,
gpu_cache_usage,
normalized_active
normalized_waiting
);
// Track best workers
......@@ -313,8 +315,10 @@ impl WorkerSelector for DefaultWorkerSelector {
}
// Return early if no valid workers found
if best_workers.is_empty() || best_logit == 0.0 {
if best_workers.is_empty() {
return Err(KvSchedulerError::NoEndpoints);
} else if best_logit == 0.0 {
tracing::warn!("best worker logit is 0");
}
let worker_id = if best_workers.len() == 1 {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment