// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use std::cmp::Ordering; use std::collections::{BinaryHeap, HashMap, HashSet}; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering}; use std::time::Instant; use tokio::sync::Mutex; use tokio::sync::watch; use super::policy::{FcfsPolicy, SchedulingPolicy}; use super::selector::{DefaultWorkerSelector, WorkerSelector}; use super::types::{SchedulingRequest, SchedulingResponse}; use crate::protocols::{WorkerConfigLike, WorkerId, WorkerWithDpRank}; use crate::sequences::{ActiveSequencesMultiWorker, SequencePublisher, SequenceRequest}; /// Large default for max_num_batched_tokens when not configured (effectively disables queueing for that worker) pub const DEFAULT_MAX_BATCHED_TOKENS: u64 = 10_000_000; /// Entry in the priority queue, ordered by key (higher key = higher priority). struct QueueEntry { key: K, request: SchedulingRequest, } impl Eq for QueueEntry {} impl PartialEq for QueueEntry { fn eq(&self, other: &Self) -> bool { self.key == other.key } } impl Ord for QueueEntry { fn cmp(&self, other: &Self) -> Ordering { self.key.cmp(&other.key) } } impl PartialOrd for QueueEntry { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } /// Queue that gates scheduling requests behind a capacity check. /// When all workers exceed `threshold_frac` utilisation the request is parked in `pending`. /// When capacity frees up (`update()`), pending requests are scheduled in priority order. /// If queueing is disabled (threshold_frac is None), requests are scheduled immediately. pub struct SchedulerQueue< P: SequencePublisher, C: WorkerConfigLike, S: SchedulingPolicy = FcfsPolicy, Sel: WorkerSelector = DefaultWorkerSelector, > { pending: Mutex>>, /// Number of requests currently parked in the pending queue. /// Incremented after push, decremented after pop. Lock-free reads via `Relaxed` load. pending_count: AtomicUsize, slots: Arc>, workers_with_configs: watch::Receiver>, /// Cached threshold fraction; None means queueing is disabled. threshold_frac: Option, /// Reference instant for computing arrival offsets. start_time: Instant, block_size: u32, selector: Sel, policy: S, } impl< P: SequencePublisher + 'static, C: WorkerConfigLike, S: SchedulingPolicy, Sel: WorkerSelector, > SchedulerQueue { pub fn new( slots: Arc>, workers_with_configs: watch::Receiver>, threshold_frac: Option, block_size: u32, selector: Sel, policy: S, ) -> Self { if let Some(frac) = threshold_frac { tracing::info!("Router queue enabled with threshold fraction {frac}"); } Self { pending: Mutex::new(BinaryHeap::new()), pending_count: AtomicUsize::new(0), slots, workers_with_configs, threshold_frac, start_time: Instant::now(), block_size, selector, policy, } } /// Register externally-provided workers in the slot tracker. /// /// Looks up DP rank/size from the discovery watch channel; defaults to /// `(0, 1)` for workers not yet known to discovery. pub fn register_workers(&self, worker_ids: &std::collections::HashSet) { let discovery_workers = self.workers_with_configs.borrow(); let dp_range: std::collections::HashMap = worker_ids .iter() .map(|&id| { let (dp_start, dp_size) = discovery_workers .get(&id) .map(|runtime_config| { ( runtime_config.data_parallel_start_rank(), runtime_config.data_parallel_size(), ) }) .unwrap_or((0, 1)); (id, (dp_start, dp_size)) }) .collect(); self.slots.register_external_workers(&dp_range); } /// Enqueue a new request. /// If queueing is disabled or workers have capacity, schedule immediately. /// Otherwise park in the pending heap. /// /// When `allowed_worker_ids` is set on the request (external routing), the /// capacity check is skipped. pub async fn enqueue(&self, request: SchedulingRequest) { let Some(threshold) = self.threshold_frac else { self.schedule(request).await; return; }; if request.allowed_worker_ids.is_some() { self.schedule(request).await; return; } if self.all_workers_busy(threshold, request.allowed_worker_ids.as_ref()) { tracing::debug!("all workers busy, queueing request"); let arrival_offset = self.start_time.elapsed(); let key = self.policy.enqueue_key(arrival_offset, &request); self.pending.lock().await.push(QueueEntry { key, request }); self.pending_count.fetch_add(1, AtomicOrdering::Relaxed); } else { self.schedule(request).await; } } /// Called on prefill_complete/free. Drains pending requests while workers have capacity. /// Each scheduled request updates active_tokens via add_request, so the busy check /// sees fresh state on the next iteration. pub async fn update(&self) { let Some(threshold) = self.threshold_frac else { return; }; if S::DYNAMIC { let now = self.start_time.elapsed(); let mut heap = self.pending.lock().await; let rekeyed: Vec<_> = std::mem::take(&mut *heap) .into_vec() .into_iter() .map(|e| QueueEntry { key: self.policy.rekey(now, &e.key, &e.request), request: e.request, }) .collect(); *heap = BinaryHeap::from(rekeyed); } loop { if self.all_workers_busy(threshold, None) { break; } let Some(entry) = self.pending.lock().await.pop() else { break; }; self.pending_count.fetch_sub(1, AtomicOrdering::Relaxed); tracing::debug!("scheduling request from pending queue"); self.schedule(entry.request).await; } } /// Run the full scheduling pipeline for a single request: /// compute potential load -> select worker -> respond -> book via add_request. async fn schedule(&self, mut request: SchedulingRequest) { let (decode_blocks, prefill_tokens) = self.slots.potential_blocks_and_tokens( request.token_seq.as_deref(), request.isl_tokens, request.overlaps.clone(), ); request.decode_blocks = decode_blocks; request.prefill_tokens = prefill_tokens; let selection = { let workers = self.workers_with_configs.borrow(); self.selector .select_worker(&workers, &request, self.block_size) }; let selection = match selection { Ok(s) => s, Err(e) => { tracing::warn!("scheduling failed: {e}"); request.respond(Err(e)); return; } }; request.respond(Ok(SchedulingResponse { best_worker: selection.worker, overlap_blocks: selection.overlap_blocks, })); if !request.update_states { return; } let Some(request_id) = request.maybe_request_id else { tracing::error!("No request_id provided to add_request to the slot tracker"); return; }; if let Err(e) = self .slots .add_request(SequenceRequest { request_id: request_id.clone(), token_sequence: request.token_seq, isl: request.isl_tokens, overlap: selection.overlap_blocks, expected_output_tokens: request.expected_output_tokens, worker: selection.worker, lora_name: request.lora_name.clone(), }) .await { tracing::warn!("Failed to add request {request_id}: {e}"); } } /// Number of requests currently parked in the pending queue (lock-free). pub fn pending_count(&self) -> usize { self.pending_count.load(AtomicOrdering::Relaxed) } /// Check if all eligible workers are busy based on threshold. /// When `allowed` is `Some`, only those worker IDs are considered; /// otherwise all registered workers are checked. /// Returns false when no eligible workers exist so the request falls /// through to `schedule`, which returns a proper `NoEndpoints` error. fn all_workers_busy(&self, threshold: f64, allowed: Option<&HashSet>) -> bool { let active_tokens = self.slots.active_tokens(); let configs = self.workers_with_configs.borrow(); let mut checked_any = false; for (&worker_id, config) in configs.iter() { if let Some(ids) = allowed && !ids.contains(&worker_id) { continue; } let dp_size = config.data_parallel_size(); let dp_start_rank = config.data_parallel_start_rank(); let max_batched = config .max_num_batched_tokens() .unwrap_or(DEFAULT_MAX_BATCHED_TOKENS); for dp_rank in dp_start_rank..dp_start_rank + dp_size { checked_any = true; let worker = WorkerWithDpRank::new(worker_id, dp_rank); let tokens = active_tokens.get(&worker).copied().unwrap_or(0); if (tokens as f64) <= threshold * (max_batched as f64) { return false; } } } checked_any } } #[cfg(test)] mod tests { use std::collections::HashMap; use std::sync::Arc; use tokio::sync::watch; use super::*; use crate::protocols::OverlapScores; use crate::selector::DefaultWorkerSelector; use crate::sequences::ActiveSequencesMultiWorker; use crate::test_utils::{NoopSequencePublisher, SimpleWorkerConfig}; fn make_queue( num_workers: usize, block_size: u32, isl: usize, threshold_frac: Option, ) -> ( Arc>, Arc>, ) { let (queue, slots, _tx) = make_queue_with_sender(num_workers, block_size, isl, threshold_frac); (queue, slots) } #[allow(clippy::type_complexity)] fn make_queue_with_sender( num_workers: usize, block_size: u32, isl: usize, threshold_frac: Option, ) -> ( Arc>, Arc>, watch::Sender>, ) { let dp_range: HashMap = (0..num_workers as u64).map(|id| (id, (0, 1))).collect(); let slots = Arc::new(ActiveSequencesMultiWorker::new( NoopSequencePublisher, block_size as usize, dp_range, false, 0, "test", )); let mut configs: HashMap = HashMap::new(); for id in 0..num_workers as u64 { configs.insert( id, SimpleWorkerConfig { max_num_batched_tokens: Some(isl as u64), ..Default::default() }, ); } let (cfg_tx, cfg_rx) = watch::channel(configs); let selector = DefaultWorkerSelector::new(None, "test"); let queue = Arc::new(SchedulerQueue::new( Arc::clone(&slots), cfg_rx, threshold_frac, block_size, selector, FcfsPolicy, )); (queue, slots, cfg_tx) } fn make_request( request_id: &str, isl_tokens: usize, ) -> ( SchedulingRequest, tokio::sync::oneshot::Receiver< Result, >, ) { let (tx, rx) = tokio::sync::oneshot::channel(); let req = SchedulingRequest { maybe_request_id: Some(request_id.to_string()), token_seq: None, isl_tokens, overlaps: OverlapScores::default(), decode_blocks: HashMap::new(), prefill_tokens: HashMap::new(), router_config_override: None, update_states: true, lora_name: None, priority_jump: 0.0, expected_output_tokens: None, allowed_worker_ids: None, resp_tx: Some(tx), }; (req, rx) } #[tokio::test(flavor = "multi_thread")] async fn test_concurrent_flood() { let block_size = 16; let isl = 512; let num_workers = 4; let num_tasks = 25; let (queue, slots) = make_queue(num_workers, block_size, isl, None); let mut handles = Vec::new(); for i in 0..num_tasks { let queue = Arc::clone(&queue); let slots = Arc::clone(&slots); handles.push(tokio::spawn(async move { let req_id = format!("req-{i}"); let (req, rx) = make_request(&req_id, isl); queue.enqueue(req).await; let resp = rx.await.expect("oneshot dropped"); let resp = resp.expect("scheduling failed"); assert!(resp.best_worker.worker_id < num_workers as u64); slots.mark_prefill_completed(&req_id).await.unwrap(); slots.free(&req_id).await.unwrap(); queue.update().await; })); } for h in handles { h.await.expect("task panicked"); } let active = slots.active_tokens(); for (worker, tokens) in &active { assert_eq!( *tokens, 0, "worker {worker:?} still has {tokens} active tokens" ); } } #[tokio::test(flavor = "multi_thread")] async fn test_queueing_under_pressure() { let block_size = 16; let isl = 512; let num_workers = 2; let num_requests = 10; let (queue, slots) = make_queue(num_workers, block_size, isl, Some(0.0)); let mut receivers = Vec::new(); let mut req_ids = Vec::new(); for i in 0..num_requests { let req_id = format!("pressure-{i}"); let (req, rx) = make_request(&req_id, isl); queue.enqueue(req).await; receivers.push(rx); req_ids.push(req_id); } // Drain pending by cycling mark_prefill_completed + free + update // on already-scheduled requests until all receivers have a response. for _ in 0..num_requests { queue.update().await; for rid in &req_ids { let _ = slots.mark_prefill_completed(rid).await; let _ = slots.free(rid).await; } } queue.update().await; let mut ok_count = 0; for mut rx in receivers { if let Ok(result) = rx.try_recv() { result.expect("scheduling returned error"); ok_count += 1; } } assert_eq!(ok_count, num_requests, "not all requests were scheduled"); } #[tokio::test(flavor = "multi_thread")] async fn test_pending_count() { let block_size = 16; let isl = 512; let num_workers = 1; // threshold_frac=0.0 means any active tokens trigger queueing let (queue, slots) = make_queue(num_workers, block_size, isl, Some(0.0)); assert_eq!(queue.pending_count(), 0); // First request goes through (worker is idle) let (req1, rx1) = make_request("req-1", isl); queue.enqueue(req1).await; let _resp1 = rx1.await.unwrap().unwrap(); assert_eq!(queue.pending_count(), 0); // scheduled immediately // Second and third requests should be queued (worker is now busy) let (req2, _rx2) = make_request("req-2", isl); queue.enqueue(req2).await; assert_eq!(queue.pending_count(), 1); let (req3, _rx3) = make_request("req-3", isl); queue.enqueue(req3).await; assert_eq!(queue.pending_count(), 2); // Free the first request and update — should drain one from pending slots .mark_prefill_completed(&"req-1".to_string()) .await .unwrap(); slots.free(&"req-1".to_string()).await.unwrap(); queue.update().await; // After update, one pending request should have been scheduled assert!( queue.pending_count() < 2, "pending_count should decrease after free+update, got {}", queue.pending_count() ); // Free req-2 and update to drain remaining let _ = slots.mark_prefill_completed(&"req-2".to_string()).await; let _ = slots.free(&"req-2".to_string()).await; queue.update().await; let _ = slots.mark_prefill_completed(&"req-3".to_string()).await; let _ = slots.free(&"req-3".to_string()).await; queue.update().await; assert_eq!(queue.pending_count(), 0, "all requests should be drained"); } #[tokio::test] async fn test_no_workers_returns_error() { let (queue, _slots) = make_queue(0, 16, 512, None); let (req, rx) = make_request("lonely-req", 512); queue.enqueue(req).await; let resp = rx.await.expect("oneshot dropped"); assert!( matches!( resp, Err(crate::scheduling::types::KvSchedulerError::NoEndpoints) ), "expected NoEndpoints, got {resp:?}" ); } /// Simulates the EPP path: router starts with zero workers (skip_initial_worker_wait), /// then register_workers lazily injects workers before routing. #[tokio::test(flavor = "multi_thread")] async fn test_register_workers_lazy_epp_path() { let block_size = 16; let isl = 512; // Start with zero workers (mimics skip_initial_worker_wait=true) let (queue, slots, cfg_tx) = make_queue_with_sender(0, block_size, isl, None); // Routing with no workers must fail let (req_fail, rx_fail) = make_request("before-register", isl); queue.enqueue(req_fail).await; let resp = rx_fail.await.expect("oneshot dropped"); assert!( matches!( resp, Err(crate::scheduling::types::KvSchedulerError::NoEndpoints) ), "expected NoEndpoints before register_workers, got {resp:?}" ); // Lazily register two workers in the slot tracker (EPP supplies pod list) let mut dp_range = std::collections::HashMap::new(); dp_range.insert(100_u64, (0_u32, 1_u32)); dp_range.insert(200_u64, (0_u32, 1_u32)); slots.register_external_workers(&dp_range); // Also update the config watch so the selector can see these workers let mut configs = HashMap::new(); for &id in &[100_u64, 200_u64] { configs.insert( id, SimpleWorkerConfig { max_num_batched_tokens: Some(isl as u64), ..Default::default() }, ); } cfg_tx.send(configs).unwrap(); // Routing after registration must succeed and pick one of the registered workers let (req_ok, rx_ok) = make_request("after-register", isl); queue.enqueue(req_ok).await; let resp = rx_ok .await .expect("oneshot dropped") .expect("scheduling failed"); assert!( resp.best_worker.worker_id == 100 || resp.best_worker.worker_id == 200, "expected worker 100 or 200, got {}", resp.best_worker.worker_id ); // Clean up slots .mark_prefill_completed(&"after-register".to_string()) .await .unwrap(); slots.free(&"after-register".to_string()).await.unwrap(); } /// Register_workers is additive: calling with a new set does NOT remove old workers. #[tokio::test(flavor = "multi_thread")] async fn test_register_workers_additive() { let block_size = 16; let isl = 256; let (queue, slots, cfg_tx) = make_queue_with_sender(0, block_size, isl, None); // Register worker 10 in slots and config let mut dp1 = std::collections::HashMap::new(); dp1.insert(10_u64, (0_u32, 1_u32)); slots.register_external_workers(&dp1); let mut configs = HashMap::new(); configs.insert( 10_u64, SimpleWorkerConfig { max_num_batched_tokens: Some(isl as u64), ..Default::default() }, ); cfg_tx.send(configs.clone()).unwrap(); // Register worker 20 (worker 10 must NOT be evicted) let mut dp2 = std::collections::HashMap::new(); dp2.insert(20_u64, (0_u32, 1_u32)); slots.register_external_workers(&dp2); configs.insert( 20_u64, SimpleWorkerConfig { max_num_batched_tokens: Some(isl as u64), ..Default::default() }, ); cfg_tx.send(configs).unwrap(); // Send enough requests to statistically prove both workers are available let mut seen = std::collections::HashSet::new(); for i in 0..20 { let req_id = format!("add-{i}"); let (req, rx) = make_request(&req_id, isl); queue.enqueue(req).await; let resp = rx .await .expect("oneshot dropped") .expect("scheduling failed"); seen.insert(resp.best_worker.worker_id); slots.mark_prefill_completed(&req_id).await.unwrap(); slots.free(&req_id).await.unwrap(); } assert!( seen.contains(&10) && seen.contains(&20), "both workers should be reachable after additive registration, saw: {seen:?}" ); } /// Requests with allowed_worker_ids should only route to the specified subset. #[tokio::test(flavor = "multi_thread")] async fn test_allowed_worker_ids_filter() { let block_size = 16; let isl = 256; let (queue, slots, cfg_tx) = make_queue_with_sender(0, block_size, isl, None); // Register three workers let mut dp = std::collections::HashMap::new(); dp.insert(1_u64, (0_u32, 1_u32)); dp.insert(2_u64, (0_u32, 1_u32)); dp.insert(3_u64, (0_u32, 1_u32)); slots.register_external_workers(&dp); let mut configs = HashMap::new(); for &id in &[1_u64, 2_u64, 3_u64] { configs.insert( id, SimpleWorkerConfig { max_num_batched_tokens: Some(isl as u64), ..Default::default() }, ); } cfg_tx.send(configs).unwrap(); // Send a request with allowed_worker_ids = {2} only let mut allowed = std::collections::HashSet::new(); allowed.insert(2_u64); let (tx, rx) = tokio::sync::oneshot::channel(); let req = SchedulingRequest { maybe_request_id: Some("filter-0".to_string()), token_seq: None, isl_tokens: isl, overlaps: OverlapScores::default(), decode_blocks: HashMap::new(), prefill_tokens: HashMap::new(), router_config_override: None, update_states: true, lora_name: None, priority_jump: 0.0, expected_output_tokens: None, allowed_worker_ids: Some(allowed), resp_tx: Some(tx), }; queue.enqueue(req).await; let resp = rx .await .expect("oneshot dropped") .expect("scheduling failed"); assert_eq!( resp.best_worker.worker_id, 2, "request must be routed to allowed worker 2, got {}", resp.best_worker.worker_id ); slots .mark_prefill_completed(&"filter-0".to_string()) .await .unwrap(); slots.free(&"filter-0".to_string()).await.unwrap(); } }