// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 //! Positional HashMap-based KV cache index with nested structure. //! //! This module provides a `PositionalIndexer` that uses nested HashMaps //! keyed by position for better cache locality and enables jump/binary-search //! optimizations in find_matches. //! //! # Structure //! //! - `index`: position -> local_hash -> seq_hash -> workers //! The main lookup structure. Position-first nesting enables O(1) position access. //! - `worker_blocks`: worker -> seq_hash -> (position, local_hash) //! Per-worker reverse lookup for efficient remove operations. //! //! # Threading //! //! `PositionalIndexer` implements `SyncIndexer`, meaning all its methods are //! synchronous and thread-safe (via `DashMap` and `RwLock`). To get the full //! `KvIndexerInterface` with sticky event routing and worker threads, wrap it //! in a `ThreadPoolIndexer`. use dashmap::DashMap; use dashmap::mapref::entry::Entry; use rustc_hash::{FxBuildHasher, FxHashMap, FxHashSet}; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; use super::{ EventKind, EventWarningKind, KvIndexerMetrics, PreBoundEventCounters, SyncIndexer, WorkerTask, }; use crate::active_set::reconcile_active_workers; use crate::protocols::{ DpRank, ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheEventError, KvCacheStoreData, KvCacheStoredBlockData, LocalBlockHash, OverlapScores, RouterEvent, WorkerId, WorkerWithDpRank, }; /// Entry for the innermost level of the index. /// /// Optimizes for the common case where there's only one sequence hash /// at a given (position, local_hash) pair, avoiding HashMap allocation. #[derive(Debug, Clone)] enum SeqEntry { /// Single seq_hash -> workers mapping (common case, no HashMap allocation) Single(ExternalSequenceBlockHash, FxHashSet), /// Multiple seq_hash -> workers mappings (rare case, different prefixes) Multi(FxHashMap>), } impl SeqEntry { /// Create a new entry with a single worker. fn new(seq_hash: ExternalSequenceBlockHash, worker: WorkerWithDpRank) -> Self { let mut workers = FxHashSet::default(); workers.insert(worker); Self::Single(seq_hash, workers) } /// Insert a worker for a given seq_hash, upgrading to Multi if needed. fn insert(&mut self, seq_hash: ExternalSequenceBlockHash, worker: WorkerWithDpRank) -> bool { match self { Self::Single(existing_hash, workers) if *existing_hash == seq_hash => { workers.insert(worker) } Self::Single(existing_hash, existing_workers) => { // Upgrade to Multi let mut map = FxHashMap::with_capacity_and_hasher(2, FxBuildHasher); map.insert(*existing_hash, std::mem::take(existing_workers)); map.entry(seq_hash).or_default().insert(worker); *self = Self::Multi(map); true } Self::Multi(map) => map.entry(seq_hash).or_default().insert(worker), } } /// Remove a worker from a given seq_hash. /// Returns true if the entry is now completely empty and should be removed. fn remove(&mut self, seq_hash: ExternalSequenceBlockHash, worker: WorkerWithDpRank) -> bool { match self { Self::Single(existing_hash, workers) if *existing_hash == seq_hash => { workers.remove(&worker); workers.is_empty() } Self::Single(_, _) => false, // Different hash, nothing to remove Self::Multi(map) => { if let Some(workers) = map.get_mut(&seq_hash) { workers.remove(&worker); if workers.is_empty() { map.remove(&seq_hash); } } map.is_empty() } } } /// Get workers for a specific seq_hash. fn get(&self, seq_hash: ExternalSequenceBlockHash) -> Option<&FxHashSet> { match self { Self::Single(existing_hash, workers) if *existing_hash == seq_hash => Some(workers), Self::Single(_, _) => None, Self::Multi(map) => map.get(&seq_hash), } } } pub type LevelIndex = FxHashMap; /// Positional HashMap-based KV cache index. /// /// Implements [`SyncIndexer`] for use with [`ThreadPoolIndexer`](crate::indexer::ThreadPoolIndexer). /// All methods are synchronous and thread-safe. pub struct PositionalIndexer { index: DashMap<(usize, LocalBlockHash), SeqEntry, FxBuildHasher>, tree_sizes: DashMap, jump_size: usize, } impl PositionalIndexer { /// Create a new PositionalIndexer. /// /// # Arguments /// * `jump_size` - Jump size for find_matches optimization (e.g., 32). /// The algorithm jumps by this many positions at a time, only scanning /// intermediate positions when workers drain (stop matching). pub fn new(jump_size: usize) -> Self { assert!(jump_size > 0, "jump_size must be greater than 0"); Self { index: DashMap::with_hasher(FxBuildHasher), tree_sizes: DashMap::with_hasher(FxBuildHasher), jump_size, } } } // ============================================================================ // SyncIndexer implementation // ============================================================================ impl SyncIndexer for PositionalIndexer { fn worker( &self, event_receiver: flume::Receiver, metrics: Option>, ) -> anyhow::Result<()> { let mut worker_blocks = FxHashMap::default(); let counters = metrics.as_ref().map(|m| m.prebind()); while let Ok(task) = event_receiver.recv() { match task { WorkerTask::Event(event) => { let kind = EventKind::of(&event.event.data); let result = self.apply_event(&mut worker_blocks, event, counters.as_ref()); if result.is_err() { tracing::warn!("Failed to apply event: {:?}", result.as_ref().err()); } if let Some(ref c) = counters { c.inc(kind, result); } } WorkerTask::RemoveWorker(worker_id) => { self.remove_or_clear_worker_blocks_impl(&mut worker_blocks, worker_id, false); } WorkerTask::RemoveWorkerDpRank(worker_id, dp_rank) => { self.remove_worker_dp_rank_impl(&mut worker_blocks, worker_id, dp_rank); } WorkerTask::CleanupStaleChildren => { self.run_cleanup_task(); } WorkerTask::DumpEvents(sender) => { let events = self.dump_events(&worker_blocks); if let Err(e) = sender.send(Ok(events)) { tracing::warn!("Failed to send events: {:?}", e); } } WorkerTask::Terminate => { break; } } } tracing::debug!("PositionalIndexer worker thread shutting down"); Ok(()) } fn find_matches(&self, sequence: &[LocalBlockHash], early_exit: bool) -> OverlapScores { self.jump_search_matches(sequence, early_exit) } fn worker_count(&self) -> usize { self.tree_sizes.len() } fn block_count(&self) -> usize { self.tree_sizes .iter() .map(|e| e.value().load(Ordering::Relaxed)) .sum() } } // ============================================================================ // Event processing (write operations) // ============================================================================ impl PositionalIndexer { /// Process an event using the provided index and worker_blocks. /// This is called from worker threads. pub fn apply_event( &self, worker_blocks: &mut FxHashMap, event: RouterEvent, counters: Option<&PreBoundEventCounters>, ) -> Result<(), KvCacheEventError> { let (worker_id, kv_event) = (event.worker_id, event.event); let (id, op) = (kv_event.event_id, kv_event.data); let worker = WorkerWithDpRank::new(worker_id, kv_event.dp_rank); tracing::trace!( id, "PositionalIndexer::apply_event_impl: operation: {:?}", op ); match op { KvCacheEventData::Stored(store_data) => { self.store_blocks_impl(worker_blocks, worker, store_data, id, counters)?; Ok(()) } KvCacheEventData::Removed(remove_data) => { self.remove_blocks_impl(worker_blocks, worker, &remove_data.block_hashes, id)?; Ok(()) } KvCacheEventData::Cleared => { self.clear_worker_blocks_impl(worker_blocks, worker_id); Ok(()) } } } fn store_blocks_impl( &self, worker_blocks: &mut FxHashMap, worker: WorkerWithDpRank, store_data: KvCacheStoreData, event_id: u64, counters: Option<&PreBoundEventCounters>, ) -> Result<(), KvCacheEventError> { let KvCacheStoreData { parent_hash, start_position, blocks, } = store_data; let worker_map = worker_blocks.entry(worker).or_default(); let start_pos = match start_position { Some(start_position) => start_position as usize, None => match parent_hash { Some(parent_hash) => { let Some(entry) = worker_map.get(&parent_hash) else { tracing::warn!( worker_id = worker.worker_id.to_string(), dp_rank = worker.dp_rank, event_id, parent_hash = ?parent_hash, ); return Err(KvCacheEventError::ParentBlockNotFound); }; entry.0 + 1 // parent position + 1 } None => 0, // Start from position 0 }, }; let worker_blocks_entry = worker_blocks.entry(worker).or_default(); let mut num_blocks_added = 0usize; let mut duplicate_store = !blocks.is_empty(); for (i, block_data) in blocks.into_iter().enumerate() { let position = start_pos + i; let local_hash = block_data.tokens_hash; let seq_hash = block_data.block_hash; match self.index.entry((position, local_hash)) { Entry::Occupied(mut entry) => { if entry.get_mut().insert(seq_hash, worker) { duplicate_store = false; } } Entry::Vacant(entry) => { entry.insert(SeqEntry::new(seq_hash, worker)); duplicate_store = false; } } // Insert into worker_blocks: worker -> seq_hash -> (position, local_hash) match worker_blocks_entry.insert(seq_hash, (position, local_hash)) { Some(existing) if existing == (position, local_hash) => {} Some(_) => duplicate_store = false, None => { num_blocks_added += 1; duplicate_store = false; } } } match self.tree_sizes.get(&worker) { Some(size) => { size.fetch_add(num_blocks_added, Ordering::Relaxed); } None => { self.tree_sizes .insert(worker, AtomicUsize::new(num_blocks_added)); } } if duplicate_store && let Some(counters) = counters { counters.inc_warning(EventWarningKind::DuplicateStore); } Ok(()) } fn remove_blocks_impl( &self, worker_blocks: &mut FxHashMap, worker: WorkerWithDpRank, seq_hashes: &Vec, event_id: u64, ) -> Result<(), KvCacheEventError> { let worker_map = worker_blocks.get_mut(&worker).ok_or_else(|| { tracing::warn!( worker_id = worker.worker_id.to_string(), dp_rank = worker.dp_rank, event_id, block_hashes = ?seq_hashes, "Failed to find worker blocks to remove" ); KvCacheEventError::BlockNotFound })?; let mut num_removed_blocks = 0; for seq_hash in seq_hashes { let Some((position, local_hash)) = worker_map.remove(seq_hash) else { tracing::warn!( worker_id = worker.worker_id.to_string(), dp_rank = worker.dp_rank, event_id, block_hash = ?seq_hash, "Failed to find block to remove; skipping remove operation" ); if let Some(size) = self.tree_sizes.get(&worker) { size.fetch_sub(num_removed_blocks, Ordering::Relaxed); } return Err(KvCacheEventError::BlockNotFound); }; if let Some(mut entry) = self.index.get_mut(&(position, local_hash)) { let _ = entry.remove(*seq_hash, worker); } num_removed_blocks += 1; } if let Some(size) = self.tree_sizes.get(&worker) { size.fetch_sub(num_removed_blocks, Ordering::Relaxed); } Ok(()) } /// Clear all blocks for a specific worker_id (all dp_ranks), but keep worker tracked. /// Static version for use in worker threads. fn clear_worker_blocks_impl( &self, worker_blocks: &mut FxHashMap, worker_id: WorkerId, ) { self.remove_or_clear_worker_blocks_impl(worker_blocks, worker_id, true); } fn remove_worker_dp_rank_impl( &self, worker_blocks: &mut FxHashMap, worker_id: WorkerId, dp_rank: DpRank, ) { let key = WorkerWithDpRank { worker_id, dp_rank }; if let Some(worker_map) = worker_blocks.remove(&key) { for (seq_hash, (position, local_hash)) in worker_map.iter() { if let Some(mut entry) = self.index.get_mut(&(*position, *local_hash)) { let _ = entry.remove(*seq_hash, key); } } self.tree_sizes.remove(&key); } } /// Helper function to remove or clear blocks for a worker. /// If `keep_worker` is true, the worker remains tracked with empty blocks. /// If `keep_worker` is false, the worker is completely removed. fn remove_or_clear_worker_blocks_impl( &self, worker_blocks: &mut FxHashMap, worker_id: WorkerId, keep_worker: bool, ) { let workers: Vec = worker_blocks .iter() .filter(|entry| entry.0.worker_id == worker_id) .map(|entry| *entry.0) .collect(); for worker in workers { if let Some(worker_map) = worker_blocks.remove(&worker) { for (seq_hash, (position, local_hash)) in worker_map.iter() { if let Some(mut entry) = self.index.get_mut(&(*position, *local_hash)) { let _ = entry.remove(*seq_hash, worker); } } } if keep_worker { // Re-insert worker with empty map to keep it tracked worker_blocks.insert(worker, FxHashMap::default()); // Reset tree size to 0 but keep the entry so scoring remains consistent. if let Some(size) = self.tree_sizes.get(&worker) { size.store(0, Ordering::Relaxed); } } else { // Fully remove the worker from tree_sizes. self.tree_sizes.remove(&worker); } } } fn dump_events( &self, worker_blocks: &FxHashMap, ) -> Vec { let mut events = Vec::new(); let mut event_id = 0u64; for (worker, worker_map) in worker_blocks.iter() { // Collect (position, local_hash, seq_hash) and sort by position. let mut blocks: Vec<_> = worker_map .iter() .map(|(seq_hash, (pos, local_hash))| (*pos, *local_hash, *seq_hash)) .collect(); blocks.sort_unstable_by_key(|(pos, _, _)| *pos); for (pos, local_hash, seq_hash) in blocks { events.push(RouterEvent { worker_id: worker.worker_id, storage_tier: crate::protocols::StorageTier::Device, event: KvCacheEvent { event_id, data: KvCacheEventData::Stored(KvCacheStoreData { parent_hash: None, start_position: Some(pos as u32), blocks: vec![KvCacheStoredBlockData { block_hash: seq_hash, tokens_hash: local_hash, mm_extra_info: None, }], }), dp_rank: worker.dp_rank, }, }); event_id += 1; } } events } } // ----------------------------------------------------------------------------- // Jump-based search methods (associated functions for use in worker threads) // ----------------------------------------------------------------------------- impl PositionalIndexer { /// Compute sequence hash incrementally from previous hash and current local hash. #[inline] fn compute_next_seq_hash(prev_seq_hash: u64, current_local_hash: u64) -> u64 { let mut bytes = [0u8; 16]; bytes[..8].copy_from_slice(&prev_seq_hash.to_le_bytes()); bytes[8..].copy_from_slice(¤t_local_hash.to_le_bytes()); crate::protocols::compute_hash(&bytes) } /// Ensure seq_hashes is computed up to and including target_pos. /// Lazily extends the seq_hashes vector as needed. #[inline] fn ensure_seq_hash_computed( seq_hashes: &mut Vec, target_pos: usize, sequence: &[LocalBlockHash], ) { while seq_hashes.len() <= target_pos { let pos = seq_hashes.len(); if pos == 0 { // First block's seq_hash equals its local_hash seq_hashes.push(ExternalSequenceBlockHash::from(sequence[0].0)); } else { let prev_seq_hash = seq_hashes[pos - 1].0; let current_local_hash = sequence[pos].0; let next_hash = Self::compute_next_seq_hash(prev_seq_hash, current_local_hash); seq_hashes.push(ExternalSequenceBlockHash::from(next_hash)); } } } /// Get workers at a position by verifying both local_hash and seq_hash match. /// /// Returns None if no workers match at this position. /// Always computes and verifies the seq_hash to ensure correctness when /// the query may have diverged from stored sequences at earlier positions. fn get_workers_lazy( &self, position: usize, local_hash: LocalBlockHash, seq_hashes: &mut Vec, sequence: &[LocalBlockHash], ) -> Option> { let entry = self.index.get(&(position, local_hash))?; // Always compute and verify seq_hash to handle divergent queries correctly. // Even if there's only one seq_hash entry, the query's seq_hash might differ // if the query diverged from the stored sequence at an earlier position. Self::ensure_seq_hash_computed(seq_hashes, position, sequence); let seq_hash = seq_hashes[position]; entry.get(seq_hash).cloned() } fn count_workers_at( &self, position: usize, local_hash: LocalBlockHash, seq_hashes: &mut Vec, sequence: &[LocalBlockHash], ) -> Option { let entry = self.index.get(&(position, local_hash))?; // Always compute and verify seq_hash to handle divergent queries correctly. // Even if there's only one seq_hash entry, the query's seq_hash might differ // if the query diverged from the stored sequence at an earlier position. Self::ensure_seq_hash_computed(seq_hashes, position, sequence); let seq_hash = seq_hashes[position]; Some( entry .get(seq_hash) .map(|workers| workers.len()) .unwrap_or(0), ) } /// Scan positions sequentially, updating active set and recording drain scores. #[expect(clippy::too_many_arguments)] fn linear_scan_drain( &self, sequence: &[LocalBlockHash], seq_hashes: &mut Vec, active: &mut FxHashSet, scores: &mut OverlapScores, lo: usize, hi: usize, early_exit: bool, ) { if active.is_empty() { return; } for pos in lo..hi { if active.is_empty() { break; } let Some(entry) = self.index.get(&(pos, sequence[pos])) else { for worker in active.drain() { scores.scores.insert(worker, pos as u32); } break; }; Self::ensure_seq_hash_computed(seq_hashes, pos, sequence); let Some(workers) = entry.get(seq_hashes[pos]) else { for worker in active.drain() { scores.scores.insert(worker, pos as u32); } break; }; if workers.len() != active.len() { reconcile_active_workers(active, workers, |worker| { scores.scores.insert(worker, pos as u32); }); } if early_exit && !active.is_empty() { break; } } } /// Jump-based search to find matches for a sequence of block hashes. /// /// # Algorithm /// /// 1. Check first position - initialize active set with matching workers /// 2. Initialize seq_hashes with first block's hash (seq_hash[0] = local_hash[0]) /// 3. Loop: jump by jump_size positions /// - At each jump, check if active workers still match: /// - All match: Continue jumping (skip intermediate positions) /// - None match: Scan range with linear_scan_drain /// - Partial match: Scan range to find exact drain points /// 4. Record final scores for remaining active workers /// 5. Populate tree_sizes from worker_blocks /// /// # Arguments /// * `index` - The position -> local_hash -> SeqEntry index /// * `worker_blocks` - Per-worker reverse lookup for tree sizes /// * `local_hashes` - Sequence of LocalBlockHash to match /// * `jump_size` - Number of positions to jump at a time /// * `early_exit` - If true, stop after finding any match fn jump_search_matches( &self, local_hashes: &[LocalBlockHash], early_exit: bool, ) -> OverlapScores { let mut scores = OverlapScores::new(); if local_hashes.is_empty() { return scores; } // Lazily computed sequence hashes let mut seq_hashes: Vec = Vec::with_capacity(local_hashes.len()); // Check first position to initialize active set let Some(initial_workers) = self.get_workers_lazy(0, local_hashes[0], &mut seq_hashes, local_hashes) else { return scores; }; let mut active = initial_workers; if active.is_empty() { return scores; } if early_exit { // For early exit, just record that these workers matched at least position 0 for worker in &active { scores.scores.insert(*worker, 1); } // Populate tree_sizes for worker in scores.scores.keys() { if let Some(worker_tree_size) = self.tree_sizes.get(worker) { scores .tree_sizes .insert(*worker, worker_tree_size.load(Ordering::Relaxed)); } } return scores; } let len = local_hashes.len(); let mut current_pos = 0; // Jump through positions while current_pos < len - 1 && !active.is_empty() { let next_pos = (current_pos + self.jump_size).min(len - 1); // Check workers at jump destination let num_workers_at_next = self .count_workers_at( next_pos, local_hashes[next_pos], &mut seq_hashes, local_hashes, ) .unwrap_or(0); if num_workers_at_next == active.len() { current_pos = next_pos; } else { // No active workers match at jump destination // Scan the range to find where each worker drained self.linear_scan_drain( local_hashes, &mut seq_hashes, &mut active, &mut scores, current_pos + 1, next_pos + 1, false, ); current_pos = next_pos; } } // Record final scores for remaining active workers // They matched all positions through the end let final_score = len as u32; for worker in active { scores.scores.insert(worker, final_score); } for worker in scores.scores.keys() { if let Some(worker_tree_size) = self.tree_sizes.get(worker) { scores .tree_sizes .insert(*worker, worker_tree_size.load(Ordering::Relaxed)); } } scores } }