// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 //! KV RadixTree //! //! This module implements a key-value (KV) store using a Radix Tree structure to efficiently manage and retrieve data blocks. //! It is designed to support LLM (Large Language Model) inference by re-using a global KV cache. //! //! # Overview //! //! The main components of this module include: //! //! - **Radix Tree Structure**: //! - The `RadixTree` struct represents the main data structure, with nodes (`RadixBlock`) containing children and associated worker IDs. //! - It allows efficient storage and retrieval of data blocks based on their hashes. //! //! - **Event Handling**: //! - The `RouterEvent` struct represents events emitted by LLM workers, which can be applied to the Radix Tree to update its state. //! - The `KvIndexer` struct manages these events and match requests asynchronously using Tokio channels. //! //! - **Hash Computation**: //! - Functions like `compute_block_hash` and `compute_block_hash_for_seq` compute hashes for data blocks and sequences of tokens, facilitating quick lookups. //! //! - **Concurrency and Asynchronous Operations**: //! - The `KvIndexer` uses a single-threaded Tokio runtime to handle events and match requests concurrently, ensuring efficient processing without blocking. //! //! - **Match Requests**: //! - The `MatchRequest` struct represents requests to find matches in the Radix Tree, returning overlap scores indicating the best matches. //! //! # Purpose //! //! This module provides a scalable and efficient way to manage and retrieve data blocks for LLM inference, leveraging a global KV cache to optimize performance. #[cfg(feature = "bench")] use std::time::Instant; use async_trait::async_trait; #[cfg(feature = "metrics")] pub use dynamo_runtime::protocols::maybe_error::MaybeError; #[cfg(feature = "metrics")] use dynamo_runtime::{ component::Component, metrics::{MetricsHierarchy, prometheus_names::kvrouter}, }; use prometheus::{IntCounterVec, Opts}; /// Trait for types that may represent an error response. /// Used for RPC-style responses that can indicate success or failure. #[cfg(not(feature = "metrics"))] pub trait MaybeError { /// Construct an instance from an error. fn from_err(err: Box) -> Self; /// Convert to an error instance if this represents an error. fn err(&self) -> Option; } use serde::{Deserialize, Serialize}; #[cfg(feature = "metrics")] use std::sync::OnceLock; use std::{ collections::{HashMap, VecDeque}, iter, sync::{Arc, Mutex}, thread::JoinHandle, time::Duration, }; use tokio::sync::{broadcast, mpsc, oneshot}; use tokio_util::sync::CancellationToken; use crate::approx::{BlockEntry, PruneConfig, PruneManager}; use crate::flat_hashmap::FlatHashMap; use crate::protocols::*; pub use crate::radix_tree::RadixTree; use dynamo_tokens::SequenceHash; // ------ // KvIndex - Unified interface for RadixTree and FlatHashMap // ------ /// Unified interface for KV cache indexing. /// /// Both `RadixTree` and `FlatHashMap` implement the same core operations: /// - `find_matches`: Find workers with matching cached blocks /// - `apply_event`: Apply store/remove events /// - `remove_worker`: Remove a worker's entries /// - `get_workers`: Get all tracked workers /// - `dump_tree_as_events`: Dump state as events /// - `current_size`: Get total (worker, block) pairs pub enum KvIndex { Tree(RadixTree), Flat(FlatHashMap), } impl KvIndex { /// Create a new KvIndex using RadixTree. pub fn new_tree() -> Self { KvIndex::Tree(RadixTree::new()) } /// Create a new KvIndex using RadixTree with frequency tracking. pub fn new_tree_with_frequency(expiration_duration: Option) -> Self { KvIndex::Tree(RadixTree::new_with_frequency(expiration_duration)) } /// Create a new KvIndex using FlatHashMap. pub fn new_flat() -> Self { KvIndex::Flat(FlatHashMap::new()) } /// Find matches for a sequence of local block hashes. pub fn find_matches(&self, sequence: Vec, early_exit: bool) -> OverlapScores { match self { KvIndex::Tree(tree) => tree.find_matches(sequence, early_exit), KvIndex::Flat(map) => map.find_matches(sequence, early_exit), } } /// Apply a RouterEvent to the index. pub fn apply_event(&mut self, event: RouterEvent) -> Result<(), KvCacheEventError> { match self { KvIndex::Tree(tree) => tree.apply_event(event), KvIndex::Flat(map) => { map.apply_event(event); Ok(()) } } } /// Remove a worker and all their blocks from the index. pub fn remove_worker(&mut self, worker_id: WorkerId) { match self { KvIndex::Tree(tree) => tree.remove_worker(worker_id), KvIndex::Flat(map) => map.remove_worker(worker_id), } } /// Clear all blocks for a worker but keep the worker tracked. pub fn clear_all_blocks(&mut self, worker_id: WorkerId) { match self { KvIndex::Tree(tree) => tree.clear_all_blocks(worker_id), KvIndex::Flat(map) => map.clear_all_blocks(worker_id), } } /// Get all worker IDs currently tracked. pub fn get_workers(&self) -> Vec { match self { KvIndex::Tree(tree) => tree.get_workers(), KvIndex::Flat(map) => map.get_workers(), } } /// Dump the index as a series of RouterEvents. pub fn dump_tree_as_events(&self) -> Vec { match self { KvIndex::Tree(tree) => tree.dump_tree_as_events(), KvIndex::Flat(map) => map.dump_tree_as_events(), } } /// Returns the total number of (worker, block) pairs stored. pub fn current_size(&self) -> usize { match self { KvIndex::Tree(tree) => tree.current_size(), KvIndex::Flat(map) => map.current_size(), } } } /// Errors that can occur in the KV Router. #[derive(Debug, thiserror::Error)] pub enum KvRouterError { #[error("Block not found")] BlockNotFound, #[error("Indexer is offline")] IndexerOffline, #[error("Indexer is dropped request")] IndexerDroppedRequest, #[error("Prune operation failed: {0}")] PruneFailed(String), } // ------- // Distributed router - Worker KV Query types // ------- /// Request to query a worker's local KV indexer. #[derive(Serialize, Deserialize, Debug, Clone)] pub struct WorkerKvQueryRequest { /// The worker ID of the worker to query. pub worker_id: WorkerId, /// Start event ID (inclusive). If `None`, dumps entire tree. pub start_event_id: Option, /// End event ID (inclusive). If `None`, returns up to newest available. pub end_event_id: Option, } /// Response from a worker's local KV indexer. #[derive(Serialize, Deserialize, Debug, Clone)] pub enum WorkerKvQueryResponse { /// Events served from the circular buffer (with original event IDs) Events(Vec), /// Full tree dump (with synthetic 0-indexed event IDs) TreeDump(Vec), /// Requested range is newer than available data TooNew { requested_start: Option, requested_end: Option, newest_available: u64, }, /// Invalid range: end_id < start_id InvalidRange { start_id: u64, end_id: u64 }, /// Query failed on worker (serialized error) Error(String), } impl MaybeError for WorkerKvQueryResponse { fn from_err(err: Box) -> Self { WorkerKvQueryResponse::Error(err.to_string()) } fn err(&self) -> Option { match self { WorkerKvQueryResponse::Error(msg) => Some(anyhow::Error::msg(msg.clone())), _ => None, } } } /// Metrics for the KV Indexer. #[derive(Clone)] pub struct KvIndexerMetrics { /// Counter of events applied. pub kv_cache_events_applied: IntCounterVec, } /// Metric status labels. pub const METRIC_STATUS_OK: &str = "ok"; pub const METRIC_STATUS_PARENT_NOT_FOUND: &str = "parent_block_not_found"; pub const METRIC_STATUS_BLOCK_NOT_FOUND: &str = "block_not_found"; pub const METRIC_STATUS_INVALID_BLOCK: &str = "invalid_block"; /// Metric event labels. pub const METRIC_EVENT_STORED: &str = "stored"; pub const METRIC_EVENT_REMOVED: &str = "removed"; pub const METRIC_EVENT_CLEARED: &str = "cleared"; /// Metric name for KV cache events applied counter. const KV_CACHE_EVENTS_APPLIED_NAME: &str = "dynamo_kvrouter_kv_cache_events_applied"; #[cfg(feature = "metrics")] static KV_INDEXER_METRICS: OnceLock> = OnceLock::new(); impl KvIndexerMetrics { #[cfg(feature = "metrics")] fn new(kv_cache_events_applied: IntCounterVec) -> Self { Self { kv_cache_events_applied, } } /// Creates a new KvIndexerMetrics from a Component, memoizing the result in /// KV_INDEXER_METRICS to avoid duplicate registration issues. #[cfg(feature = "metrics")] pub fn from_component(component: &Component) -> Arc { KV_INDEXER_METRICS.get_or_init(|| { match component.metrics().create_intcountervec( kvrouter::KV_CACHE_EVENTS_APPLIED, "Total number of KV cache events applied to index", &["event_type", "status"], &[], ) { Ok(kv_cache_events_applied) => Arc::new(Self::new(kv_cache_events_applied)), Err(e) => { tracing::warn!("Failed to create kv indexer metrics from component: {}. Using unregistered metrics as fallback.", e); Arc::new(Self::new_unregistered()) } } }).clone() } /// Creates a new KvIndexerMetrics which is not registered with a MetricsRegistry. /// This may be used for tests or as a fallback for when a MetricsRegistry is not available / has errored. pub fn new_unregistered() -> Self { Self { kv_cache_events_applied: IntCounterVec::new( Opts::new( KV_CACHE_EVENTS_APPLIED_NAME, "Total number of KV cache events applied to index", ), &["event_type", "status"], ) .unwrap(), } } pub fn get_event_type(event_data: &KvCacheEventData) -> &'static str { match event_data { KvCacheEventData::Stored(_) => METRIC_EVENT_STORED, KvCacheEventData::Removed(_) => METRIC_EVENT_REMOVED, KvCacheEventData::Cleared => METRIC_EVENT_CLEARED, } } pub fn increment_event_applied( &self, event_type: &'static str, result: Result<(), KvCacheEventError>, ) { match result { Ok(_) => { self.kv_cache_events_applied .with_label_values(&[event_type, METRIC_STATUS_OK]) .inc_by(1); } Err(e) => { let error_label = match e { KvCacheEventError::ParentBlockNotFound => METRIC_STATUS_PARENT_NOT_FOUND, KvCacheEventError::BlockNotFound => METRIC_STATUS_BLOCK_NOT_FOUND, KvCacheEventError::InvalidBlockSequence => METRIC_STATUS_INVALID_BLOCK, }; self.kv_cache_events_applied .with_label_values(&[event_type, error_label]) .inc_by(1); } } } } /// A request to find matches in the Radix Tree. pub struct MatchRequest { /// A vector of `LocalBlockHash` representing the sequence to match. sequence: Vec, /// A boolean indicating whether to exit early if a single match is found. early_exit: bool, /// A channel sender to send the `OverlapScores` response. resp: oneshot::Sender, /// Timestamp when the request was created (for queue wait time measurement) #[cfg(feature = "bench")] created_at: Instant, } impl MatchRequest { fn new( sequence: Vec, early_exit: bool, resp: oneshot::Sender, ) -> Self { Self { sequence, early_exit, resp, #[cfg(feature = "bench")] created_at: Instant::now(), } } } /// A request to dump the tree as events pub struct DumpRequest { /// Channel to send the dumped events pub resp: oneshot::Sender>, } /// A request to get all workers currently tracked pub struct GetWorkersRequest { /// Channel to send the worker IDs pub resp: oneshot::Sender>, } #[async_trait] pub trait KvIndexerInterface { /// Find matches for a given sequence of `LocalBlockHash`es. /// /// ### Arguments /// /// * `sequence` - A vector of `LocalBlockHash` representing the sequence to match. /// /// ### Returns /// /// An `OverlapScores` representing the match scores. async fn find_matches( &self, sequence: Vec, ) -> Result; /// Find matches for a given sequence of tokens. /// /// ### Arguments /// /// * `tokens` - A vector of `u32` tokens. /// /// ### Returns /// /// An `OverlapScores` representing the match scores. async fn find_matches_for_request( &self, tokens: &[u32], ) -> Result; /// Apply a `RouterEvent` to the KV store. /// /// ### Arguments /// /// * `event` - The `RouterEvent` to apply. async fn apply_event(&mut self, event: RouterEvent); /// Remove a worker's entries from the trie. /// /// ### Arguments /// /// * `worker` - The worker to remove from the trie. async fn remove_worker(&mut self, worker: WorkerId); /// Shutdown the KV Indexer. fn shutdown(&mut self); /// Dump the entire tree as RouterEvents. /// /// ### Returns /// /// A vector of RouterEvents representing the current state of the tree. async fn dump_events(&self) -> Result, KvRouterError>; /// Process a routing decision for a request with tokens. /// /// Uses TokensWithHashes for lazy hash computation - if hashes were already /// computed (e.g., by find_best_match), they will be reused. /// /// ### Arguments /// /// * `tokens_with_hashes` - Tokens with lazily computed hashes. /// * `worker` - The worker (with dp_rank) that was selected. async fn process_routing_decision_for_request( &self, tokens_with_hashes: &mut TokensWithHashes, worker: WorkerWithDpRank, ) -> Result<(), KvRouterError>; } /// A request to process a routing decision. struct RoutingDecisionRequest { worker: WorkerWithDpRank, local_hashes: Vec, sequence_hashes: Vec, } /// The KV Indexer, managing the KV store and handling events and match requests. #[derive(Clone)] pub struct KvIndexer { /// A `CancellationToken` for managing shutdown. cancel: CancellationToken, /// A sender for `RouterEvent`s. event_tx: mpsc::Sender, /// A sender for `MatchRequest`s. match_tx: mpsc::Sender, /// A sender for remove worker requests. remove_worker_tx: mpsc::Sender, /// A sender for get workers requests. get_workers_tx: mpsc::Sender, /// A sender for dump requests. dump_tx: mpsc::Sender, /// A sender for routing decision requests. routing_tx: mpsc::Sender, /// The size of the KV block this indexer can handle. kv_block_size: u32, /// Reference counter for Clone-aware Drop. /// Only the last clone should cancel the token on drop. _ref_count: Arc<()>, } impl KvIndexer { /// Create a new `KvIndexer`. /// /// ### Arguments /// /// * `token` - A `CancellationToken` for managing shutdown. /// * `expiration_duration` - The amount of time that block usage should be buffered. /// * `ttl` - The time-to-live for blocks before they expire. /// * `prune_config` - Configuration for tree-size based pruning. /// /// ### Returns /// /// A new `KvIndexer`. pub fn new_with_frequency( token: CancellationToken, expiration_duration: Option, kv_block_size: u32, metrics: Arc, prune_config: Option, ) -> Self { let (event_tx, event_rx) = mpsc::channel::(2048); let (match_tx, match_rx) = mpsc::channel::(128); let (remove_worker_tx, remove_worker_rx) = mpsc::channel::(16); let (get_workers_tx, get_workers_rx) = mpsc::channel::(16); let (dump_tx, dump_rx) = mpsc::channel::(16); let (routing_tx, mut routing_rx) = mpsc::channel::(2048); let (prune_tx, mut prune_rx) = mpsc::channel::<()>(1); let cancel_clone = token.clone(); std::thread::spawn(move || { // Create a single-threaded tokio runtime let runtime = tokio::runtime::Builder::new_current_thread() .enable_all() .build() .unwrap(); runtime.block_on(async move { let cancel = cancel_clone; let mut match_rx = match_rx; let mut event_rx = event_rx; let mut remove_worker_rx = remove_worker_rx; let mut get_workers_rx = get_workers_rx; let mut dump_rx = dump_rx; let mut trie = RadixTree::new_with_frequency(expiration_duration); // Create PruneManager if prune_config is specified let mut prune_manager = prune_config.map(|config| { PruneManager::::new(50, config) }); let mut event_id_counter = 0u64; loop { // Create a future that sleeps until the next expiration time let expiry_fut = if let Some(ref pm) = prune_manager && let Some(next_expiry) = pm.peek_next_expiry() { tokio::time::sleep_until(next_expiry) } else { tokio::time::sleep(Duration::MAX) }; tokio::select! { biased; _ = cancel.cancelled() => { tracing::debug!("KvCacheIndexer progress loop shutting down"); return; } Some(worker) = remove_worker_rx.recv() => { trie.remove_worker(worker); } Some(get_workers_req) = get_workers_rx.recv() => { let workers = trie.get_workers(); let _ = get_workers_req.resp.send(workers); } Some(_) = prune_rx.recv() => { // Tree size-based pruning triggered let Some(ref mut pm) = prune_manager else { continue }; let Ok(pruned) = pm.prune(trie.current_size()) else { continue }; for p in pruned { event_id_counter += 1; let event = RouterEvent::new( p.worker.worker_id, KvCacheEvent { event_id: event_id_counter, data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes: vec![p.key], }), dp_rank: p.worker.dp_rank, } ); let _ = trie.apply_event(event); } } Some(event) = event_rx.recv() => { let event_type = KvIndexerMetrics::get_event_type(&event.event.data); let event_id = event.event.event_id; let worker_id = event.worker_id; // Only clone if we need the event for prune_manager afterward let event_for_prune = prune_manager.is_some().then(|| event.clone()); let result = trie.apply_event(event); let result_is_ok = result.is_ok(); let tree_size = trie.current_size(); tracing::trace!( "Applied KV event to global radix tree: event_type={event_type}, event_id={event_id}, worker_id={worker_id}, success={result_is_ok}, global_radix_tree_size={tree_size}" ); metrics.increment_event_applied(event_type, result); // Track blocks in PruneManager if TTL is enabled and event was stored successfully let Some(ref mut pm) = prune_manager else { continue }; if !result_is_ok { continue }; let Some(ref event) = event_for_prune else { continue }; let KvCacheEventData::Stored(ref store_data) = event.event.data else { continue }; let worker = WorkerWithDpRank::new(event.worker_id, event.event.dp_rank); let block_entries: Vec = store_data.blocks.iter().enumerate().map(|(idx, block)| { BlockEntry { key: block.block_hash, worker, seq_position: idx, } }).collect(); pm.insert(block_entries); // Check if we need to prune due to tree size let Some(ref pc) = pm.prune_config else { continue }; let current_size = trie.current_size(); if current_size > pc.max_tree_size { tracing::info!( "Pruning: tree size ({}) exceeded max tree size ({}), scheduling pruning", current_size, pc.max_tree_size ); let _ = prune_tx.try_send(()); } } Some(dump_req) = dump_rx.recv() => { let events = trie.dump_tree_as_events(); let _ = dump_req.resp.send(events); } Some(routing_req) = routing_rx.recv() => { // Process routing decisions when TTL/pruning is enabled let Some(ref mut pm) = prune_manager else { continue }; event_id_counter += 1; let hashes = routing_req.local_hashes.iter().zip(routing_req.sequence_hashes.iter()); let stored_event = KvCacheEventData::Stored(KvCacheStoreData { parent_hash: None, blocks: hashes.map(|(local_hash, sequence_hash)| KvCacheStoredBlockData { tokens_hash: *local_hash, block_hash: ExternalSequenceBlockHash(*sequence_hash), mm_extra_info: None, }).collect(), }); let event = RouterEvent::new( routing_req.worker.worker_id, KvCacheEvent { event_id: event_id_counter, data: stored_event, dp_rank: routing_req.worker.dp_rank, } ); if trie.apply_event(event).is_err() { continue; } let block_entries: Vec = routing_req.sequence_hashes.iter().enumerate().map(|(idx, h)| { BlockEntry { key: ExternalSequenceBlockHash(*h), worker: routing_req.worker, seq_position: idx, } }).collect(); pm.insert(block_entries); // Check if we need to prune due to tree size let Some(ref pc) = pm.prune_config else { continue }; let current_size = trie.current_size(); if current_size > pc.max_tree_size { tracing::info!( "Pruning: tree size ({}) exceeded max tree size ({}), scheduling pruning", current_size, pc.max_tree_size ); let _ = prune_tx.try_send(()); } } Some(req) = match_rx.recv() => { #[cfg(feature = "bench")] let queue_wait = req.created_at.elapsed(); #[cfg(feature = "bench")] let seq_len = req.sequence.len(); #[cfg(feature = "bench")] let process_start = Instant::now(); let matches = trie.find_matches(req.sequence, req.early_exit); #[cfg(feature = "bench")] let process_time = process_start.elapsed(); #[cfg(feature = "bench")] tracing::info!( seq_len, queue_wait_us = queue_wait.as_micros() as u64, process_us = process_time.as_micros() as u64, "indexer: processed find_matches" ); let _ = req.resp.send(matches); } _ = expiry_fut => { // TTL-based expiry triggered let Some(ref mut pm) = prune_manager else { continue }; let expired = pm.pop_expired(); for e in expired { event_id_counter += 1; let event = RouterEvent::new( e.worker.worker_id, KvCacheEvent { event_id: event_id_counter, data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes: vec![e.key], }), dp_rank: e.worker.dp_rank, } ); let _ = trie.apply_event(event); } } } } }); tracing::debug!("KvCacheIndexer task completed"); }); Self { cancel: token, event_tx, match_tx, remove_worker_tx, get_workers_tx, dump_tx, routing_tx, kv_block_size, _ref_count: Arc::new(()), } } pub fn block_size(&self) -> u32 { self.kv_block_size } pub fn new( token: CancellationToken, kv_block_size: u32, metrics: Arc, ) -> Self { Self::new_with_frequency(token, None, kv_block_size, metrics, None) } /// Get a sender for `RouterEvent`s. /// /// ### Returns /// /// A `mpsc::Sender` for `RouterEvent`s. pub fn event_sender(&self) -> mpsc::Sender { self.event_tx.clone() } /// Get a sender for dump requests (snapshot events). /// /// ### Returns /// /// A `mpsc::Sender` for `DumpRequest`s. pub fn snapshot_event_sender(&self) -> mpsc::Sender { self.dump_tx.clone() } /// Get a sender for worker removal requests. /// /// ### Returns /// /// A `mpsc::Sender` for `WorkerId`s. pub fn remove_worker_sender(&self) -> mpsc::Sender { self.remove_worker_tx.clone() } /// Get a sender for get workers requests. /// /// ### Returns /// /// A `mpsc::Sender` for `GetWorkersRequest`s. pub fn get_workers_sender(&self) -> mpsc::Sender { self.get_workers_tx.clone() } } #[async_trait] impl KvIndexerInterface for KvIndexer { async fn find_matches( &self, sequence: Vec, ) -> Result { #[cfg(feature = "bench")] let start = Instant::now(); let seq_len = sequence.len(); let (resp_tx, resp_rx) = oneshot::channel(); let req = MatchRequest::new(sequence, false, resp_tx); if let Err(e) = self.match_tx.send(req).await { tracing::error!( "Failed to send match request: {:?}; the indexer maybe offline", e ); return Err(KvRouterError::IndexerOffline); } let result = resp_rx .await .map_err(|_| KvRouterError::IndexerDroppedRequest); #[cfg(feature = "bench")] { let elapsed = start.elapsed(); tracing::info!( seq_len, elapsed_us = elapsed.as_micros() as u64, "find_matches completed" ); } #[cfg(not(feature = "bench"))] let _ = seq_len; result } async fn find_matches_for_request( &self, tokens: &[u32], ) -> Result { tracing::debug!( "Finding matches for request tokens: {:?} / len: {}", tokens, tokens.len() ); let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size, None); tracing::debug!("Computed sequence: {:?}", sequence); self.find_matches(sequence).await } async fn apply_event(&mut self, event: RouterEvent) { self.event_tx.send(event).await.unwrap(); } async fn remove_worker(&mut self, worker: WorkerId) { self.remove_worker_tx.send(worker).await.unwrap(); } fn shutdown(&mut self) { self.cancel.cancel(); } async fn dump_events(&self) -> Result, KvRouterError> { let (resp_tx, resp_rx) = oneshot::channel(); let dump_req = DumpRequest { resp: resp_tx }; if let Err(e) = self.dump_tx.send(dump_req).await { tracing::error!("Failed to send dump request: {:?}", e); return Err(KvRouterError::IndexerOffline); } resp_rx .await .map_err(|_| KvRouterError::IndexerDroppedRequest) } async fn process_routing_decision_for_request( &self, tokens_with_hashes: &mut TokensWithHashes, worker: WorkerWithDpRank, ) -> Result<(), KvRouterError> { let local_hashes = tokens_with_hashes.get_or_compute_block_hashes().to_vec(); let sequence_hashes = tokens_with_hashes.get_or_compute_seq_hashes().to_vec(); self.process_routing_decision_internal(worker, local_hashes, sequence_hashes) .await } } impl KvIndexer { /// Internal method to process a routing decision with pre-computed hashes. async fn process_routing_decision_internal( &self, worker: WorkerWithDpRank, local_hashes: Vec, sequence_hashes: Vec, ) -> Result<(), KvRouterError> { self.routing_tx .send(RoutingDecisionRequest { worker, local_hashes, sequence_hashes, }) .await .map_err(|_| KvRouterError::IndexerDroppedRequest)?; Ok(()) } } impl Drop for KvIndexer { fn drop(&mut self) { // Only cancel the token if we're the last reference. // This allows clones to be dropped without killing the background task. if Arc::strong_count(&self._ref_count) == 1 { self.shutdown(); } } } // ------------------------------------------------- // Decentralized router: LocalKvIndexer for workers // ------------------------------------------------- /// A thin wrapper around KvIndexer that buffers recent events /// (e.g. which may be queued by router upon startup) /// pub struct LocalKvIndexer { /// The underlying indexer indexer: KvIndexer, /// Circular buffer of recent events event_buffer: Mutex>, /// Maximum number of events to keep in buffer max_buffer_size: usize, // Router sets this to WORKER_KV_INDEXER_BUFFER_SIZE } impl LocalKvIndexer { /// create a new LocalKvIndexer pointing to a KvIndexer. pub fn new( token: CancellationToken, kv_block_size: u32, metrics: Arc, max_buffer_size: usize, ) -> Self { Self { indexer: KvIndexer::new(token, kv_block_size, metrics), event_buffer: Mutex::new(VecDeque::with_capacity(max_buffer_size)), max_buffer_size, } } /// Get all buffered events (oldest first). pub fn get_all_events_in_buffer(&self) -> Vec { let buffer = self.event_buffer.lock().unwrap(); buffer.iter().cloned().collect() } /// Query events by ID range, returning events in `[start_id, end_id]` (both inclusive). /// /// ### Arguments /// /// * `start_id` - Starting event ID (inclusive). If `None`, dumps entire tree. /// * `end_id` - Ending event ID (inclusive). If `None`, returns up to newest available. /// /// ### Returns /// /// - `Events`: Buffered events with original IDs (when range is within buffer) /// - `TreeDump`: Full tree dump with synthetic IDs (when range is too old or unspecified) /// - `TooNew`: Error when requested range is newer than available data /// - `InvalidRange`: Error when end_id < start_id pub async fn get_events_in_id_range( &self, start_id: Option, end_id: Option, ) -> WorkerKvQueryResponse { // Validate range if both specified if let (Some(s), Some(e)) = (start_id, end_id) && e < s { tracing::warn!(start_id = s, end_id = e, "Invalid range: end_id < start_id"); return WorkerKvQueryResponse::InvalidRange { start_id: s, end_id: e, }; } // Get buffer state let (first_id, last_id) = { let buffer = self.event_buffer.lock().unwrap(); if buffer.is_empty() { (None, None) } else { ( Some(buffer.front().unwrap().event.event_id), Some(buffer.back().unwrap().event.event_id), ) } }; // If no start_id specified, dump entire tree if start_id.is_none() { tracing::debug!("No start_id specified, dumping entire tree"); let events = self.dump_events().await.unwrap_or_default(); return WorkerKvQueryResponse::TreeDump(events); } let start_id = start_id.unwrap(); let end_id = end_id.unwrap_or_else(|| last_id.unwrap_or(start_id)); // Check for empty buffer let Some(first_buffered) = first_id else { tracing::debug!("Buffer empty, dumping entire tree"); let events = self.dump_events().await.unwrap_or_default(); return WorkerKvQueryResponse::TreeDump(events); }; let last_buffered = last_id.unwrap(); // Check if request is too new if start_id > last_buffered { tracing::warn!( start_id, last_buffered, "Requested start_id is newer than buffer" ); return WorkerKvQueryResponse::TooNew { requested_start: Some(start_id), requested_end: Some(end_id), newest_available: last_buffered, }; } // Check if start_id is too old (before buffer) -> tree dump if start_id < first_buffered { tracing::info!( start_id, first_buffered, "Requested start_id is older than buffer, dumping entire tree" ); let events = self.dump_events().await.unwrap_or_default(); return WorkerKvQueryResponse::TreeDump(events); } // Serve from buffer let buffer = self.event_buffer.lock().unwrap(); let start_idx = match buffer.binary_search_by_key(&start_id, |e| e.event.event_id) { Ok(idx) => idx, Err(insertion_point) => insertion_point, }; // Clamp end_id to buffer bounds let clamped_end_id = end_id.min(last_buffered); let end_idx = match buffer.binary_search_by_key(&clamped_end_id, |e| e.event.event_id) { Ok(idx) => idx + 1, // Include the matched element Err(insertion_point) => insertion_point, }; let events: Vec = buffer .iter() .skip(start_idx) .take(end_idx.saturating_sub(start_idx)) .cloned() .collect(); WorkerKvQueryResponse::Events(events) } /// Record an event in the buffer fn record_event(&self, event: RouterEvent) { let mut buffer = self.event_buffer.lock().unwrap(); // Check that event id is consecutive to last one if let Some(last_event) = buffer.back() && event.event.event_id != last_event.event.event_id + 1 { let expected = last_event.event.event_id + 1; tracing::error!( worker_id = event.worker_id, expected, got = event.event.event_id, "Non-consecutive KV event id; buffer may have gaps" ); } tracing::debug!( "Recorded event {:?} in buffer, now size is {}", event, buffer.len() ); // Add to back buffer.push_back(event); // Remove from front if over capacity (circular buffer behavior) while buffer.len() > self.max_buffer_size { buffer.pop_front(); } } /// Apply event with buffering. /// /// This records the event in the buffer and forwards it to the underlying indexer. pub async fn apply_event_with_buffer(&self, event: RouterEvent) -> Result<(), KvRouterError> { // Record in buffer self.record_event(event.clone()); // Forward to underlying indexer self.indexer .event_sender() .send(event) .await .map_err(|_| KvRouterError::IndexerOffline) } /// Clear the event buffer. pub fn clear_buffer(&self) { let mut buffer = self.event_buffer.lock().unwrap(); buffer.clear(); } /// Get the current buffer size. pub fn buffer_len(&self) -> usize { let buffer = self.event_buffer.lock().unwrap(); buffer.len() } // Delegation methods to underlying KvIndexer /// Get a sender for `RouterEvent`s. pub fn event_sender(&self) -> mpsc::Sender { self.indexer.event_sender() } /// Get a sender for dump requests (snapshot events). pub fn snapshot_event_sender(&self) -> mpsc::Sender { self.indexer.snapshot_event_sender() } /// Get a sender for worker removal requests. pub fn remove_worker_sender(&self) -> mpsc::Sender { self.indexer.remove_worker_sender() } /// Get a sender for get workers requests. pub fn get_workers_sender(&self) -> mpsc::Sender { self.indexer.get_workers_sender() } /// Get the KV block size. pub fn block_size(&self) -> u32 { self.indexer.block_size() } } // Implement KvIndexerInterface by delegating to the underlying indexer #[async_trait] impl KvIndexerInterface for LocalKvIndexer { async fn find_matches( &self, sequence: Vec, ) -> Result { self.indexer.find_matches(sequence).await } async fn find_matches_for_request( &self, tokens: &[u32], ) -> Result { self.indexer.find_matches_for_request(tokens).await } async fn apply_event(&mut self, event: RouterEvent) { // Use the buffering version let _ = self.apply_event_with_buffer(event).await; } async fn remove_worker(&mut self, worker: WorkerId) { let _ = self.indexer.remove_worker_sender().send(worker).await; } fn shutdown(&mut self) { // Note: Since indexer is Arc, we can't call mutable methods directly. // The indexer will be shut down when the CancellationToken is cancelled // or when the last Arc reference is dropped. } async fn dump_events(&self) -> Result, KvRouterError> { self.indexer.dump_events().await } async fn process_routing_decision_for_request( &self, tokens_with_hashes: &mut TokensWithHashes, worker: WorkerWithDpRank, ) -> Result<(), KvRouterError> { // TODO I guess the local kvindexers have little use for this method? // Keeping it here now to implement the trait fully self.indexer .process_routing_decision_for_request(tokens_with_hashes, worker) .await } } #[derive(Debug, Clone)] pub struct ShardedMatchRequest { sequence: Vec, early_exit: bool, resp: mpsc::Sender, #[cfg(feature = "bench")] created_at: Instant, } impl ShardedMatchRequest { fn new( sequence: Vec, early_exit: bool, resp: mpsc::Sender, ) -> Self { Self { sequence, early_exit, resp, #[cfg(feature = "bench")] created_at: Instant::now(), } } } /// A sharded KV Indexer that partitions the RadixTree across multiple independent shards. /// /// ## Sharding Strategy /// - Each worker is **permanently assigned** to a single shard on first event /// - All KV blocks from a worker exist only in that worker's assigned shard /// - New workers are assigned to the shard with the fewest workers (load balancing) /// /// ## Operation /// - **Events**: Routed directly to the worker's assigned shard /// - **Match requests**: Broadcast to all shards (scatter-gather pattern) /// - **Threading**: Each shard runs in its own thread with a single-threaded runtime /// /// This design ensures no cross-shard synchronization for writes while enabling /// parallel processing and better scalability. pub struct KvIndexerSharded { /// A `CancellationToken` for managing shutdown. cancel: CancellationToken, /// The size of the KV block this indexer can handle. kv_block_size: u32, worker_assignments: HashMap, worker_counts: Vec, event_tx: Vec>, request_broadcast_tx: broadcast::Sender, remove_worker_tx: Vec>, dump_tx: Vec>, routing_tx: Vec>, tasks: Vec>, } impl KvIndexerSharded { /// Create a new `KvIndexerSharded`. /// /// ### Arguments /// /// * `token` - A `CancellationToken` for managing shutdown. /// * `shards` - A list of kvindexer shards. /// * `expiration_duration` - The amount of time that block usage should be buffered. /// * `ttl` - The time-to-live for blocks before they expire. /// * `prune_config` - Configuration for tree-size based pruning. /// /// ### Returns /// /// A new `KvIndexer`. pub fn new_with_frequency( token: CancellationToken, num_shards: usize, expiration_duration: Option, kv_block_size: u32, metrics: Arc, prune_config: Option, ) -> Self { let worker_assignments: HashMap = HashMap::new(); let worker_counts: Vec = vec![0; num_shards]; let mut event_tx = Vec::new(); let mut remove_worker_tx = Vec::new(); let mut get_workers_tx = Vec::new(); let mut dump_tx = Vec::new(); let mut routing_tx = Vec::new(); let mut tasks = Vec::new(); let (request_broadcast_tx, _) = broadcast::channel::(1048576); for _ in 0..num_shards { let (shard_event_tx, mut shard_event_rx) = mpsc::channel::(2048); let (shard_remove_worker_tx, mut shard_remove_worker_rx) = mpsc::channel::(16); let (shard_get_workers_tx, mut shard_get_workers_rx) = mpsc::channel::(16); let (shard_dump_tx, mut shard_dump_rx) = mpsc::channel::(16); let (shard_routing_tx, mut shard_routing_rx) = mpsc::channel::(2048); let (shard_prune_tx, mut shard_prune_rx) = mpsc::channel::<()>(1); let mut shard_broadcast_rx = request_broadcast_tx.subscribe(); let cancel = token.clone(); let metrics = metrics.clone(); let prune_config_clone = prune_config.clone(); event_tx.push(shard_event_tx); remove_worker_tx.push(shard_remove_worker_tx); get_workers_tx.push(shard_get_workers_tx); dump_tx.push(shard_dump_tx); routing_tx.push(shard_routing_tx); let runtime = tokio::runtime::Builder::new_current_thread() .enable_all() .build() .unwrap(); tasks.push(std::thread::spawn(move || { runtime.block_on(async move { let mut trie = RadixTree::new_with_frequency(expiration_duration); // Create PruneManager if prune_config is specified let mut prune_manager = prune_config_clone.map(|config| { PruneManager::::new(50, config) }); let mut event_id_counter = 0u64; loop { // Create a future that sleeps until the next expiration time let expiry_fut = if let Some(ref pm) = prune_manager && let Some(next_expiry) = pm.peek_next_expiry() { tokio::time::sleep_until(next_expiry) } else { tokio::time::sleep(Duration::MAX) }; tokio::select! { biased; _ = cancel.cancelled() => { tracing::trace!("KvCacheIndexer progress loop shutting down"); return; } Some(worker) = shard_remove_worker_rx.recv() => { trie.remove_worker(worker); } Some(get_workers_req) = shard_get_workers_rx.recv() => { let workers = trie.get_workers(); let _ = get_workers_req.resp.send(workers); } Some(_) = shard_prune_rx.recv() => { // Tree size-based pruning triggered let Some(ref mut pm) = prune_manager else { continue }; let Ok(pruned) = pm.prune(trie.current_size()) else { continue }; for p in pruned { event_id_counter += 1; let event = RouterEvent::new( p.worker.worker_id, KvCacheEvent { event_id: event_id_counter, data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes: vec![p.key], }), dp_rank: p.worker.dp_rank, } ); let _ = trie.apply_event(event); } } Some(event) = shard_event_rx.recv() => { let event_type = KvIndexerMetrics::get_event_type(&event.event.data); // Only clone if we need the event for prune_manager afterward let event_for_prune = prune_manager.is_some().then(|| event.clone()); let result = trie.apply_event(event); let result_is_ok = result.is_ok(); metrics.increment_event_applied(event_type, result); // Track blocks in PruneManager if TTL is enabled and event was stored successfully let Some(ref mut pm) = prune_manager else { continue }; if !result_is_ok { continue }; let Some(ref event) = event_for_prune else { continue }; let KvCacheEventData::Stored(ref store_data) = event.event.data else { continue }; let worker = WorkerWithDpRank::new(event.worker_id, event.event.dp_rank); let block_entries: Vec = store_data.blocks.iter().enumerate().map(|(idx, block)| { BlockEntry { key: block.block_hash, worker, seq_position: idx, } }).collect(); pm.insert(block_entries); // Check if we need to prune due to tree size let Some(ref pc) = pm.prune_config else { continue }; let current_size = trie.current_size(); if current_size > pc.max_tree_size { tracing::info!( "Pruning: tree size ({}) exceeded max tree size ({}), scheduling pruning", current_size, pc.max_tree_size ); let _ = shard_prune_tx.try_send(()); } } Some(routing_req) = shard_routing_rx.recv() => { // Process routing decisions when TTL/pruning is enabled let Some(ref mut pm) = prune_manager else { continue }; event_id_counter += 1; let hashes = routing_req.local_hashes.iter().zip(routing_req.sequence_hashes.iter()); let stored_event = KvCacheEventData::Stored(KvCacheStoreData { parent_hash: None, blocks: hashes.map(|(local_hash, sequence_hash)| KvCacheStoredBlockData { tokens_hash: *local_hash, block_hash: ExternalSequenceBlockHash(*sequence_hash), mm_extra_info: None, }).collect(), }); let event = RouterEvent::new( routing_req.worker.worker_id, KvCacheEvent { event_id: event_id_counter, data: stored_event, dp_rank: routing_req.worker.dp_rank, } ); if trie.apply_event(event).is_err() { continue; } let block_entries: Vec = routing_req.sequence_hashes.iter().enumerate().map(|(idx, h)| { BlockEntry { key: ExternalSequenceBlockHash(*h), worker: routing_req.worker, seq_position: idx, } }).collect(); pm.insert(block_entries); // Check if we need to prune due to tree size let Some(ref pc) = pm.prune_config else { continue }; let current_size = trie.current_size(); if current_size > pc.max_tree_size { tracing::info!( "Pruning: tree size ({}) exceeded max tree size ({}), scheduling pruning", current_size, pc.max_tree_size ); let _ = shard_prune_tx.try_send(()); } } Some(dump_req) = shard_dump_rx.recv() => { let events = trie.dump_tree_as_events(); let _ = dump_req.resp.send(events); } Ok(req) = shard_broadcast_rx.recv() => { #[cfg(feature = "bench")] let queue_wait = req.created_at.elapsed(); #[cfg(feature = "bench")] let seq_len = req.sequence.len(); #[cfg(feature = "bench")] let process_start = Instant::now(); let matches = trie.find_matches(req.sequence, req.early_exit); #[cfg(feature = "bench")] let process_time = process_start.elapsed(); #[cfg(feature = "bench")] tracing::info!( seq_len, queue_wait_us = queue_wait.as_micros() as u64, process_us = process_time.as_micros() as u64, "sharded indexer: processed find_matches" ); if let Err(e) = req.resp.send(matches).await { tracing::trace!("Failed to send match response: {:?}", e); } } _ = expiry_fut => { // TTL-based expiry triggered let Some(ref mut pm) = prune_manager else { continue }; let expired = pm.pop_expired(); for e in expired { event_id_counter += 1; let event = RouterEvent::new( e.worker.worker_id, KvCacheEvent { event_id: event_id_counter, data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes: vec![e.key], }), dp_rank: e.worker.dp_rank, } ); let _ = trie.apply_event(event); } } } } }); tracing::debug!("KvCacheIndexer task completed"); })); } Self { cancel: token, kv_block_size, worker_assignments, worker_counts, event_tx, request_broadcast_tx, remove_worker_tx, dump_tx, routing_tx, tasks, } } pub fn block_size(&self) -> u32 { self.kv_block_size } pub fn new( token: CancellationToken, num_shards: usize, kv_block_size: u32, metrics: Arc, ) -> Self { Self::new_with_frequency(token, num_shards, None, kv_block_size, metrics, None) } } #[async_trait] impl KvIndexerInterface for KvIndexerSharded { async fn find_matches( &self, sequence: Vec, ) -> Result { #[cfg(feature = "bench")] let start = Instant::now(); #[cfg(feature = "bench")] let seq_len = sequence.len(); #[cfg(feature = "bench")] let num_shards = self.event_tx.len(); 'match_loop: loop { let (match_tx, mut match_rx) = mpsc::channel(self.event_tx.len()); let sharded_req = ShardedMatchRequest::new(sequence.clone(), false, match_tx); self.request_broadcast_tx .send(sharded_req) .map_err(|_| KvRouterError::IndexerOffline)?; let mut scores = OverlapScores::new(); for response_num in 0..self.event_tx.len() { match match_rx.recv().await { Some(response) => { scores.scores.extend(response.scores); scores.tree_sizes.extend(response.tree_sizes); if response_num == 0 { scores.frequencies = response.frequencies; } else { let diff = (response.frequencies.len() as i64) - (scores.frequencies.len() as i64); if diff > 0 { scores.frequencies.extend(iter::repeat_n(0, diff as usize)); } for i in 0..response.frequencies.len() { scores.frequencies[i] += response.frequencies[i]; } } } None => { // This can only happen if the broadcast channel overflows. // In this case, we don't want to recursively call find_matches again. Otherwise, we could overflow the stack. continue 'match_loop; } } } #[cfg(feature = "bench")] { let elapsed = start.elapsed(); tracing::info!( seq_len, num_shards, elapsed_us = elapsed.as_micros() as u64, "find_matches (sharded) completed" ); } return Ok(scores); } } async fn find_matches_for_request( &self, tokens: &[u32], ) -> Result { let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size, None); self.find_matches(sequence).await } async fn apply_event(&mut self, event: RouterEvent) { #[allow(clippy::map_entry)] if !self.worker_assignments.contains_key(&event.worker_id) { // Get the shard with the smallest amount of workers. let selected_shard = self .worker_counts .iter() .enumerate() .min_by_key(|&(_, value)| value) .unwrap() .0; self.worker_assignments .insert(event.worker_id, selected_shard); self.worker_counts[selected_shard] += 1; } self.event_tx[self.worker_assignments[&event.worker_id]] .send(event) .await .unwrap(); } async fn remove_worker(&mut self, worker: WorkerId) { if let Some((_, shard)) = self.worker_assignments.remove_entry(&worker) { self.worker_counts[shard] -= 1; self.remove_worker_tx[shard].send(worker).await.unwrap(); } } /// Shutdown the KV Indexer. fn shutdown(&mut self) { self.cancel.cancel(); while !self.tasks.is_empty() { self.tasks.pop().unwrap().join().unwrap(); } } async fn dump_events(&self) -> Result, KvRouterError> { let mut all_events = Vec::new(); // Create channels for each shard let mut receivers = Vec::new(); for shard_dump_tx in &self.dump_tx { let (resp_tx, resp_rx) = oneshot::channel(); let dump_req = DumpRequest { resp: resp_tx }; if let Err(e) = shard_dump_tx.send(dump_req).await { tracing::error!("Failed to send dump request to shard: {:?}", e); return Err(KvRouterError::IndexerOffline); } receivers.push(resp_rx); } // Collect results from all shards for resp_rx in receivers { match resp_rx.await { Ok(events) => all_events.extend(events), Err(_) => return Err(KvRouterError::IndexerDroppedRequest), } } Ok(all_events) } async fn process_routing_decision_for_request( &self, tokens_with_hashes: &mut TokensWithHashes, worker: WorkerWithDpRank, ) -> Result<(), KvRouterError> { let local_hashes = tokens_with_hashes.get_or_compute_block_hashes().to_vec(); let sequence_hashes = tokens_with_hashes.get_or_compute_seq_hashes().to_vec(); self.process_routing_decision_internal(worker, local_hashes, sequence_hashes) .await } } impl KvIndexerSharded { /// Internal method to process a routing decision with pre-computed hashes. async fn process_routing_decision_internal( &self, worker: WorkerWithDpRank, local_hashes: Vec, sequence_hashes: Vec, ) -> Result<(), KvRouterError> { // Route to the appropriate shard based on worker assignment let shard_idx = self .worker_assignments .get(&worker.worker_id) .copied() .unwrap_or(0); self.routing_tx[shard_idx] .send(RoutingDecisionRequest { worker, local_hashes, sequence_hashes, }) .await .map_err(|_| KvRouterError::IndexerDroppedRequest)?; Ok(()) } } impl Drop for KvIndexerSharded { fn drop(&mut self) { self.shutdown(); } } #[cfg(test)] mod tests { use super::*; use crate::protocols::{ExternalSequenceBlockHash, LocalBlockHash}; use rstest::rstest; use rstest_reuse::{self, *}; use std::time::Instant; use tokio::time; use tokio_util::sync::CancellationToken; fn setup() { // Logging init removed to avoid dynamo-runtime dependency } fn make_blocks(hashes: Vec) -> Vec { hashes .iter() .map(|i| KvCacheStoredBlockData { tokens_hash: LocalBlockHash(*i), block_hash: ExternalSequenceBlockHash(*i * 100), mm_extra_info: None, }) .collect() } fn add_blocks( hashes: Vec, parent_hash: Option, ) -> KvCacheEventData { KvCacheEventData::Stored(KvCacheStoreData { parent_hash, blocks: make_blocks(hashes), }) } fn create_store_event( worker_id: WorkerId, event_id: u64, hashes: Vec, parent: Option, ) -> RouterEvent { RouterEvent { worker_id, event: KvCacheEvent { event_id, data: add_blocks(hashes, parent), dp_rank: 0, }, } } fn make_indexer( token: &CancellationToken, num_shards: usize, kv_block_size: u32, ) -> Box { let metrics = KvIndexerMetrics::new_unregistered(); if num_shards == 1 { Box::new(KvIndexer::new(token.clone(), kv_block_size, metrics.into())) } else { Box::new(KvIndexerSharded::new( token.clone(), num_shards, kv_block_size, metrics.into(), )) } } #[template] #[rstest] fn indexer_template( #[values(1, 3, 8)] num_shards: usize, #[values(11, 32, 64)] kv_block_size: usize, ) { } #[tokio::test] #[apply(indexer_template)] async fn test_kv_indexer_new(num_shards: usize, kv_block_size: u32) { setup(); let token: CancellationToken = CancellationToken::new(); let _ = make_indexer(&token, num_shards, kv_block_size); } #[tokio::test] #[apply(indexer_template)] async fn test_find_matches(num_shards: usize, kv_block_size: u32) { setup(); let token = CancellationToken::new(); let kv_indexer = make_indexer(&token, num_shards, kv_block_size); let sequence = vec![compute_block_hash(b"test data")]; let scores = kv_indexer.find_matches(sequence).await; assert!(scores.unwrap().scores.is_empty()); } #[tokio::test] #[apply(indexer_template)] async fn test_find_matches_for_request(num_shards: usize, kv_block_size: u32) { setup(); let token = CancellationToken::new(); let kv_indexer = make_indexer(&token, num_shards, kv_block_size); let tokens = vec![1, 2, 3, 4]; let scores = kv_indexer.find_matches_for_request(&tokens).await; assert!(scores.unwrap().scores.is_empty()); } #[tokio::test] #[apply(indexer_template)] async fn test_apply_event(num_shards: usize, kv_block_size: u32) { setup(); let worker_id = 0; let token = CancellationToken::new(); let mut kv_indexer = make_indexer(&token, num_shards, kv_block_size); let event = create_store_event(worker_id, 1, vec![1, 2, 3], None); kv_indexer.apply_event(event).await; // No assertion here, just ensuring it runs without panic } #[tokio::test] #[apply(indexer_template)] async fn test_shutdown(num_shards: usize, kv_block_size: u32) { setup(); let token = CancellationToken::new(); let mut kv_indexer = make_indexer(&token, num_shards, kv_block_size); kv_indexer.shutdown(); } #[tokio::test] #[apply(indexer_template)] async fn test_frequency(num_shards: usize, kv_block_size: u32) { const ONE_MILLIS: Duration = Duration::from_millis(1); setup(); let mut kv_indexer: Box; let token = CancellationToken::new(); let expiration = Duration::from_millis(50); let metrics = Arc::new(KvIndexerMetrics::new_unregistered()); if num_shards == 1 { kv_indexer = Box::new(KvIndexer::new_with_frequency( token, Some(expiration), kv_block_size, metrics, None, )); } else { kv_indexer = Box::new(KvIndexerSharded::new_with_frequency( token, num_shards, Some(expiration), kv_block_size, metrics, None, )); } // The blocks let block_hashes = vec![ LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3), LocalBlockHash(4), ]; let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap(); assert_eq!( overlap.frequencies.len(), 0, "Should be no cached blocks yet" ); // Blocks go in cache let worker_id = 0; let event = create_store_event(worker_id, 0, vec![1, 2, 3, 4], None); kv_indexer.apply_event(event).await; // First access // The store event is applied async so poll briefly let mut overlap = OverlapScores::default(); let timeout = Duration::from_millis(10); let start = Instant::now(); while overlap.scores.is_empty() && Instant::now().duration_since(start) < timeout { time::sleep(ONE_MILLIS).await; overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap(); } assert_eq!( overlap.scores.len(), 1, "One worker has these blocks cached" ); assert_eq!( overlap.frequencies.len(), 0, "Blocks have not previously been accessed" ); // Second access let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap(); assert_eq!(overlap.scores.len(), 1, "Still one worker matches"); assert_eq!( overlap.frequencies, vec![1, 1, 1, 1], "We should see the first access now" ); // Let those two accesses expire time::sleep(expiration + Duration::from_millis(10)).await; // New first access let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap(); assert_eq!( overlap.frequencies.len(), 0, "Blocks were accessed too long ago" ); // New second access let _ = kv_indexer.find_matches(block_hashes.clone()).await.unwrap(); // Access only the first three blocks let overlap = kv_indexer .find_matches(block_hashes[0..3].to_vec()) .await .unwrap(); // We see the previous two new accesses assert_eq!(overlap.frequencies, vec![2, 2, 2]); // The third access did not touch the last block let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap(); assert_eq!(overlap.frequencies, vec![3, 3, 3, 2]); } #[tokio::test] async fn test_dump_tree_as_events_round_trip() { setup(); // Configuration let kv_block_size = 32; let num_shards = 2; let metrics = Arc::new(KvIndexerMetrics::new_unregistered()); // Build a non-trivial indexer with events let token1 = CancellationToken::new(); let mut original_indexer = KvIndexerSharded::new(token1.clone(), num_shards, kv_block_size, metrics.clone()); let worker_0 = 0; let worker_1 = 1; let worker_2 = 2; // Apply events to the original indexer original_indexer .apply_event(create_store_event(worker_0, 0, vec![1, 2, 3], None)) .await; original_indexer .apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None)) .await; original_indexer .apply_event(create_store_event( worker_1, 2, vec![4, 5], Some(ExternalSequenceBlockHash(100)), )) .await; original_indexer .apply_event(create_store_event(worker_2, 3, vec![6, 7], None)) .await; original_indexer .apply_event(create_store_event( worker_0, 4, vec![4], Some(ExternalSequenceBlockHash(100)), )) .await; // Allow some time for events to be processed tokio::time::sleep(Duration::from_millis(50)).await; // Dump the original indexer let dump1 = original_indexer.dump_events().await.unwrap(); println!("Dumped {} events", dump1.len()); // Create a new indexer and apply all dumped events let token2 = CancellationToken::new(); let mut reconstructed_indexer = KvIndexerSharded::new(token2.clone(), num_shards, kv_block_size, metrics); for event in &dump1 { reconstructed_indexer.apply_event(event.clone()).await; } // Allow some time for events to be processed tokio::time::sleep(Duration::from_millis(50)).await; // Dump the reconstructed indexer let dump2 = reconstructed_indexer.dump_events().await.unwrap(); // Sort both dumps for comparison (order might differ due to HashMap iteration and sharding) let mut sorted_dump1 = dump1.clone(); let mut sorted_dump2 = dump2.clone(); // Sort by (worker_id, tokens_hash, parent_hash) let sort_key = |event: &RouterEvent| { if let KvCacheEventData::Stored(ref data) = event.event.data { ( event.worker_id, data.blocks.first().map(|b| b.tokens_hash.0).unwrap_or(0), data.parent_hash.map(|h| h.0).unwrap_or(0), ) } else { (event.worker_id, 0, 0) } }; sorted_dump1.sort_by_key(sort_key); sorted_dump2.sort_by_key(sort_key); // Verify the dumps have the same length assert_eq!( sorted_dump1.len(), sorted_dump2.len(), "Dumps have different lengths: {} vs {}", sorted_dump1.len(), sorted_dump2.len() ); // Verify each event matches for (i, (event1, event2)) in sorted_dump1.iter().zip(sorted_dump2.iter()).enumerate() { assert_eq!( event1.worker_id, event2.worker_id, "Event {} worker_id mismatch", i ); if let (KvCacheEventData::Stored(data1), KvCacheEventData::Stored(data2)) = (&event1.event.data, &event2.event.data) { assert_eq!( data1.parent_hash, data2.parent_hash, "Event {} parent_hash mismatch", i ); assert_eq!( data1.blocks.len(), data2.blocks.len(), "Event {} blocks length mismatch", i ); for (j, (block1, block2)) in data1.blocks.iter().zip(data2.blocks.iter()).enumerate() { assert_eq!( block1.tokens_hash, block2.tokens_hash, "Event {} block {} tokens_hash mismatch", i, j ); assert_eq!( block1.block_hash, block2.block_hash, "Event {} block {} block_hash mismatch", i, j ); } } else { panic!("Expected Stored events in both dumps"); } } // Also verify that both indexers produce the same match results for test_seq in [ vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)], vec![LocalBlockHash(1), LocalBlockHash(4), LocalBlockHash(5)], vec![LocalBlockHash(6), LocalBlockHash(7)], vec![LocalBlockHash(1)], ] { let scores1 = original_indexer .find_matches(test_seq.clone()) .await .unwrap(); let scores2 = reconstructed_indexer .find_matches(test_seq.clone()) .await .unwrap(); // Sort the scores to compare let mut scores1_sorted: Vec<_> = scores1.scores.iter().collect(); let mut scores2_sorted: Vec<_> = scores2.scores.iter().collect(); scores1_sorted.sort_by_key(|(k, _)| *k); scores2_sorted.sort_by_key(|(k, _)| *k); assert_eq!( scores1_sorted, scores2_sorted, "Match scores differ for sequence {:?}", test_seq ); } // Clean up original_indexer.shutdown(); reconstructed_indexer.shutdown(); } #[test] fn test_increment_event_applied() { let metrics = KvIndexerMetrics::new_unregistered(); metrics.increment_event_applied(METRIC_EVENT_STORED, Ok(())); assert_eq!( metrics .kv_cache_events_applied .get_metric_with_label_values(&[METRIC_EVENT_STORED, METRIC_STATUS_OK]) .unwrap() .get(), 1 ); metrics.increment_event_applied( METRIC_EVENT_STORED, Err(KvCacheEventError::ParentBlockNotFound), ); assert_eq!( metrics .kv_cache_events_applied .get_metric_with_label_values(&[ METRIC_EVENT_STORED, METRIC_STATUS_PARENT_NOT_FOUND ]) .unwrap() .get(), 1 ); metrics .increment_event_applied(METRIC_EVENT_REMOVED, Err(KvCacheEventError::BlockNotFound)); assert_eq!( metrics .kv_cache_events_applied .get_metric_with_label_values(&[ METRIC_EVENT_REMOVED, METRIC_STATUS_BLOCK_NOT_FOUND ]) .unwrap() .get(), 1 ); } // LocalKvIndexer tests fn make_indexer_with_events(ids: &[u64]) -> LocalKvIndexer { let indexer = LocalKvIndexer::new( CancellationToken::new(), 4, Arc::new(KvIndexerMetrics::new_unregistered()), 32, ); { let mut buffer = indexer.event_buffer.lock().unwrap(); for &id in ids { buffer.push_back(RouterEvent::new( 0, KvCacheEvent { event_id: id, data: KvCacheEventData::Cleared, dp_rank: 0, }, )); } } indexer } #[tokio::test] async fn returns_slice_within_range() { let indexer = make_indexer_with_events(&[1, 2, 3, 4, 5]); // Helper to extract events from response let extract_events = |resp: WorkerKvQueryResponse| -> Vec { match resp { WorkerKvQueryResponse::Events(e) => e, WorkerKvQueryResponse::TreeDump(e) => e, _ => panic!("Unexpected response type"), } }; let get_ids = |events: Vec| -> Vec { events.iter().map(|e| e.event.event_id).collect() }; // Test get_events_in_id_range (buffer queries) // Range is [start, end] inclusive let result = indexer.get_events_in_id_range(Some(2), Some(4)).await; let ids = get_ids(extract_events(result)); assert_eq!(ids, vec![2, 3, 4]); // inclusive range [2, 4] let result = indexer.get_events_in_id_range(Some(2), Some(6)).await; let ids = get_ids(extract_events(result)); assert_eq!(ids, vec![2, 3, 4, 5]); // clamp end to buffer max // start_id=0 is before buffer (first is 1), so should trigger tree dump let result = indexer.get_events_in_id_range(Some(0), Some(4)).await; assert!(matches!(result, WorkerKvQueryResponse::TreeDump(_))); let result = indexer.get_events_in_id_range(Some(3), Some(3)).await; let ids = get_ids(extract_events(result)); assert_eq!(ids, vec![3]); // single element when start == end // Invalid range: end < start let result = indexer.get_events_in_id_range(Some(5), Some(2)).await; assert!(matches!(result, WorkerKvQueryResponse::InvalidRange { .. })); } #[tokio::test] async fn test_get_events_in_id_range_all_cases() { // Create indexer with small buffer (5 events max) // This way older events will only be in the tree, not the buffer let indexer = LocalKvIndexer::new( CancellationToken::new(), 4, // block_size Arc::new(KvIndexerMetrics::new_unregistered()), 5, // max_buffer_size - only keeps 5 most recent events ); // Helper to create a test event let make_event = |id: u64| { RouterEvent::new( 0, // worker_id KvCacheEvent { event_id: id, data: KvCacheEventData::Stored(KvCacheStoreData { parent_hash: None, blocks: vec![KvCacheStoredBlockData { block_hash: ExternalSequenceBlockHash(id * 100), tokens_hash: LocalBlockHash(id * 200), mm_extra_info: None, }], }), dp_rank: 0, }, ) }; // Add 10 events (IDs 5-14) // Buffer will only keep the last 5: events 10-14 // Tree will have all blocks for id in 5..15 { indexer .apply_event_with_buffer(make_event(id)) .await .unwrap(); } // Wait for events to be processed by the tree tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; // Helper to extract events from response let extract_events = |resp: WorkerKvQueryResponse| -> Vec { match resp { WorkerKvQueryResponse::Events(e) => e, WorkerKvQueryResponse::TreeDump(e) => e, _ => panic!("Unexpected response type: {:?}", resp), } }; // Helper to extract event IDs from result let get_ids = |events: Vec| -> Vec { events.iter().map(|e| e.event.event_id).collect() }; // Verify buffer state: should have events 10-14 (last 5) let buffer_events = indexer.get_all_events_in_buffer(); assert_eq!( get_ids(buffer_events), vec![10, 11, 12, 13, 14], "Buffer should have events 10-14" ); // ========== BUFFER PATH TESTS (start_id >= first_buffered) ========== // Range is [start, end] inclusive // Test: start_id within buffer, no end let result = indexer.get_events_in_id_range(Some(11), None).await; assert!(matches!(result, WorkerKvQueryResponse::Events(_))); assert_eq!( get_ids(extract_events(result)), vec![11, 12, 13, 14], "start_id=11 (in buffer) should return [11, 14]" ); // Test: start_id at buffer boundary let result = indexer.get_events_in_id_range(Some(10), None).await; assert!(matches!(result, WorkerKvQueryResponse::Events(_))); assert_eq!( get_ids(extract_events(result)), vec![10, 11, 12, 13, 14], "start_id=10 (buffer start) should return [10, 14]" ); // Test: both start and end within buffer (inclusive) let result = indexer.get_events_in_id_range(Some(11), Some(13)).await; assert!(matches!(result, WorkerKvQueryResponse::Events(_))); assert_eq!( get_ids(extract_events(result)), vec![11, 12, 13], "range [11, 13] inclusive should return 3 events" ); let result = indexer.get_events_in_id_range(Some(10), Some(14)).await; assert!(matches!(result, WorkerKvQueryResponse::Events(_))); assert_eq!( get_ids(extract_events(result)), vec![10, 11, 12, 13, 14], "range [10, 14] should return all buffer events" ); // ========== TREE DUMP PATH TESTS (range extends before buffer) ========== // Note: Tree dumps return synthetic 0-indexed event IDs, so we just check // that we get events back (the IDs won't match original IDs) // Test: (None, None) dumps entire tree let result = indexer.get_events_in_id_range(None, None).await; assert!(matches!(result, WorkerKvQueryResponse::TreeDump(_))); assert_eq!( extract_events(result).len(), 10, "(None, None) should dump entire tree (10 events)" ); // Test: (None, Some(_)) dumps entire tree let result = indexer.get_events_in_id_range(None, Some(8)).await; assert!(matches!(result, WorkerKvQueryResponse::TreeDump(_))); assert_eq!( extract_events(result).len(), 10, "(None, Some(_)) dumps entire tree - end_id is ignored for tree dumps" ); // Test: start_id before buffer triggers tree dump let result = indexer.get_events_in_id_range(Some(7), None).await; assert!(matches!(result, WorkerKvQueryResponse::TreeDump(_))); assert_eq!( extract_events(result).len(), 10, "start_id=7 (before buffer) should dump entire tree" ); let result = indexer.get_events_in_id_range(Some(5), Some(12)).await; assert!(matches!(result, WorkerKvQueryResponse::TreeDump(_))); assert_eq!( extract_events(result).len(), 10, "range [5, 12] extending before buffer should dump entire tree" ); // ========== EDGE CASES ========== // Single element when start == end (inclusive range) let result = indexer.get_events_in_id_range(Some(12), Some(12)).await; assert!(matches!(result, WorkerKvQueryResponse::Events(_))); assert_eq!( get_ids(extract_events(result)), vec![12], "start == end should return single event" ); // InvalidRange when start > end let result = indexer.get_events_in_id_range(Some(15), Some(10)).await; assert!( matches!(result, WorkerKvQueryResponse::InvalidRange { .. }), "start > end should return InvalidRange" ); // TooNew when start_id is beyond buffer let result = indexer.get_events_in_id_range(Some(100), Some(200)).await; assert!( matches!(result, WorkerKvQueryResponse::TooNew { .. }), "start_id beyond buffer should return TooNew" ); // Request with end beyond buffer but valid start -> buffer returns what it has let result = indexer.get_events_in_id_range(Some(12), Some(100)).await; assert!(matches!(result, WorkerKvQueryResponse::Events(_))); assert_eq!( get_ids(extract_events(result)), vec![12, 13, 14], "range with end beyond buffer should return available buffer events" ); } #[tokio::test] async fn test_local_indexer_buffer_and_serialization() { // Tests components of the LocalKvIndexer query without using nats let worker_id = 42u64; // Create a local indexer let token = CancellationToken::new(); let metrics = Arc::new(KvIndexerMetrics::new_unregistered()); let local_indexer = Arc::new(LocalKvIndexer::new(token.clone(), 4, metrics, 100)); // Add events to local indexer's buffer let test_event_1 = RouterEvent::new( worker_id, KvCacheEvent { event_id: 1, data: KvCacheEventData::Stored(KvCacheStoreData { parent_hash: None, blocks: vec![KvCacheStoredBlockData { block_hash: ExternalSequenceBlockHash(100), tokens_hash: LocalBlockHash(200), mm_extra_info: None, }], }), dp_rank: 0, }, ); // Apply events with buffer local_indexer .apply_event_with_buffer(test_event_1) .await .unwrap(); // Wait for events to be processed tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; // Get buffered events (what the query service would return) let buffered_events = local_indexer.get_all_events_in_buffer(); // Verify buffer contents assert_eq!(buffered_events.len(), 1, "Buffer should have 1 event"); assert_eq!(buffered_events[0].worker_id, worker_id); assert_eq!(buffered_events[0].event.event_id, 1); // Build the response that would be sent (Events variant) let response = WorkerKvQueryResponse::Events(buffered_events.clone()); // Test serialization/deserialization (simulating NATS round-trip) let serialized = serde_json::to_vec(&response).unwrap(); let deserialized: WorkerKvQueryResponse = serde_json::from_slice(&serialized).unwrap(); // Verify response correctness let events = match deserialized { WorkerKvQueryResponse::Events(e) => e, _ => panic!("Expected Events variant"), }; assert_eq!(events.len(), 1); assert_eq!(events[0].worker_id, worker_id); assert_eq!(events[0].event.event_id, 1); // Verify event data match &events[0].event.data { KvCacheEventData::Stored(store_data) => { assert_eq!(store_data.blocks.len(), 1); assert_eq!(store_data.blocks[0].block_hash.0, 100); assert_eq!(store_data.blocks[0].tokens_hash.0, 200); } _ => panic!("Expected Stored event"), } } } /// Tests for KvIndex enum (parametrized over RadixTree and FlatHashMap variants). #[cfg(test)] mod kv_index_tests { use super::*; use crate::protocols::{ExternalSequenceBlockHash, LocalBlockHash, compute_seq_hash_for_block}; use rstest::rstest; use rstest_reuse::{self, *}; /// Create a store event with proper sequence hashes computed from local hashes. fn make_store_event(worker_id: u64, local_hashes: &[u64]) -> RouterEvent { let local_block_hashes: Vec = local_hashes.iter().map(|&h| LocalBlockHash(h)).collect(); let seq_hashes = compute_seq_hash_for_block(&local_block_hashes); RouterEvent { worker_id, event: KvCacheEvent { event_id: 0, data: KvCacheEventData::Stored(KvCacheStoreData { parent_hash: None, blocks: local_block_hashes .iter() .zip(seq_hashes.iter()) .map(|(&local, &seq)| KvCacheStoredBlockData { tokens_hash: local, block_hash: ExternalSequenceBlockHash(seq), mm_extra_info: None, }) .collect(), }), dp_rank: 0, }, } } /// Create a remove event for blocks with given local hashes. fn make_remove_event(worker_id: u64, local_hashes: &[u64]) -> RouterEvent { let local_block_hashes: Vec = local_hashes.iter().map(|&h| LocalBlockHash(h)).collect(); let seq_hashes = compute_seq_hash_for_block(&local_block_hashes); RouterEvent { worker_id, event: KvCacheEvent { event_id: 0, data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes: seq_hashes .iter() .map(|&h| ExternalSequenceBlockHash(h)) .collect(), }), dp_rank: 0, }, } } #[template] #[rstest] fn kv_index_template(#[values("tree", "flat")] variant: &str) {} fn make_kv_index(variant: &str) -> KvIndex { match variant { "tree" => KvIndex::new_tree(), "flat" => KvIndex::new_flat(), _ => panic!("Unknown variant: {}", variant), } } #[apply(kv_index_template)] fn test_store_and_find(variant: &str) { let mut index = make_kv_index(variant); // Store a sequence for worker 0 index.apply_event(make_store_event(0, &[1, 2, 3])).unwrap(); assert_eq!(index.current_size(), 3); // Find matches using local hashes let scores = index.find_matches( vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)], false, ); assert_eq!(scores.scores.len(), 1); assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3); } #[apply(kv_index_template)] fn test_partial_match(variant: &str) { let mut index = make_kv_index(variant); // Store [1, 2, 3] for worker 0 index.apply_event(make_store_event(0, &[1, 2, 3])).unwrap(); // Find matches for [1, 2, 999] - should match first 2 then stop let scores = index.find_matches( vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(999)], false, ); assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2); } #[apply(kv_index_template)] fn test_remove(variant: &str) { let mut index = make_kv_index(variant); // Store sequence for worker 0 index.apply_event(make_store_event(0, &[1, 2, 3])).unwrap(); assert_eq!(index.current_size(), 3); // Remove all blocks index.apply_event(make_remove_event(0, &[1, 2, 3])).unwrap(); assert_eq!(index.current_size(), 0); // Find should return nothing let scores = index.find_matches( vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)], false, ); assert!(scores.scores.is_empty()); } #[apply(kv_index_template)] fn test_multiple_workers_shared_prefix(variant: &str) { let mut index = make_kv_index(variant); // Worker 0 has [1, 2], Worker 1 has [1, 3] // Since sequence hashes are cumulative, [1] has same hash for both, // but [1, 2] and [1, 3] have different hashes. index.apply_event(make_store_event(0, &[1, 2])).unwrap(); index.apply_event(make_store_event(1, &[1, 3])).unwrap(); // Query [1] - both workers should match let scores = index.find_matches(vec![LocalBlockHash(1)], false); assert_eq!(scores.scores.len(), 2); assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 1); assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(), 1); // Query [1, 2] - worker 0 matches both, worker 1 matches only first block let scores = index.find_matches(vec![LocalBlockHash(1), LocalBlockHash(2)], false); assert_eq!(scores.scores.len(), 2); assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2); assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(), 1); } #[apply(kv_index_template)] fn test_remove_worker(variant: &str) { let mut index = make_kv_index(variant); index.apply_event(make_store_event(0, &[1, 2, 3])).unwrap(); index.apply_event(make_store_event(1, &[1, 2, 3])).unwrap(); assert_eq!(index.current_size(), 6); index.remove_worker(0); assert_eq!(index.current_size(), 3); let scores = index.find_matches( vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)], false, ); assert_eq!(scores.scores.len(), 1); assert!(scores.scores.contains_key(&WorkerWithDpRank::new(1, 0))); } #[apply(kv_index_template)] fn test_get_workers(variant: &str) { let mut index = make_kv_index(variant); index.apply_event(make_store_event(0, &[1])).unwrap(); index.apply_event(make_store_event(2, &[1])).unwrap(); index.apply_event(make_store_event(1, &[1])).unwrap(); let workers = index.get_workers(); assert_eq!(workers, vec![0, 1, 2]); } #[apply(kv_index_template)] fn test_early_exit(variant: &str) { let mut index = make_kv_index(variant); // Worker 0 has [0, 1, 2], Worker 1 has [0] only index.apply_event(make_store_event(0, &[0, 1, 2])).unwrap(); index.apply_event(make_store_event(1, &[0])).unwrap(); // Query [0, 1, 2] with early_exit=true // Should stop after [0, 1] since only worker 0 has block 1 let scores = index.find_matches( vec![LocalBlockHash(0), LocalBlockHash(1), LocalBlockHash(2)], true, ); // Both workers should appear in results assert_eq!(scores.scores.len(), 2); // Worker 0 got 2 points (blocks 0 and 1, stopped early) assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2); // Worker 1 got 1 point (block 0 only) assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(), 1); // Without early_exit, worker 0 should get all 3 blocks let scores = index.find_matches( vec![LocalBlockHash(0), LocalBlockHash(1), LocalBlockHash(2)], false, ); assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3); } #[apply(kv_index_template)] fn test_large_stores(variant: &str) { let mut index = make_kv_index(variant); // Test sequences of increasing sizes for i in 0..10 { let len = 1 << i; // 1, 2, 4, 8, ..., 512 let worker_id = i; let sequence: Vec = (1..=len).map(|x| x + (i as u64 * 10000)).collect(); index .apply_event(make_store_event(worker_id, &sequence)) .unwrap(); assert!(index.current_size() > 0); } } #[apply(kv_index_template)] fn test_dump_and_restore(variant: &str) { let mut index = make_kv_index(variant); // Store some data index.apply_event(make_store_event(0, &[1, 2, 3])).unwrap(); index.apply_event(make_store_event(1, &[1, 2, 4])).unwrap(); let original_size = index.current_size(); let workers_before = index.get_workers(); // Dump the tree as events let events = index.dump_tree_as_events(); assert!(!events.is_empty()); // Create a new index and replay events let mut restored = make_kv_index(variant); for event in events { let _ = restored.apply_event(event); } // Verify the restored index has same size and workers assert_eq!(restored.current_size(), original_size); assert_eq!(restored.get_workers(), workers_before); // Verify find_matches produces same results let original_scores = index.find_matches(vec![LocalBlockHash(1), LocalBlockHash(2)], false); let restored_scores = restored.find_matches(vec![LocalBlockHash(1), LocalBlockHash(2)], false); assert_eq!(original_scores.scores, restored_scores.scores); } #[apply(kv_index_template)] fn test_clear_all_blocks(variant: &str) { let mut index = make_kv_index(variant); // Store some data for two workers index.apply_event(make_store_event(0, &[1, 2, 3])).unwrap(); index.apply_event(make_store_event(1, &[1, 2, 3])).unwrap(); assert_eq!(index.current_size(), 6); // Clear worker 0's blocks index.clear_all_blocks(0); // Worker 0's blocks should be gone, worker 1's remain assert_eq!(index.current_size(), 3); let scores = index.find_matches( vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)], false, ); assert_eq!(scores.scores.len(), 1); assert!(scores.scores.contains_key(&WorkerWithDpRank::new(1, 0))); } #[apply(kv_index_template)] fn test_empty_query(variant: &str) { let mut index = make_kv_index(variant); index.apply_event(make_store_event(0, &[1, 2, 3])).unwrap(); // Empty query should return empty scores let scores = index.find_matches(vec![], false); assert!(scores.scores.is_empty()); } #[apply(kv_index_template)] fn test_miss_query(variant: &str) { let mut index = make_kv_index(variant); index.apply_event(make_store_event(0, &[1, 2, 3])).unwrap(); // Query for non-existent blocks let scores = index.find_matches(vec![LocalBlockHash(999), LocalBlockHash(998)], false); assert!(scores.scores.is_empty()); } }