// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use std::collections::{HashMap, HashSet}; use std::sync::Arc; use std::time::Duration; use tokio::sync::{mpsc, watch}; use tokio::time::Instant; use tokio_util::sync::CancellationToken; use super::policy::{RouterSchedulingPolicy, SchedulingPolicy}; use super::prefill_load::PrefillLoadEstimator; use super::queue::SchedulerQueue; use super::selector::{DefaultWorkerSelector, WorkerSelector}; use super::types::{KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse}; use crate::protocols::{OverlapScores, WorkerConfigLike, WorkerId, WorkerWithDpRank}; use crate::sequences::{ ActiveSequencesMultiWorker, SequenceError, SequencePublisher, SequenceRequest, }; use dynamo_tokens::SequenceHash; pub struct LocalScheduler where P: SequencePublisher, C: WorkerConfigLike, S: SchedulingPolicy, Sel: WorkerSelector, { request_tx: mpsc::Sender, slots: Arc>, queue: Arc>, queue_updates: watch::Sender<()>, track_prefill_tokens_default: bool, worker_type: &'static str, } impl LocalScheduler where P: SequencePublisher + 'static, C: WorkerConfigLike + Clone + PartialEq + Send + Sync + 'static, S: SchedulingPolicy + 'static, Sel: WorkerSelector + Send + Sync + 'static, { #[allow(clippy::too_many_arguments)] pub fn new( slots: Arc>, workers_with_configs: watch::Receiver>, threshold_frac: Option, block_size: u32, selector: Sel, policy: S, prefill_load_estimator: Option>, recheck_interval: Duration, track_prefill_tokens_default: bool, cancellation_token: CancellationToken, worker_type: &'static str, monitor_worker_configs: bool, ) -> Self { if monitor_worker_configs { let slots_monitor = Arc::clone(&slots); let mut monitor_rx = workers_with_configs.clone(); let mut last_workers = monitor_rx.borrow().clone(); let monitor_cancel_token = cancellation_token.clone(); tokio::spawn(async move { tracing::trace!("LocalScheduler workers monitoring task started"); loop { tokio::select! { _ = monitor_cancel_token.cancelled() => { tracing::trace!("LocalScheduler workers monitoring task shutting down"); break; } result = monitor_rx.changed() => { if result.is_err() { tracing::warn!("LocalScheduler worker config watch dropped, shutting down"); break; } } } let current_workers = monitor_rx.borrow_and_update().clone(); if current_workers == last_workers { continue; } let dp_range: HashMap = current_workers .iter() .map(|(&id, cfg)| { ( id, (cfg.data_parallel_start_rank(), cfg.data_parallel_size()), ) }) .collect(); slots_monitor.update_workers(&dp_range); last_workers = current_workers; } }); } let queue = Arc::new(SchedulerQueue::new( Arc::clone(&slots), workers_with_configs, threshold_frac, block_size, selector, policy, prefill_load_estimator, )); let (queue_updates, _) = watch::channel(()); let (request_tx, request_rx) = mpsc::channel::(1024); let queue_clone = Arc::clone(&queue); let queue_remote_updates = Arc::clone(&queue); let mut remote_state_updates = slots.subscribe_remote_state_changes(); let remote_update_cancel_token = cancellation_token.clone(); let queue_updates_remote = queue_updates.clone(); tokio::spawn(async move { tracing::trace!("LocalScheduler remote state listener started"); loop { tokio::select! { _ = remote_update_cancel_token.cancelled() => { tracing::trace!("LocalScheduler remote state listener shutting down"); break; } result = remote_state_updates.changed() => { if result.is_err() { tracing::trace!("LocalScheduler remote state listener shutting down"); break; } queue_remote_updates.update().await; let _ = queue_updates_remote.send(()); } } } }); tokio::spawn(async move { let mut request_rx = request_rx; let mut recheck_interval = tokio::time::interval(recheck_interval); tracing::trace!("LocalScheduler background task started"); loop { tokio::select! { _ = cancellation_token.cancelled() => { tracing::trace!("LocalScheduler background task shutting down"); break; } request = request_rx.recv() => { let Some(request) = request else { tracing::warn!("LocalScheduler request channel closed"); break; }; tracing::trace!("received request to be scheduled"); queue_clone.enqueue(request).await; } _ = recheck_interval.tick() => { queue_clone.update().await; } } } }); Self { request_tx, slots, queue, queue_updates, track_prefill_tokens_default, worker_type, } } #[expect(clippy::too_many_arguments)] pub async fn schedule( &self, maybe_request_id: Option, isl_tokens: usize, token_seq: Option>, overlaps: OverlapScores, router_config_override: Option<&super::config::RouterConfigOverride>, update_states: bool, lora_name: Option, priority_jump: f64, expected_output_tokens: Option, allowed_worker_ids: Option>, ) -> Result { let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let track_prefill_tokens = router_config_override .and_then(|cfg| cfg.track_prefill_tokens) .unwrap_or(self.track_prefill_tokens_default); let request = SchedulingRequest { maybe_request_id, token_seq, isl_tokens, overlaps, decode_blocks: HashMap::new(), prefill_tokens: HashMap::new(), track_prefill_tokens, router_config_override: router_config_override.cloned(), update_states, lora_name, priority_jump, expected_output_tokens, allowed_worker_ids, resp_tx: Some(resp_tx), }; self.request_tx .send(request) .await .map_err(|_| KvSchedulerError::SubscriberShutdown)?; resp_rx .await .map_err(|_| KvSchedulerError::SubscriberShutdown)? } pub fn register_workers(&self, worker_ids: &HashSet) { self.queue.register_workers(worker_ids); } pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> { self.slots.add_request(req, Instant::now()) } pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> { self.slots .mark_prefill_completed(&request_id.to_string(), Instant::now())?; self.queue.update().await; Ok(()) } pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> { self.slots.free(&request_id.to_string(), Instant::now())?; self.queue.update().await; Ok(()) } pub fn pending_count(&self) -> usize { self.queue.pending_count() } pub fn pending_isl_tokens(&self) -> usize { self.queue.pending_isl_tokens() } pub fn worker_type(&self) -> &'static str { self.worker_type } pub fn subscribe_queue_updates(&self) -> watch::Receiver<()> { self.queue_updates.subscribe() } pub fn add_output_block( &self, request_id: &str, decay_fraction: Option, ) -> Result<(), SequenceError> { self.slots .add_output_block(&request_id.to_string(), decay_fraction) } pub fn get_potential_loads( &self, token_seq: Option>, isl_tokens: usize, overlaps: OverlapScores, track_prefill_tokens: bool, ) -> Vec { let decay_now = Instant::now(); let (decode_blocks, prefill_tokens) = self .slots .potential_blocks_and_tokens_with_prefill_tracking( token_seq.as_deref(), isl_tokens, overlaps, track_prefill_tokens, decay_now, ); let mut workers: HashSet = HashSet::new(); workers.extend(decode_blocks.keys().copied()); workers.extend(prefill_tokens.keys().copied()); let mut loads = Vec::with_capacity(workers.len()); for worker in workers { loads.push(PotentialLoad { worker_id: worker.worker_id, dp_rank: worker.dp_rank, potential_prefill_tokens: prefill_tokens .get(&worker) .copied() .unwrap_or(isl_tokens), potential_decode_blocks: decode_blocks.get(&worker).copied().unwrap_or(0), }); } loads } pub fn get_active_lora_counts(&self) -> HashMap { self.slots.get_active_lora_counts() } } #[cfg(test)] mod tests { use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use tokio::sync::{mpsc, watch}; use super::*; use crate::protocols::{ActiveSequenceEvent, ActiveSequenceEventData, OverlapScores}; use crate::scheduling::PrefillLoadEstimator; use crate::scheduling::policy::FcfsPolicy; use crate::scheduling::selector::DefaultWorkerSelector; use crate::sequences::SequenceSubscriber; use crate::test_utils::{NoopSequencePublisher, SimpleWorkerConfig}; struct TestSequenceSubscriber { rx: mpsc::UnboundedReceiver, } impl SequenceSubscriber for TestSequenceSubscriber { async fn next_event(&mut self) -> Option> { self.rx.recv().await.map(Ok) } } struct FixedPrefillLoadEstimator { duration: Duration, } impl PrefillLoadEstimator for FixedPrefillLoadEstimator { fn predict_prefill_duration( &self, _batch_size: usize, _effective_isl: usize, _prefix: usize, ) -> anyhow::Result { Ok(self.duration) } } #[allow(clippy::type_complexity)] fn make_scheduler( workers: HashMap, threshold_frac: Option, monitor_worker_configs: bool, prefill_load_estimator: Option>, ) -> ( Arc>, Arc>, watch::Sender>, CancellationToken, ) { let dp_range = workers .iter() .map(|(&id, cfg)| (id, (cfg.data_parallel_start_rank, cfg.data_parallel_size))) .collect(); let slots = Arc::new(ActiveSequencesMultiWorker::new( NoopSequencePublisher, 64, dp_range, false, 0, "test", )); let (cfg_tx, cfg_rx) = watch::channel(workers); let cancel_token = CancellationToken::new(); let scheduler = Arc::new(LocalScheduler::new( Arc::clone(&slots), cfg_rx, threshold_frac, 64, DefaultWorkerSelector::new(None, "test"), FcfsPolicy, prefill_load_estimator, Duration::from_secs(60), true, cancel_token.clone(), "test", monitor_worker_configs, )); (scheduler, slots, cfg_tx, cancel_token) } fn start_replica_sync( slots: &Arc>, cancel_token: &CancellationToken, ) -> mpsc::UnboundedSender { let (tx, rx) = mpsc::unbounded_channel(); slots.start_replica_sync(TestSequenceSubscriber { rx }, cancel_token.clone()); tx } async fn wait_for_pending_count( scheduler: &Arc>, expected: usize, ) { tokio::time::timeout(Duration::from_millis(250), async { loop { if scheduler.pending_count() == expected { break; } tokio::time::sleep(Duration::from_millis(5)).await; } }) .await .unwrap(); } #[tokio::test] async fn test_schedule_books_request_into_active_sequences() { let mut workers = HashMap::new(); workers.insert( 0, SimpleWorkerConfig { max_num_batched_tokens: Some(64), ..Default::default() }, ); let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true, None); let response = scheduler .schedule( Some("req-1".to_string()), 64, Some(vec![1, 2, 3, 4]), OverlapScores::default(), None, true, Some("adapter-a".to_string()), 0.0, None, None, ) .await .unwrap(); assert_eq!(response.best_worker.worker_id, 0); assert_eq!( scheduler.get_active_lora_counts(), HashMap::from([(String::from("adapter-a"), 1)]) ); cancel_token.cancel(); } #[tokio::test] async fn test_schedule_override_can_disable_prefill_tracking() { let mut workers = HashMap::new(); workers.insert( 0, SimpleWorkerConfig { max_num_batched_tokens: Some(64), ..Default::default() }, ); let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true, None); scheduler .schedule( Some("req-1".to_string()), 64, Some(vec![1, 2, 3, 4]), OverlapScores::default(), Some(&crate::config::RouterConfigOverride { track_prefill_tokens: Some(false), ..Default::default() }), true, None, 0.0, None, None, ) .await .unwrap(); assert_eq!( slots .active_tokens(Instant::now()) .get(&WorkerWithDpRank::new(0, 0)) .copied(), Some(0) ); cancel_token.cancel(); } #[tokio::test] async fn test_mark_prefill_completed_drains_pending_queue() { let mut workers = HashMap::new(); workers.insert( 0, SimpleWorkerConfig { max_num_batched_tokens: Some(64), ..Default::default() }, ); let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, Some(0.5), true, None); scheduler .schedule( Some("req-1".to_string()), 64, Some(vec![1, 2, 3, 4]), OverlapScores::default(), None, true, None, 0.0, None, None, ) .await .unwrap(); let queued = { let scheduler = Arc::clone(&scheduler); tokio::spawn(async move { scheduler .schedule( Some("req-2".to_string()), 64, Some(vec![5, 6, 7, 8]), OverlapScores::default(), None, true, None, 0.0, None, None, ) .await }) }; wait_for_pending_count(&scheduler, 1).await; scheduler.mark_prefill_completed("req-1").await.unwrap(); queued.await.unwrap().unwrap(); assert_eq!(scheduler.pending_count(), 0); cancel_token.cancel(); } #[tokio::test] async fn test_remote_mark_prefill_completed_drains_pending_queue() { let mut workers = HashMap::new(); workers.insert( 0, SimpleWorkerConfig { max_num_batched_tokens: Some(64), ..Default::default() }, ); let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, Some(0.5), true, None); let event_tx = start_replica_sync(&slots, &cancel_token); scheduler .schedule( Some("req-1".to_string()), 64, Some(vec![1, 2, 3, 4]), OverlapScores::default(), None, true, None, 0.0, None, None, ) .await .unwrap(); let queued = { let scheduler = Arc::clone(&scheduler); tokio::spawn(async move { scheduler .schedule( Some("req-2".to_string()), 64, Some(vec![5, 6, 7, 8]), OverlapScores::default(), None, true, None, 0.0, None, None, ) .await }) }; wait_for_pending_count(&scheduler, 1).await; event_tx .send(ActiveSequenceEvent { request_id: "req-1".to_string(), worker: WorkerWithDpRank::new(0, 0), data: ActiveSequenceEventData::MarkPrefillCompleted, router_id: 1, lora_name: None, }) .unwrap(); tokio::time::timeout(Duration::from_millis(250), async { queued.await.unwrap().unwrap(); }) .await .unwrap(); assert_eq!(scheduler.pending_count(), 0); cancel_token.cancel(); } #[tokio::test] async fn test_remote_queue_update_notification_fires_after_drain() { let mut workers = HashMap::new(); workers.insert( 0, SimpleWorkerConfig { max_num_batched_tokens: Some(64), ..Default::default() }, ); let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, Some(0.5), true, None); let event_tx = start_replica_sync(&slots, &cancel_token); let mut queue_updates = scheduler.subscribe_queue_updates(); scheduler .schedule( Some("req-1".to_string()), 64, Some(vec![1, 2, 3, 4]), OverlapScores::default(), None, true, None, 0.0, None, None, ) .await .unwrap(); let queued = { let scheduler = Arc::clone(&scheduler); tokio::spawn(async move { scheduler .schedule( Some("req-2".to_string()), 64, Some(vec![5, 6, 7, 8]), OverlapScores::default(), None, true, None, 0.0, None, None, ) .await }) }; wait_for_pending_count(&scheduler, 1).await; event_tx .send(ActiveSequenceEvent { request_id: "req-1".to_string(), worker: WorkerWithDpRank::new(0, 0), data: ActiveSequenceEventData::Free, router_id: 1, lora_name: None, }) .unwrap(); tokio::time::timeout(Duration::from_millis(250), queue_updates.changed()) .await .unwrap() .unwrap(); assert_eq!(scheduler.pending_count(), 0); queued.await.unwrap().unwrap(); cancel_token.cancel(); } #[tokio::test] async fn test_remote_free_drains_pending_queue() { let mut workers = HashMap::new(); workers.insert( 0, SimpleWorkerConfig { max_num_batched_tokens: Some(64), ..Default::default() }, ); let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, Some(0.5), true, None); let event_tx = start_replica_sync(&slots, &cancel_token); scheduler .schedule( Some("req-1".to_string()), 64, Some(vec![1, 2, 3, 4]), OverlapScores::default(), None, true, None, 0.0, None, None, ) .await .unwrap(); let queued = { let scheduler = Arc::clone(&scheduler); tokio::spawn(async move { scheduler .schedule( Some("req-2".to_string()), 64, Some(vec![5, 6, 7, 8]), OverlapScores::default(), None, true, None, 0.0, None, None, ) .await }) }; wait_for_pending_count(&scheduler, 1).await; event_tx .send(ActiveSequenceEvent { request_id: "req-1".to_string(), worker: WorkerWithDpRank::new(0, 0), data: ActiveSequenceEventData::Free, router_id: 1, lora_name: None, }) .unwrap(); tokio::time::timeout(Duration::from_millis(250), async { queued.await.unwrap().unwrap(); }) .await .unwrap(); assert_eq!(scheduler.pending_count(), 0); cancel_token.cancel(); } #[tokio::test] async fn test_free_updates_active_state() { let mut workers = HashMap::new(); workers.insert( 0, SimpleWorkerConfig { max_num_batched_tokens: Some(64), ..Default::default() }, ); let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true, None); scheduler .schedule( Some("req-1".to_string()), 64, Some(vec![1, 2, 3, 4]), OverlapScores::default(), None, true, Some("adapter-a".to_string()), 0.0, None, None, ) .await .unwrap(); assert_eq!( scheduler.get_active_lora_counts(), HashMap::from([(String::from("adapter-a"), 1)]) ); scheduler.free("req-1").await.unwrap(); assert!(scheduler.get_active_lora_counts().is_empty()); cancel_token.cancel(); } #[tokio::test] async fn test_get_potential_loads_matches_slots() { let mut workers = HashMap::new(); workers.insert( 0, SimpleWorkerConfig { max_num_batched_tokens: Some(256), ..Default::default() }, ); workers.insert( 1, SimpleWorkerConfig { max_num_batched_tokens: Some(256), ..Default::default() }, ); let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true, None); let token_seq = vec![11, 22, 33, 44]; let overlaps = OverlapScores::default(); let (decode_blocks, prefill_tokens) = slots.potential_blocks_and_tokens( Some(&token_seq), 128, overlaps.clone(), Instant::now(), ); let mut expected: Vec<_> = decode_blocks .keys() .map(|worker| PotentialLoad { worker_id: worker.worker_id, dp_rank: worker.dp_rank, potential_prefill_tokens: prefill_tokens.get(worker).copied().unwrap_or(128), potential_decode_blocks: decode_blocks.get(worker).copied().unwrap_or(0), }) .collect(); expected.sort_by_key(|load| (load.worker_id, load.dp_rank)); let mut actual = scheduler.get_potential_loads(Some(token_seq), 128, overlaps, true); actual.sort_by_key(|load| (load.worker_id, load.dp_rank)); assert_eq!(actual.len(), expected.len()); for (actual, expected) in actual.iter().zip(expected.iter()) { assert_eq!(actual.worker_id, expected.worker_id); assert_eq!(actual.dp_rank, expected.dp_rank); assert_eq!( actual.potential_prefill_tokens, expected.potential_prefill_tokens ); assert_eq!( actual.potential_decode_blocks, expected.potential_decode_blocks ); } cancel_token.cancel(); } #[tokio::test(start_paused = true)] async fn test_get_potential_loads_uses_decayed_prefill_tokens() { let mut workers = HashMap::new(); workers.insert( 0, SimpleWorkerConfig { max_num_batched_tokens: Some(256), ..Default::default() }, ); let estimator: Arc = Arc::new(FixedPrefillLoadEstimator { duration: Duration::from_secs(10), }); let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true, Some(estimator)); scheduler .schedule( Some("req-1".to_string()), 100, Some(vec![1, 2, 3, 4]), OverlapScores::default(), None, true, None, 0.0, None, None, ) .await .unwrap(); tokio::time::advance(Duration::from_secs(6)).await; let loads = scheduler.get_potential_loads(None, 0, OverlapScores::default(), true); assert_eq!(loads.len(), 1); assert_eq!(loads[0].potential_prefill_tokens, 40); cancel_token.cancel(); } #[tokio::test] async fn test_register_workers_uses_default_dp_fallback() { let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(HashMap::new(), None, false, None); scheduler.register_workers(&HashSet::from([42])); let loads = scheduler.get_potential_loads(None, 64, OverlapScores::default(), true); assert_eq!(loads.len(), 1); assert_eq!(loads[0].worker_id, 42); assert_eq!(loads[0].dp_rank, 0); cancel_token.cancel(); } #[tokio::test] async fn test_worker_watch_updates_slot_ranges() { let mut workers = HashMap::new(); workers.insert(0, SimpleWorkerConfig::default()); let (scheduler, _slots, cfg_tx, cancel_token) = make_scheduler(workers, None, true, None); assert_eq!( scheduler .get_potential_loads(None, 64, OverlapScores::default(), true) .len(), 1 ); let mut updated_workers = HashMap::new(); updated_workers.insert( 0, SimpleWorkerConfig { data_parallel_size: 2, ..Default::default() }, ); updated_workers.insert(1, SimpleWorkerConfig::default()); cfg_tx.send(updated_workers).unwrap(); tokio::time::timeout(Duration::from_secs(1), async { loop { if scheduler .get_potential_loads(None, 64, OverlapScores::default(), true) .len() == 3 { break; } tokio::task::yield_now().await; } }) .await .unwrap(); cancel_token.cancel(); } #[tokio::test] async fn test_get_potential_loads_can_ignore_prefill_tokens() { let mut workers = HashMap::new(); workers.insert( 0, SimpleWorkerConfig { max_num_batched_tokens: Some(256), ..Default::default() }, ); let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true, None); scheduler .schedule( Some("req-1".to_string()), 64, Some(vec![11, 22]), OverlapScores::default(), None, true, None, 0.0, None, None, ) .await .unwrap(); let loads = scheduler.get_potential_loads(None, 64, OverlapScores::default(), false); assert_eq!(loads.len(), 1); assert_eq!(loads[0].potential_prefill_tokens, 64); cancel_token.cancel(); } }