// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use std::cmp::Ordering; use std::collections::{BinaryHeap, HashMap}; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering}; use std::time::{Duration, Instant}; use tokio::sync::Mutex; use tokio::sync::watch; use super::selector::WorkerSelector; use super::types::{SchedulingRequest, SchedulingResponse}; use crate::protocols::{WorkerConfigLike, WorkerId, WorkerWithDpRank}; use crate::sequences::{ActiveSequencesMultiWorker, SequencePublisher, SequenceRequest}; /// Large default for max_num_batched_tokens when not configured (effectively disables queueing for that worker) pub const DEFAULT_MAX_BATCHED_TOKENS: u64 = 10_000_000; /// Entry in the priority queue, ordered by effective arrival time (lower = higher priority). /// Effective arrival = elapsed time since queue start minus `priority_jump`. struct QueueEntry { effective_offset: Duration, request: SchedulingRequest, } impl Eq for QueueEntry {} impl PartialEq for QueueEntry { fn eq(&self, other: &Self) -> bool { self.effective_offset == other.effective_offset } } impl Ord for QueueEntry { fn cmp(&self, other: &Self) -> Ordering { // BinaryHeap is a max-heap; reverse so lower effective_offset = higher priority other.effective_offset.cmp(&self.effective_offset) } } impl PartialOrd for QueueEntry { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } /// Queue that gates scheduling requests behind a capacity check. /// When all workers exceed `threshold_frac` utilisation the request is parked in `pending`. /// When capacity frees up (`update()`), pending requests are scheduled in priority order. /// If queueing is disabled (threshold_frac is None), requests are scheduled immediately. pub struct SchedulerQueue { pending: Mutex>, /// Number of requests currently parked in the pending queue. /// Incremented after push, decremented after pop. Lock-free reads via `Relaxed` load. pending_count: AtomicUsize, slots: Arc>, workers_with_configs: watch::Receiver>, /// Cached threshold fraction; None means queueing is disabled. threshold_frac: Option, /// Reference instant for computing arrival offsets. start_time: Instant, block_size: u32, selector: Box + Send + Sync>, } impl SchedulerQueue { pub fn new( slots: Arc>, workers_with_configs: watch::Receiver>, threshold_frac: Option, block_size: u32, selector: Box + Send + Sync>, ) -> Self { if let Some(frac) = threshold_frac { tracing::info!("Router queue enabled with threshold fraction {frac}"); } Self { pending: Mutex::new(BinaryHeap::new()), pending_count: AtomicUsize::new(0), slots, workers_with_configs, threshold_frac, start_time: Instant::now(), block_size, selector, } } /// Build a QueueEntry for a request, computing its effective arrival offset. fn make_entry(&self, request: SchedulingRequest) -> QueueEntry { let arrival_offset = self.start_time.elapsed(); let jump = Duration::from_secs_f64(request.priority_jump.max(0.0)); let effective_offset = arrival_offset.saturating_sub(jump); QueueEntry { effective_offset, request, } } /// Enqueue a new request. /// If queueing is disabled or workers have capacity, schedule immediately. /// Otherwise park in the pending heap. pub async fn enqueue(&self, request: SchedulingRequest) { let Some(threshold) = self.threshold_frac else { self.schedule(request).await; return; }; if self.all_workers_busy(threshold) { tracing::debug!("all workers busy, queueing request"); let entry = self.make_entry(request); self.pending.lock().await.push(entry); self.pending_count.fetch_add(1, AtomicOrdering::Relaxed); } else { self.schedule(request).await; } } /// Called on prefill_complete/free. Drains pending requests while workers have capacity. /// Each scheduled request updates active_tokens via add_request, so the busy check /// sees fresh state on the next iteration. pub async fn update(&self) { let Some(threshold) = self.threshold_frac else { return; }; loop { if self.all_workers_busy(threshold) { break; } let Some(entry) = self.pending.lock().await.pop() else { break; }; self.pending_count.fetch_sub(1, AtomicOrdering::Relaxed); tracing::debug!("scheduling request from pending queue"); self.schedule(entry.request).await; } } /// Run the full scheduling pipeline for a single request: /// compute potential load -> select worker -> respond -> book via add_request. async fn schedule(&self, mut request: SchedulingRequest) { let (decode_blocks, prefill_tokens) = self.slots.potential_blocks_and_tokens( request.token_seq.as_deref(), request.isl_tokens, request.overlaps.clone(), ); request.decode_blocks = decode_blocks; request.prefill_tokens = prefill_tokens; let selection = { let workers = self.workers_with_configs.borrow(); self.selector .select_worker(&workers, &request, self.block_size) }; let selection = match selection { Ok(s) => s, Err(e) => { tracing::warn!("scheduling failed: {e}"); request.respond(Err(e)); return; } }; request.respond(Ok(SchedulingResponse { best_worker: selection.worker, overlap_blocks: selection.overlap_blocks, })); if !request.update_states { return; } let Some(request_id) = request.maybe_request_id else { tracing::error!("No request_id provided to add_request to the slot tracker"); return; }; if let Err(e) = self .slots .add_request(SequenceRequest { request_id: request_id.clone(), token_sequence: request.token_seq, isl: request.isl_tokens, overlap: selection.overlap_blocks, expected_output_tokens: request.expected_output_tokens, worker: selection.worker, lora_name: request.lora_name.clone(), }) .await { tracing::warn!("Failed to add request {request_id}: {e}"); } } /// Number of requests currently parked in the pending queue (lock-free). pub fn pending_count(&self) -> usize { self.pending_count.load(AtomicOrdering::Relaxed) } /// Check if all workers are busy based on threshold. /// Returns true only if ALL workers exceed the threshold (no worker has capacity). fn all_workers_busy(&self, threshold: f64) -> bool { let active_tokens = self.slots.active_tokens(); let configs = self.workers_with_configs.borrow(); for (&worker_id, config) in configs.iter() { let dp_size = config.data_parallel_size(); let dp_start_rank = config.data_parallel_start_rank(); let max_batched = config .max_num_batched_tokens() .unwrap_or(DEFAULT_MAX_BATCHED_TOKENS); for dp_rank in dp_start_rank..dp_start_rank + dp_size { let worker = WorkerWithDpRank::new(worker_id, dp_rank); let tokens = active_tokens.get(&worker).copied().unwrap_or(0); if (tokens as f64) <= threshold * (max_batched as f64) { return false; } } } true } } #[cfg(test)] mod tests { use std::collections::HashMap; use std::sync::Arc; use tokio::sync::watch; use super::*; use crate::protocols::OverlapScores; use crate::selector::DefaultWorkerSelector; use crate::sequences::ActiveSequencesMultiWorker; use crate::test_utils::{NoopSequencePublisher, SimpleWorkerConfig}; fn make_queue( num_workers: usize, block_size: u32, isl: usize, threshold_frac: Option, ) -> ( Arc>, Arc>, ) { let dp_range: HashMap = (0..num_workers as u64).map(|id| (id, (0, 1))).collect(); let slots = Arc::new(ActiveSequencesMultiWorker::new( NoopSequencePublisher, block_size as usize, dp_range, false, 0, "test", )); let mut configs: HashMap = HashMap::new(); for id in 0..num_workers as u64 { configs.insert( id, SimpleWorkerConfig { max_num_batched_tokens: Some(isl as u64), ..Default::default() }, ); } let (cfg_tx, cfg_rx) = watch::channel(configs); std::mem::forget(cfg_tx); let selector = Box::new(DefaultWorkerSelector::new(None, "test")); let queue = Arc::new(SchedulerQueue::new( Arc::clone(&slots), cfg_rx, threshold_frac, block_size, selector, )); (queue, slots) } fn make_request( request_id: &str, isl_tokens: usize, ) -> ( SchedulingRequest, tokio::sync::oneshot::Receiver< Result, >, ) { let (tx, rx) = tokio::sync::oneshot::channel(); let req = SchedulingRequest { maybe_request_id: Some(request_id.to_string()), token_seq: None, isl_tokens, overlaps: OverlapScores::default(), decode_blocks: HashMap::new(), prefill_tokens: HashMap::new(), router_config_override: None, update_states: true, lora_name: None, priority_jump: 0.0, expected_output_tokens: None, allowed_worker_ids: None, resp_tx: Some(tx), }; (req, rx) } #[tokio::test(flavor = "multi_thread")] async fn test_concurrent_flood() { let block_size = 16; let isl = 512; let num_workers = 4; let num_tasks = 25; let (queue, slots) = make_queue(num_workers, block_size, isl, None); let mut handles = Vec::new(); for i in 0..num_tasks { let queue = Arc::clone(&queue); let slots = Arc::clone(&slots); handles.push(tokio::spawn(async move { let req_id = format!("req-{i}"); let (req, rx) = make_request(&req_id, isl); queue.enqueue(req).await; let resp = rx.await.expect("oneshot dropped"); let resp = resp.expect("scheduling failed"); assert!(resp.best_worker.worker_id < num_workers as u64); slots.mark_prefill_completed(&req_id).await.unwrap(); slots.free(&req_id).await.unwrap(); queue.update().await; })); } for h in handles { h.await.expect("task panicked"); } let active = slots.active_tokens(); for (worker, tokens) in &active { assert_eq!( *tokens, 0, "worker {worker:?} still has {tokens} active tokens" ); } } #[tokio::test(flavor = "multi_thread")] async fn test_queueing_under_pressure() { let block_size = 16; let isl = 512; let num_workers = 2; let num_requests = 10; let (queue, slots) = make_queue(num_workers, block_size, isl, Some(0.0)); let mut receivers = Vec::new(); let mut req_ids = Vec::new(); for i in 0..num_requests { let req_id = format!("pressure-{i}"); let (req, rx) = make_request(&req_id, isl); queue.enqueue(req).await; receivers.push(rx); req_ids.push(req_id); } // Drain pending by cycling mark_prefill_completed + free + update // on already-scheduled requests until all receivers have a response. for _ in 0..num_requests { queue.update().await; for rid in &req_ids { let _ = slots.mark_prefill_completed(rid).await; let _ = slots.free(rid).await; } } queue.update().await; let mut ok_count = 0; for mut rx in receivers { if let Ok(result) = rx.try_recv() { result.expect("scheduling returned error"); ok_count += 1; } } assert_eq!(ok_count, num_requests, "not all requests were scheduled"); } #[tokio::test(flavor = "multi_thread")] async fn test_pending_count() { let block_size = 16; let isl = 512; let num_workers = 1; // threshold_frac=0.0 means any active tokens trigger queueing let (queue, slots) = make_queue(num_workers, block_size, isl, Some(0.0)); assert_eq!(queue.pending_count(), 0); // First request goes through (worker is idle) let (req1, rx1) = make_request("req-1", isl); queue.enqueue(req1).await; let _resp1 = rx1.await.unwrap().unwrap(); assert_eq!(queue.pending_count(), 0); // scheduled immediately // Second and third requests should be queued (worker is now busy) let (req2, _rx2) = make_request("req-2", isl); queue.enqueue(req2).await; assert_eq!(queue.pending_count(), 1); let (req3, _rx3) = make_request("req-3", isl); queue.enqueue(req3).await; assert_eq!(queue.pending_count(), 2); // Free the first request and update — should drain one from pending slots .mark_prefill_completed(&"req-1".to_string()) .await .unwrap(); slots.free(&"req-1".to_string()).await.unwrap(); queue.update().await; // After update, one pending request should have been scheduled assert!( queue.pending_count() < 2, "pending_count should decrease after free+update, got {}", queue.pending_count() ); // Free req-2 and update to drain remaining let _ = slots.mark_prefill_completed(&"req-2".to_string()).await; let _ = slots.free(&"req-2".to_string()).await; queue.update().await; let _ = slots.mark_prefill_completed(&"req-3".to_string()).await; let _ = slots.free(&"req-3".to_string()).await; queue.update().await; assert_eq!(queue.pending_count(), 0, "all requests should be drained"); } #[tokio::test] async fn test_no_workers_returns_error() { let (queue, _slots) = make_queue(0, 16, 512, None); let (req, rx) = make_request("lonely-req", 512); queue.enqueue(req).await; let resp = rx.await.expect("oneshot dropped"); assert!( matches!( resp, Err(crate::scheduling::types::KvSchedulerError::NoEndpoints) ), "expected NoEndpoints, got {resp:?}" ); } }