// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use std::collections::{HashMap, HashSet}; use std::sync::Arc; use std::time::Duration; use tokio::sync::{mpsc, watch}; use tokio_util::sync::CancellationToken; use super::policy::{RouterSchedulingPolicy, SchedulingPolicy}; use super::queue::SchedulerQueue; use super::selector::{DefaultWorkerSelector, WorkerSelector}; use super::types::{KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse}; use crate::protocols::{OverlapScores, WorkerConfigLike, WorkerId, WorkerWithDpRank}; use crate::sequences::{ ActiveSequencesMultiWorker, SequenceError, SequencePublisher, SequenceRequest, }; use dynamo_tokens::SequenceHash; const RECHECK_INTERVAL: Duration = Duration::from_secs(60); pub struct LocalScheduler where P: SequencePublisher, C: WorkerConfigLike, S: SchedulingPolicy, Sel: WorkerSelector, { request_tx: mpsc::Sender, slots: Arc>, queue: Arc>, worker_type: &'static str, } impl LocalScheduler where P: SequencePublisher + 'static, C: WorkerConfigLike + Clone + PartialEq + Send + Sync + 'static, S: SchedulingPolicy + 'static, Sel: WorkerSelector + Send + Sync + 'static, { #[allow(clippy::too_many_arguments)] pub fn new( slots: Arc>, workers_with_configs: watch::Receiver>, threshold_frac: Option, block_size: u32, selector: Sel, policy: S, cancellation_token: CancellationToken, worker_type: &'static str, monitor_worker_configs: bool, ) -> Self { if monitor_worker_configs { let slots_monitor = Arc::clone(&slots); let mut monitor_rx = workers_with_configs.clone(); let mut last_workers = monitor_rx.borrow().clone(); let monitor_cancel_token = cancellation_token.clone(); tokio::spawn(async move { tracing::trace!("LocalScheduler workers monitoring task started"); loop { tokio::select! { _ = monitor_cancel_token.cancelled() => { tracing::trace!("LocalScheduler workers monitoring task shutting down"); break; } result = monitor_rx.changed() => { if result.is_err() { tracing::warn!("LocalScheduler worker config watch dropped, shutting down"); break; } } } let current_workers = monitor_rx.borrow_and_update().clone(); if current_workers == last_workers { continue; } let dp_range: HashMap = current_workers .iter() .map(|(&id, cfg)| { ( id, (cfg.data_parallel_start_rank(), cfg.data_parallel_size()), ) }) .collect(); slots_monitor.update_workers(&dp_range); last_workers = current_workers; } }); } let queue = Arc::new(SchedulerQueue::new( Arc::clone(&slots), workers_with_configs, threshold_frac, block_size, selector, policy, )); let (request_tx, request_rx) = mpsc::channel::(1024); let queue_clone = Arc::clone(&queue); tokio::spawn(async move { let mut request_rx = request_rx; let mut recheck_interval = tokio::time::interval(RECHECK_INTERVAL); tracing::trace!("LocalScheduler background task started"); loop { tokio::select! { _ = cancellation_token.cancelled() => { tracing::trace!("LocalScheduler background task shutting down"); break; } request = request_rx.recv() => { let Some(request) = request else { tracing::warn!("LocalScheduler request channel closed"); break; }; tracing::trace!("received request to be scheduled"); queue_clone.enqueue(request).await; } _ = recheck_interval.tick() => { queue_clone.update().await; } } } }); Self { request_tx, slots, queue, worker_type, } } #[expect(clippy::too_many_arguments)] pub async fn schedule( &self, maybe_request_id: Option, isl_tokens: usize, token_seq: Option>, overlaps: OverlapScores, router_config_override: Option<&super::config::RouterConfigOverride>, update_states: bool, lora_name: Option, priority_jump: f64, expected_output_tokens: Option, allowed_worker_ids: Option>, ) -> Result { let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let request = SchedulingRequest { maybe_request_id, token_seq, isl_tokens, overlaps, decode_blocks: HashMap::new(), prefill_tokens: HashMap::new(), router_config_override: router_config_override.cloned(), update_states, lora_name, priority_jump, expected_output_tokens, allowed_worker_ids, resp_tx: Some(resp_tx), }; self.request_tx .send(request) .await .map_err(|_| KvSchedulerError::SubscriberShutdown)?; resp_rx .await .map_err(|_| KvSchedulerError::SubscriberShutdown)? } pub fn register_workers(&self, worker_ids: &HashSet) { self.queue.register_workers(worker_ids); } pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> { self.slots.add_request(req).await } pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> { self.slots .mark_prefill_completed(&request_id.to_string()) .await?; self.queue.update().await; Ok(()) } pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> { self.slots.free(&request_id.to_string()).await?; self.queue.update().await; Ok(()) } pub fn pending_count(&self) -> usize { self.queue.pending_count() } pub fn worker_type(&self) -> &'static str { self.worker_type } pub fn add_output_block( &self, request_id: &str, decay_fraction: Option, ) -> Result<(), SequenceError> { self.slots .add_output_block(&request_id.to_string(), decay_fraction) } pub fn get_potential_loads( &self, token_seq: Option>, isl_tokens: usize, overlaps: OverlapScores, ) -> Vec { let (decode_blocks, prefill_tokens) = self.slots .potential_blocks_and_tokens(token_seq.as_deref(), isl_tokens, overlaps); let mut workers: HashSet = HashSet::new(); workers.extend(decode_blocks.keys().copied()); workers.extend(prefill_tokens.keys().copied()); let mut loads = Vec::with_capacity(workers.len()); for worker in workers { loads.push(PotentialLoad { worker_id: worker.worker_id, dp_rank: worker.dp_rank, potential_prefill_tokens: prefill_tokens .get(&worker) .copied() .unwrap_or(isl_tokens), potential_decode_blocks: decode_blocks.get(&worker).copied().unwrap_or(0), }); } loads } pub fn get_active_lora_counts(&self) -> HashMap { self.slots.get_active_lora_counts() } } #[cfg(test)] mod tests { use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use tokio::sync::watch; use super::*; use crate::protocols::OverlapScores; use crate::scheduling::policy::FcfsPolicy; use crate::scheduling::selector::DefaultWorkerSelector; use crate::test_utils::{NoopSequencePublisher, SimpleWorkerConfig}; #[allow(clippy::type_complexity)] fn make_scheduler( workers: HashMap, threshold_frac: Option, monitor_worker_configs: bool, ) -> ( Arc>, Arc>, watch::Sender>, CancellationToken, ) { let dp_range = workers .iter() .map(|(&id, cfg)| (id, (cfg.data_parallel_start_rank, cfg.data_parallel_size))) .collect(); let slots = Arc::new(ActiveSequencesMultiWorker::new( NoopSequencePublisher, 64, dp_range, false, 0, "test", )); let (cfg_tx, cfg_rx) = watch::channel(workers); let cancel_token = CancellationToken::new(); let scheduler = Arc::new(LocalScheduler::new( Arc::clone(&slots), cfg_rx, threshold_frac, 64, DefaultWorkerSelector::new(None, "test"), FcfsPolicy, cancel_token.clone(), "test", monitor_worker_configs, )); (scheduler, slots, cfg_tx, cancel_token) } #[tokio::test] async fn test_schedule_books_request_into_active_sequences() { let mut workers = HashMap::new(); workers.insert( 0, SimpleWorkerConfig { max_num_batched_tokens: Some(64), ..Default::default() }, ); let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true); let response = scheduler .schedule( Some("req-1".to_string()), 64, Some(vec![1, 2, 3, 4]), OverlapScores::default(), None, true, Some("adapter-a".to_string()), 0.0, None, None, ) .await .unwrap(); assert_eq!(response.best_worker.worker_id, 0); assert_eq!( scheduler.get_active_lora_counts(), HashMap::from([(String::from("adapter-a"), 1)]) ); cancel_token.cancel(); } #[tokio::test] async fn test_mark_prefill_completed_drains_pending_queue() { let mut workers = HashMap::new(); workers.insert( 0, SimpleWorkerConfig { max_num_batched_tokens: Some(64), ..Default::default() }, ); let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, Some(0.5), true); scheduler .schedule( Some("req-1".to_string()), 64, Some(vec![1, 2, 3, 4]), OverlapScores::default(), None, true, None, 0.0, None, None, ) .await .unwrap(); let queued = { let scheduler = Arc::clone(&scheduler); tokio::spawn(async move { scheduler .schedule( Some("req-2".to_string()), 64, Some(vec![5, 6, 7, 8]), OverlapScores::default(), None, true, None, 0.0, None, None, ) .await }) }; tokio::time::sleep(Duration::from_millis(25)).await; assert_eq!(scheduler.pending_count(), 1); scheduler.mark_prefill_completed("req-1").await.unwrap(); queued.await.unwrap().unwrap(); assert_eq!(scheduler.pending_count(), 0); cancel_token.cancel(); } #[tokio::test] async fn test_free_updates_active_state() { let mut workers = HashMap::new(); workers.insert( 0, SimpleWorkerConfig { max_num_batched_tokens: Some(64), ..Default::default() }, ); let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true); scheduler .schedule( Some("req-1".to_string()), 64, Some(vec![1, 2, 3, 4]), OverlapScores::default(), None, true, Some("adapter-a".to_string()), 0.0, None, None, ) .await .unwrap(); assert_eq!( scheduler.get_active_lora_counts(), HashMap::from([(String::from("adapter-a"), 1)]) ); scheduler.free("req-1").await.unwrap(); assert!(scheduler.get_active_lora_counts().is_empty()); cancel_token.cancel(); } #[tokio::test] async fn test_get_potential_loads_matches_slots() { let mut workers = HashMap::new(); workers.insert( 0, SimpleWorkerConfig { max_num_batched_tokens: Some(256), ..Default::default() }, ); workers.insert( 1, SimpleWorkerConfig { max_num_batched_tokens: Some(256), ..Default::default() }, ); let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true); let token_seq = vec![11, 22, 33, 44]; let overlaps = OverlapScores::default(); let (decode_blocks, prefill_tokens) = slots.potential_blocks_and_tokens(Some(&token_seq), 128, overlaps.clone()); let mut expected: Vec<_> = decode_blocks .keys() .map(|worker| PotentialLoad { worker_id: worker.worker_id, dp_rank: worker.dp_rank, potential_prefill_tokens: prefill_tokens.get(worker).copied().unwrap_or(128), potential_decode_blocks: decode_blocks.get(worker).copied().unwrap_or(0), }) .collect(); expected.sort_by_key(|load| (load.worker_id, load.dp_rank)); let mut actual = scheduler.get_potential_loads(Some(token_seq), 128, overlaps); actual.sort_by_key(|load| (load.worker_id, load.dp_rank)); assert_eq!(actual.len(), expected.len()); for (actual, expected) in actual.iter().zip(expected.iter()) { assert_eq!(actual.worker_id, expected.worker_id); assert_eq!(actual.dp_rank, expected.dp_rank); assert_eq!( actual.potential_prefill_tokens, expected.potential_prefill_tokens ); assert_eq!( actual.potential_decode_blocks, expected.potential_decode_blocks ); } cancel_token.cancel(); } #[tokio::test] async fn test_register_workers_uses_default_dp_fallback() { let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(HashMap::new(), None, false); scheduler.register_workers(&HashSet::from([42])); let loads = scheduler.get_potential_loads(None, 64, OverlapScores::default()); assert_eq!(loads.len(), 1); assert_eq!(loads[0].worker_id, 42); assert_eq!(loads[0].dp_rank, 0); cancel_token.cancel(); } #[tokio::test] async fn test_worker_watch_updates_slot_ranges() { let mut workers = HashMap::new(); workers.insert(0, SimpleWorkerConfig::default()); let (scheduler, _slots, cfg_tx, cancel_token) = make_scheduler(workers, None, true); assert_eq!( scheduler .get_potential_loads(None, 64, OverlapScores::default()) .len(), 1 ); let mut updated_workers = HashMap::new(); updated_workers.insert( 0, SimpleWorkerConfig { data_parallel_size: 2, ..Default::default() }, ); updated_workers.insert(1, SimpleWorkerConfig::default()); cfg_tx.send(updated_workers).unwrap(); tokio::time::timeout(Duration::from_secs(1), async { loop { if scheduler .get_potential_loads(None, 64, OverlapScores::default()) .len() == 3 { break; } tokio::task::yield_now().await; } }) .await .unwrap(); cancel_token.cancel(); } }