Unverified Commit e5850e23 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(kv-router): add ActiveSequences benchmark and extract common bench utils (#6633)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent b302ec41
......@@ -3,10 +3,14 @@
//! Shared test utilities for radix tree tests.
use std::future;
use crate::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash, RouterEvent, WorkerId,
ActiveLoad, ActiveSequenceEvent, ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData,
KvCacheRemoveData, KvCacheStoreData, KvCacheStoredBlockData, LocalBlockHash, RouterEvent,
WorkerConfigLike, WorkerId, WorkerWithDpRank,
};
use crate::sequences::SequencePublisher;
/// Creates blocks with artificial hash mapping (hash * 100) for testing.
pub fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> {
......@@ -61,3 +65,51 @@ pub fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec<u64>)
},
}
}
/// No-op [`SequencePublisher`] for tests and benchmarks that don't need event transport.
pub struct NoopSequencePublisher;
impl SequencePublisher for NoopSequencePublisher {
fn publish_event(
&self,
_event: &ActiveSequenceEvent,
) -> impl future::Future<Output = anyhow::Result<()>> + Send {
future::ready(Ok(()))
}
fn publish_load(&self, _load: ActiveLoad) {}
fn observe_load(&self, _: &WorkerWithDpRank, _: &str, _: usize, _: usize) {}
}
/// Minimal [`WorkerConfigLike`] for scheduler/queue tests and benchmarks.
#[derive(Debug, Clone)]
pub struct SimpleWorkerConfig {
pub data_parallel_size: u32,
pub max_num_batched_tokens: Option<u64>,
pub total_kv_blocks: Option<u64>,
}
impl Default for SimpleWorkerConfig {
fn default() -> Self {
Self {
data_parallel_size: 1,
max_num_batched_tokens: None,
total_kv_blocks: None,
}
}
}
impl WorkerConfigLike for SimpleWorkerConfig {
fn data_parallel_size(&self) -> u32 {
self.data_parallel_size
}
fn max_num_batched_tokens(&self) -> Option<u64> {
self.max_num_batched_tokens
}
fn total_kv_blocks(&self) -> Option<u64> {
self.total_kv_blocks
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
......@@ -27,6 +26,8 @@ use validator::Validate;
pub use dynamo_kv_router::approx;
pub use dynamo_kv_router::indexer;
pub use dynamo_kv_router::protocols;
pub use dynamo_kv_router::scheduling;
pub use dynamo_kv_router::selector;
pub mod cache_control;
pub mod config;
......@@ -56,10 +57,10 @@ use crate::{
indexer::{GetWorkersRequest, KvIndexer, KvIndexerInterface, KvRouterError},
protocols::{
BlockExtraInfo, DpRank, LocalBlockHash, OverlapScores, RouterEvent, RouterRequest,
RouterResponse, TokensWithHashes, WorkerId, WorkerSelectionResult, WorkerWithDpRank,
RouterResponse, TokensWithHashes, WorkerId, WorkerWithDpRank,
compute_block_hash_for_seq,
},
scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
scheduler::{KvScheduler, PotentialLoad},
sequence::{SequenceError, SequenceRequest},
},
local_model::runtime_config::ModelRuntimeConfig,
......@@ -118,15 +119,9 @@ pub fn router_discovery_query(namespace: String, component: String) -> Discovery
}
}
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
fn select_worker(
&self,
workers: &HashMap<protocols::WorkerId, ModelRuntimeConfig>,
request: &SchedulingRequest,
block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
/// Concrete `WorkerSelector` bound to the runtime config type.
pub type WorkerSelector =
dyn dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync;
#[derive(Clone)]
pub enum Indexer {
......@@ -297,7 +292,7 @@ impl KvRouter {
client: Client,
mut workers_with_configs: RuntimeConfigWatch,
block_size: u32,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
selector: Option<Box<WorkerSelector>>,
kv_router_config: Option<KvRouterConfig>,
worker_type: &'static str,
) -> Result<Self> {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use derive_builder::Builder;
use rand::Rng;
use serde::{Deserialize, Serialize};
use validator::{Validate, ValidationError};
use crate::kv_router::protocols::{compute_block_hash_for_seq, compute_seq_hash_for_block};
/// Override configuration for router settings that can be specified per-request
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize, Validate)]
pub struct RouterConfigOverride {
#[builder(default)]
pub overlap_score_weight: Option<f64>,
#[builder(default)]
#[validate(range(min = 0.0))]
pub router_temperature: Option<f64>,
#[builder(default)]
pub assume_kv_reuse: Option<bool>,
}
/// KV Router configuration parameters
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Validate)]
#[validate(schema(function = "validate_kv_router_config"))]
pub struct KvRouterConfig {
#[validate(range(min = 0.0))]
pub overlap_score_weight: f64,
#[validate(range(min = 0.0))]
pub router_temperature: f64,
pub use_kv_events: bool,
/// **Deprecated:** Enable durable KV events using NATS JetStream instead of the default event plane.
/// This option will be removed in a future release. The event-plane subscriber
/// (local_indexer mode) is now the recommended path.
pub durable_kv_events: bool,
pub router_replica_sync: bool,
/// Whether to track active blocks in the router (default: true)
pub router_track_active_blocks: bool,
/// Whether to track output blocks during generation (default: false)
/// When enabled, the router adds placeholder blocks as tokens are generated
/// and applies fractional decay based on progress toward agent_hints.osl.
pub router_track_output_blocks: bool,
/// Whether to assume KV cache reuse when tracking active blocks (default: true).
/// When true, computes actual block hashes for sequence tracking.
/// When false, generates random hashes (assuming no KV cache reuse).
pub router_assume_kv_reuse: bool,
/// Threshold for triggering snapshots. If None, no snapshots will be performed.
#[validate(range(min = 1))]
pub router_snapshot_threshold: Option<u32>,
/// Whether to reset the router state on startup (default: false)
pub router_reset_states: bool,
/// TTL for blocks in seconds (only used when use_kv_events is false, default: 120.0)
#[validate(range(min = 0.0))]
pub router_ttl_secs: f64,
/// Maximum tree size before pruning (only used when use_kv_events is false, default: 2^20 = 1048576)
#[validate(range(min = 1))]
pub router_max_tree_size: usize,
/// Target size ratio after pruning (only used when use_kv_events is false, default: 0.8)
#[validate(range(min = 0.0, max = 1.0))]
pub router_prune_target_ratio: f64,
/// Queue threshold fraction for prefill token capacity.
/// When set, requests are queued if all workers exceed this fraction of max_num_batched_tokens.
/// If None (default), queueing is disabled and all requests go directly to ready.
/// Must be > 0.
#[validate(range(min = 0.0))]
pub router_queue_threshold: Option<f64>,
/// Number of event processing threads for the KV indexer.
/// When > 1, uses ConcurrentRadixTree with a thread pool instead of the
/// single-threaded RadixTree. Default: 4.
#[validate(range(min = 1))]
pub router_event_threads: u32,
/// Enable cache control (PIN with TTL) via the worker's cache_control service mesh endpoint.
/// When true, the router creates a cache_control client and honors nvext.cache_control on
/// requests, firing a pin_prefix call (with TTL) to the worker after generation completes.
/// When false (default), cache_control is ignored and no cache_control client is created.
pub router_enable_cache_control: bool,
}
impl Default for KvRouterConfig {
fn default() -> Self {
Self {
overlap_score_weight: 1.0,
router_temperature: 0.0,
use_kv_events: true,
durable_kv_events: false, // default to NATS Core (local indexer mode)
router_replica_sync: false,
router_track_active_blocks: true,
router_track_output_blocks: false,
router_assume_kv_reuse: true,
router_snapshot_threshold: Some(1000000),
router_reset_states: false,
router_ttl_secs: 120.0,
router_max_tree_size: 2usize.pow(20), // 2^20 = 1048576, matches PruneConfig::default()
router_prune_target_ratio: 0.8,
router_queue_threshold: None,
router_event_threads: 4,
router_enable_cache_control: false,
}
}
}
fn validate_kv_router_config(config: &KvRouterConfig) -> Result<(), ValidationError> {
if config.durable_kv_events {
tracing::warn!(
"--durable-kv-events is deprecated and will be removed in a future release. \
The event-plane subscriber (local_indexer mode) is now the recommended path."
);
}
if config.durable_kv_events && !config.use_kv_events {
return Err(ValidationError::new(
"durable_kv_events requires use_kv_events=true",
));
}
if config.router_track_output_blocks && !config.router_track_active_blocks {
return Err(ValidationError::new(
"router_track_output_blocks requires router_track_active_blocks=true",
));
}
Ok(())
}
impl KvRouterConfig {
/// Compute sequence hashes for active block tracking based on configuration.
///
/// Returns:
/// - `None` if `router_track_active_blocks` is false
/// - Random hashes if `router_track_active_blocks` is true but `router_assume_kv_reuse` is false
/// - Actual sequence hashes if both are true
pub fn compute_seq_hashes_for_tracking(
&self,
tokens: &[u32],
block_size: u32,
config_override: Option<&RouterConfigOverride>,
lora_name: Option<&str>,
) -> Option<Vec<u64>> {
if !self.router_track_active_blocks {
return None;
}
let num_blocks = tokens.len() / block_size as usize;
if num_blocks == 0 {
return Some(Vec::new());
}
let assume_kv_reuse = config_override
.and_then(|cfg| cfg.assume_kv_reuse)
.unwrap_or(self.router_assume_kv_reuse);
if assume_kv_reuse {
let block_hashes = compute_block_hash_for_seq(tokens, block_size, None, lora_name);
Some(compute_seq_hash_for_block(&block_hashes))
} else {
let mut rng = rand::rng();
Some((0..num_blocks).map(|_| rng.random::<u64>()).collect())
}
}
/// Check if KV event subscription should be started.
///
/// Returns false if:
/// - KV events are disabled (`use_kv_events=false`)
/// - Overlap scoring is disabled (`overlap_score_weight=0`)
///
/// When false, the router skips starting the KV event subscription entirely,
/// avoiding the need to query workers for their local indexer state.
pub fn should_subscribe_to_kv_events(&self) -> bool {
self.use_kv_events && self.overlap_score_weight > 0.0
}
}
pub use dynamo_kv_router::config::{KvRouterConfig, RouterConfigOverride};
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::sync::Arc;
use std::time::{Duration, Instant};
pub use dynamo_kv_router::queue::DEFAULT_MAX_BATCHED_TOKENS;
use tokio::sync::Mutex;
use crate::kv_router::sequence::RuntimeSequencePublisher;
use crate::local_model::runtime_config::ModelRuntimeConfig;
use super::WorkerSelector;
use super::protocols::WorkerWithDpRank;
use super::scheduler::{SchedulingRequest, SchedulingResponse};
use super::sequence::{ActiveSequencesMulti, SequenceRequest};
use crate::discovery::RuntimeConfigWatch;
/// Large default for max_num_batched_tokens when not configured (effectively disables queueing for that worker)
const DEFAULT_MAX_BATCHED_TOKENS: u64 = 10_000_000;
/// Entry in the priority queue, ordered by effective arrival time (lower = higher priority).
/// Effective arrival = elapsed time since queue start minus `priority_jump`.
struct QueueEntry {
effective_offset: Duration,
request: SchedulingRequest,
}
impl Eq for QueueEntry {}
impl PartialEq for QueueEntry {
fn eq(&self, other: &Self) -> bool {
self.effective_offset == other.effective_offset
}
}
impl Ord for QueueEntry {
fn cmp(&self, other: &Self) -> Ordering {
// BinaryHeap is a max-heap; reverse so lower effective_offset = higher priority
other.effective_offset.cmp(&self.effective_offset)
}
}
impl PartialOrd for QueueEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
/// Queue that gates scheduling requests behind a capacity check.
/// When all workers exceed `threshold_frac` utilisation the request is parked in `pending`.
/// When capacity frees up (`update()`), pending requests are scheduled in priority order.
/// If queueing is disabled (threshold_frac is None), requests are scheduled immediately.
pub struct SchedulerQueue {
pending: Mutex<BinaryHeap<QueueEntry>>,
slots: Arc<ActiveSequencesMulti>,
workers_with_configs: RuntimeConfigWatch,
/// Cached threshold fraction; None means queueing is disabled.
threshold_frac: Option<f64>,
/// Reference instant for computing arrival offsets.
start_time: Instant,
block_size: u32,
selector: Box<dyn WorkerSelector + Send + Sync>,
}
impl SchedulerQueue {
pub fn new(
slots: Arc<ActiveSequencesMulti>,
workers_with_configs: RuntimeConfigWatch,
threshold_frac: Option<f64>,
block_size: u32,
selector: Box<dyn WorkerSelector + Send + Sync>,
) -> Self {
if let Some(frac) = threshold_frac {
tracing::info!("Router queue enabled with threshold fraction {frac}");
}
Self {
pending: Mutex::new(BinaryHeap::new()),
slots,
workers_with_configs,
threshold_frac,
start_time: Instant::now(),
block_size,
selector,
}
}
/// Build a QueueEntry for a request, computing its effective arrival offset.
fn make_entry(&self, request: SchedulingRequest) -> QueueEntry {
let arrival_offset = self.start_time.elapsed();
let jump = Duration::from_secs_f64(request.priority_jump.max(0.0));
let effective_offset = arrival_offset.saturating_sub(jump);
QueueEntry {
effective_offset,
request,
}
}
/// Enqueue a new request.
/// If queueing is disabled or workers have capacity, schedule immediately.
/// Otherwise park in the pending heap.
pub async fn enqueue(&self, request: SchedulingRequest) {
let Some(threshold) = self.threshold_frac else {
self.schedule(request).await;
return;
};
if self.all_workers_busy(threshold) {
tracing::debug!("all workers busy, queueing request");
let entry = self.make_entry(request);
self.pending.lock().await.push(entry);
} else {
self.schedule(request).await;
}
}
/// Called on prefill_complete/free. Drains pending requests while workers have capacity.
/// Each scheduled request updates active_tokens via add_request, so the busy check
/// sees fresh state on the next iteration.
pub async fn update(&self) {
let Some(threshold) = self.threshold_frac else {
return;
};
loop {
if self.all_workers_busy(threshold) {
break;
}
let Some(entry) = self.pending.lock().await.pop() else {
break;
};
tracing::debug!("scheduling request from pending queue");
self.schedule(entry.request).await;
}
}
/// Run the full scheduling pipeline for a single request:
/// compute potential load → select worker → respond → book via add_request.
async fn schedule(&self, mut request: SchedulingRequest) {
let (decode_blocks, prefill_tokens) = self.slots.potential_blocks_and_tokens(
request.token_seq.clone(),
request.isl_tokens,
request.overlaps.clone(),
);
request.decode_blocks = decode_blocks;
request.prefill_tokens = prefill_tokens;
let selection = {
let workers = self.workers_with_configs.borrow();
self.selector
.select_worker(&workers, &request, self.block_size)
};
let selection = match selection {
Ok(s) => s,
Err(e) => {
tracing::warn!("scheduling failed: {e}");
request.respond(Err(e));
return;
}
};
request.respond(Ok(SchedulingResponse {
best_worker: selection.worker,
overlap_blocks: selection.overlap_blocks,
}));
if !request.update_states {
return;
}
let Some(request_id) = request.maybe_request_id else {
tracing::error!("No request_id provided to add_request to the slot tracker");
return;
};
if let Err(e) = self
.slots
.add_request(SequenceRequest {
request_id: request_id.clone(),
token_sequence: request.token_seq,
isl: request.isl_tokens,
overlap: selection.overlap_blocks,
expected_output_tokens: None,
worker: selection.worker,
lora_name: request.lora_name.clone(),
})
.await
{
tracing::warn!("Failed to add request {request_id}: {e}");
}
}
/// Check if all workers are busy based on threshold.
/// Returns true only if ALL workers exceed the threshold (no worker has capacity).
fn all_workers_busy(&self, threshold: f64) -> bool {
let active_tokens = self.slots.active_tokens();
let configs = self.workers_with_configs.borrow();
for (&worker_id, config) in configs.iter() {
let dp_size = config.data_parallel_size;
let max_batched = config
.max_num_batched_tokens
.unwrap_or(DEFAULT_MAX_BATCHED_TOKENS);
for dp_rank in 0..dp_size {
let worker = WorkerWithDpRank::new(worker_id, dp_rank);
let tokens = active_tokens.get(&worker).copied().unwrap_or(0);
if (tokens as f64) <= threshold * (max_batched as f64) {
return false;
}
}
}
true
}
}
/// Concrete `SchedulerQueue` wired to the runtime publisher and config types.
pub type SchedulerQueue =
dynamo_kv_router::queue::SchedulerQueue<RuntimeSequencePublisher, ModelRuntimeConfig>;
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub use dynamo_kv_router::scheduling::{
KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse,
};
pub use dynamo_kv_router::selector::DefaultWorkerSelector;
use super::KvRouterConfig;
use super::RouterConfigOverride;
use super::WorkerSelector;
use super::protocols::{DpRank, OverlapScores, WorkerId, WorkerSelectionResult, WorkerWithDpRank};
use super::protocols::{OverlapScores, WorkerId};
use super::queue::SchedulerQueue;
use super::sequence::{
ActiveSequencesMulti, SequenceError, SequenceRequest, create_multi_worker_sequences,
......@@ -14,8 +19,6 @@ use crate::local_model::runtime_config::ModelRuntimeConfig;
use anyhow::Result;
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
......@@ -24,64 +27,6 @@ use std::time::Instant;
use dynamo_tokens::SequenceHash;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PotentialLoad {
pub worker_id: WorkerId,
pub dp_rank: DpRank,
pub potential_prefill_tokens: usize,
pub potential_decode_blocks: usize,
}
#[derive(Debug, thiserror::Error)]
pub enum KvSchedulerError {
#[error("no endpoints available to route work")]
NoEndpoints,
#[error("endpoint subscriber shutdown")]
SubscriberShutdown,
#[error("failed to initialize event publisher: {0}")]
InitFailed(String),
}
#[derive(Debug)]
pub struct SchedulingResponse {
pub best_worker: WorkerWithDpRank,
pub overlap_blocks: u32,
}
pub struct SchedulingRequest {
pub maybe_request_id: Option<String>,
pub token_seq: Option<Vec<SequenceHash>>,
pub isl_tokens: usize,
pub overlaps: OverlapScores,
pub decode_blocks: HashMap<WorkerWithDpRank, usize>,
pub prefill_tokens: HashMap<WorkerWithDpRank, usize>,
// Router config overrides for this specific request
pub router_config_override: Option<RouterConfigOverride>,
// Whether to update scheduler states (false for query_instance_id requests)
pub update_states: bool,
// LORA adapter name extracted from request.model field
pub lora_name: Option<String>,
/// Priority jump in seconds; decreases effective arrival time in the queue.
pub priority_jump: f64,
/// Optional set of allowed worker IDs to restrict routing decisions (EPP).
pub allowed_worker_ids: Option<HashSet<WorkerId>>,
resp_tx: Option<tokio::sync::oneshot::Sender<Result<SchedulingResponse, KvSchedulerError>>>,
}
impl SchedulingRequest {
pub fn respond(&mut self, result: Result<SchedulingResponse, KvSchedulerError>) {
let Some(tx) = self.resp_tx.take() else {
tracing::error!("respond called multiple times on same request");
return;
};
if tx.send(result).is_err() {
tracing::error!("failed to send response to requestor");
}
}
}
pub struct KvScheduler {
request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
slots: Arc<ActiveSequencesMulti>,
......@@ -93,7 +38,7 @@ impl KvScheduler {
component: Component,
block_size: u32,
workers_with_configs: RuntimeConfigWatch,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
selector: Option<Box<WorkerSelector>>,
kv_router_config: &KvRouterConfig,
worker_type: &'static str,
) -> Result<Self, KvSchedulerError> {
......@@ -145,7 +90,7 @@ impl KvScheduler {
.iter()
.map(|(&id, c)| (id, c.data_parallel_size))
.collect();
slots_monitor.update_workers(dp_sizes);
slots_monitor.update_workers(&dp_sizes);
last_workers = current_workers;
}
}
......@@ -295,12 +240,12 @@ impl KvScheduler {
isl_tokens: usize,
overlaps: OverlapScores,
) -> Vec<PotentialLoad> {
let (decode_blocks, prefill_tokens) = self
.slots
.potential_blocks_and_tokens(token_seq, isl_tokens, overlaps);
let (decode_blocks, prefill_tokens) =
self.slots
.potential_blocks_and_tokens(token_seq.as_deref(), isl_tokens, overlaps);
// Get all unique WorkerWithDpRank from both hashmaps
let mut workers: HashSet<WorkerWithDpRank> = HashSet::new();
let mut workers: HashSet<dynamo_kv_router::protocols::WorkerWithDpRank> = HashSet::new();
workers.extend(decode_blocks.keys().copied());
workers.extend(prefill_tokens.keys().copied());
......@@ -326,314 +271,3 @@ impl KvScheduler {
self.slots.get_active_lora_counts()
}
}
// Helper function for softmax sampling
// Returns a vec of workers: multiple if tied, single if sampled
fn softmax_sample(
logits: &HashMap<WorkerWithDpRank, f64>,
temperature: f64,
) -> Vec<WorkerWithDpRank> {
if logits.is_empty() {
panic!("Empty logits for softmax sampling");
}
// Guard: if temperature is 0, return all keys with the smallest logit value (ties)
if temperature == 0.0 {
// Find the minimum logit value
let min_logit = logits.values().fold(f64::INFINITY, |a, &b| a.min(b));
// Collect all keys with the minimum logit value (to handle ties)
let min_keys: Vec<_> = logits
.iter()
.filter(|&(_, &v)| v == min_logit)
.map(|(k, _)| *k)
.collect();
return min_keys;
}
let keys: Vec<_> = logits.keys().copied().collect();
let values: Vec<_> = logits.values().copied().collect();
// Find min and max for normalization
let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let probabilities = if min_val == max_val {
// All values are the same, uniform probability
vec![1.0 / keys.len() as f64; keys.len()]
} else {
// Fused normalize → negate → scale → exp, then normalize probabilities
let range = max_val - min_val;
let scaled: Vec<f64> = values.iter().map(|&v| -(v / range) / temperature).collect();
let max_scaled = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let mut probs: Vec<f64> = scaled.iter().map(|&v| (v - max_scaled).exp()).collect();
let sum: f64 = probs.iter().sum();
probs.iter_mut().for_each(|p| *p /= sum);
probs
};
// Sample from the probability distribution
let mut rng = rand::rng();
let sample: f64 = rng.random();
let mut cumsum = 0.0;
for (i, &prob) in probabilities.iter().enumerate() {
cumsum += prob;
if sample <= cumsum {
return vec![keys[i]];
}
}
// Fallback to last key (shouldn't normally reach here)
vec![keys[keys.len() - 1]]
}
// Default implementation matching the Python _cost_function
#[derive(Debug, Clone, Default)]
pub struct DefaultWorkerSelector {
pub kv_router_config: KvRouterConfig,
}
impl DefaultWorkerSelector {
pub fn new(kv_router_config: Option<KvRouterConfig>) -> Self {
Self {
kv_router_config: kv_router_config.unwrap_or_default(),
}
}
}
impl WorkerSelector for DefaultWorkerSelector {
fn select_worker(
&self,
workers: &HashMap<WorkerId, ModelRuntimeConfig>,
request: &SchedulingRequest,
block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError> {
assert!(request.isl_tokens > 0);
let allowed_ids = request.allowed_worker_ids.as_ref();
if allowed_ids.map_or(workers.is_empty(), |ids| {
!workers.keys().any(|wid| ids.contains(wid))
}) {
return Err(KvSchedulerError::NoEndpoints);
}
let isl = request.isl_tokens;
let request_blocks = isl.div_ceil(block_size as usize);
let overlaps = &request.overlaps.scores;
let decode_blocks = &request.decode_blocks;
let prefill_tokens = &request.prefill_tokens;
let mut worker_logits = HashMap::new();
// Use override if provided, otherwise use default config
let overlap_weight = request
.router_config_override
.as_ref()
.and_then(|cfg| cfg.overlap_score_weight)
.unwrap_or(self.kv_router_config.overlap_score_weight);
for (worker_id, config) in workers
.iter()
.filter(|(wid, _)| allowed_ids.is_none_or(|ids| ids.contains(wid)))
{
let data_parallel_size = config.data_parallel_size;
for dp_rank in 0..data_parallel_size {
let worker = WorkerWithDpRank::new(*worker_id, dp_rank);
// Get overlap for this worker (defaults to 0 if not in overlaps)
let overlap = *overlaps.get(&worker).unwrap_or(&0);
// this is the number of prefill tokens the worker would have if the request were scheduled there
let prefill_token = *prefill_tokens.get(&worker).unwrap_or(&isl);
let potential_prefill_block = (prefill_token as f64) / (block_size as f64);
// this is the number of decode blocks the worker would have if the request were scheduled there
let decode_block = *decode_blocks
.get(&worker)
.unwrap_or(&(potential_prefill_block.floor() as usize))
as f64;
// Calculate logit (lower is better)
let logit = overlap_weight * potential_prefill_block + decode_block;
worker_logits.insert(worker, logit);
tracing::info!(
"Formula for worker_id={} dp_rank={:?} with {overlap} cached blocks: {logit:.3} \
= {overlap_weight:.1} * prefill_blocks + decode_blocks \
= {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}",
worker.worker_id,
worker.dp_rank
);
}
}
// Use softmax sampling to select worker(s)
// Use override if provided, otherwise use default config
let temperature = request
.router_config_override
.as_ref()
.and_then(|cfg| cfg.router_temperature)
.unwrap_or(self.kv_router_config.router_temperature);
let candidates = softmax_sample(&worker_logits, temperature);
// If multiple candidates (tied), use tree size as tie-breaker
// If tree sizes are also equal, use random selection to avoid bias
let best_worker = if candidates.len() > 1 {
tracing::info!("Multiple workers tied with same logit, using tree size as tie-breaker");
let tree_sizes: Vec<(usize, &WorkerWithDpRank)> = candidates
.iter()
.map(|w| (request.overlaps.tree_sizes.get(w).copied().unwrap_or(0), w))
.collect();
if tree_sizes.iter().all(|(s, _)| *s == tree_sizes[0].0) {
let idx = rand::rng().random_range(0..candidates.len());
candidates[idx]
} else {
*tree_sizes.iter().min_by_key(|(s, _)| *s).unwrap().1
}
} else {
candidates[0]
};
let best_logit = worker_logits[&best_worker];
let best_overlap = *overlaps.get(&best_worker).unwrap_or(&0);
// this is a runtime config set on a per worker basis, not per dp-rank
let total_blocks_info = workers
.get(&best_worker.worker_id)
.and_then(|cfg| cfg.total_kv_blocks)
.map(|blocks| format!(", total blocks: {}", blocks))
.unwrap_or_default();
let tree_size = request
.overlaps
.tree_sizes
.get(&best_worker)
.copied()
.unwrap_or(0);
tracing::info!(
"Selected worker: worker_id={} dp_rank={:?}, logit: {:.3}, cached blocks: {}, tree size: {}{}",
best_worker.worker_id,
best_worker.dp_rank,
best_logit,
best_overlap,
tree_size,
total_blocks_info
);
Ok(WorkerSelectionResult {
worker: best_worker,
required_blocks: request_blocks as u64,
overlap_blocks: overlaps.get(&best_worker).copied().unwrap_or(0),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_softmax_sample_single_key() {
// Test that with a single key, softmax_sample always returns that key
let mut logits = HashMap::new();
let worker = WorkerWithDpRank::from_worker_id(42);
logits.insert(worker, 0.5); // The value doesn't matter
// Test with different temperatures
for temperature in &[0.1, 1.0, 10.0] {
let result = softmax_sample(&logits, *temperature);
assert_eq!(result.len(), 1, "Should return exactly one worker");
assert_eq!(result[0], worker, "Should return the only available worker");
}
// Test with different logit values
logits.clear();
logits.insert(worker, -100.0); // Very negative value
let result = softmax_sample(&logits, 1.0);
assert_eq!(result.len(), 1);
assert_eq!(result[0], worker);
logits.clear();
logits.insert(worker, 100.0); // Very positive value
let result = softmax_sample(&logits, 1.0);
assert_eq!(result.len(), 1);
assert_eq!(result[0], worker);
logits.clear();
logits.insert(worker, 0.0); // Zero value
let result = softmax_sample(&logits, 1.0);
assert_eq!(result.len(), 1);
assert_eq!(result[0], worker);
}
#[test]
fn test_softmax_sample_zero_temperature() {
// Test that with temperature 0, softmax_sample returns all keys with smallest logit
let mut logits = HashMap::new();
let worker1 = WorkerWithDpRank::from_worker_id(1);
let worker2 = WorkerWithDpRank::from_worker_id(2);
let worker3 = WorkerWithDpRank::from_worker_id(3);
let worker4 = WorkerWithDpRank::from_worker_id(4);
logits.insert(worker1, 5.0);
logits.insert(worker2, 3.0); // This has the smallest logit
logits.insert(worker3, 7.0);
logits.insert(worker4, 3.5);
// With temperature 0, should always return only worker2 (smallest logit)
let result = softmax_sample(&logits, 0.0);
assert_eq!(
result.len(),
1,
"Should return one worker when there's no tie"
);
assert_eq!(
result[0], worker2,
"Should return worker with smallest logit when temperature is 0"
);
// Test with tied minimum logits
logits.clear();
let worker5 = WorkerWithDpRank::from_worker_id(5);
let worker6 = WorkerWithDpRank::from_worker_id(6);
logits.insert(worker1, 5.0);
logits.insert(worker2, 3.0); // Tied for smallest
logits.insert(worker5, 3.0); // Tied for smallest
logits.insert(worker6, 7.0);
let result = softmax_sample(&logits, 0.0);
assert_eq!(
result.len(),
2,
"Should return all workers with smallest logit when tied"
);
assert!(
result.contains(&worker2) && result.contains(&worker5),
"Should contain both tied workers"
);
// Test with negative values
logits.clear();
let worker10 = WorkerWithDpRank::from_worker_id(10);
let worker20 = WorkerWithDpRank::from_worker_id(20);
let worker30 = WorkerWithDpRank::from_worker_id(30);
logits.insert(worker10, -1.0);
logits.insert(worker20, -5.0); // This has the smallest logit
logits.insert(worker30, 0.0);
let result = softmax_sample(&logits, 0.0);
assert_eq!(result.len(), 1);
assert_eq!(
result[0], worker20,
"Should handle negative logits correctly"
);
}
}
......@@ -80,6 +80,20 @@ impl Default for ModelRuntimeConfig {
}
}
impl dynamo_kv_router::WorkerConfigLike for ModelRuntimeConfig {
fn data_parallel_size(&self) -> u32 {
self.data_parallel_size
}
fn max_num_batched_tokens(&self) -> Option<u64> {
self.max_num_batched_tokens
}
fn total_kv_blocks(&self) -> Option<u64> {
self.total_kv_blocks
}
}
impl ModelRuntimeConfig {
pub fn new() -> Self {
Self::default()
......
......@@ -214,6 +214,8 @@ pub struct MockVllmEngine {
engine_args: MockEngineArgs,
/// Bootstrap server for prefill workers in disaggregated mode
bootstrap_server: Arc<OnceCell<Arc<BootstrapServer>>>,
/// Keep schedulers alive so their CancelGuards don't fire prematurely.
_schedulers: OnceCell<Vec<Scheduler>>,
}
impl MockVllmEngine {
......@@ -225,6 +227,7 @@ impl MockVllmEngine {
senders_ready: Notify::new(),
engine_args,
bootstrap_server: Arc::new(OnceCell::new()),
_schedulers: OnceCell::new(),
}
}
......@@ -268,6 +271,8 @@ impl MockVllmEngine {
Self::start_metrics_publishing(&schedulers, component, cancel_token.clone()).await?;
let _ = self._schedulers.set(schedulers);
Ok(())
}
......
......@@ -246,11 +246,22 @@ impl SchedulerState {
}
}
/// Cancels its token when dropped. Shared via Arc so the background task is
/// only cancelled when the last Scheduler clone is dropped.
struct CancelGuard(CancellationToken);
impl Drop for CancelGuard {
fn drop(&mut self) {
self.0.cancel();
}
}
/// Manages scheduling of requests using KvManager resources
#[derive(Clone)]
pub struct Scheduler {
request_tx: mpsc::UnboundedSender<DirectRequest>,
metrics_rx: tokio::sync::watch::Receiver<MockerMetrics>,
_cancel_guard: Arc<CancelGuard>,
}
impl Scheduler {
......@@ -273,7 +284,9 @@ impl Scheduler {
let (metrics_tx, metrics_rx) =
tokio::sync::watch::channel::<MockerMetrics>(initial_metrics);
let cancel_token_clone = cancellation_token.unwrap_or_default().clone();
let cancel_token = cancellation_token.unwrap_or_default();
let cancel_token_clone = cancel_token.clone();
let cancel_guard = Arc::new(CancelGuard(cancel_token));
// Spawn main background task with cancellation token
tokio::spawn(async move {
......@@ -330,6 +343,7 @@ impl Scheduler {
Self {
request_tx,
metrics_rx,
_cancel_guard: cancel_guard,
}
}
......@@ -360,13 +374,16 @@ async fn receive_requests(
}
if state.is_empty() {
// Fully idle - block until new request arrives
// Fully idle - block until new request arrives or shutdown
tokio::select! {
biased;
_ = cancel_token.cancelled() => {
return None;
}
Some(request) = request_rx.recv() => {
result = request_rx.recv() => {
let Some(request) = result else {
return None; // channel closed
};
state.receive(request);
return Some(());
}
......
......@@ -123,7 +123,7 @@ sglang_configs = {
marks=[pytest.mark.gpu_2, pytest.mark.post_merge],
model="Qwen/Qwen3-0.6B",
env={
"DYN_LOG": "dynamo_llm::kv_router::publisher=trace,dynamo_llm::kv_router::scheduler=info",
"DYN_LOG": "dynamo_llm::kv_router::publisher=trace,dynamo_kv_router::scheduling::selector=info",
},
frontend_port=DefaultPort.FRONTEND.value,
request_payloads=[
......
......@@ -152,7 +152,7 @@ trtllm_configs = {
)
],
env={
"DYN_LOG": "dynamo_llm::kv_router::publisher=trace,dynamo_llm::kv_router::scheduler=info",
"DYN_LOG": "dynamo_llm::kv_router::publisher=trace,dynamo_kv_router::scheduling::selector=info",
},
),
"disaggregated_router": TRTLLMConfig(
......
......@@ -204,7 +204,7 @@ vllm_configs = {
)
],
env={
"DYN_LOG": "dynamo_llm::kv_router::publisher=trace,dynamo_llm::kv_router::scheduler=info",
"DYN_LOG": "dynamo_llm::kv_router::publisher=trace,dynamo_kv_router::scheduling::selector=info",
},
),
"agg-router-approx": VLLMConfig(
......@@ -235,7 +235,7 @@ vllm_configs = {
),
],
env={
"DYN_LOG": "dynamo_llm::kv_router::scheduler=info",
"DYN_LOG": "dynamo_kv_router::scheduling::selector=info",
},
),
"disaggregated": VLLMConfig(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment