// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use std::sync::{Arc, Mutex}; use anyhow::anyhow; use dynamo_kv_router::indexer::{ KvIndexerInterface, KvIndexerMetrics, LocalKvIndexer, METRIC_EVENT_REMOVED, METRIC_EVENT_STORED, METRIC_STATUS_BLOCK_NOT_FOUND, METRIC_STATUS_INVALID_BLOCK, METRIC_STATUS_OK, METRIC_STATUS_PARENT_NOT_FOUND, METRIC_WARNING_DUPLICATE_STORE, }; use dynamo_kv_router::protocols::{ KvCacheEvent, KvCacheEventData, LocalBlockHash, RouterEvent, WorkerId, WorkerWithDpRank, }; use tokio::sync::mpsc; use tokio::task::JoinHandle; use tokio::time::Duration; use tokio_util::sync::CancellationToken; use super::{DirectRequest, ForwardPassSnapshot, OutputSignal, SchedulerHandle}; use crate::common::protocols::{FpmSink, KvCacheEventSink}; pub(crate) struct RouterIndexerHarness { indexer: Arc, metrics: Arc, worker: WorkerWithDpRank, } impl RouterIndexerHarness { pub(crate) fn new(block_size: u32, worker_id: WorkerId) -> Self { let metrics = Arc::new(KvIndexerMetrics::new_unregistered()); let indexer = Arc::new(LocalKvIndexer::new( CancellationToken::new(), block_size, metrics.clone(), 4096, )); Self { indexer, metrics, worker: WorkerWithDpRank::new(worker_id, 0), } } pub(crate) async fn apply_events(&self, events: I) where I: IntoIterator, { for event in events { self.indexer.apply_event_with_buffer(event).await.unwrap(); } let _ = self.indexer.flush().await; self.assert_no_event_errors(); } pub(crate) async fn overlap_for_hashes(&self, local_hashes: Vec) -> u32 { self.indexer .find_matches(local_hashes) .await .unwrap() .scores .get(&self.worker) .copied() .unwrap_or(0) } pub(crate) fn ok_count(&self, event_type: &'static str) -> u64 { metric_value(&self.metrics, event_type, METRIC_STATUS_OK) } pub(crate) fn status_count(&self, event_type: &'static str, status: &'static str) -> u64 { metric_value(&self.metrics, event_type, status) } pub(crate) fn invalid_counts(&self) -> Vec<(&'static str, &'static str, u64)> { [METRIC_EVENT_STORED, METRIC_EVENT_REMOVED] .into_iter() .flat_map(|event_type| { [ METRIC_STATUS_PARENT_NOT_FOUND, METRIC_STATUS_BLOCK_NOT_FOUND, METRIC_STATUS_INVALID_BLOCK, ] .into_iter() .map(move |status| (event_type, status, self.status_count(event_type, status))) }) .collect() } pub(crate) fn invalid_event_count(&self) -> u64 { self.invalid_counts() .into_iter() .map(|(_, _, count)| count) .sum() } pub(crate) fn warning_count(&self, warning_kind: &'static str) -> u64 { warning_metric_value(&self.metrics, warning_kind) } pub(crate) fn warning_counts(&self) -> Vec<(&'static str, u64)> { [METRIC_WARNING_DUPLICATE_STORE] .into_iter() .map(|warning_kind| (warning_kind, self.warning_count(warning_kind))) .collect() } pub(crate) fn total_warning_count(&self) -> u64 { self.warning_counts() .into_iter() .map(|(_, count)| count) .sum() } pub(crate) fn spawn_forwarder(&self) -> (Arc, JoinHandle<()>) { let (event_tx, mut event_rx) = mpsc::unbounded_channel::(); let sink = Arc::new(TestKvEventSink { worker_id: self.worker.worker_id, event_tx, }); let indexer = self.indexer.clone(); let forwarder = tokio::spawn(async move { while let Some(event) = event_rx.recv().await { indexer.apply_event_with_buffer(event).await.unwrap(); } let _ = indexer.flush().await; }); (sink, forwarder) } pub(crate) async fn flush(&self) { let _ = self.indexer.flush().await; } pub(crate) fn assert_no_event_errors(&self) { let breakdown = self .invalid_counts() .into_iter() .filter(|(_, _, count)| *count > 0) .map(|(event_type, status, count)| format!("{event_type}/{status}={count}")) .collect::>() .join(", "); assert_eq!( self.invalid_event_count(), 0, "router indexer reported invalid KV events{}", if breakdown.is_empty() { String::new() } else { format!(": {breakdown}") } ); } pub(crate) fn assert_no_event_warnings(&self) { let breakdown = self .warning_counts() .into_iter() .filter(|(_, count)| *count > 0) .map(|(warning_kind, count)| format!("{warning_kind}={count}")) .collect::>() .join(", "); assert_eq!( self.total_warning_count(), 0, "router indexer reported suspicious KV events{}", if breakdown.is_empty() { String::new() } else { format!(": {breakdown}") } ); } pub(crate) fn shutdown(&self) { self.indexer.shutdown(); } } #[derive(Clone)] pub(crate) struct TestKvEventSink { worker_id: WorkerId, event_tx: mpsc::UnboundedSender, } impl KvCacheEventSink for TestKvEventSink { fn publish(&self, event: KvCacheEvent) -> anyhow::Result<()> { self.event_tx .send(RouterEvent::new(self.worker_id, event)) .map_err(|_| anyhow!("router test event channel closed")) } } pub(crate) fn metric_value( metrics: &KvIndexerMetrics, event_type: &'static str, status: &'static str, ) -> u64 { metrics .kv_cache_events_applied .get_metric_with_label_values(&[event_type, status]) .unwrap() .get() } pub(crate) fn warning_metric_value(metrics: &KvIndexerMetrics, warning_kind: &'static str) -> u64 { metrics .kv_cache_event_warnings .get_metric_with_label_values(&[warning_kind]) .unwrap() .get() } pub(crate) fn stored_hashes(events: &[RouterEvent]) -> Vec { events .iter() .filter_map(|event| match &event.event.data { KvCacheEventData::Stored(store) => Some( store .blocks .iter() .map(|block| block.tokens_hash) .collect::>(), ), _ => None, }) .flatten() .collect() } pub(crate) fn nth_stored_hashes(events: &[RouterEvent], nth: usize) -> Vec { events .iter() .filter_map(|event| match &event.event.data { KvCacheEventData::Stored(store) => Some( store .blocks .iter() .map(|block| block.tokens_hash) .collect::>(), ), _ => None, }) .nth(nth) .unwrap_or_default() } pub(crate) fn removed_event_count(events: &[RouterEvent]) -> usize { events .iter() .filter(|event| matches!(event.event.data, KvCacheEventData::Removed(_))) .count() } /// Test sink that captures FPM snapshots for assertion. #[derive(Default)] pub(crate) struct CapturingFpmSink { snapshots: Mutex>, } impl FpmSink for CapturingFpmSink { fn publish(&self, snapshot: ForwardPassSnapshot) -> anyhow::Result<()> { self.snapshots.lock().unwrap().push(snapshot); Ok(()) } } impl CapturingFpmSink { pub(crate) fn take(&self) -> Vec { std::mem::take(&mut *self.snapshots.lock().unwrap()) } } /// Send `num_requests` to a scheduler, collect all output signals, and assert /// that the scheduler produces exactly `num_requests * max_output_tokens` signals /// and returns to idle (0 active decode blocks). /// /// When `use_shared_tokens` is true, the first half of each request shares a /// common prefix to exercise prefix caching / radix tree reuse. pub(crate) async fn assert_scheduler_completes_all( scheduler: &dyn SchedulerHandle, output_rx: &mut mpsc::UnboundedReceiver>, num_requests: usize, input_len: usize, max_output_tokens: usize, use_shared_tokens: bool, ) { let shared_tokens = if use_shared_tokens { Some( (0..input_len / 2) .map(|_| rand::random::() % 50000) .collect::>(), ) } else { None }; for _ in 0..num_requests { let input_tokens = if let Some(ref shared) = shared_tokens { let mut tokens = shared.clone(); tokens.extend((0..input_len / 2).map(|_| rand::random::() % 50000)); tokens } else { (0..input_len) .map(|_| rand::random::() % 50000) .collect::>() }; scheduler.receive(DirectRequest { tokens: input_tokens, max_output_tokens, uuid: None, dp_rank: 0, arrival_timestamp_ms: None, }); } let expected_tokens = num_requests * max_output_tokens; let mut received_tokens = 0; let timeout = tokio::time::sleep(Duration::from_millis(200)); tokio::pin!(timeout); loop { tokio::select! { biased; Some(output_batch) = output_rx.recv() => { received_tokens += output_batch.len(); if received_tokens >= expected_tokens { break; } timeout.set(tokio::time::sleep(Duration::from_millis(200))); } _ = &mut timeout => break, } } assert_eq!( received_tokens, expected_tokens, "Expected {expected_tokens} output signals, got {received_tokens}" ); let metrics = scheduler.metrics_receiver().borrow().clone(); assert_eq!( metrics.active_decode_blocks, 0, "Scheduler should be idle after all requests complete, got {} active blocks", metrics.active_decode_blocks ); assert_eq!( metrics.gpu_cache_usage_perc, 0.0, "Scheduler should report zero cache usage after draining, got {}", metrics.gpu_cache_usage_perc ); assert!( metrics.total_blocks > 0, "Scheduler should populate total_blocks, got {}", metrics.total_blocks ); }