// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 //! # KV Manager (vLLM Backend) //! A synchronous implementation of a block manager that handles MoveBlock signals for caching KV blocks. //! //! Uses [`HashCache`] for O(1) block lookups with active/inactive pool management. //! //! ## Block Operations //! The KV manager processes four types of MoveBlock signals: //! //! ### Use //! - Checks if block exists in active pool → increment reference count //! - If in inactive pool → move to active pool //! - If neither → try evicting from inactive pool to make room //! - If inactive pool is empty → pre-empt the oldest running request //! //! ### Destroy //! - Removes the block from the active pool //! //! ### Deref //! - Decrements reference count of a block in active pool //! - If count reaches zero → move block to inactive pool //! //! ### Promote //! - Converts a partial block (uuid) into a full block (global block hash) //! //! ## Preemption //! If a Use operation fails (typically due to insufficient space), a false boolean signal //! is returned to the scheduler for preemption. Initial KV block allocations for new requests //! should not fail due to the capacity checking during scheduling. //! //! ## NOTE //! For simplicity (or non-simplicity), reference counting is tracked manually instead of using //! the more idiomatic built-in Arc reference counter. This can be considered a shadow / mirror //! implementation of the main block manager. use crate::cache::HashCache; use crate::common::kv_cache_trace; use crate::common::protocols::{KvEventPublishers, MoveBlock, PrefillCost}; use crate::common::sequence::ActiveSequence; use dynamo_kv_router::protocols::{ ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData, KvCacheStoredBlockData, LocalBlockHash, }; use dynamo_tokens::blocks::UniqueBlock; use dynamo_tokens::{BlockHash, SequenceHash}; use rustc_hash::FxHashMap; pub struct KvManager { cache: HashCache, block_size: usize, kv_event_publishers: KvEventPublishers, dp_rank: u32, next_event_id: u64, } impl KvManager { pub fn new(max_capacity: usize, block_size: usize) -> Self { Self::new_with_event_sink(max_capacity, block_size, KvEventPublishers::default(), 0) } pub fn new_with_event_sink( max_capacity: usize, block_size: usize, kv_event_publishers: KvEventPublishers, dp_rank: u32, ) -> Self { debug_assert!(max_capacity > 0, "max_capacity must be > 0"); if !kv_event_publishers.is_empty() { tracing::info!( "KvManager initialized with event sink for DP rank {dp_rank} with block_size {block_size}" ); } KvManager { cache: HashCache::new(max_capacity), block_size, kv_event_publishers, dp_rank, next_event_id: 0, } } /// Converts stored/removed blocks into KvCacheEventData and publishes if sink is available. fn publish_kv_event( &mut self, full_blocks: Vec, local_hashes: &[BlockHash], parent_hash: Option, is_store: bool, token_ids: Option>>, ) { if full_blocks.is_empty() { return; } kv_cache_trace::log_vllm_trace( if is_store { "allocation" } else { "eviction" }, self.dp_rank, self.block_size, self.cache.num_active(), self.cache.num_inactive(), self.cache.max_capacity(), ); if self.kv_event_publishers.is_empty() { return; } let event_data = if is_store { let num_blocks = full_blocks.len(); let local_hashes_slice = &local_hashes[local_hashes .len() .checked_sub(num_blocks) .expect("local hashes fewer than stored blocks")..]; KvCacheEventData::Stored(KvCacheStoreData { parent_hash: parent_hash.map(ExternalSequenceBlockHash), start_position: None, blocks: full_blocks .into_iter() .zip(local_hashes_slice.iter()) .map(|(global_hash, local_hash)| KvCacheStoredBlockData { block_hash: ExternalSequenceBlockHash(global_hash), tokens_hash: LocalBlockHash(*local_hash), mm_extra_info: None, }) .collect(), }) } else { KvCacheEventData::Removed(KvCacheRemoveData { block_hashes: full_blocks .into_iter() .map(ExternalSequenceBlockHash) .collect(), }) }; // Use incremental event ID starting from 0 and incrementing by 1 for each event. let event_id = self.next_event_id; self.next_event_id += 1; let event = KvCacheEvent { event_id, data: event_data, dp_rank: self.dp_rank, }; if let Err(e) = self .kv_event_publishers .publish(event, token_ids.as_deref()) { tracing::warn!("Failed to publish KV event: {e}"); } } /// Process a MoveBlock instruction synchronously. /// /// For `MoveBlock::Use`, returns the number of blocks successfully allocated. /// On partial failure, blocks 0..N are committed but block N+1 could not be /// allocated. Callers should use the return value to track partial progress. /// /// For other variants, returns the total block count (they always succeed or panic). pub fn process(&mut self, event: &MoveBlock) -> usize { match event { MoveBlock::Use(hashes, local_hashes, token_ids, parent) => { let mut blocks_stored = Vec::::new(); let mut stored_token_ids: Option>> = token_ids.as_ref().map(|_| Vec::new()); let mut parent_block: Option<&UniqueBlock> = parent.as_ref(); let mut allocated = 0; for (i, hash) in hashes.iter().enumerate() { // First check if it already exists in active blocks if self.cache.contains_active(hash) { // Block already active, just increment reference count self.cache.increment_ref(hash); parent_block = Some(hash); allocated += 1; continue; } // Then check if it exists in inactive and move it to active if found if self.cache.reactivate(hash) { parent_block = Some(hash); allocated += 1; continue; } // If at max capacity, evict the oldest entry from inactive blocks if self.cache.is_at_capacity() { let Some(evicted) = self.cache.evict_inactive() else { break; }; tracing::trace!( "Evicting block from inactive pool: {evicted:?}, dp_rank={}", self.dp_rank ); if let UniqueBlock::FullBlock(evicted_full_block) = evicted { self.publish_kv_event(vec![evicted_full_block], &[], None, false, None); } } // Now insert the new block in active blocks with reference count 1 self.cache.insert_active(hash.clone(), 1); allocated += 1; // Track blocks for trace/event if let UniqueBlock::FullBlock(stored_full_block) = hash { blocks_stored.push(*stored_full_block); if let Some(ref mut stids) = stored_token_ids { stids.push(token_ids.as_ref().unwrap()[i].clone()); } } } let parent_hash = match parent_block { None => None, Some(UniqueBlock::FullBlock(block)) => Some(*block), Some(UniqueBlock::PartialBlock(_)) => panic!("parent block cannot be partial"), }; self.publish_kv_event( blocks_stored, local_hashes, parent_hash, true, stored_token_ids, ); return allocated; } MoveBlock::Destroy(hashes) => { let mut blocks_destroyed = Vec::::new(); // Process blocks in order (already reversed by caller if needed) for hash in hashes.iter() { self.cache.remove_active(hash).unwrap(); // Track blocks for batch sending if let UniqueBlock::FullBlock(destroyed_full_block) = hash { blocks_destroyed.push(*destroyed_full_block); } } self.publish_kv_event(blocks_destroyed, &[], None, false, None); } MoveBlock::Deref(hashes) => { // Process blocks in order (already reversed by caller if needed) for hash in hashes.iter() { // Decrement reference count and check if we need to move to inactive if let Some(ref_count) = self.cache.get_active_ref_count(hash) { if ref_count == 0 { panic!("Negative reference count would be encountered after Deref."); } // If reference count reaches zero, remove from active and move to inactive if ref_count == 1 { self.cache.deactivate(hash); } else { self.cache.decrement_ref(hash); } } } } MoveBlock::Promote(uuid, hash, parent_hash, local_hash, promote_token_ids) => { let uuid_block = UniqueBlock::PartialBlock(*uuid); let hash_block = UniqueBlock::FullBlock(*hash); assert_eq!( self.cache.remove_active(&uuid_block), Some(1), "uuid_block {uuid_block:?} should exist and be unique with ref_count=1" ); let hash_ref_count = self.cache.get_active_ref_count(&hash_block); // Block is new if it's not in active and not in inactive let is_new = if hash_ref_count.is_some() { false } else { !self.cache.remove_inactive(&hash_block) }; self.cache .insert_active(hash_block, hash_ref_count.unwrap_or(0) + 1); if is_new { self.publish_kv_event( vec![*hash], &[*local_hash], *parent_hash, true, promote_token_ids.as_ref().map(|t| vec![t.clone()]), ); } } } 1 } /// Get the count of blocks that aren't in active or inactive pools pub fn probe_new_blocks(&self, blocks: &[UniqueBlock]) -> usize { blocks .iter() .filter(|&block| !self.cache.contains(block)) .count() } /// Get the current capacity (active blocks + inactive blocks) pub fn current_capacity(&self) -> usize { self.cache.current_capacity() } /// Get the current capacity as a percentage of the maximum capacity pub fn current_capacity_perc(&self) -> f64 { self.cache.current_capacity() as f64 / self.cache.max_capacity() as f64 } /// Get the number of active blocks pub fn num_active_blocks(&self) -> usize { self.cache.num_active() } /// Get the percentage of active blocks relative to maximum capacity pub fn get_active_perc(&self) -> f64 { self.cache.num_active() as f64 / self.cache.max_capacity() as f64 } /// Get the number of inactive blocks pub fn num_inactive_blocks(&self) -> usize { self.cache.num_inactive() } /// Get the keys of inactive blocks pub fn get_inactive_blocks(&self) -> Vec<&UniqueBlock> { self.cache.inactive_keys().collect() } /// Get the keys of active blocks pub fn get_active_blocks(&self) -> Vec<&UniqueBlock> { self.cache.active_keys().collect() } pub fn max_capacity(&self) -> usize { self.cache.max_capacity() } pub fn block_size(&self) -> usize { self.block_size } pub fn dp_rank(&self) -> u32 { self.dp_rank } /// Direct access to active blocks map (for tests). pub fn active_blocks(&self) -> &FxHashMap { self.cache.active_blocks() } /// Check if a sequence can be scheduled and calculate cost if possible pub fn get_prefill_cost(&self, sequence: &ActiveSequence) -> PrefillCost { let seq_blocks = sequence.unique_blocks(); // Find the longest prefix that exists in cache // We must stop at the first cache miss since KV states are computed sequentially let mut overlap_blocks = 0; for block in seq_blocks { if !self.cache.contains(block) { // First cache miss - can't use anything after this point break; } overlap_blocks += 1; } let new_blocks = seq_blocks.len() - overlap_blocks; // Clamp cached_tokens to handle partial blocks (last block may have < block_size tokens) let cached_tokens = (overlap_blocks * self.block_size).min(sequence.num_input_tokens()); let new_tokens = sequence.num_input_tokens() - cached_tokens; PrefillCost { new_blocks, new_tokens, cached_tokens, } } } #[cfg(test)] mod tests { use super::*; use std::sync::Arc; use crate::common::protocols::KvCacheEventSink; #[test] fn test_failure_on_max_capacity() { // Create a KvManager with 10 blocks capacity let mut manager = KvManager::new(10, 16); // Helper function to use multiple blocks that returns the count allocated fn use_blocks(manager: &mut KvManager, ids: Vec) -> usize { let blocks: Vec<_> = ids.iter().map(|&id| UniqueBlock::FullBlock(id)).collect(); let hashes: Vec<_> = ids.into_iter().collect(); manager.process(&MoveBlock::Use(blocks, hashes, None, None)) } // First use 10 blocks (0 to 9) in a batch let response = use_blocks(&mut manager, (0..10).collect()); assert_eq!(response, 10, "Expected all 10 blocks allocated"); // Verify we are at capacity assert_eq!(manager.current_capacity(), 10); // The 11th block should return 0, not panic let response = use_blocks(&mut manager, vec![10]); assert_eq!( response, 0, "Expected 0 blocks allocated when exceeding max capacity" ); } #[test] fn test_block_lifecycle_stringent() { // Create a KvManager with 10 blocks capacity (no KV event publisher for tests) let mut manager = KvManager::new(10, 16); // Helper function to use multiple blocks fn use_blocks(manager: &mut KvManager, ids: Vec) { let blocks: Vec<_> = ids.iter().map(|&id| UniqueBlock::FullBlock(id)).collect(); let hashes: Vec<_> = ids.into_iter().collect(); manager.process(&MoveBlock::Use(blocks, hashes, None, None)); } // Helper function to destroy multiple blocks fn destroy_blocks(manager: &mut KvManager, ids: Vec) { let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect(); manager.process(&MoveBlock::Destroy(blocks)); } // Helper function to deref multiple blocks fn deref_blocks(manager: &mut KvManager, ids: Vec) { let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect(); manager.process(&MoveBlock::Deref(blocks)); } // Helper function to check if active blocks contain expected blocks with expected ref counts fn assert_active_blocks(manager: &KvManager, expected_blocks: &[(u64, usize)]) { assert_eq!( manager.active_blocks().len(), expected_blocks.len(), "Active blocks count doesn't match expected" ); for &(id, ref_count) in expected_blocks { let block = UniqueBlock::FullBlock(id); assert!( manager.active_blocks().contains_key(&block), "Block {id} not found in active blocks", ); assert_eq!( manager.active_blocks().get(&block), Some(&ref_count), "Block {id} has wrong reference count", ); } } // Helper function to check if inactive blocks contain expected blocks fn assert_inactive_blocks( manager: &KvManager, expected_size: usize, expected_blocks: &[u64], ) { let inactive_blocks = manager.get_inactive_blocks(); let inactive_blocks_count = manager.num_inactive_blocks(); assert_eq!( inactive_blocks_count, expected_size, "Inactive blocks count doesn't match expected" ); for &id in expected_blocks { let block = UniqueBlock::FullBlock(id); assert!( inactive_blocks.iter().any(|&b| *b == block), "Block {id} not found in inactive blocks", ); } } // First use blocks 0, 1, 2, 3, 4 in a batch use_blocks(&mut manager, (0..5).collect()); // Then use blocks 0, 1, 5, 6 in a batch use_blocks(&mut manager, vec![0, 1, 5, 6]); // Check that the blocks 0 and 1 are in active blocks, both with reference counts of 2 assert_active_blocks( &manager, &[(0, 2), (1, 2), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1)], ); // Now destroy block 4 destroy_blocks(&mut manager, vec![4]); // And deref blocks 3, 2, 1, 0 in this order as a batch deref_blocks(&mut manager, vec![0, 1, 2, 3]); // Check that the inactive_blocks is size 2 (via num_objects) and contains 3 and 2 assert_inactive_blocks(&manager, 2, &[3, 2]); assert_active_blocks(&manager, &[(0, 1), (1, 1), (5, 1), (6, 1)]); // Now destroy block 6 destroy_blocks(&mut manager, vec![6]); // And deref blocks 5, 1, 0 as a batch deref_blocks(&mut manager, vec![0, 1, 5]); // Check that the inactive_blocks is size 5, and contains 0, 1, 2, 3, 5 assert_inactive_blocks(&manager, 5, &[0, 1, 2, 3, 5]); assert_active_blocks(&manager, &[]); // Now use 0, 1, 2, 7, 8, 9 as a batch use_blocks(&mut manager, vec![0, 1, 2, 7, 8, 9]); // Check that the inactive_blocks is size 2, and contains 3 and 5 assert_inactive_blocks(&manager, 2, &[3, 5]); assert_active_blocks(&manager, &[(0, 1), (1, 1), (2, 1), (7, 1), (8, 1), (9, 1)]); // Test the new_blocks method - only block 4 should be new out of [0,1,2,3,4] let blocks_to_check: Vec = vec![0, 1, 2, 3, 4] .into_iter() .map(UniqueBlock::FullBlock) .collect(); assert_eq!(manager.probe_new_blocks(&blocks_to_check), 1); // Now use blocks 10, 11, 12 as a batch use_blocks(&mut manager, vec![10, 11, 12]); // Check that the inactive_blocks is size 1 and contains only 5 assert_inactive_blocks(&manager, 1, &[5]); use_blocks(&mut manager, vec![13]); } #[test] fn test_chunked_prefill_parent_hash() { use std::sync::Mutex; use crate::common::sequence::ActiveSequence; #[derive(Default)] struct CapturingSink { events: Mutex>, } impl KvCacheEventSink for CapturingSink { fn publish(&self, event: KvCacheEvent) -> anyhow::Result<()> { self.events.lock().unwrap().push(event); Ok(()) } } let block_size = 64; let tokens: Vec = (0..512).collect(); // 8 blocks let mut seq = ActiveSequence::new(tokens, 100, Some(block_size), true, false); let sink = Arc::new(CapturingSink::default()); let mut manager = KvManager::new_with_event_sink( 256, block_size, KvEventPublishers::new(Some(sink.clone() as _), None), 0, ); // Chunk 1: allocate blocks 0-3 let signal = seq.prepare_allocation(256).unwrap(); manager.process(&signal); seq.commit_allocation(256); // Chunk 2: allocate blocks 4-7 let signal = seq.prepare_allocation(512).unwrap(); manager.process(&signal); seq.commit_allocation(512); let events = sink.events.lock().unwrap(); assert_eq!(events.len(), 2, "expected two store events"); // First event: parent_hash should be None (starts from root) let KvCacheEventData::Stored(ref store1) = events[0].data else { panic!("expected store event"); }; assert!( store1.parent_hash.is_none(), "first chunk should have no parent" ); // Second event: parent_hash should be the seq_hash of block 3 // (the last block from the first chunk) let KvCacheEventData::Stored(ref store2) = events[1].data else { panic!("expected store event"); }; let expected_parent = seq.unique_blocks()[3].clone(); let UniqueBlock::FullBlock(expected_hash) = expected_parent else { panic!("expected full block"); }; assert_eq!( store2.parent_hash, Some(ExternalSequenceBlockHash(expected_hash)), "second chunk's parent should be block 3's seq_hash" ); } #[test] fn test_repreempt_after_partial_recompute_only_frees_reallocated_blocks() { let mut seq = ActiveSequence::new((0..6).collect(), 16, Some(4), true, false); let mut manager = KvManager::new(16, 4); let signal = seq.take_creation_signal().unwrap(); assert_eq!(manager.process(&signal), 2); for _ in 0..3 { let signals = seq.generate(); for signal in &signals { manager.process(signal); } if seq.generated_tokens() < seq.max_output_tokens() { seq.commit_allocation(seq.len()); } } assert_eq!(manager.num_active_blocks(), 3); let first_reset = seq.reset_with_signal(); for signal in &first_reset { manager.process(signal); } assert_eq!(manager.num_active_blocks(), 0); let prompt_only = seq.prepare_allocation(seq.num_input_tokens()).unwrap(); assert_eq!(manager.process(&prompt_only), 2); seq.commit_allocation(seq.num_input_tokens()); assert_eq!(manager.num_active_blocks(), 2); let second_reset = seq.reset_with_signal(); for signal in &second_reset { manager.process(signal); } assert_eq!(manager.num_active_blocks(), 0); } }