// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use dynamo_runtime::component::{Component, Instance}; use dynamo_runtime::traits::events::EventPublisher; use rand::Rng; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use super::indexer::OverlapScores; use super::protocols::WorkerSelectionResult; use super::sequence::ActiveSequencesMultiWorker; use super::KvRouterConfig; use super::WorkerSelector; use super::KV_HIT_RATE_SUBJECT; use crate::tokens::SequenceHash; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct KVHitRateEvent { pub worker_id: i64, pub isl_blocks: usize, pub overlap_blocks: u32, } #[derive(Debug, thiserror::Error)] pub enum KvSchedulerError { #[error("no endpoints aviailable to route work")] NoEndpoints, #[error("all workers busy")] AllWorkersBusy, #[error("endpoint subscriber shutdown")] SubscriberShutdown, } #[derive(Debug)] pub struct SchedulingResponse { pub best_worker_id: i64, pub overlap_blocks: u32, } pub struct SchedulingRequest { pub request_id: String, pub token_seq: Vec, pub isl_tokens: usize, pub overlaps: OverlapScores, pub decode_blocks: HashMap, pub prefill_tokens: HashMap, // Option to take it out to send the response without moving the struct resp_tx: Option>, } impl SchedulingRequest { pub fn respond(&mut self, response: SchedulingResponse) { // Changed to &mut self if let Some(tx) = self.resp_tx.take() { // Use take() to extract the sender if tx.send(response).is_err() { tracing::error!("failed to send response to requestor"); } } else { tracing::error!("respond called multiple times on same request"); } } } pub struct KvScheduler { request_tx: tokio::sync::mpsc::Sender, slots: Arc, } impl KvScheduler { pub async fn start( component: Component, block_size: u32, mut instances_rx: tokio::sync::watch::Receiver>, // Changed from ProcessedEndpoints selector: Option>, replica_sync: bool, ) -> Result { let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default())); let mut instances: Vec = instances_rx.borrow_and_update().clone(); let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::(); let ns_clone = component.namespace().clone(); tokio::spawn(async move { let mut event_rx = event_rx; while let Some(event) = event_rx.recv().await { if let Err(e) = ns_clone.publish(KV_HIT_RATE_SUBJECT, &event).await { tracing::warn!("Failed to publish KV hit rate event: {:?}", e); } } }); let worker_ids: Vec = instances .iter() .map(|instance| instance.instance_id) .collect(); let slots = Arc::new(ActiveSequencesMultiWorker::new( component, block_size as usize, worker_ids, replica_sync, )); let slots_clone = slots.clone(); let (request_tx, request_rx) = tokio::sync::mpsc::channel::(1024); // Background task to handle scheduling requests tokio::spawn(async move { let mut request_rx = request_rx; tracing::trace!("scheduler background task started"); loop { // First, check for instance updates (non-blocking) match instances_rx.has_changed() { Ok(true) => { instances = instances_rx.borrow_and_update().clone(); let worker_ids: Vec = instances .iter() .map(|instance| instance.instance_id) .collect(); slots_clone.update_workers(worker_ids); } Ok(false) => { // No changes, continue. This is the happy path. } Err(_) => { tracing::warn!("endpoint watch sender shutdown"); break; } } // Then, wait for a new request let Some(mut request) = request_rx.recv().await else { tracing::warn!("scheduler shutdown"); break; }; tracing::trace!("received request to be scheduled"); let (decode_blocks, prefill_tokens) = slots_clone .potential_blocks_and_tokens( request.token_seq.clone(), request.isl_tokens, request.overlaps.clone(), ) .await; request.decode_blocks = decode_blocks; request.prefill_tokens = prefill_tokens; match selector.select_worker(&instances, &request, block_size) { Ok(selection) => { if let Err(e) = event_tx.send(KVHitRateEvent { worker_id: selection.worker_id, isl_blocks: selection.required_blocks as usize, overlap_blocks: selection.overlap_blocks, }) { tracing::warn!("Failed to send KV hit rate event: {:?}", e); } let response = SchedulingResponse { best_worker_id: selection.worker_id, overlap_blocks: selection.overlap_blocks, }; request.respond(response); let _ = slots_clone .add_request( request.request_id, request.token_seq, request.isl_tokens, selection.overlap_blocks, selection.worker_id, ) .await; continue; } Err(KvSchedulerError::NoEndpoints) => { tracing::trace!("no endpoints available; waiting for endpoints update"); tokio::time::sleep(Duration::from_millis(5)).await; continue; } // TODO: this is not actually hooked up Err(KvSchedulerError::AllWorkersBusy) => { tracing::trace!("all workers busy; waiting for more capacity"); tokio::time::sleep(Duration::from_millis(5)).await; continue; } Err(e) => { tracing::error!("error scheduling request: {:?}", e); break; } } } tracing::trace!("background endpoint subscriber shutting down"); }); Ok(KvScheduler { request_tx, slots }) } pub async fn schedule( &self, request_id: String, isl_tokens: usize, token_seq: Vec, overlaps: OverlapScores, ) -> Result { let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let request = SchedulingRequest { request_id, token_seq, isl_tokens, overlaps, decode_blocks: HashMap::new(), prefill_tokens: HashMap::new(), resp_tx: Some(resp_tx), // Wrap in Some() }; self.request_tx .send(request) .await .map_err(|_| KvSchedulerError::SubscriberShutdown)?; let response = resp_rx .await .map_err(|_| KvSchedulerError::SubscriberShutdown)?; let best_worker_id = response.best_worker_id; Ok(best_worker_id) } pub async fn mark_prefill_completed(&self, request_id: &str) { let _ = self .slots .mark_prefill_completed(&request_id.to_string()) .await; } pub async fn free(&self, request_id: &str) { let _ = self.slots.free(&request_id.to_string()).await; } } // Helper function for softmax sampling fn softmax_sample(logits: &HashMap, temperature: f64) -> i64 { if logits.is_empty() { panic!("Empty logits for softmax sampling"); } // Guard: if temperature is 0, return the key with the smallest logit value if temperature == 0.0 { // Find the minimum logit value let min_logit = logits.values().fold(f64::INFINITY, |a, &b| a.min(b)); // Collect all keys with the minimum logit value (to handle ties) let min_keys: Vec<_> = logits .iter() .filter(|(_, &v)| v == min_logit) .map(|(k, _)| *k) .collect(); // Randomly select from the minimum keys (handles single key case naturally) let mut rng = rand::rng(); let index = rng.random_range(0..min_keys.len()); return min_keys[index]; } let keys: Vec<_> = logits.keys().copied().collect(); let values: Vec<_> = logits.values().copied().collect(); // Find min and max for normalization let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b)); let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)); let probabilities = if min_val == max_val { // All values are the same, uniform probability vec![1.0 / keys.len() as f64; keys.len()] } else { // Normalize values let normalized: Vec<_> = values .iter() .map(|&v| { // Lower is better, so negate // Note we don't need to do actual min-max norm here, just off by an offset let norm = v / (max_val - min_val); -norm }) .collect(); // Apply temperature and softmax let scaled: Vec<_> = normalized.iter().map(|&v| v / temperature).collect(); let max_scaled = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)); let exp_values: Vec<_> = scaled.iter().map(|&v| (v - max_scaled).exp()).collect(); let sum_exp: f64 = exp_values.iter().sum(); exp_values.iter().map(|&v| v / sum_exp).collect() }; // Sample from the probability distribution let mut rng = rand::rng(); let sample: f64 = rng.random(); let mut cumsum = 0.0; for (i, &prob) in probabilities.iter().enumerate() { cumsum += prob; if sample <= cumsum { return keys[i]; } } // Fallback to last key (shouldn't normally reach here) keys[keys.len() - 1] } // Default implementation matching the Python _cost_function #[derive(Debug, Clone, Default)] pub struct DefaultWorkerSelector { pub kv_router_config: KvRouterConfig, } impl DefaultWorkerSelector { pub fn new(kv_router_config: Option) -> Self { Self { kv_router_config: kv_router_config.unwrap_or_default(), } } } impl WorkerSelector for DefaultWorkerSelector { fn select_worker( &self, workers: &[Instance], request: &SchedulingRequest, block_size: u32, ) -> Result { assert!(request.isl_tokens > 0); if workers.is_empty() { return Err(KvSchedulerError::NoEndpoints); } let isl = request.isl_tokens; let request_blocks = isl.div_ceil(block_size as usize); let overlaps = &request.overlaps.scores; let decode_blocks = &request.decode_blocks; let prefill_tokens = &request.prefill_tokens; let mut worker_logits = HashMap::new(); let mut max_logit = f64::NEG_INFINITY; // Calculate logits for each worker for instance in workers.iter() { let worker_id = instance.instance_id; let overlap = *overlaps.get(&worker_id).unwrap_or(&0); // this is the number of prefill tokens the worker would have if the request were scheduled there let prefill_token = *prefill_tokens.get(&worker_id).unwrap_or(&isl); let potential_prefill_block = (prefill_token as f64) / (block_size as f64); // this is the number of decode blocks the worker would have if the request were scheduled there let decode_block = *decode_blocks .get(&worker_id) .unwrap_or(&(potential_prefill_block.floor() as usize)) as f64; // Calculate logit (lower is better) let logit = self.kv_router_config.overlap_score_weight * potential_prefill_block + decode_block; max_logit = max_logit.max(logit); worker_logits.insert(worker_id, logit); let overlap_weight = self.kv_router_config.overlap_score_weight; tracing::info!( "Formula for {worker_id} with {overlap} cached blocks: {logit:.3} \ = {overlap_weight:.1} * prefill_blocks + decode_blocks \ = {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}" ); } // Use softmax sampling to select worker let temperature = self.kv_router_config.router_temperature; let best_worker_id = softmax_sample(&worker_logits, temperature); let best_logit = worker_logits[&best_worker_id]; tracing::info!( "Selected worker: {}, logit: {:.3}", best_worker_id, best_logit ); Ok(WorkerSelectionResult { worker_id: best_worker_id, required_blocks: request_blocks as u64, overlap_blocks: overlaps.get(&best_worker_id).copied().unwrap_or(0), }) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_softmax_sample_single_key() { // Test that with a single key, softmax_sample always returns that key let mut logits = HashMap::new(); let worker_id = 42; logits.insert(worker_id, 0.5); // The value doesn't matter // Test with different temperatures for temperature in &[0.1, 1.0, 10.0] { let result = softmax_sample(&logits, *temperature); assert_eq!(result, worker_id, "Should return the only available worker"); } // Test with different logit values logits.clear(); logits.insert(worker_id, -100.0); // Very negative value assert_eq!(softmax_sample(&logits, 1.0), worker_id); logits.clear(); logits.insert(worker_id, 100.0); // Very positive value assert_eq!(softmax_sample(&logits, 1.0), worker_id); logits.clear(); logits.insert(worker_id, 0.0); // Zero value assert_eq!(softmax_sample(&logits, 1.0), worker_id); } #[test] fn test_softmax_sample_zero_temperature() { // Test that with temperature 0, softmax_sample returns the key with smallest logit let mut logits = HashMap::new(); logits.insert(1, 5.0); logits.insert(2, 3.0); // This has the smallest logit logits.insert(3, 7.0); logits.insert(4, 3.5); // With temperature 0, should always return worker 2 (smallest logit) for _ in 0..10 { let result = softmax_sample(&logits, 0.0); assert_eq!( result, 2, "Should return worker with smallest logit when temperature is 0" ); } // Test with negative values logits.clear(); logits.insert(10, -1.0); logits.insert(20, -5.0); // This has the smallest logit logits.insert(30, 0.0); let result = softmax_sample(&logits, 0.0); assert_eq!(result, 20, "Should handle negative logits correctly"); } }