"examples/vscode:/vscode.git/clone" did not exist on "9f310225eb48964710baa9c465b842c64e2c35c9"
Unverified Commit 1142478a authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix(kv-router): route pinned requests through scheduler (#8142)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent b35db6e2
...@@ -520,6 +520,7 @@ impl RouterHandles { ...@@ -520,6 +520,7 @@ impl RouterHandles {
None, None,
0.0, 0.0,
None, None,
None,
allowed_worker_ids, allowed_worker_ids,
) )
.await .await
......
...@@ -1077,6 +1077,7 @@ impl KvRouter { ...@@ -1077,6 +1077,7 @@ impl KvRouter {
lora_name.clone(), lora_name.clone(),
0.0, 0.0,
None, None,
None,
None, // allowed_worker_ids: pass via RoutingHints in PreprocessedRequest path None, // allowed_worker_ids: pass via RoutingHints in PreprocessedRequest path
) )
.await .await
......
...@@ -185,6 +185,7 @@ where ...@@ -185,6 +185,7 @@ where
lora_name: Option<String>, lora_name: Option<String>,
priority_jump: f64, priority_jump: f64,
expected_output_tokens: Option<u32>, expected_output_tokens: Option<u32>,
pinned_worker: Option<WorkerWithDpRank>,
allowed_worker_ids: Option<HashSet<WorkerId>>, allowed_worker_ids: Option<HashSet<WorkerId>>,
) -> Result<SchedulingResponse, KvSchedulerError> { ) -> Result<SchedulingResponse, KvSchedulerError> {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
...@@ -204,6 +205,7 @@ where ...@@ -204,6 +205,7 @@ where
lora_name, lora_name,
priority_jump, priority_jump,
expected_output_tokens, expected_output_tokens,
pinned_worker,
allowed_worker_ids, allowed_worker_ids,
resp_tx: Some(resp_tx), resp_tx: Some(resp_tx),
}; };
...@@ -440,6 +442,7 @@ mod tests { ...@@ -440,6 +442,7 @@ mod tests {
0.0, 0.0,
None, None,
None, None,
None,
) )
.await .await
.unwrap(); .unwrap();
...@@ -480,6 +483,7 @@ mod tests { ...@@ -480,6 +483,7 @@ mod tests {
0.0, 0.0,
None, None,
None, None,
None,
) )
.await .await
.unwrap(); .unwrap();
...@@ -520,6 +524,7 @@ mod tests { ...@@ -520,6 +524,7 @@ mod tests {
0.0, 0.0,
None, None,
None, None,
None,
) )
.await .await
.unwrap(); .unwrap();
...@@ -539,6 +544,7 @@ mod tests { ...@@ -539,6 +544,7 @@ mod tests {
0.0, 0.0,
None, None,
None, None,
None,
) )
.await .await
}) })
...@@ -579,6 +585,7 @@ mod tests { ...@@ -579,6 +585,7 @@ mod tests {
0.0, 0.0,
None, None,
None, None,
None,
) )
.await .await
.unwrap(); .unwrap();
...@@ -598,6 +605,7 @@ mod tests { ...@@ -598,6 +605,7 @@ mod tests {
0.0, 0.0,
None, None,
None, None,
None,
) )
.await .await
}) })
...@@ -652,6 +660,7 @@ mod tests { ...@@ -652,6 +660,7 @@ mod tests {
0.0, 0.0,
None, None,
None, None,
None,
) )
.await .await
.unwrap(); .unwrap();
...@@ -671,6 +680,7 @@ mod tests { ...@@ -671,6 +680,7 @@ mod tests {
0.0, 0.0,
None, None,
None, None,
None,
) )
.await .await
}) })
...@@ -724,6 +734,7 @@ mod tests { ...@@ -724,6 +734,7 @@ mod tests {
0.0, 0.0,
None, None,
None, None,
None,
) )
.await .await
.unwrap(); .unwrap();
...@@ -743,6 +754,7 @@ mod tests { ...@@ -743,6 +754,7 @@ mod tests {
0.0, 0.0,
None, None,
None, None,
None,
) )
.await .await
}) })
...@@ -794,6 +806,7 @@ mod tests { ...@@ -794,6 +806,7 @@ mod tests {
0.0, 0.0,
None, None,
None, None,
None,
) )
.await .await
.unwrap(); .unwrap();
...@@ -894,6 +907,7 @@ mod tests { ...@@ -894,6 +907,7 @@ mod tests {
0.0, 0.0,
None, None,
None, None,
None,
) )
.await .await
.unwrap(); .unwrap();
...@@ -988,6 +1002,7 @@ mod tests { ...@@ -988,6 +1002,7 @@ mod tests {
0.0, 0.0,
None, None,
None, None,
None,
) )
.await .await
.unwrap(); .unwrap();
......
...@@ -60,9 +60,9 @@ impl SchedulingPolicy for LcfsPolicy { ...@@ -60,9 +60,9 @@ impl SchedulingPolicy for LcfsPolicy {
/// Weighted Shortest Processing Time (Smith's rule): /// Weighted Shortest Processing Time (Smith's rule):
/// key = (1 + priority_jump) / new_tokens, where new_tokens estimates the /// key = (1 + priority_jump) / new_tokens, where new_tokens estimates the
/// actual prefill cost by subtracting the max KV cache overlap from ISL. /// actual prefill cost by subtracting the effective KV cache overlap from ISL.
/// We use max because the downstream selector routes to the best-overlap /// Unpinned requests use the best available overlap. Pinned requests use only
/// worker, so the realized overlap is well-approximated by the best available. /// the overlap for their exact target worker so queue ordering matches routing.
/// ///
/// Optimizes for average TTFT — minimizes total weighted completion time /// Optimizes for average TTFT — minimizes total weighted completion time
/// (Smith 1956). Short or high-priority requests are scheduled before /// (Smith 1956). Short or high-priority requests are scheduled before
...@@ -76,8 +76,7 @@ impl SchedulingPolicy for WsptPolicy { ...@@ -76,8 +76,7 @@ impl SchedulingPolicy for WsptPolicy {
fn enqueue_key(&self, _arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key { fn enqueue_key(&self, _arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key {
let weight = 1.0 + request.priority_jump.max(0.0); let weight = 1.0 + request.priority_jump.max(0.0);
let max_overlap = request.overlaps.scores.values().copied().max().unwrap_or(0) as usize; let cached_tokens = request.overlap_blocks() as usize * self.block_size;
let cached_tokens = max_overlap * self.block_size;
let new_tokens = request.isl_tokens.saturating_sub(cached_tokens).max(1); let new_tokens = request.isl_tokens.saturating_sub(cached_tokens).max(1);
OrderedFloat(weight / new_tokens as f64) OrderedFloat(weight / new_tokens as f64)
} }
...@@ -141,6 +140,7 @@ mod tests { ...@@ -141,6 +140,7 @@ mod tests {
lora_name: None, lora_name: None,
priority_jump, priority_jump,
expected_output_tokens: None, expected_output_tokens: None,
pinned_worker: None,
allowed_worker_ids: None, allowed_worker_ids: None,
resp_tx: None, resp_tx: None,
} }
...@@ -289,6 +289,28 @@ mod tests { ...@@ -289,6 +289,28 @@ mod tests {
assert_eq!(key, expected); assert_eq!(key, expected);
} }
#[test]
fn wspt_uses_pinned_worker_overlap_when_present() {
let policy = WsptPolicy { block_size: 16 };
let mut req = request_with(1024, 0.0, overlaps_from(&[(0, 60), (1, 1)]));
req.pinned_worker = Some(WorkerWithDpRank::new(1, 0));
let key = policy.enqueue_key(Duration::ZERO, &req);
let expected = OrderedFloat(1.0 / 1008.0);
assert_eq!(key, expected);
}
#[test]
fn wspt_missing_pinned_overlap_uses_zero() {
let policy = WsptPolicy { block_size: 16 };
let mut req = request_with(1024, 0.0, overlaps_from(&[(0, 60)]));
req.pinned_worker = Some(WorkerWithDpRank::new(1, 0));
let key = policy.enqueue_key(Duration::ZERO, &req);
let expected = OrderedFloat(1.0 / 1024.0);
assert_eq!(key, expected);
}
#[test] #[test]
fn wspt_no_overlap_falls_back_to_isl() { fn wspt_no_overlap_falls_back_to_isl() {
let policy = WsptPolicy { block_size: 16 }; let policy = WsptPolicy { block_size: 16 };
......
...@@ -13,7 +13,7 @@ use tokio::time::Instant; ...@@ -13,7 +13,7 @@ use tokio::time::Instant;
use super::policy::{FcfsPolicy, SchedulingPolicy}; use super::policy::{FcfsPolicy, SchedulingPolicy};
use super::prefill_load::PrefillLoadEstimator; use super::prefill_load::PrefillLoadEstimator;
use super::selector::{DefaultWorkerSelector, WorkerSelector}; use super::selector::{DefaultWorkerSelector, WorkerSelector};
use super::types::{SchedulingRequest, SchedulingResponse}; use super::types::{SchedulingRequest, SchedulingResponse, pinned_worker_config};
use crate::protocols::{PrefillLoadHint, WorkerConfigLike, WorkerId, WorkerWithDpRank}; use crate::protocols::{PrefillLoadHint, WorkerConfigLike, WorkerId, WorkerWithDpRank};
use crate::sequences::{ActiveSequencesMultiWorker, SequencePublisher, SequenceRequest}; use crate::sequences::{ActiveSequencesMultiWorker, SequencePublisher, SequenceRequest};
...@@ -137,21 +137,31 @@ impl< ...@@ -137,21 +137,31 @@ impl<
/// If queueing is disabled or workers have capacity, schedule immediately. /// If queueing is disabled or workers have capacity, schedule immediately.
/// Otherwise park in the pending heap. /// Otherwise park in the pending heap.
/// ///
/// When `allowed_worker_ids` is set on the request (external routing), the /// When `allowed_worker_ids` is set on the request without an exact pin
/// capacity check is skipped. /// (external routing), the capacity check is skipped.
pub async fn enqueue(&self, request: SchedulingRequest) { pub async fn enqueue(&self, mut request: SchedulingRequest) {
if let Err(error) = request.validate_worker_constraints() {
request.respond(Err(error));
return;
}
let Some(threshold) = self.threshold_frac else { let Some(threshold) = self.threshold_frac else {
self.schedule(request, Instant::now()).await; self.schedule(request, Instant::now()).await;
return; return;
}; };
if request.allowed_worker_ids.is_some() { if request.bypass_capacity_check() {
self.schedule(request, Instant::now()).await; self.schedule(request, Instant::now()).await;
return; return;
} }
let decay_now = Instant::now(); let decay_now = Instant::now();
if self.all_workers_busy(threshold, request.allowed_worker_ids.as_ref(), decay_now) { if self.all_workers_busy(
threshold,
request.allowed_worker_ids.as_ref(),
request.pinned_worker,
decay_now,
) {
tracing::debug!("all workers busy, queueing request"); tracing::debug!("all workers busy, queueing request");
let arrival_offset = self.start_time.elapsed(); let arrival_offset = self.start_time.elapsed();
let key = self.policy.enqueue_key(arrival_offset, &request); let key = self.policy.enqueue_key(arrival_offset, &request);
...@@ -189,12 +199,24 @@ impl< ...@@ -189,12 +199,24 @@ impl<
loop { loop {
let decay_now = Instant::now(); let decay_now = Instant::now();
if self.all_workers_busy(threshold, None, decay_now) { let mut heap = self.pending.lock().await;
break; let Some(front) = heap.peek() else {
}
let Some(entry) = self.pending.lock().await.pop() else {
break; break;
}; };
// TODO: This preserves head-of-line blocking for now to keep queue
// drain overhead bounded to the heap front. A blocked pinned or
// otherwise constrained request can temporarily stall later
// schedulable entries until we adopt a cheaper non-HOL strategy.
if self.all_workers_busy(
threshold,
front.request.allowed_worker_ids.as_ref(),
front.request.pinned_worker,
decay_now,
) {
break;
}
let entry = heap.pop().expect("heap front vanished before pop");
drop(heap);
self.pending_count.fetch_sub(1, AtomicOrdering::Relaxed); self.pending_count.fetch_sub(1, AtomicOrdering::Relaxed);
self.pending_isl_tokens self.pending_isl_tokens
.fetch_sub(entry.request.isl_tokens, AtomicOrdering::Relaxed); .fetch_sub(entry.request.isl_tokens, AtomicOrdering::Relaxed);
...@@ -318,7 +340,8 @@ impl< ...@@ -318,7 +340,8 @@ impl<
} }
/// Check if all eligible workers are busy based on threshold. /// Check if all eligible workers are busy based on threshold.
/// When `allowed` is `Some`, only those worker IDs are considered; /// When `pinned_worker` is `Some`, only that exact worker/rank is considered.
/// Otherwise when `allowed` is `Some`, only those worker IDs are considered;
/// otherwise all registered workers are checked. /// otherwise all registered workers are checked.
/// Returns false when no eligible workers exist so the request falls /// Returns false when no eligible workers exist so the request falls
/// through to `schedule`, which returns a proper `NoEndpoints` error. /// through to `schedule`, which returns a proper `NoEndpoints` error.
...@@ -326,11 +349,24 @@ impl< ...@@ -326,11 +349,24 @@ impl<
&self, &self,
threshold: f64, threshold: f64,
allowed: Option<&HashSet<WorkerId>>, allowed: Option<&HashSet<WorkerId>>,
pinned_worker: Option<WorkerWithDpRank>,
decay_now: Instant, decay_now: Instant,
) -> bool { ) -> bool {
let active_tokens = self.slots.active_tokens(decay_now); let active_tokens = self.slots.active_tokens(decay_now);
let configs = self.workers_with_configs.borrow(); let configs = self.workers_with_configs.borrow();
if let Some(worker) = pinned_worker {
let Ok(config) = pinned_worker_config::<C>(&*configs, worker) else {
return false;
};
let max_batched = config
.max_num_batched_tokens()
.unwrap_or(DEFAULT_MAX_BATCHED_TOKENS);
let tokens = active_tokens.get(&worker).copied().unwrap_or(0);
return (tokens as f64) > threshold * (max_batched as f64);
}
let mut checked_any = false; let mut checked_any = false;
for (&worker_id, config) in configs.iter() { for (&worker_id, config) in configs.iter() {
if let Some(ids) = allowed if let Some(ids) = allowed
...@@ -367,6 +403,7 @@ mod tests { ...@@ -367,6 +403,7 @@ mod tests {
use super::*; use super::*;
use crate::protocols::OverlapScores; use crate::protocols::OverlapScores;
use crate::scheduling::types::KvSchedulerError;
use crate::selector::DefaultWorkerSelector; use crate::selector::DefaultWorkerSelector;
use crate::sequences::ActiveSequencesMultiWorker; use crate::sequences::ActiveSequencesMultiWorker;
use crate::test_utils::{NoopSequencePublisher, SimpleWorkerConfig}; use crate::test_utils::{NoopSequencePublisher, SimpleWorkerConfig};
...@@ -476,6 +513,7 @@ mod tests { ...@@ -476,6 +513,7 @@ mod tests {
lora_name: None, lora_name: None,
priority_jump: 0.0, priority_jump: 0.0,
expected_output_tokens: None, expected_output_tokens: None,
pinned_worker: None,
allowed_worker_ids: None, allowed_worker_ids: None,
resp_tx: Some(tx), resp_tx: Some(tx),
}; };
...@@ -823,6 +861,7 @@ mod tests { ...@@ -823,6 +861,7 @@ mod tests {
lora_name: None, lora_name: None,
priority_jump: 0.0, priority_jump: 0.0,
expected_output_tokens: None, expected_output_tokens: None,
pinned_worker: None,
allowed_worker_ids: Some(allowed), allowed_worker_ids: Some(allowed),
resp_tx: Some(tx), resp_tx: Some(tx),
}; };
...@@ -842,6 +881,86 @@ mod tests { ...@@ -842,6 +881,86 @@ mod tests {
slots.free(&"filter-0".to_string(), decay_now()).unwrap(); slots.free(&"filter-0".to_string(), decay_now()).unwrap();
} }
#[tokio::test(flavor = "multi_thread")]
async fn test_pinned_worker_conflict_with_allowed_ids_fails_early() {
let (queue, _slots) = make_queue(1, 16, 256, Some(0.0));
let (mut req, rx) = make_request("conflict", 256);
req.pinned_worker = Some(WorkerWithDpRank::new(0, 0));
req.allowed_worker_ids = Some(HashSet::from([1]));
queue.enqueue(req).await;
let resp = rx.await.expect("oneshot dropped");
assert!(matches!(
resp,
Err(KvSchedulerError::PinnedWorkerNotAllowed { worker_id: 0 })
));
}
#[tokio::test(flavor = "multi_thread")]
async fn test_pinned_request_head_of_line_blocks_other_worker_capacity() {
let (queue, slots) = make_queue(2, 16, 256, Some(0.0));
let (mut first, first_rx) = make_request("pinned-1", 256);
first.pinned_worker = Some(WorkerWithDpRank::new(1, 0));
queue.enqueue(first).await;
let first_resp = first_rx.await.unwrap().unwrap();
assert_eq!(first_resp.best_worker, WorkerWithDpRank::new(1, 0));
let (mut second, mut second_rx) = make_request("pinned-2", 256);
second.pinned_worker = Some(WorkerWithDpRank::new(1, 0));
queue.enqueue(second).await;
assert_eq!(queue.pending_count(), 1);
assert!(
second_rx.try_recv().is_err(),
"request should remain queued"
);
let (occupy_other, occupy_other_rx) = make_request("worker-0", 256);
queue.enqueue(occupy_other).await;
let occupy_other_resp = occupy_other_rx.await.unwrap().unwrap();
assert_eq!(occupy_other_resp.best_worker, WorkerWithDpRank::new(0, 0));
let (unpinned, mut unpinned_rx) = make_request("unpinned", 256);
queue.enqueue(unpinned).await;
assert_eq!(queue.pending_count(), 2);
slots
.mark_prefill_completed(&"worker-0".to_string(), decay_now())
.unwrap();
slots.free(&"worker-0".to_string(), decay_now()).unwrap();
queue.update().await;
assert_eq!(queue.pending_count(), 2);
assert!(
unpinned_rx.try_recv().is_err(),
"unpinned request should remain queued behind the pinned head"
);
assert!(
second_rx.try_recv().is_err(),
"pinned request should still be queued"
);
slots
.mark_prefill_completed(&"pinned-1".to_string(), decay_now())
.unwrap();
slots.free(&"pinned-1".to_string(), decay_now()).unwrap();
queue.update().await;
let second_resp = second_rx
.try_recv()
.expect("pinned request should have been scheduled");
let second_resp = second_resp.expect("scheduling returned error");
assert_eq!(second_resp.best_worker, WorkerWithDpRank::new(1, 0));
let unpinned_resp = unpinned_rx
.try_recv()
.expect("unpinned request should have been scheduled");
let unpinned_resp = unpinned_resp.expect("scheduling returned error");
assert_eq!(unpinned_resp.best_worker, WorkerWithDpRank::new(0, 0));
assert_eq!(queue.pending_count(), 0);
}
#[tokio::test(flavor = "multi_thread")] #[tokio::test(flavor = "multi_thread")]
async fn test_queue_busy_check_ignores_untracked_prefill_tokens() { async fn test_queue_busy_check_ignores_untracked_prefill_tokens() {
let (queue, slots) = make_queue(1, 16, 256, Some(0.0)); let (queue, slots) = make_queue(1, 16, 256, Some(0.0));
......
...@@ -6,7 +6,7 @@ use std::collections::HashMap; ...@@ -6,7 +6,7 @@ use std::collections::HashMap;
use rand::Rng; use rand::Rng;
use super::config::KvRouterConfig; use super::config::KvRouterConfig;
use super::types::{KvSchedulerError, SchedulingRequest}; use super::types::{KvSchedulerError, SchedulingRequest, pinned_worker_config};
use crate::protocols::{WorkerConfigLike, WorkerId, WorkerSelectionResult, WorkerWithDpRank}; use crate::protocols::{WorkerConfigLike, WorkerId, WorkerSelectionResult, WorkerWithDpRank};
/// A trait that users can implement to define custom selection logic. /// A trait that users can implement to define custom selection logic.
...@@ -98,6 +98,12 @@ pub struct DefaultWorkerSelector { ...@@ -98,6 +98,12 @@ pub struct DefaultWorkerSelector {
pub worker_type: &'static str, pub worker_type: &'static str,
} }
#[derive(Debug, Clone, Copy)]
struct WorkerScore {
overlap_blocks: u32,
logit: f64,
}
impl DefaultWorkerSelector { impl DefaultWorkerSelector {
pub fn new(kv_router_config: Option<KvRouterConfig>, worker_type: &'static str) -> Self { pub fn new(kv_router_config: Option<KvRouterConfig>, worker_type: &'static str) -> Self {
Self { Self {
...@@ -105,6 +111,44 @@ impl DefaultWorkerSelector { ...@@ -105,6 +111,44 @@ impl DefaultWorkerSelector {
worker_type, worker_type,
} }
} }
fn worker_score(
&self,
request: &SchedulingRequest,
worker: WorkerWithDpRank,
block_size: u32,
overlap_weight: f64,
formula_name: &'static str,
) -> WorkerScore {
let isl = request.isl_tokens;
let overlap_blocks = request.overlaps.scores.get(&worker).copied().unwrap_or(0);
let default_prefill_token = if request.track_prefill_tokens { isl } else { 0 };
let prefill_token = request
.prefill_tokens
.get(&worker)
.copied()
.unwrap_or(default_prefill_token);
let potential_prefill_block = (prefill_token as f64) / (block_size as f64);
let decode_block = request
.decode_blocks
.get(&worker)
.copied()
.unwrap_or(potential_prefill_block.floor() as usize) as f64;
let logit = overlap_weight * potential_prefill_block + decode_block;
tracing::debug!(
"{formula_name} for worker_id={} dp_rank={:?} with {overlap_blocks} cached blocks: {logit:.3} \
= {overlap_weight:.1} * prefill_blocks + decode_blocks \
= {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}",
worker.worker_id,
worker.dp_rank
);
WorkerScore {
overlap_blocks,
logit,
}
}
} }
impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
...@@ -115,12 +159,16 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -115,12 +159,16 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
block_size: u32, block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError> { ) -> Result<WorkerSelectionResult, KvSchedulerError> {
assert!(request.isl_tokens > 0); assert!(request.isl_tokens > 0);
request.validate_worker_constraints()?;
let allowed_ids = request.allowed_worker_ids.as_ref(); let allowed_ids = request.allowed_worker_ids.as_ref();
let pinned_worker = request.pinned_worker;
if allowed_ids.map_or(workers.is_empty(), |ids| { if pinned_worker.is_none()
&& allowed_ids.map_or(workers.is_empty(), |ids| {
!workers.keys().any(|wid| ids.contains(wid)) !workers.keys().any(|wid| ids.contains(wid))
}) { })
{
return Err(KvSchedulerError::NoEndpoints); return Err(KvSchedulerError::NoEndpoints);
} }
...@@ -128,8 +176,28 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -128,8 +176,28 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
let request_blocks = isl.div_ceil(block_size as usize); let request_blocks = isl.div_ceil(block_size as usize);
let overlaps = &request.overlaps.scores; let overlaps = &request.overlaps.scores;
let decode_blocks = &request.decode_blocks; if let Some(worker) = pinned_worker {
let prefill_tokens = &request.prefill_tokens; pinned_worker_config(workers, worker)?;
let overlap_weight = request
.router_config_override
.as_ref()
.and_then(|cfg| cfg.overlap_score_weight)
.unwrap_or(self.kv_router_config.overlap_score_weight);
let score = self.worker_score(
request,
worker,
block_size,
overlap_weight,
"Pinned formula",
);
return Ok(WorkerSelectionResult {
worker,
required_blocks: request_blocks as u64,
overlap_blocks: score.overlap_blocks,
});
}
let overlap_weight = request let overlap_weight = request
.router_config_override .router_config_override
...@@ -144,30 +212,8 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -144,30 +212,8 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
.unwrap_or(self.kv_router_config.router_temperature); .unwrap_or(self.kv_router_config.router_temperature);
let get_score = |worker: WorkerWithDpRank| -> f64 { let get_score = |worker: WorkerWithDpRank| -> f64 {
let overlap = *overlaps.get(&worker).unwrap_or(&0); self.worker_score(request, worker, block_size, overlap_weight, "Formula")
// Use 0 for unregistered decode workers (track_prefill_tokens=false) .logit
// to match registered idle workers; use isl otherwise.
let default_prefill_token = if request.track_prefill_tokens { isl } else { 0 };
let prefill_token = *prefill_tokens
.get(&worker)
.unwrap_or(&default_prefill_token);
let potential_prefill_block = (prefill_token as f64) / (block_size as f64);
let decode_block = *decode_blocks
.get(&worker)
.unwrap_or(&(potential_prefill_block.floor() as usize))
as f64;
let logit = overlap_weight * potential_prefill_block + decode_block;
tracing::debug!(
"Formula for worker_id={} dp_rank={:?} with {overlap} cached blocks: {logit:.3} \
= {overlap_weight:.1} * prefill_blocks + decode_blocks \
= {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}",
worker.worker_id,
worker.dp_rank
);
logit
}; };
let worker_iter = workers let worker_iter = workers
......
...@@ -7,7 +7,7 @@ use dynamo_tokens::SequenceHash; ...@@ -7,7 +7,7 @@ use dynamo_tokens::SequenceHash;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::config::RouterConfigOverride; use super::config::RouterConfigOverride;
use crate::protocols::{DpRank, OverlapScores, WorkerId, WorkerWithDpRank}; use crate::protocols::{DpRank, OverlapScores, WorkerConfigLike, WorkerId, WorkerWithDpRank};
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PotentialLoad { pub struct PotentialLoad {
...@@ -22,6 +22,9 @@ pub enum KvSchedulerError { ...@@ -22,6 +22,9 @@ pub enum KvSchedulerError {
#[error("no endpoints available to route work")] #[error("no endpoints available to route work")]
NoEndpoints, NoEndpoints,
#[error("pinned worker {worker_id} is not in allowed worker set")]
PinnedWorkerNotAllowed { worker_id: WorkerId },
#[error("endpoint subscriber shutdown")] #[error("endpoint subscriber shutdown")]
SubscriberShutdown, SubscriberShutdown,
...@@ -51,12 +54,44 @@ pub struct SchedulingRequest { ...@@ -51,12 +54,44 @@ pub struct SchedulingRequest {
/// Expected output tokens from agent_hints.osl, forwarded to the slot tracker /// Expected output tokens from agent_hints.osl, forwarded to the slot tracker
/// for output block decay estimation. /// for output block decay estimation.
pub expected_output_tokens: Option<u32>, pub expected_output_tokens: Option<u32>,
/// Exact worker/rank pin used by scheduler queueing, WSPT, and selection.
pub pinned_worker: Option<WorkerWithDpRank>,
/// Optional set of allowed worker IDs to restrict routing decisions (EPP). /// Optional set of allowed worker IDs to restrict routing decisions (EPP).
pub allowed_worker_ids: Option<HashSet<WorkerId>>, pub allowed_worker_ids: Option<HashSet<WorkerId>>,
pub resp_tx: Option<tokio::sync::oneshot::Sender<Result<SchedulingResponse, KvSchedulerError>>>, pub resp_tx: Option<tokio::sync::oneshot::Sender<Result<SchedulingResponse, KvSchedulerError>>>,
} }
impl SchedulingRequest { impl SchedulingRequest {
pub fn validate_worker_constraints(&self) -> Result<(), KvSchedulerError> {
let Some(pinned_worker) = self.pinned_worker else {
return Ok(());
};
let Some(allowed_worker_ids) = self.allowed_worker_ids.as_ref() else {
return Ok(());
};
if allowed_worker_ids.contains(&pinned_worker.worker_id) {
return Ok(());
}
Err(KvSchedulerError::PinnedWorkerNotAllowed {
worker_id: pinned_worker.worker_id,
})
}
/// Scheduling consumers use the exact pinned-worker overlap when present;
/// otherwise they use the best available overlap across eligible workers.
pub fn overlap_blocks(&self) -> u32 {
if let Some(worker) = self.pinned_worker {
return self.overlaps.scores.get(&worker).copied().unwrap_or(0);
}
self.overlaps.scores.values().copied().max().unwrap_or(0)
}
pub fn bypass_capacity_check(&self) -> bool {
self.pinned_worker.is_none() && self.allowed_worker_ids.is_some()
}
pub fn respond(&mut self, result: Result<SchedulingResponse, KvSchedulerError>) { pub fn respond(&mut self, result: Result<SchedulingResponse, KvSchedulerError>) {
let Some(tx) = self.resp_tx.take() else { let Some(tx) = self.resp_tx.take() else {
tracing::error!("respond called multiple times on same request"); tracing::error!("respond called multiple times on same request");
...@@ -67,3 +102,19 @@ impl SchedulingRequest { ...@@ -67,3 +102,19 @@ impl SchedulingRequest {
} }
} }
} }
pub fn pinned_worker_config<C: WorkerConfigLike>(
workers: &HashMap<WorkerId, C>,
worker: WorkerWithDpRank,
) -> Result<&C, KvSchedulerError> {
let Some(config) = workers.get(&worker.worker_id) else {
return Err(KvSchedulerError::NoEndpoints);
};
let dp_start_rank = config.data_parallel_start_rank();
let dp_end_rank = dp_start_rank + config.data_parallel_size();
if !(dp_start_rank..dp_end_rank).contains(&worker.dp_rank) {
return Err(KvSchedulerError::NoEndpoints);
}
Ok(config)
}
...@@ -283,6 +283,9 @@ where ...@@ -283,6 +283,9 @@ where
/// Returns the best worker (with dp_rank) and overlap amount in number of blocks. /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
/// Now also takes optional context_id for request tracking. /// Now also takes optional context_id for request tracking.
/// ///
/// When `pinned_worker` is Some, scheduling and queueing are constrained to
/// that exact worker/rank.
///
/// When `allowed_worker_ids` is Some, only workers in that set are considered for selection. /// When `allowed_worker_ids` is Some, only workers in that set are considered for selection.
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub async fn find_best_match( pub async fn find_best_match(
...@@ -295,6 +298,7 @@ where ...@@ -295,6 +298,7 @@ where
lora_name: Option<String>, lora_name: Option<String>,
priority_jump: f64, priority_jump: f64,
expected_output_tokens: Option<u32>, expected_output_tokens: Option<u32>,
pinned_worker: Option<WorkerWithDpRank>,
allowed_worker_ids: Option<HashSet<WorkerId>>, allowed_worker_ids: Option<HashSet<WorkerId>>,
) -> anyhow::Result<(WorkerWithDpRank, u32)> { ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
let start = Instant::now(); let start = Instant::now();
...@@ -345,6 +349,7 @@ where ...@@ -345,6 +349,7 @@ where
lora_name, lora_name,
priority_jump, priority_jump,
expected_output_tokens, expected_output_tokens,
pinned_worker,
allowed_worker_ids, allowed_worker_ids,
) )
.instrument(tracing::info_span!("kv_router.schedule")) .instrument(tracing::info_span!("kv_router.schedule"))
...@@ -600,6 +605,7 @@ where ...@@ -600,6 +605,7 @@ where
0.0, 0.0,
None, None,
None, None,
None,
) )
.await?; .await?;
......
...@@ -287,6 +287,7 @@ impl PrefillRouter { ...@@ -287,6 +287,7 @@ impl PrefillRouter {
lora_name, lora_name,
priority_jump, priority_jump,
None, None,
None,
allowed_worker_ids, allowed_worker_ids,
) )
.await?; .await?;
......
...@@ -27,6 +27,7 @@ use crate::{ ...@@ -27,6 +27,7 @@ use crate::{
preprocessor::PreprocessedRequest, preprocessor::PreprocessedRequest,
protocols::common::{ protocols::common::{
llm_backend::LLMEngineOutput, llm_backend::LLMEngineOutput,
preprocessor::RoutingHints,
timing::{RequestPhase, RequestTracker}, timing::{RequestPhase, RequestTracker},
}, },
}; };
...@@ -48,6 +49,23 @@ struct WorkerSelection { ...@@ -48,6 +49,23 @@ struct WorkerSelection {
overlap_amount: Option<u32>, overlap_amount: Option<u32>,
} }
fn pinned_worker_hint(
phase: RequestPhase,
routing: Option<&RoutingHints>,
) -> Option<(u64, Option<u32>)> {
let routing = routing?;
let worker_id = match phase {
RequestPhase::Prefill => routing.prefill_worker_id.or(routing.backend_instance_id),
RequestPhase::Decode => routing.decode_worker_id.or(routing.backend_instance_id),
RequestPhase::Aggregated => routing.backend_instance_id,
}?;
let dp_rank = match phase {
RequestPhase::Prefill => routing.prefill_dp_rank.or(routing.dp_rank),
RequestPhase::Decode | RequestPhase::Aggregated => routing.dp_rank,
};
Some((worker_id, dp_rank))
}
/// Drop guard that manages the full lifecycle of a routed request: /// Drop guard that manages the full lifecycle of a routed request:
/// per-item tracking (prefill, first token, output blocks) and final cleanup (free + metrics). /// per-item tracking (prefill, first token, output blocks) and final cleanup (free + metrics).
/// ///
...@@ -259,9 +277,8 @@ impl KvPushRouter { ...@@ -259,9 +277,8 @@ impl KvPushRouter {
} }
} }
/// Select a worker for the request, either using a preselected worker or finding the best match. /// Select a worker for the request, either using an exact phase-specific pin
/// /// or by finding the best KV overlap match.
/// When `is_query_only` is false, this also registers the request with the scheduler via `add_request`.
async fn select_worker( async fn select_worker(
&self, &self,
context_id: &str, context_id: &str,
...@@ -276,23 +293,7 @@ impl KvPushRouter { ...@@ -276,23 +293,7 @@ impl KvPushRouter {
let expected_output_tokens = routing.and_then(|r| r.expected_output_tokens); let expected_output_tokens = routing.and_then(|r| r.expected_output_tokens);
let allowed_worker_ids = routing.and_then(|r| r.allowed_worker_ids.clone()); let allowed_worker_ids = routing.and_then(|r| r.allowed_worker_ids.clone());
let (routing_token_ids, block_mm_infos) = request.block_mm_routing_info(); let (routing_token_ids, block_mm_infos) = request.block_mm_routing_info();
let Some((pinned_worker_id, requested_dp_rank)) = pinned_worker_hint(phase, routing) else {
// Get pre-selected worker based on phase, with backend_instance_id as fallback
let preselected_id = match phase {
RequestPhase::Prefill => {
routing.and_then(|r| r.prefill_worker_id.or(r.backend_instance_id))
}
RequestPhase::Decode => {
routing.and_then(|r| r.decode_worker_id.or(r.backend_instance_id))
}
RequestPhase::Aggregated => routing.and_then(|r| r.backend_instance_id),
};
let requested_dp_rank = match phase {
RequestPhase::Prefill => routing.and_then(|r| r.prefill_dp_rank.or(r.dp_rank)),
RequestPhase::Decode | RequestPhase::Aggregated => routing.and_then(|r| r.dp_rank),
};
let Some(id) = preselected_id else {
let _nvtx_kv = dynamo_nvtx_range!("route.kv_match"); let _nvtx_kv = dynamo_nvtx_range!("route.kv_match");
let (best_worker, overlap_amount) = self let (best_worker, overlap_amount) = self
.chooser .chooser
...@@ -305,6 +306,7 @@ impl KvPushRouter { ...@@ -305,6 +306,7 @@ impl KvPushRouter {
lora_name, lora_name,
priority_jump, priority_jump,
expected_output_tokens, expected_output_tokens,
None,
allowed_worker_ids, allowed_worker_ids,
) )
.await?; .await?;
...@@ -338,18 +340,46 @@ impl KvPushRouter { ...@@ -338,18 +340,46 @@ impl KvPushRouter {
}); });
}; };
let backend_dp_rank = let resolved_pinned_worker = requested_dp_rank
requested_dp_rank.or_else(|| self.chooser.unique_dp_rank_for_worker(id)); .or_else(|| self.chooser.unique_dp_rank_for_worker(pinned_worker_id))
.map(|dp_rank| WorkerWithDpRank::new(pinned_worker_id, dp_rank));
if !is_query_only && let Some(pinned_worker) = resolved_pinned_worker {
let (best_worker, overlap_amount) = self
.chooser
.find_best_match(
Some(context_id),
routing_token_ids,
block_mm_infos,
request.router_config_override.as_ref(),
true,
lora_name.clone(),
priority_jump,
expected_output_tokens,
Some(pinned_worker),
allowed_worker_ids,
)
.await?;
return Ok(WorkerSelection {
instance_id: best_worker.worker_id,
backend_dp_rank: Some(best_worker.dp_rank),
bookkeeping_dp_rank: Some(best_worker.dp_rank),
overlap_amount: Some(overlap_amount),
});
}
let backend_dp_rank = resolved_pinned_worker.map(|worker| worker.dp_rank);
tracing::debug!( tracing::debug!(
worker_id = id, worker_id = pinned_worker_id,
dp_rank = ?backend_dp_rank, dp_rank = ?backend_dp_rank,
?phase, ?phase,
"Routing to specified worker" "Routing to specified worker"
); );
let (bookkeeping_dp_rank, overlap_amount) = if let Some(dp_rank) = backend_dp_rank { let (bookkeeping_dp_rank, overlap_amount) = if let Some(dp_rank) = backend_dp_rank {
let worker = WorkerWithDpRank::new(id, dp_rank); let worker = WorkerWithDpRank::new(pinned_worker_id, dp_rank);
let overlap_blocks = self let overlap_blocks = self
.chooser .chooser
.get_overlap_blocks( .get_overlap_blocks(
...@@ -376,7 +406,7 @@ impl KvPushRouter { ...@@ -376,7 +406,7 @@ impl KvPushRouter {
} else { } else {
tracing::debug!( tracing::debug!(
request_id = %context_id, request_id = %context_id,
worker_id = id, worker_id = pinned_worker_id,
dp_rank = dp_rank, dp_rank = dp_rank,
"Skipping add_request - query-only request" "Skipping add_request - query-only request"
); );
...@@ -386,7 +416,7 @@ impl KvPushRouter { ...@@ -386,7 +416,7 @@ impl KvPushRouter {
} else { } else {
tracing::debug!( tracing::debug!(
request_id = %context_id, request_id = %context_id,
worker_id = id, worker_id = pinned_worker_id,
?phase, ?phase,
"Routing to specified worker without resolved dp_rank; skipping scheduler bookkeeping" "Routing to specified worker without resolved dp_rank; skipping scheduler bookkeeping"
); );
...@@ -394,7 +424,7 @@ impl KvPushRouter { ...@@ -394,7 +424,7 @@ impl KvPushRouter {
}; };
Ok(WorkerSelection { Ok(WorkerSelection {
instance_id: id, instance_id: pinned_worker_id,
backend_dp_rank, backend_dp_rank,
bookkeeping_dp_rank, bookkeeping_dp_rank,
overlap_amount, overlap_amount,
...@@ -413,10 +443,10 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -413,10 +443,10 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
/// - Does NOT update any router local states /// - Does NOT update any router local states
/// - Response includes worker_instance_id and token_data annotations /// - Response includes worker_instance_id and token_data annotations
/// ///
/// 2. **If `backend_instance_id` is set in the request**: /// 2. **If a phase-specific worker or `backend_instance_id` is set in the request**:
/// - Routes directly to the specified backend instance /// - Query-only requests return that worker selection without state updates
/// - DOES update router states to track this request (unless query_instance_id is also set) /// - Execution requests route through the scheduler as an exact pin when dp_rank is resolved
/// - Bypasses the normal KV matching logic /// - If dp_rank cannot be resolved, falls back to direct routing without scheduler bookkeeping
/// ///
/// 3. **If neither are set (default behavior)**: /// 3. **If neither are set (default behavior)**:
/// - Finds the best worker based on KV cache overlap /// - Finds the best worker based on KV cache overlap
...@@ -694,3 +724,48 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -694,3 +724,48 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
self.inner.direct(request, worker_id).await self.inner.direct(request, worker_id).await
} }
} }
#[cfg(test)]
mod tests {
use super::pinned_worker_hint;
use crate::protocols::common::{preprocessor::RoutingHints, timing::RequestPhase};
#[test]
fn pinned_worker_hint_prefill_uses_prefill_worker_before_backend() {
let routing = RoutingHints {
backend_instance_id: Some(1),
prefill_worker_id: Some(2),
dp_rank: Some(3),
prefill_dp_rank: Some(4),
..Default::default()
};
let hint = pinned_worker_hint(RequestPhase::Prefill, Some(&routing));
assert_eq!(hint, Some((2, Some(4))));
}
#[test]
fn pinned_worker_hint_decode_uses_decode_worker_before_backend() {
let routing = RoutingHints {
backend_instance_id: Some(1),
decode_worker_id: Some(5),
dp_rank: Some(6),
..Default::default()
};
let hint = pinned_worker_hint(RequestPhase::Decode, Some(&routing));
assert_eq!(hint, Some((5, Some(6))));
}
#[test]
fn pinned_worker_hint_aggregated_uses_backend_worker() {
let routing = RoutingHints {
backend_instance_id: Some(9),
dp_rank: Some(7),
..Default::default()
};
let hint = pinned_worker_hint(RequestPhase::Aggregated, Some(&routing));
assert_eq!(hint, Some((9, Some(7))));
}
}
...@@ -18,7 +18,7 @@ use anyhow::Result; ...@@ -18,7 +18,7 @@ use anyhow::Result;
use dynamo_kv_router::{ use dynamo_kv_router::{
PrefillLoadEstimator, PrefillLoadEstimator,
config::{KvRouterConfig, RouterConfigOverride}, config::{KvRouterConfig, RouterConfigOverride},
protocols::{OverlapScores, WorkerId}, protocols::{OverlapScores, WorkerId, WorkerWithDpRank},
}; };
use dynamo_runtime::component::Component; use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::traits::DistributedRuntimeProvider;
...@@ -136,6 +136,7 @@ where ...@@ -136,6 +136,7 @@ where
lora_name: Option<String>, lora_name: Option<String>,
priority_jump: f64, priority_jump: f64,
expected_output_tokens: Option<u32>, expected_output_tokens: Option<u32>,
pinned_worker: Option<WorkerWithDpRank>,
allowed_worker_ids: Option<HashSet<WorkerId>>, allowed_worker_ids: Option<HashSet<WorkerId>>,
) -> Result<SchedulingResponse, KvSchedulerError> { ) -> Result<SchedulingResponse, KvSchedulerError> {
let response = self let response = self
...@@ -150,6 +151,7 @@ where ...@@ -150,6 +151,7 @@ where
lora_name, lora_name,
priority_jump, priority_jump,
expected_output_tokens, expected_output_tokens,
pinned_worker,
allowed_worker_ids, allowed_worker_ids,
) )
.await; .await;
......
...@@ -140,6 +140,7 @@ impl PendingRequest { ...@@ -140,6 +140,7 @@ impl PendingRequest {
lora_name: None, lora_name: None,
priority_jump: 0.0, priority_jump: 0.0,
expected_output_tokens: self.expected_output_tokens, expected_output_tokens: self.expected_output_tokens,
pinned_worker: None,
allowed_worker_ids: None, allowed_worker_ids: None,
resp_tx: None, resp_tx: None,
} }
......
...@@ -210,6 +210,7 @@ impl KvReplayRouter { ...@@ -210,6 +210,7 @@ impl KvReplayRouter {
.context("max_output_tokens does not fit into u32")?, .context("max_output_tokens does not fit into u32")?,
), ),
None, None,
None,
) )
.await?; .await?;
usize::try_from(response.best_worker.worker_id) usize::try_from(response.best_worker.worker_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment