Unverified Commit fc92fc18 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: clean ups in kv-router (#5771)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 842f0f15
...@@ -700,6 +700,15 @@ dependencies = [ ...@@ -700,6 +700,15 @@ dependencies = [
"objc2", "objc2",
] ]
[[package]]
name = "bs58"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf88ba1141d185c399bee5288d850d63b8369520c1eafc32a0430b5b6c287bf4"
dependencies = [
"tinyvec",
]
[[package]] [[package]]
name = "bs62" name = "bs62"
version = "0.1.4" version = "0.1.4"
...@@ -1598,6 +1607,24 @@ dependencies = [ ...@@ -1598,6 +1607,24 @@ dependencies = [
"anyhow", "anyhow",
] ]
[[package]]
name = "dynamo-kv-router"
version = "0.9.0"
dependencies = [
"anyhow",
"async-trait",
"dynamo-runtime",
"dynamo-tokens",
"prometheus",
"rand 0.9.2",
"serde",
"thiserror 2.0.17",
"tokio",
"tokio-util",
"tracing",
"xxhash-rust",
]
[[package]] [[package]]
name = "dynamo-llm" name = "dynamo-llm"
version = "0.9.0" version = "0.9.0"
...@@ -1626,9 +1653,11 @@ dependencies = [ ...@@ -1626,9 +1653,11 @@ dependencies = [
"derive_builder", "derive_builder",
"dialoguer", "dialoguer",
"dynamo-async-openai", "dynamo-async-openai",
"dynamo-kv-router",
"dynamo-memory", "dynamo-memory",
"dynamo-parsers", "dynamo-parsers",
"dynamo-runtime", "dynamo-runtime",
"dynamo-tokens",
"either", "either",
"erased-serde", "erased-serde",
"etcd-client", "etcd-client",
...@@ -1833,6 +1862,20 @@ dependencies = [ ...@@ -1833,6 +1862,20 @@ dependencies = [
"zmq", "zmq",
] ]
[[package]]
name = "dynamo-tokens"
version = "0.9.0"
dependencies = [
"bs58",
"bytemuck",
"dashmap 6.1.0",
"derive-getters",
"serde",
"thiserror 2.0.17",
"uuid",
"xxhash-rust",
]
[[package]] [[package]]
name = "ed25519" name = "ed25519"
version = "2.2.3" version = "2.2.3"
......
...@@ -360,7 +360,7 @@ impl KvEventPublisher { ...@@ -360,7 +360,7 @@ impl KvEventPublisher {
#[pyclass] #[pyclass]
#[derive(Clone)] #[derive(Clone)]
pub(crate) struct OverlapScores { pub(crate) struct OverlapScores {
inner: llm_rs::kv_router::indexer::OverlapScores, inner: llm_rs::kv_router::protocols::OverlapScores,
} }
#[pymethods] #[pymethods]
...@@ -386,7 +386,7 @@ enum RadixTreeRequest { ...@@ -386,7 +386,7 @@ enum RadixTreeRequest {
FindMatches { FindMatches {
local_block_hashes: Vec<llm_rs::kv_router::protocols::LocalBlockHash>, local_block_hashes: Vec<llm_rs::kv_router::protocols::LocalBlockHash>,
early_exit: bool, early_exit: bool,
response_tx: mpsc::SyncSender<llm_rs::kv_router::indexer::OverlapScores>, response_tx: mpsc::SyncSender<llm_rs::kv_router::protocols::OverlapScores>,
}, },
ApplyEvent { ApplyEvent {
worker_id: WorkerId, worker_id: WorkerId,
...@@ -402,7 +402,7 @@ enum RadixTreeRequest { ...@@ -402,7 +402,7 @@ enum RadixTreeRequest {
response_tx: mpsc::SyncSender<()>, response_tx: mpsc::SyncSender<()>,
}, },
DumpTreeAsEvents { DumpTreeAsEvents {
response_tx: mpsc::SyncSender<Vec<llm_rs::kv_router::indexer::RouterEvent>>, response_tx: mpsc::SyncSender<Vec<llm_rs::kv_router::protocols::RouterEvent>>,
}, },
Shutdown, Shutdown,
} }
...@@ -616,8 +616,10 @@ impl RadixTree { ...@@ -616,8 +616,10 @@ impl RadixTree {
>(&kv_cache_event_bytes) >(&kv_cache_event_bytes)
{ {
Ok(kv_cache_event) => { Ok(kv_cache_event) => {
let router_event = let router_event = llm_rs::kv_router::protocols::RouterEvent::new(
llm_rs::kv_router::indexer::RouterEvent::new(worker_id, kv_cache_event); worker_id,
kv_cache_event,
);
match radix_tree.apply_event(router_event) { match radix_tree.apply_event(router_event) {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(e) => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>( Err(e) => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
...@@ -898,7 +900,7 @@ impl KvRecorder { ...@@ -898,7 +900,7 @@ impl KvRecorder {
// Spawn a task to forward events to the recorder // Spawn a task to forward events to the recorder
tokio::spawn(async move { tokio::spawn(async move {
while let Some(event) = kv_events_rx.next().await { while let Some(event) = kv_events_rx.next().await {
let event: llm_rs::kv_router::indexer::RouterEvent = let event: llm_rs::kv_router::protocols::RouterEvent =
serde_json::from_slice(&event.payload).unwrap(); serde_json::from_slice(&event.payload).unwrap();
tracing::debug!("KvRecorder received kv event: {:?}", event); tracing::debug!("KvRecorder received kv event: {:?}", event);
if let Err(e) = event_tx.send(event).await { if let Err(e) = event_tx.send(event).await {
......
...@@ -15,17 +15,97 @@ ...@@ -15,17 +15,97 @@
use clap::{Parser, ValueEnum}; use clap::{Parser, ValueEnum};
use dynamo_kv_router::{ use dynamo_kv_router::{
compute_block_hash_for_seq, OverlapScores, RadixTree, RouterEvent, compute_block_hash_for_seq,
indexer::{RadixTree, RouterEvent}, flat_hashmap::FlatHashMap,
protocols::{ protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData,
KvCacheStoreData, KvCacheStoredBlockData, LocalBlockHash, WorkerId, KvCacheStoreData, KvCacheStoredBlockData, LocalBlockHash, WorkerId,
compute_seq_hash_for_block,
}, },
}; };
use rand::rngs::StdRng; use rand::rngs::StdRng;
use rand::{Rng, SeedableRng}; use rand::{Rng, SeedableRng};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
/// Unified interface for RadixTree and FlatHashMap benchmarking.
///
/// Both structures have feature parity for store, remove, find_matches, and current_size.
/// The key difference is find_matches input:
/// - RadixTree: uses LocalBlockHash (tokens_hash)
/// - FlatHashMap: uses ExternalSequenceBlockHash (cumulative sequence hash)
enum KvIndex {
Tree(RadixTree),
Flat(FlatHashMap),
}
impl KvIndex {
fn name(&self) -> &'static str {
match self {
KvIndex::Tree(_) => "RadixTree",
KvIndex::Flat(_) => "FlatHashMap",
}
}
fn apply_event(&mut self, event: RouterEvent) {
match self {
KvIndex::Tree(tree) => {
let _ = tree.apply_event(event);
}
KvIndex::Flat(map) => {
map.apply_event(event);
}
}
}
fn find_matches_timed(&self, seq: &SequenceData, early_exit: bool) -> Duration {
let local_hashes = seq.local_hashes.clone();
let start = Instant::now();
let _ = match self {
KvIndex::Tree(tree) => tree.find_matches(local_hashes, early_exit),
KvIndex::Flat(map) => map.find_matches(local_hashes, early_exit),
};
start.elapsed()
}
fn find_matches_miss_timed(&self, depth: usize, i: usize, early_exit: bool) -> Duration {
let miss_hashes: Vec<LocalBlockHash> = (0..depth)
.map(|j| LocalBlockHash(0xBAD_C0DE_0000_0000 | ((i as u64) << 16) | (j as u64)))
.collect();
let start = Instant::now();
let _ = match self {
KvIndex::Tree(tree) => tree.find_matches(miss_hashes, early_exit),
KvIndex::Flat(map) => map.find_matches(miss_hashes, early_exit),
};
start.elapsed()
}
fn find_matches_partial_timed(
&self,
seq: &SequenceData,
half: usize,
i: usize,
early_exit: bool,
) -> Duration {
let mut partial = seq.local_hashes[..half].to_vec();
partial.extend(
(0..half).map(|j| LocalBlockHash(0xDEAD_0000 | ((i as u64) << 16) | (j as u64))),
);
let start = Instant::now();
let _ = match self {
KvIndex::Tree(tree) => tree.find_matches(partial, early_exit),
KvIndex::Flat(map) => map.find_matches(partial, early_exit),
};
start.elapsed()
}
fn current_size(&self) -> usize {
match self {
KvIndex::Tree(tree) => tree.current_size(),
KvIndex::Flat(map) => map.current_size(),
}
}
}
/// Sweep benchmark mode /// Sweep benchmark mode
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] #[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
enum SweepMode { enum SweepMode {
...@@ -61,6 +141,10 @@ struct Args { ...@@ -61,6 +141,10 @@ struct Args {
#[arg(long, default_value = "1000")] #[arg(long, default_value = "1000")]
iterations: usize, iterations: usize,
/// Warmup ratio (0.0 to 1.0) - fraction of iterations to discard for warmup
#[arg(long, default_value = "0.1")]
warmup_ratio: f64,
/// Prefix prompt ratio (0.0 to 1.0) - portion of sequence from the beginning that is a shared prefix /// Prefix prompt ratio (0.0 to 1.0) - portion of sequence from the beginning that is a shared prefix
#[arg(long, default_value = "0.25")] #[arg(long, default_value = "0.25")]
prefix_prompt_ratio: f64, prefix_prompt_ratio: f64,
...@@ -116,6 +200,10 @@ struct Args { ...@@ -116,6 +200,10 @@ struct Args {
/// Random seed for reproducibility /// Random seed for reproducibility
#[arg(long, default_value = "42")] #[arg(long, default_value = "42")]
seed: u64, seed: u64,
/// Use flat HashMap baseline instead of radix tree (for comparison)
#[arg(long)]
flat_hashmap: bool,
} }
/// Pre-generated sequence data for benchmarking /// Pre-generated sequence data for benchmarking
...@@ -127,13 +215,14 @@ struct SequenceData { ...@@ -127,13 +215,14 @@ struct SequenceData {
} }
impl SequenceData { impl SequenceData {
fn new(seq_id: u64, worker_id: WorkerId, depth: usize) -> Self { /// Create a new SequenceData from local_hashes.
let local_hashes: Vec<LocalBlockHash> = (0..depth) /// Automatically computes external_hashes using compute_seq_hash_for_block (cumulative hashes).
.map(|block_idx| LocalBlockHash((seq_id << 32) | (block_idx as u64))) /// This ensures FlatHashMap can correctly identify block positions.
.collect(); fn from_local_hashes(worker_id: WorkerId, local_hashes: Vec<LocalBlockHash>) -> Self {
let seq_hashes = compute_seq_hash_for_block(&local_hashes);
let external_hashes: Vec<ExternalSequenceBlockHash> = (0..depth) let external_hashes = seq_hashes
.map(|block_idx| ExternalSequenceBlockHash((seq_id << 32) | (block_idx as u64))) .into_iter()
.map(ExternalSequenceBlockHash)
.collect(); .collect();
Self { Self {
...@@ -190,28 +279,42 @@ fn generate_sequences( ...@@ -190,28 +279,42 @@ fn generate_sequences(
seed: u64, seed: u64,
) -> Vec<SequenceData> { ) -> Vec<SequenceData> {
let mut sequences = Vec::with_capacity(num_sequences); let mut sequences = Vec::with_capacity(num_sequences);
let prefix_length: usize = (depth as f64 * prefix_prompt_ratio).round() as usize; let prefix_length = (depth as f64 * prefix_prompt_ratio).round() as usize;
let mut rng: StdRng = StdRng::seed_from_u64(seed); let mut rng: StdRng = StdRng::seed_from_u64(seed);
for seq_id in 0..num_sequences { for seq_id in 0..num_sequences {
let seq_id_u64 = seq_id as u64;
let worker_id = (seq_id % num_workers) as WorkerId; let worker_id = (seq_id % num_workers) as WorkerId;
let mut seq = SequenceData::new(seq_id as u64, worker_id, depth);
if num_prefix_prompts > 0 && prefix_length > 0 { // Determine prefix group for this sequence
let group_id = rng.random_range(0..num_prefix_prompts); let group_id = if num_prefix_prompts > 0 && prefix_length > 0 {
for i in 0..prefix_length { Some(rng.random_range(0..num_prefix_prompts) as u64)
seq.local_hashes[i] = } else {
LocalBlockHash(0xDEAD_BEEF_0000_0000 | ((group_id as u64) << 32) | (i as u64)); None
};
// Build local_hashes: shared prefix (if applicable) + unique suffix
let local_hashes: Vec<LocalBlockHash> = (0..depth)
.map(|block_idx| {
let block_idx_u64 = block_idx as u64;
if let Some(gid) = group_id {
if block_idx < prefix_length {
// Shared prefix based on group_id
return LocalBlockHash(0xDEAD_BEEF_0000_0000 | (gid << 32) | block_idx_u64);
} }
} }
// Unique suffix (or no shared prefix)
LocalBlockHash((seq_id_u64 << 32) | block_idx_u64)
})
.collect();
sequences.push(seq); sequences.push(SequenceData::from_local_hashes(worker_id, local_hashes));
} }
sequences sequences
} }
/// Build a pre-populated tree (prints timing info) /// Build a pre-populated RadixTree (for sweep/dump benchmarks that specifically need RadixTree)
fn build_tree(sequences: &[SequenceData]) -> RadixTree { fn build_tree(sequences: &[SequenceData]) -> RadixTree {
let num_blocks: usize = sequences.iter().map(|s| s.local_hashes.len()).sum(); let num_blocks: usize = sequences.iter().map(|s| s.local_hashes.len()).sum();
print!( print!(
...@@ -239,6 +342,45 @@ fn build_tree(sequences: &[SequenceData]) -> RadixTree { ...@@ -239,6 +342,45 @@ fn build_tree(sequences: &[SequenceData]) -> RadixTree {
tree tree
} }
/// Build a pre-populated KvIndex (prints timing info)
fn build_index(sequences: &[SequenceData], use_flat_hashmap: bool) -> KvIndex {
let num_blocks: usize = sequences.iter().map(|s| s.local_hashes.len()).sum();
let name = if use_flat_hashmap {
"FlatHashMap"
} else {
"RadixTree"
};
print!(
" Building {} with {} sequences ({} blocks)... ",
name,
sequences.len(),
num_blocks
);
std::io::Write::flush(&mut std::io::stdout()).unwrap();
let start = Instant::now();
let mut index = if use_flat_hashmap {
KvIndex::Flat(FlatHashMap::new())
} else {
KvIndex::Tree(RadixTree::new())
};
for (event_id, seq) in sequences.iter().enumerate() {
let event = seq.to_store_event(event_id as u64);
index.apply_event(event);
}
let elapsed = start.elapsed();
println!(
"done in {:.2?} ({:.2} sequences/sec, {:.2} blocks/sec)",
elapsed,
sequences.len() as f64 / elapsed.as_secs_f64(),
num_blocks as f64 / elapsed.as_secs_f64()
);
index
}
/// Statistics for a set of timing measurements /// Statistics for a set of timing measurements
#[derive(Debug)] #[derive(Debug)]
struct LatencyStats { struct LatencyStats {
...@@ -304,14 +446,18 @@ fn bench_hash(args: &Args) { ...@@ -304,14 +446,18 @@ fn bench_hash(args: &Args) {
}) })
.collect(); .collect();
let mut durations = Vec::with_capacity(args.iterations); let warmup_iters = (args.iterations as f64 * args.warmup_ratio) as usize;
let measured_iters = args.iterations - warmup_iters;
let mut durations = Vec::with_capacity(measured_iters);
for (i, tokens) in token_sequences.iter().enumerate() { for (i, tokens) in token_sequences.iter().enumerate() {
let start = Instant::now(); let start = Instant::now();
let _ = compute_block_hash_for_seq(tokens, args.block_size, None); let _ = compute_block_hash_for_seq(tokens, args.block_size, None);
let elapsed = start.elapsed(); let elapsed = start.elapsed();
if i >= warmup_iters {
durations.push(elapsed); durations.push(elapsed);
}
if args.verbose && (i + 1) % 100 == 0 { if args.verbose && (i + 1) % 100 == 0 {
println!(" Completed {}/{} iterations", i + 1, args.iterations); println!(" Completed {}/{} iterations", i + 1, args.iterations);
...@@ -322,14 +468,19 @@ fn bench_hash(args: &Args) { ...@@ -322,14 +468,19 @@ fn bench_hash(args: &Args) {
stats.print("COMPUTE_BLOCK_HASH", args.depth); stats.print("COMPUTE_BLOCK_HASH", args.depth);
} }
/// Benchmark store_block operation /// Benchmark store or remove operation on a steady-state index.
fn bench_store(args: &Args) { ///
println!("\n=== Benchmarking STORE_BLOCK ==="); /// Uses a remove/store cycle to maintain size. If `time_store` is true,
/// the store operation is timed; otherwise the remove operation is timed.
fn bench_store_remove_cycle(args: &Args, time_store: bool) {
let op_name = if time_store {
"STORE_BLOCK"
} else {
"REMOVE_BLOCK"
};
let num_sequences = args.size / args.depth; let num_sequences = args.size / args.depth;
let bench_iters = args.iterations.min(num_sequences); let sequences = generate_sequences(
let all_sequences = generate_sequences(
num_sequences, num_sequences,
args.depth, args.depth,
args.num_workers, args.num_workers,
...@@ -337,87 +488,58 @@ fn bench_store(args: &Args) { ...@@ -337,87 +488,58 @@ fn bench_store(args: &Args) {
args.num_prefix_prompts, args.num_prefix_prompts,
args.seed, args.seed,
); );
let split_point = num_sequences.saturating_sub(bench_iters);
let pre_sequences = &all_sequences[..split_point];
let bench_sequences = &all_sequences[split_point..];
// Build tree once, then store sequences sequentially let mut index = build_index(&sequences, args.flat_hashmap);
// Tree grows from (size - iterations) to size over the benchmark println!("\n=== Benchmarking {} ({}) ===", op_name, index.name());
let mut tree = build_tree(&pre_sequences); println!(" Size: {} blocks", index.current_size());
println!(
" Initial tree size: {} blocks, will grow to ~{} blocks",
tree.current_size(),
tree.current_size() + bench_iters * args.depth
);
let mut durations = Vec::with_capacity(bench_iters); let warmup_iters = (args.iterations as f64 * args.warmup_ratio) as usize;
let measured_iters = args.iterations - warmup_iters;
let mut durations = Vec::with_capacity(measured_iters);
for (i, seq) in bench_sequences.iter().enumerate() { for i in 0..args.iterations {
let event = seq.to_store_event(i as u64); let seq = &sequences[i % sequences.len()];
let remove_event = seq.to_remove_event(i as u64);
let store_event = seq.to_store_event(i as u64 + args.iterations as u64);
let elapsed = if time_store {
index.apply_event(remove_event);
let start = Instant::now(); let start = Instant::now();
let _ = tree.apply_event(event); index.apply_event(store_event);
start.elapsed()
} else {
let start = Instant::now();
index.apply_event(remove_event);
let elapsed = start.elapsed(); let elapsed = start.elapsed();
index.apply_event(store_event);
elapsed
};
if i >= warmup_iters {
durations.push(elapsed); durations.push(elapsed);
}
if args.verbose && (i + 1) % 100 == 0 { if args.verbose && (i + 1) % 100 == 0 {
println!(" Completed {}/{} iterations", i + 1, bench_iters); println!(" Completed {}/{} iterations", i + 1, args.iterations);
} }
} }
let stats = LatencyStats::from_durations(durations); let stats = LatencyStats::from_durations(durations);
stats.print("STORE_BLOCK", args.depth); stats.print(op_name, args.depth);
}
/// Benchmark store_block operation
fn bench_store(args: &Args) {
bench_store_remove_cycle(args, true);
} }
/// Benchmark remove_block operation /// Benchmark remove_block operation
fn bench_remove(args: &Args) { fn bench_remove(args: &Args) {
println!("\n=== Benchmarking REMOVE_BLOCK ==="); bench_store_remove_cycle(args, false);
let num_sequences = args.size / args.depth;
let sequences = generate_sequences(
num_sequences,
args.depth,
args.num_workers,
args.prefix_prompt_ratio,
args.num_prefix_prompts,
args.seed,
);
// Build tree once, then remove/re-add to restore state after each timed removal
let mut tree = build_tree(&sequences);
println!(" Tree size: {} blocks", tree.current_size());
let mut durations = Vec::with_capacity(args.iterations);
for i in 0..args.iterations {
// Remove a sequence (timed)
let seq_to_remove = &sequences[i % sequences.len()];
let remove_event = seq_to_remove.to_remove_event(i as u64);
let start = Instant::now();
let _ = tree.apply_event(remove_event);
let elapsed = start.elapsed();
durations.push(elapsed);
// Re-add the sequence to restore tree state (untimed)
let store_event = seq_to_remove.to_store_event(i as u64 + args.iterations as u64);
let _ = tree.apply_event(store_event);
if args.verbose && (i + 1) % 100 == 0 {
println!(" Completed {}/{} iterations", i + 1, args.iterations);
}
}
let stats = LatencyStats::from_durations(durations);
stats.print("REMOVE_BLOCK", args.depth);
} }
/// Benchmark find_matches operation /// Benchmark find_matches operation
fn bench_find_matches(args: &Args) { fn bench_find_matches(args: &Args) {
println!("\n=== Benchmarking FIND_MATCHES ===");
let num_sequences = args.size / args.depth; let num_sequences = args.size / args.depth;
let sequences = generate_sequences( let sequences = generate_sequences(
num_sequences, num_sequences,
...@@ -428,104 +550,74 @@ fn bench_find_matches(args: &Args) { ...@@ -428,104 +550,74 @@ fn bench_find_matches(args: &Args) {
args.seed, args.seed,
); );
// Build tree once for all find_matches calls let index = build_index(&sequences, args.flat_hashmap);
let tree = build_tree(&sequences); println!("\n=== Benchmarking FIND_MATCHES ({}) ===", index.name());
println!( println!(
" Tree built with {} sequences, {} total blocks", " Built with {} sequences, {} total blocks",
sequences.len(), sequences.len(),
tree.current_size() index.current_size()
); );
// Benchmark hit case (lookup existing sequences) let warmup_iters = (args.iterations as f64 * args.warmup_ratio) as usize;
println!("\n --- HIT case (existing sequences) ---"); let measured_iters = args.iterations - warmup_iters;
let mut hit_durations = Vec::with_capacity(args.iterations); let half = args.depth / 2;
// HIT case
println!("\n --- HIT case (existing sequences) ---");
let mut hit_durations = Vec::with_capacity(measured_iters);
for i in 0..args.iterations { for i in 0..args.iterations {
let seq = &sequences[i % sequences.len()]; let seq = &sequences[i % sequences.len()];
let hashes_copy = seq.local_hashes.clone(); let elapsed = index.find_matches_timed(seq, false);
if i >= warmup_iters {
let start = Instant::now();
let _ = tree.find_matches(hashes_copy, false);
let elapsed = start.elapsed();
hit_durations.push(elapsed); hit_durations.push(elapsed);
}
if args.verbose && (i + 1) % 100 == 0 { if args.verbose && (i + 1) % 100 == 0 {
println!(" Completed {}/{} iterations", i + 1, args.iterations); println!(" Completed {}/{} iterations", i + 1, args.iterations);
} }
} }
LatencyStats::from_durations(hit_durations).print("FIND_MATCHES (HIT)", args.depth);
let hit_stats = LatencyStats::from_durations(hit_durations); // MISS case
hit_stats.print("FIND_MATCHES (HIT)", args.depth);
// Benchmark miss case (find_matches on non-existing sequences)
println!("\n --- MISS case (non-existing sequences) ---"); println!("\n --- MISS case (non-existing sequences) ---");
let mut miss_durations = Vec::with_capacity(args.iterations); let mut miss_durations = Vec::with_capacity(measured_iters);
for i in 0..args.iterations { for i in 0..args.iterations {
// Generate a sequence that won't match let elapsed = index.find_matches_miss_timed(args.depth, i, false);
let miss_hashes: Vec<LocalBlockHash> = (0..args.depth) if i >= warmup_iters {
.map(|j| LocalBlockHash(0xBAD_C0DE_0000_0000 | ((i as u64) << 16) | (j as u64)))
.collect();
let start = Instant::now();
let _ = tree.find_matches(miss_hashes, false);
let elapsed = start.elapsed();
miss_durations.push(elapsed); miss_durations.push(elapsed);
}
if args.verbose && (i + 1) % 100 == 0 { if args.verbose && (i + 1) % 100 == 0 {
println!(" Completed {}/{} iterations", i + 1, args.iterations); println!(" Completed {}/{} iterations", i + 1, args.iterations);
} }
} }
LatencyStats::from_durations(miss_durations).print("FIND_MATCHES (MISS)", args.depth);
let miss_stats = LatencyStats::from_durations(miss_durations); // PARTIAL case
miss_stats.print("FIND_MATCHES (MISS)", args.depth);
// Benchmark partial match case
println!("\n --- PARTIAL case (prefix match only) ---"); println!("\n --- PARTIAL case (prefix match only) ---");
let mut partial_durations = Vec::with_capacity(args.iterations); let mut partial_durations = Vec::with_capacity(measured_iters);
for i in 0..args.iterations { for i in 0..args.iterations {
let seq = &sequences[i % sequences.len()]; let seq = &sequences[i % sequences.len()];
// Use first half of real sequence, second half is garbage let elapsed = index.find_matches_partial_timed(seq, half, i, false);
let half = args.depth / 2; if i >= warmup_iters {
let mut partial_hashes = seq.local_hashes[..half].to_vec();
partial_hashes.extend(
(0..half).map(|j| LocalBlockHash(0xDEAD_0000 | ((i as u64) << 16) | (j as u64))),
);
let start = Instant::now();
let _ = tree.find_matches(partial_hashes, false);
let elapsed = start.elapsed();
partial_durations.push(elapsed); partial_durations.push(elapsed);
}
if args.verbose && (i + 1) % 100 == 0 { if args.verbose && (i + 1) % 100 == 0 {
println!(" Completed {}/{} iterations", i + 1, args.iterations); println!(" Completed {}/{} iterations", i + 1, args.iterations);
} }
} }
LatencyStats::from_durations(partial_durations).print("FIND_MATCHES (PARTIAL)", args.depth);
let partial_stats = LatencyStats::from_durations(partial_durations); // EARLY_EXIT case
partial_stats.print("FIND_MATCHES (PARTIAL)", args.depth);
// Benchmark with early_exit=true
println!("\n --- EARLY_EXIT case ---"); println!("\n --- EARLY_EXIT case ---");
let mut early_exit_durations = Vec::with_capacity(args.iterations); let mut early_exit_durations = Vec::with_capacity(measured_iters);
for i in 0..args.iterations { for i in 0..args.iterations {
let seq = &sequences[i % sequences.len()]; let seq = &sequences[i % sequences.len()];
let elapsed = index.find_matches_timed(seq, true);
let start = Instant::now(); if i >= warmup_iters {
let _ = tree.find_matches(seq.local_hashes.clone(), true);
let elapsed = start.elapsed();
early_exit_durations.push(elapsed); early_exit_durations.push(elapsed);
} }
}
let early_exit_stats = LatencyStats::from_durations(early_exit_durations); LatencyStats::from_durations(early_exit_durations)
early_exit_stats.print("FIND_MATCHES (EARLY_EXIT)", args.depth); .print("FIND_MATCHES (EARLY_EXIT)", args.depth);
} }
/// Generate logarithmically spaced values between min and max /// Generate logarithmically spaced values between min and max
...@@ -932,6 +1024,10 @@ fn main() { ...@@ -932,6 +1024,10 @@ fn main() {
eprintln!("prefix_prompt_ratio must be between 0.0 and 1.0"); eprintln!("prefix_prompt_ratio must be between 0.0 and 1.0");
std::process::exit(1); std::process::exit(1);
} }
if !(0.0..=1.0).contains(&args.warmup_ratio) {
eprintln!("warmup_ratio must be between 0.0 and 1.0");
std::process::exit(1);
}
let num_sequences = args.size / args.depth; let num_sequences = args.size / args.depth;
if matches!( if matches!(
...@@ -959,6 +1055,11 @@ fn main() { ...@@ -959,6 +1055,11 @@ fn main() {
println!(" Block size: {} tokens", args.block_size); println!(" Block size: {} tokens", args.block_size);
println!(" Workers: {}", args.num_workers); println!(" Workers: {}", args.num_workers);
println!(" Iterations: {}", args.iterations); println!(" Iterations: {}", args.iterations);
println!(
" Warmup: {:.0}% ({} iterations discarded)",
args.warmup_ratio * 100.0,
(args.iterations as f64 * args.warmup_ratio) as usize
);
println!( println!(
" Prefix prompt ratio: {:.1}% ({} blocks at depth {})", " Prefix prompt ratio: {:.1}% ({} blocks at depth {})",
args.prefix_prompt_ratio * 100.0, args.prefix_prompt_ratio * 100.0,
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Flat HashMap baseline for benchmarking comparison with RadixTree.
//!
//! This module provides a `FlatHashMap` structure that has full feature parity with `RadixTree`
//! but uses flat HashMaps instead of a tree structure. This isolates the overhead of
//! tree traversal (pointer chasing) from pure HashMap operations.
//!
//! The `find_matches` API matches RadixTree exactly: it takes `LocalBlockHash` values
//! and internally computes the cumulative sequence hashes for lookup.
use std::collections::{HashMap, HashSet};
use crate::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash, OverlapScores, RouterEvent, WorkerId, WorkerWithDpRank,
compute_seq_hash_for_block,
};
/// A flat HashMap-based structure for KV cache indexing.
///
/// Unlike RadixTree which uses a tree of nodes connected by pointers,
/// FlatHashMap uses bidirectional HashMaps. This provides the same
/// find_matches semantics but with better cache locality.
///
/// # Structure
///
/// - `block_to_workers`: Maps ExternalSequenceBlockHash -> Set of workers that have this block.
/// Used for efficient find_matches lookups.
/// - `worker_to_blocks`: Maps Worker -> Set of ExternalSequenceBlockHash they have.
/// Used for remove operations and current_size.
pub struct FlatHashMap {
/// Primary index: block -> workers (for find_matches)
block_to_workers: HashMap<ExternalSequenceBlockHash, HashSet<WorkerWithDpRank>>,
/// Secondary index: worker -> blocks (for remove and current_size)
worker_to_blocks: HashMap<WorkerWithDpRank, HashSet<ExternalSequenceBlockHash>>,
}
impl FlatHashMap {
/// Create a new empty FlatHashMap.
pub fn new() -> Self {
Self {
block_to_workers: HashMap::new(),
worker_to_blocks: HashMap::new(),
}
}
/// Store blocks for a worker.
///
/// Updates both indexes for each block.
pub fn store(&mut self, worker: WorkerWithDpRank, block_hashes: &[ExternalSequenceBlockHash]) {
let worker_blocks = self.worker_to_blocks.entry(worker).or_default();
for &block_hash in block_hashes {
// Add to block -> workers index
self.block_to_workers
.entry(block_hash)
.or_default()
.insert(worker);
// Add to worker -> blocks index
worker_blocks.insert(block_hash);
}
}
/// Remove blocks for a worker.
///
/// Updates both indexes for each block.
pub fn remove(&mut self, worker: WorkerWithDpRank, block_hashes: &[ExternalSequenceBlockHash]) {
let Some(worker_blocks) = self.worker_to_blocks.get_mut(&worker) else {
return;
};
for &block_hash in block_hashes {
// Remove from worker -> blocks index
worker_blocks.remove(&block_hash);
// Remove from block -> workers index
if let Some(workers) = self.block_to_workers.get_mut(&block_hash) {
workers.remove(&worker);
if workers.is_empty() {
self.block_to_workers.remove(&block_hash);
}
}
}
// Clean up empty worker entry
if worker_blocks.is_empty() {
self.worker_to_blocks.remove(&worker);
}
}
/// Find matches for a sequence of local block hashes.
///
/// This has the same signature as `RadixTree::find_matches`: it takes `LocalBlockHash`
/// values and internally computes the cumulative sequence hashes for lookup.
///
/// Returns OverlapScores showing which workers have matching blocks.
/// Stops at first non-match (same semantics as RadixTree).
///
/// # Algorithm
///
/// 1. Compute cumulative sequence hashes from local block hashes
/// 2. For each sequence hash:
/// - Look up which workers have this block
/// - Intersect with previously matching workers (in place)
/// - Track depth for scoring
/// - Stop if no workers remain
///
/// This is O(depth) HashMap lookups + O(num_workers) set operations per level.
pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores {
let mut scores = OverlapScores::new();
if sequence.is_empty() {
return scores;
}
// Compute cumulative sequence hashes from local block hashes
let seq_hashes = compute_seq_hash_for_block(&sequence);
// Track active workers and their match depth
// Workers drop out when they miss a block; their final score is the depth they reached
let mut active_workers: Option<HashSet<WorkerWithDpRank>> = None;
let mut depth = 0u32;
for seq_hash in seq_hashes {
let block_hash = ExternalSequenceBlockHash(seq_hash);
// Look up workers that have this block
let Some(workers) = self.block_to_workers.get(&block_hash) else {
break; // No workers have this block, stop
};
// Intersect with previously active workers (or initialize on first block)
match &mut active_workers {
None => {
// First block: initialize with workers that have it
active_workers = Some(workers.clone());
}
Some(active) => {
// Record score for workers about to drop out (they matched up to current depth)
for &worker in active.iter() {
if !workers.contains(&worker) {
scores.scores.insert(worker, depth);
}
}
// Keep only workers that have this block (in-place, no allocation)
active.retain(|w| workers.contains(w));
}
}
depth += 1;
let active = active_workers.as_ref().unwrap();
if active.is_empty() {
break;
}
// Early exit if only one worker matches
if early_exit && active.len() == 1 {
break;
}
}
// Record final scores for workers that matched all blocks (or until early exit)
if let Some(active) = active_workers {
for worker in active {
scores.scores.insert(worker, depth);
}
}
// Populate tree sizes for workers with scores
for &worker in scores.scores.keys() {
if let Some(blocks) = self.worker_to_blocks.get(&worker) {
scores.tree_sizes.insert(worker, blocks.len());
}
}
scores
}
/// Apply a RouterEvent (for API compatibility with RadixTree).
pub fn apply_event(&mut self, event: RouterEvent) {
let worker = WorkerWithDpRank::new(event.worker_id, event.event.dp_rank);
match event.event.data {
KvCacheEventData::Stored(store_data) => {
let hashes: Vec<_> = store_data.blocks.iter().map(|b| b.block_hash).collect();
self.store(worker, &hashes);
}
KvCacheEventData::Removed(remove_data) => {
self.remove(worker, &remove_data.block_hashes);
}
KvCacheEventData::Cleared => {
self.clear_all_blocks(worker.worker_id);
}
}
}
/// Helper function to remove or clear blocks for a worker.
/// If `keep_worker` is true, the worker remains in lookup with empty blocks.
/// If `keep_worker` is false, the worker is completely removed from lookup.
fn remove_or_clear_worker_blocks(&mut self, worker_id: WorkerId, keep_worker: bool) {
// Collect all WorkerWithDpRank keys that match this worker_id
let workers: Vec<WorkerWithDpRank> = self
.worker_to_blocks
.keys()
.filter(|w| w.worker_id == worker_id)
.copied()
.collect();
for worker in workers {
if let Some(blocks) = self.worker_to_blocks.remove(&worker) {
for block_hash in blocks {
if let Some(workers_set) = self.block_to_workers.get_mut(&block_hash) {
workers_set.remove(&worker);
if workers_set.is_empty() {
self.block_to_workers.remove(&block_hash);
}
}
}
if keep_worker {
// Re-insert worker with empty blocks set to keep it tracked
self.worker_to_blocks.insert(worker, HashSet::new());
}
}
}
}
/// Remove a worker and all their blocks from the index.
pub fn remove_worker(&mut self, worker_id: WorkerId) {
self.remove_or_clear_worker_blocks(worker_id, false);
}
/// Clear all blocks for a worker but keep the worker tracked.
pub fn clear_all_blocks(&mut self, worker_id: WorkerId) {
self.remove_or_clear_worker_blocks(worker_id, true);
}
/// Get all worker IDs currently tracked in the index.
/// Returns unique worker_ids sorted (ignoring dp_rank differences).
pub fn get_workers(&self) -> Vec<WorkerId> {
let mut worker_ids: Vec<WorkerId> = self
.worker_to_blocks
.keys()
.map(|w| w.worker_id)
.collect::<HashSet<_>>()
.into_iter()
.collect();
worker_ids.sort_unstable();
worker_ids
}
/// Dump the index as a series of RouterEvents that can reconstruct the state.
/// For API compatibility with RadixTree.
pub fn dump_tree_as_events(&self) -> Vec<RouterEvent> {
let mut events = Vec::new();
let mut event_id = 0u64;
for (&worker, blocks) in &self.worker_to_blocks {
for &block_hash in blocks {
let event = RouterEvent {
worker_id: worker.worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None, // FlatHashMap doesn't track parent relationships
blocks: vec![KvCacheStoredBlockData {
block_hash,
mm_extra_info: None,
// We don't have the original tokens_hash, use a placeholder
tokens_hash: LocalBlockHash(0),
}],
}),
dp_rank: worker.dp_rank,
},
};
events.push(event);
event_id += 1;
}
}
events
}
/// Returns the total number of (worker, block) pairs stored.
pub fn current_size(&self) -> usize {
self.worker_to_blocks.values().map(|s| s.len()).sum()
}
}
impl Default for FlatHashMap {
fn default() -> Self {
Self::new()
}
}
...@@ -54,21 +54,115 @@ use serde::{Deserialize, Serialize}; ...@@ -54,21 +54,115 @@ use serde::{Deserialize, Serialize};
#[cfg(feature = "metrics")] #[cfg(feature = "metrics")]
use std::sync::OnceLock; use std::sync::OnceLock;
use std::{ use std::{
cell::RefCell, collections::{HashMap, VecDeque},
collections::{HashMap, HashSet, VecDeque},
iter, iter,
rc::Rc,
sync::{Arc, Mutex}, sync::{Arc, Mutex},
thread::JoinHandle, thread::JoinHandle,
time::{Duration, Instant}, time::Duration,
}; };
use tokio::sync::{broadcast, mpsc, oneshot}; use tokio::sync::{broadcast, mpsc, oneshot};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use crate::approx::{BlockEntry, PruneConfig, PruneManager}; use crate::approx::{BlockEntry, PruneConfig, PruneManager};
use crate::flat_hashmap::FlatHashMap;
use crate::protocols::*; use crate::protocols::*;
pub use crate::radix_tree::RadixTree;
use dynamo_tokens::SequenceHash; use dynamo_tokens::SequenceHash;
// ------
// KvIndex - Unified interface for RadixTree and FlatHashMap
// ------
/// Unified interface for KV cache indexing.
///
/// Both `RadixTree` and `FlatHashMap` implement the same core operations:
/// - `find_matches`: Find workers with matching cached blocks
/// - `apply_event`: Apply store/remove events
/// - `remove_worker`: Remove a worker's entries
/// - `get_workers`: Get all tracked workers
/// - `dump_tree_as_events`: Dump state as events
/// - `current_size`: Get total (worker, block) pairs
pub enum KvIndex {
Tree(RadixTree),
Flat(FlatHashMap),
}
impl KvIndex {
/// Create a new KvIndex using RadixTree.
pub fn new_tree() -> Self {
KvIndex::Tree(RadixTree::new())
}
/// Create a new KvIndex using RadixTree with frequency tracking.
pub fn new_tree_with_frequency(expiration_duration: Option<std::time::Duration>) -> Self {
KvIndex::Tree(RadixTree::new_with_frequency(expiration_duration))
}
/// Create a new KvIndex using FlatHashMap.
pub fn new_flat() -> Self {
KvIndex::Flat(FlatHashMap::new())
}
/// Find matches for a sequence of local block hashes.
pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores {
match self {
KvIndex::Tree(tree) => tree.find_matches(sequence, early_exit),
KvIndex::Flat(map) => map.find_matches(sequence, early_exit),
}
}
/// Apply a RouterEvent to the index.
pub fn apply_event(&mut self, event: RouterEvent) -> Result<(), KvCacheEventError> {
match self {
KvIndex::Tree(tree) => tree.apply_event(event),
KvIndex::Flat(map) => {
map.apply_event(event);
Ok(())
}
}
}
/// Remove a worker and all their blocks from the index.
pub fn remove_worker(&mut self, worker_id: WorkerId) {
match self {
KvIndex::Tree(tree) => tree.remove_worker(worker_id),
KvIndex::Flat(map) => map.remove_worker(worker_id),
}
}
/// Clear all blocks for a worker but keep the worker tracked.
pub fn clear_all_blocks(&mut self, worker_id: WorkerId) {
match self {
KvIndex::Tree(tree) => tree.clear_all_blocks(worker_id),
KvIndex::Flat(map) => map.clear_all_blocks(worker_id),
}
}
/// Get all worker IDs currently tracked.
pub fn get_workers(&self) -> Vec<WorkerId> {
match self {
KvIndex::Tree(tree) => tree.get_workers(),
KvIndex::Flat(map) => map.get_workers(),
}
}
/// Dump the index as a series of RouterEvents.
pub fn dump_tree_as_events(&self) -> Vec<RouterEvent> {
match self {
KvIndex::Tree(tree) => tree.dump_tree_as_events(),
KvIndex::Flat(map) => map.dump_tree_as_events(),
}
}
/// Returns the total number of (worker, block) pairs stored.
pub fn current_size(&self) -> usize {
match self {
KvIndex::Tree(tree) => tree.current_size(),
KvIndex::Flat(map) => map.current_size(),
}
}
}
/// Errors that can occur in the KV Router. /// Errors that can occur in the KV Router.
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum KvRouterError { pub enum KvRouterError {
...@@ -85,47 +179,6 @@ pub enum KvRouterError { ...@@ -85,47 +179,6 @@ pub enum KvRouterError {
PruneFailed(String), PruneFailed(String),
} }
/// Errors that can occur during KV Cache Event processing.
#[derive(Debug, thiserror::Error)]
pub enum KvCacheEventError {
#[error("Failed to find parent block")]
ParentBlockNotFound,
#[error("Failed to find block")]
BlockNotFound,
#[error("Invalid block sequence")]
InvalidBlockSequence,
}
/// A shared reference to a [`RadixBlock`].
type SharedRadixBlock = Rc<RefCell<RadixBlock>>;
/// A [`KvCacheEvent`] on a specific LLM worker denoted by [`WorkerId`].
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RouterEvent {
/// The ID of the worker emitting the event.
pub worker_id: WorkerId,
/// The cache event associated with the worker.
pub event: KvCacheEvent,
}
impl RouterEvent {
/// Create a new `RouterEvent`.
///
/// ### Arguments
///
/// * `worker_id` - The ID of the worker emitting the event.
/// * `event` - The cache event.
///
/// ### Returns
///
/// A new `RouterEvent`.
pub fn new(worker_id: WorkerId, event: KvCacheEvent) -> Self {
Self { worker_id, event }
}
}
// ------- // -------
// Distributed router - Worker KV Query types // Distributed router - Worker KV Query types
// ------- // -------
...@@ -174,450 +227,6 @@ impl MaybeError for WorkerKvQueryResponse { ...@@ -174,450 +227,6 @@ impl MaybeError for WorkerKvQueryResponse {
} }
} }
/// A block in the Radix Tree.
#[derive(Debug)]
struct RadixBlock {
/// A map of child blocks, keyed by their local block hash.
children: HashMap<LocalBlockHash, SharedRadixBlock>,
/// The set of workers that have this block cached.
workers: HashSet<WorkerWithDpRank>,
/// The external sequence block hash for this block (None for root).
/// This is the same for all workers under the simplifying assumption.
block_hash: Option<ExternalSequenceBlockHash>,
/// A buffer of times that this block was last traversed
recent_uses: VecDeque<Instant>,
}
impl RadixBlock {
/// Create a new `RadixBlock` (used for root node).
///
/// ### Returns
///
/// A new `RadixBlock` with no block_hash.
pub fn new() -> Self {
Self {
children: HashMap::new(),
workers: HashSet::new(),
block_hash: None,
recent_uses: VecDeque::new(),
}
}
/// Create a new `RadixBlock` with a specific block hash.
///
/// ### Returns
///
/// A new `RadixBlock` with the given block_hash.
pub fn with_hash(block_hash: ExternalSequenceBlockHash) -> Self {
Self {
children: HashMap::new(),
workers: HashSet::new(),
block_hash: Some(block_hash),
recent_uses: VecDeque::new(),
}
}
}
pub struct RadixTree {
/// This is the root of the radix/prefix tree
/// This will only contain root blocks
root: SharedRadixBlock,
/// Per-worker lookup table for O(1) block access.
/// Maps worker -> (block_hash -> block).
lookup: HashMap<WorkerWithDpRank, HashMap<ExternalSequenceBlockHash, SharedRadixBlock>>,
/// The time buffer the radix tree should check when considering frequence of block accesses
expiration_duration: Option<Duration>,
}
impl Default for RadixTree {
fn default() -> Self {
Self::new()
}
}
// Dropping Radix blocks can cause a cascade of drops that can overflow the stack.
// This custom drop implementation avoids this using an iterative approach.
impl Drop for RadixTree {
fn drop(&mut self) {
let mut stack: Vec<SharedRadixBlock> = Vec::new();
// Break root -> children edge up front
{
let mut root = self.root.borrow_mut();
stack.extend(root.children.drain().map(|(_, v)| v));
}
// Remove all lookup references (they may include blocks not reachable from root)
for (_, worker_blocks) in self.lookup.drain() {
stack.extend(worker_blocks.into_values());
}
// Iteratively free any uniquely-owned blocks without recursion
while let Some(block) = stack.pop() {
match Rc::try_unwrap(block) {
Ok(cell) => {
// We own the cell, so we can take inner and it will drop after this block.
let mut inner: RadixBlock = cell.into_inner();
stack.extend(inner.children.drain().map(|(_, v)| v));
}
Err(rc) => {
// We don't own the cell, just call drop on it.
drop(rc);
}
}
}
}
}
impl RadixTree {
/// Create a new `RadixTree`.
///
/// ### Returns
///
/// A new `RadixTree`.
pub fn new_with_frequency(expiration_duration: Option<Duration>) -> Self {
Self {
root: Rc::new(RefCell::new(RadixBlock::new())),
lookup: HashMap::new(),
expiration_duration,
}
}
pub fn new() -> Self {
Self::new_with_frequency(None)
}
/// Traverse the radix tree to find the best match for a given sequence of [`LocalBlockHash`]es.
///
/// ### Arguments
///
/// * `sequence` - A vector of `LocalBlockHash` representing the sequence to match.
/// * `early_exit` - A boolean indicating whether to exit early if a single match is found.
///
/// ### Returns
///
/// An `OverlapScores` representing the match scores.
pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores {
let mut scores = OverlapScores::new();
let mut current = self.root.clone();
let now = Instant::now();
tracing::trace!(
"RadixTree::find_matches: looking for sequence={:?}",
sequence.iter().map(|h| h.0).collect::<Vec<_>>()
);
for (idx, block_hash) in sequence.iter().enumerate() {
let next_block = {
let current_borrow = current.borrow();
current_borrow.children.get(block_hash).cloned()
};
if let Some(block) = next_block {
scores.update_scores(block.borrow().workers.iter());
if let Some(expiration_duration) = self.expiration_duration {
let mut block_mut = block.borrow_mut();
while let Some(access_time) = block_mut.recent_uses.front() {
if now.duration_since(*access_time) > expiration_duration {
block_mut.recent_uses.pop_front();
} else {
break;
}
}
scores.add_frequency(block_mut.recent_uses.len());
block_mut.recent_uses.push_back(now);
}
if early_exit && block.borrow().workers.len() == 1 {
break;
}
current = block;
} else {
tracing::trace!(
"RadixTree::find_matches: block not found at index {} for hash {}",
idx,
block_hash.0
);
break;
}
}
tracing::trace!("RadixTree::find_matches: final scores={:?}", scores.scores);
// Populate tree sizes for all workers that have scores
for worker in scores.scores.keys() {
let tree_size = self
.lookup
.get(worker)
.expect("worker in scores must exist in lookup table")
.len();
scores.tree_sizes.insert(*worker, tree_size);
}
scores
}
/// Apply a [`RouterEvent`] to the radix tree.
///
/// ### Arguments
///
/// * `event` - The `RouterEvent` to apply.
pub fn apply_event(&mut self, event: RouterEvent) -> Result<(), KvCacheEventError> {
let (worker_id, kv_event) = (event.worker_id, event.event);
let (id, op) = (kv_event.event_id, kv_event.data);
// Construct WorkerWithDpRank from worker_id and dp_rank from the event
let worker = WorkerWithDpRank::new(worker_id, kv_event.dp_rank);
tracing::trace!(id, "RadixTree::apply_event: Store operation: {:?}", op);
let worker_lookup = self.lookup.entry(worker).or_default();
match op {
KvCacheEventData::Stored(op) => {
// find the parent block from this worker's lookup
let mut current = match op.parent_hash {
Some(parent) => match worker_lookup.get(&parent) {
Some(current) => current.clone(),
None => {
tracing::warn!(
worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank,
id,
parent_hash = ?op.parent_hash,
num_blocks = op.blocks.len(),
"Failed to find parent block; skipping store operation"
);
return Err(KvCacheEventError::ParentBlockNotFound);
}
},
None => self.root.clone(),
};
for block_data in op.blocks {
let mut parent_mut = current.borrow_mut();
let child = match parent_mut.children.get(&block_data.tokens_hash) {
Some(block) => {
// Verify our simplifying assumption: block_hash is uniform across workers
if block.borrow().block_hash != Some(block_data.block_hash) {
tracing::warn!(
expected = ?block_data.block_hash,
actual = ?block.borrow().block_hash,
"block_hash mismatch: sequence hashes should be uniform across workers"
);
}
block.clone()
}
None => {
// create new block or reuse existing from worker's lookup
let new_block = worker_lookup
.get(&block_data.block_hash)
.cloned()
.unwrap_or_else(|| {
Rc::new(RefCell::new(RadixBlock::with_hash(
block_data.block_hash,
)))
});
// insert into radix tree
parent_mut
.children
.insert(block_data.tokens_hash, new_block.clone());
new_block
}
};
// Update child and check for self referential blocks
{
// Try to borrow the child mutably - if it fails, it's already borrowed
// which means a self referencing block.
let mut child_mut = match child.try_borrow_mut() {
Ok(b) => b,
Err(_) => {
tracing::warn!(
worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank,
id,
block_hash = ?block_data.block_hash,
"Detected self referencing block in store event; rejecting sequence"
);
return Err(KvCacheEventError::InvalidBlockSequence);
}
};
// add our worker to the block
child_mut.workers.insert(worker);
}
// add the block to the worker's lookup table
worker_lookup.insert(block_data.block_hash, child.clone());
// drop child so we can shift current to this block
drop(parent_mut);
current = child;
}
Ok(())
}
KvCacheEventData::Removed(remove) => {
let mut kv_cache_err: Option<KvCacheEventError> = None;
for block in remove.block_hashes {
// lookup block in worker's table
let entry = match worker_lookup.get(&block) {
Some(entry) => entry.clone(),
None => {
tracing::warn!(
worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank,
id,
block_hash = ?block,
"Failed to find block to remove; skipping remove operation"
);
// Kv cache removed events may be batched; we should try to apply all
// operations in the batch before returning an error. Return the first
// error.
if kv_cache_err.is_none() {
kv_cache_err = Some(KvCacheEventError::BlockNotFound);
}
continue;
}
};
let mut guard = entry.borrow_mut();
guard.workers.remove(&worker);
if guard.workers.is_empty() {
// if no workers are using this block, that is true for all children
guard.children.clear();
}
// remove the block from the worker's lookup table
worker_lookup.remove(&block);
}
kv_cache_err.map_or(Ok(()), Err)
}
KvCacheEventData::Cleared => {
self.clear_all_blocks(worker.worker_id);
Ok(())
}
}
}
/// Helper function to remove or clear blocks for a worker.
/// If `keep_worker` is true, the worker remains in lookup with empty blocks.
/// If `keep_worker` is false, the worker is completely removed from lookup.
fn remove_or_clear_worker_blocks(&mut self, worker_id: WorkerId, keep_worker: bool) {
// Collect all WorkerWithDpRank keys that match this worker_id
let workers: Vec<WorkerWithDpRank> = self
.lookup
.keys()
.filter(|w| w.worker_id == worker_id)
.copied()
.collect();
for worker in workers {
if let Some((worker_key, blocks)) = self.lookup.remove_entry(&worker) {
for (_, block) in blocks {
block.borrow_mut().workers.remove(&worker);
// If no workers are using this block, that is true for all children
if block.borrow().workers.is_empty() {
block.borrow_mut().children.clear();
}
}
if keep_worker {
// Re-insert worker with empty blocks map to keep it tracked
self.lookup.insert(worker_key, HashMap::new());
}
}
}
}
pub fn remove_worker(&mut self, worker_id: WorkerId) {
self.remove_or_clear_worker_blocks(worker_id, false);
}
pub fn clear_all_blocks(&mut self, worker_id: WorkerId) {
self.remove_or_clear_worker_blocks(worker_id, true);
}
/// Get all worker IDs currently tracked in the radix tree.
/// Returns unique worker_ids (ignoring dp_rank differences).
pub fn get_workers(&self) -> Vec<WorkerId> {
let mut worker_ids: Vec<WorkerId> = self.lookup.keys().map(|w| w.worker_id).collect();
worker_ids.sort_unstable();
worker_ids.dedup();
worker_ids
}
/// Dump the radix tree as a series of RouterEvents that can reconstruct the tree.
/// Uses BFS traversal to ensure that the tree reconstruction is unique,
/// though the exact event ordering will be lost.
pub fn dump_tree_as_events(&self) -> Vec<RouterEvent> {
tracing::debug!(
"Dumping radix tree as events (contains information about {:?} workers)",
self.lookup.len()
);
let mut events = Vec::new();
let mut event_id = 0u64;
// Queue entries: (current_block, parent_hash, tokens_hash)
let mut queue = VecDeque::new();
// Process root's children first
let root_borrow = self.root.borrow();
for (tokens_hash, child_block) in &root_borrow.children {
queue.push_back((child_block.clone(), None, *tokens_hash));
}
drop(root_borrow);
while let Some((current_block, parent_hash, tokens_hash)) = queue.pop_front() {
let current_borrow = current_block.borrow();
// Get this block's hash (same for all workers)
let block_hash = current_borrow
.block_hash
.expect("non-root block must have block_hash");
// For each worker that has this block
for worker in &current_borrow.workers {
// Create a store event for this worker
let event = RouterEvent {
worker_id: worker.worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: vec![KvCacheStoredBlockData {
block_hash,
mm_extra_info: None,
tokens_hash,
}],
}),
dp_rank: worker.dp_rank,
},
};
events.push(event);
event_id += 1;
}
// Enqueue children with this block's hash as their parent
for (child_tokens_hash, child_block) in &current_borrow.children {
queue.push_back((child_block.clone(), Some(block_hash), *child_tokens_hash));
}
}
events
}
pub fn current_size(&self) -> usize {
self.lookup.values().map(|m| m.len()).sum()
}
}
/// Metrics for the KV Indexer. /// Metrics for the KV Indexer.
#[derive(Clone)] #[derive(Clone)]
pub struct KvIndexerMetrics { pub struct KvIndexerMetrics {
...@@ -718,63 +327,6 @@ impl KvIndexerMetrics { ...@@ -718,63 +327,6 @@ impl KvIndexerMetrics {
} }
} }
/// Scores representing the overlap of workers (with their dp_rank).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OverlapScores {
// map of worker (with dp_rank) to score
pub scores: HashMap<WorkerWithDpRank, u32>,
// List of frequencies that the blocks have been accessed. Entries with value 0 are omitted.
pub frequencies: Vec<usize>,
// Map of worker to their tree size (number of blocks in the tree for that worker)
pub tree_sizes: HashMap<WorkerWithDpRank, usize>,
}
impl Default for OverlapScores {
fn default() -> Self {
Self::new()
}
}
impl OverlapScores {
/// Create a new `OverlapScores`.
///
/// ### Returns
///
/// A new `OverlapScores`.
pub fn new() -> Self {
Self {
scores: HashMap::new(),
frequencies: Vec::with_capacity(32),
tree_sizes: HashMap::new(),
}
}
/// Update the scores with a set of workers.
///
/// ### Arguments
///
/// * `workers` - An iterator over `WorkerWithDpRank` references.
pub fn update_scores<'a, I>(&mut self, workers: I)
where
I: IntoIterator<Item = &'a WorkerWithDpRank>,
{
for worker in workers {
let score = self.scores.entry(*worker).or_insert(0);
*score += 1;
}
}
/// Add an entry in the frequency list.
pub fn add_frequency(&mut self, frequency: usize) {
if frequency != 0 {
self.frequencies
.last()
.inspect(|elem| debug_assert!(**elem >= frequency));
self.frequencies.push(frequency);
}
}
}
/// A request to find matches in the Radix Tree. /// A request to find matches in the Radix Tree.
pub struct MatchRequest { pub struct MatchRequest {
/// A vector of `LocalBlockHash` representing the sequence to match. /// A vector of `LocalBlockHash` representing the sequence to match.
...@@ -2047,640 +1599,63 @@ impl KvIndexerSharded { ...@@ -2047,640 +1599,63 @@ impl KvIndexerSharded {
.map_err(|_| KvRouterError::IndexerDroppedRequest)?; .map_err(|_| KvRouterError::IndexerDroppedRequest)?;
Ok(()) Ok(())
} }
} }
impl Drop for KvIndexerSharded {
fn drop(&mut self) {
self.shutdown();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocols::{ExternalSequenceBlockHash, LocalBlockHash};
use rstest::rstest;
use rstest_reuse::{self, *};
use tokio::time;
use tokio_util::sync::CancellationToken;
fn setup() {
// Logging init removed to avoid dynamo-runtime dependency
}
fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> {
hashes
.iter()
.map(|i| KvCacheStoredBlockData {
tokens_hash: LocalBlockHash(*i),
block_hash: ExternalSequenceBlockHash(*i * 100),
mm_extra_info: None,
})
.collect()
}
fn add_blocks(
hashes: Vec<u64>,
parent_hash: Option<ExternalSequenceBlockHash>,
) -> KvCacheEventData {
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: make_blocks(hashes),
})
}
fn create_store_event(
worker_id: WorkerId,
event_id: u64,
hashes: Vec<u64>,
parent: Option<ExternalSequenceBlockHash>,
) -> RouterEvent {
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id,
data: add_blocks(hashes, parent),
dp_rank: 0,
},
}
}
fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec<u64>) -> RouterEvent {
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: hashes
.iter()
.map(|i| ExternalSequenceBlockHash(*i * 100))
.collect(),
}),
dp_rank: 0,
},
}
}
#[test]
fn test_radix_tree() {
setup();
let mut trie = RadixTree::new();
let worker_1 = 0;
let worker_2 = 1;
trie.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None))
.unwrap();
let scores = trie.find_matches(
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap(),
&3
);
assert_eq!(trie.lookup.len(), 1);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.len(),
3
);
assert_eq!(trie.root.borrow().workers.len(), 0);
assert_eq!(trie.root.borrow().children.len(), 1);
assert_eq!(
trie.root
.borrow()
.children
.get(&LocalBlockHash(1))
.unwrap()
.borrow()
.workers
.len(),
1
);
assert_eq!(
trie.root
.borrow()
.children
.get(&LocalBlockHash(1))
.unwrap()
.borrow()
.children
.len(),
1
);
trie.apply_event(create_store_event(worker_2, 1, vec![1, 4, 5], None))
.unwrap();
let scores = trie.find_matches(
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap(),
&3
);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap(),
&1
);
assert_eq!(trie.lookup.len(), 2);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.len(),
3
);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap()
.len(),
3
);
assert_eq!(trie.root.borrow().workers.len(), 0);
assert_eq!(trie.root.borrow().children.len(), 1);
assert_eq!(
trie.root
.borrow()
.children
.get(&LocalBlockHash(1))
.unwrap()
.borrow()
.workers
.len(),
2
);
assert_eq!(
trie.root
.borrow()
.children
.get(&LocalBlockHash(1))
.unwrap()
.borrow()
.children
.len(),
2
);
trie.apply_event(create_remove_event(worker_2, 2, vec![5]))
.unwrap();
assert_eq!(trie.lookup.len(), 2);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.len(),
3
);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap()
.len(),
2
);
assert_eq!(trie.root.borrow().workers.len(), 0);
assert_eq!(trie.root.borrow().children.len(), 1);
assert_eq!(
trie.root
.borrow()
.children
.get(&LocalBlockHash(1))
.unwrap()
.borrow()
.workers
.len(),
2
);
assert_eq!(
trie.root
.borrow()
.children
.get(&LocalBlockHash(1))
.unwrap()
.borrow()
.children
.len(),
2
);
trie.apply_event(create_remove_event(worker_2, 3, vec![4]))
.unwrap();
assert_eq!(trie.lookup.len(), 2);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.len(),
3
);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap()
.len(),
1
);
assert_eq!(trie.root.borrow().workers.len(), 0);
assert_eq!(trie.root.borrow().children.len(), 1);
assert_eq!(
trie.root
.borrow()
.children
.get(&LocalBlockHash(1))
.unwrap()
.borrow()
.workers
.len(),
2
);
assert_eq!(
trie.root
.borrow()
.children
.get(&LocalBlockHash(1))
.unwrap()
.borrow()
.children
.len(),
2
);
trie.apply_event(create_store_event(
worker_2,
4,
vec![2, 6, 7],
Some(ExternalSequenceBlockHash(100)),
))
.unwrap();
let scores = trie.find_matches(
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap(),
&3
);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap(),
&2
);
assert_eq!(trie.lookup.len(), 2);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.len(),
3
);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap()
.len(),
4
);
assert_eq!(trie.root.borrow().workers.len(), 0);
assert_eq!(trie.root.borrow().children.len(), 1);
assert_eq!(
trie.root
.borrow()
.children
.get(&LocalBlockHash(1))
.unwrap()
.borrow()
.workers
.len(),
2
);
assert_eq!(
trie.root
.borrow()
.children
.get(&LocalBlockHash(1))
.unwrap()
.borrow()
.children
.len(),
2
);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.get(&ExternalSequenceBlockHash(200))
.unwrap()
.borrow()
.workers
.len(),
2
);
}
#[test]
fn test_radix_tree_apply_event_errors() {
let mut trie = RadixTree::new();
let worker_0 = 0;
// Parent block not found
let result = trie.apply_event(create_store_event(
worker_0,
0,
vec![1, 2, 3],
Some(ExternalSequenceBlockHash(12345)),
));
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
KvCacheEventError::ParentBlockNotFound
));
// Block not found for remove event.
let result = trie.apply_event(create_remove_event(worker_0, 0, vec![1, 2, 3]));
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
KvCacheEventError::BlockNotFound
));
// Parent appears in blocks: parent=1, blocks=[1, 2, 3]
// This should be rejected as block 1 (hash 100) is the parent - this is
// a self referencing block.
trie.apply_event(create_store_event(worker_0, 4, vec![1], None))
.unwrap();
let result = trie.apply_event(create_store_event(
worker_0,
5,
vec![1, 2, 3],
Some(ExternalSequenceBlockHash(100)),
));
assert!(matches!(
result.unwrap_err(),
KvCacheEventError::InvalidBlockSequence
));
}
#[test]
fn test_radix_tree_large_stores() {
setup();
let mut trie = RadixTree::new();
for i in 0..=16 {
let len = 1 << i;
let worker_id = i;
tracing::info!("Testing sequence of length {}", len);
let sequence = (1..len + 1).collect::<Vec<u64>>();
trie.apply_event(create_store_event(worker_id, 1, sequence, None))
.unwrap();
}
}
#[test]
fn test_remove_worker() {
setup();
let mut trie = RadixTree::new();
let worker_0 = 0;
let worker_1 = 1;
assert!(
trie.find_matches(vec![LocalBlockHash(0)], false)
.scores
.is_empty()
);
trie.apply_event(create_store_event(worker_0, 0, vec![0], None))
.unwrap();
trie.apply_event(create_store_event(worker_1, 0, vec![0], None))
.unwrap();
let result = trie.find_matches(vec![LocalBlockHash(0)], false).scores;
assert!(
result.len() == 2
&& result[&WorkerWithDpRank::from_worker_id(worker_0)] == 1
&& result[&WorkerWithDpRank::from_worker_id(worker_1)] == 1
);
trie.remove_worker(worker_0);
let result = trie.find_matches(vec![LocalBlockHash(0)], false).scores;
assert!(result.len() == 1 && result[&WorkerWithDpRank::from_worker_id(worker_1)] == 1);
}
#[test]
fn test_clear_all_blocks() {
let mut trie = RadixTree::new();
let worker_0 = 0;
let worker_1 = 1;
assert!(
trie.find_matches(vec![LocalBlockHash(0)], false)
.scores
.is_empty()
);
// Test clearing an empty worker
trie.clear_all_blocks(worker_0);
assert!(
!trie
.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
// Test clearing a worker with shared blocks
trie.apply_event(create_store_event(worker_0, 0, vec![0, 1, 3], None))
.unwrap();
trie.apply_event(create_store_event(worker_1, 0, vec![0, 2, 3], None))
.unwrap();
let result = trie.find_matches(vec![LocalBlockHash(0)], false).scores;
assert!(
result.len() == 2
&& result[&WorkerWithDpRank::from_worker_id(worker_0)] == 1
&& result[&WorkerWithDpRank::from_worker_id(worker_1)] == 1
);
trie.clear_all_blocks(worker_0);
assert!(
trie.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.unwrap()
.is_empty()
);
let result = trie
.find_matches(vec![LocalBlockHash(0), LocalBlockHash(2)], false)
.scores;
assert_eq!(result.len(), 1);
assert_eq!(result[&WorkerWithDpRank::from_worker_id(worker_1)], 2);
let result = trie
.find_matches(
vec![LocalBlockHash(0), LocalBlockHash(1), LocalBlockHash(3)],
false,
)
.scores;
assert_eq!(result.len(), 1);
assert_eq!(result[&WorkerWithDpRank::from_worker_id(worker_1)], 1);
// Test re-adding blocks after clearing worker
trie.apply_event(create_store_event(worker_0, 0, vec![4, 5], None))
.unwrap();
let result = trie
.find_matches(vec![LocalBlockHash(4), LocalBlockHash(5)], false)
.scores;
assert_eq!(result.len(), 1);
assert_eq!(result[&WorkerWithDpRank::from_worker_id(worker_0)], 2);
// Test multiple clears
trie.clear_all_blocks(worker_0);
trie.clear_all_blocks(worker_0);
assert!(
trie.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
// Test clearing all workers
trie.clear_all_blocks(worker_0);
trie.clear_all_blocks(worker_1);
assert!(!trie.lookup.is_empty());
assert!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.unwrap()
.is_empty()
);
assert!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.is_empty()
);
// Test clearing a worker that has been removed
trie.apply_event(create_store_event(worker_0, 0, vec![6], None))
.unwrap();
trie.apply_event(create_store_event(worker_1, 0, vec![6], None))
.unwrap();
trie.remove_worker(worker_0);
trie.clear_all_blocks(worker_0);
assert!(
!trie
.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
let result = trie.find_matches(vec![LocalBlockHash(6)], false).scores;
assert_eq!(result.len(), 1);
assert_eq!(result[&WorkerWithDpRank::from_worker_id(worker_1)], 1);
// Test clearing a worker that doesn't exist
let worker_fake = 2;
assert!(
!trie
.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_fake))
);
trie.clear_all_blocks(worker_fake);
assert!(
!trie
.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_fake))
);
assert!(
trie.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_1))
);
let result = trie.find_matches(vec![LocalBlockHash(6)], false).scores;
assert_eq!(result.len(), 1);
assert_eq!(result[&WorkerWithDpRank::from_worker_id(worker_1)], 1);
}
#[test]
fn test_early_stopping() {
setup();
let mut trie = RadixTree::new();
let worker_0 = 0;
let worker_1 = 1;
trie.apply_event(create_store_event(worker_0, 0, vec![0, 1, 2], None))
.unwrap();
trie.apply_event(create_store_event(worker_1, 0, vec![0], None))
.unwrap();
let result = trie impl Drop for KvIndexerSharded {
.find_matches( fn drop(&mut self) {
vec![LocalBlockHash(0), LocalBlockHash(1), LocalBlockHash(2)], self.shutdown();
true, }
) }
.scores;
assert!( #[cfg(test)]
result.len() == 2 mod tests {
&& result[&WorkerWithDpRank::from_worker_id(worker_0)] == 2 use super::*;
&& result[&WorkerWithDpRank::from_worker_id(worker_1)] == 1 use crate::protocols::{ExternalSequenceBlockHash, LocalBlockHash};
); use rstest::rstest;
use rstest_reuse::{self, *};
use std::time::Instant;
use tokio::time;
use tokio_util::sync::CancellationToken;
let result = trie fn setup() {
.find_matches(vec![LocalBlockHash(0), LocalBlockHash(1)], true) // Logging init removed to avoid dynamo-runtime dependency
.scores;
assert!(
result.len() == 2
&& result[&WorkerWithDpRank::from_worker_id(worker_0)] == 2
&& result[&WorkerWithDpRank::from_worker_id(worker_1)] == 1
);
} }
#[rstest] fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> {
#[case(11)] hashes
#[case(32)] .iter()
#[case(64)] .map(|i| KvCacheStoredBlockData {
fn test_compute_block_hash_for_seq(#[case] kv_block_size: u32) { tokens_hash: LocalBlockHash(*i),
setup(); block_hash: ExternalSequenceBlockHash(*i * 100),
// create a sequence of 64 elements mm_extra_info: None,
let sequence = (0..kv_block_size).collect::<Vec<u32>>(); })
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None); .collect()
assert_eq!(hashes.len(), 1); }
// create a sequence of 65 elements fn add_blocks(
let sequence = (0..(kv_block_size + 1)).collect::<Vec<u32>>(); hashes: Vec<u64>,
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None); parent_hash: Option<ExternalSequenceBlockHash>,
assert_eq!(hashes.len(), 1); ) -> KvCacheEventData {
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: make_blocks(hashes),
})
}
// create a sequence of 129 elements fn create_store_event(
let sequence = (0..(2 * kv_block_size + 1)).collect::<Vec<u32>>(); worker_id: WorkerId,
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None); event_id: u64,
assert_eq!(hashes.len(), 2); hashes: Vec<u64>,
parent: Option<ExternalSequenceBlockHash>,
) -> RouterEvent {
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id,
data: add_blocks(hashes, parent),
dp_rank: 0,
},
}
} }
fn make_indexer( fn make_indexer(
...@@ -2874,54 +1849,6 @@ mod tests { ...@@ -2874,54 +1849,6 @@ mod tests {
assert_eq!(overlap.frequencies, vec![3, 3, 3, 2]); assert_eq!(overlap.frequencies, vec![3, 3, 3, 2]);
} }
#[test]
fn test_router_event_new() {
setup();
let worker_id = 0;
let kv_cache_event = KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(0),
mm_extra_info: None,
tokens_hash: LocalBlockHash(13226331709069118873),
}],
}),
dp_rank: 0,
};
let router_event = RouterEvent::new(worker_id, kv_cache_event);
assert_eq!(router_event.worker_id, worker_id);
assert_eq!(router_event.event.event_id, 1);
if let KvCacheEventData::Stored(store_op) = &router_event.event.data {
assert_eq!(store_op.blocks.len(), 1);
assert_eq!(
store_op.blocks[0].tokens_hash,
compute_block_hash(b"test data")
);
assert_eq!(store_op.blocks[0].block_hash, ExternalSequenceBlockHash(0));
} else {
panic!("Expected KvCacheEventData::Stored");
}
}
#[test]
fn test_radix_tree_default() {
setup();
let radix_tree: RadixTree = Default::default();
assert!(radix_tree.root.borrow().children.is_empty());
assert!(radix_tree.root.borrow().workers.is_empty());
assert!(radix_tree.lookup.is_empty());
}
#[test]
fn test_overlap_scores_default() {
setup();
let overlap_scores: OverlapScores = Default::default();
assert!(overlap_scores.scores.is_empty());
}
#[tokio::test] #[tokio::test]
async fn test_dump_tree_as_events_round_trip() { async fn test_dump_tree_as_events_round_trip() {
setup(); setup();
...@@ -3142,126 +2069,6 @@ mod tests { ...@@ -3142,126 +2069,6 @@ mod tests {
); );
} }
#[test]
fn test_remove_worker_verifies_hash_removal() {
setup();
let mut trie = RadixTree::new();
let worker_0 = 0;
let worker_1 = 1;
let worker_2 = 2;
// Add blocks for multiple workers
trie.apply_event(create_store_event(worker_0, 0, vec![1, 2, 3], None))
.unwrap();
trie.apply_event(create_store_event(worker_1, 0, vec![1, 2, 3], None))
.unwrap();
trie.apply_event(create_store_event(worker_2, 0, vec![1, 4, 5], None))
.unwrap();
// Verify worker_0 has 3 blocks in lookup
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.unwrap()
.len(),
3
);
// Verify that blocks have the correct workers
let block_1 = trie
.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.unwrap()
.get(&ExternalSequenceBlockHash(100))
.unwrap();
assert_eq!(block_1.borrow().workers.len(), 3); // worker_0, worker_1, and worker_2 (all have hash 1)
assert!(
block_1
.borrow()
.workers
.contains(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert!(
block_1
.borrow()
.workers
.contains(&WorkerWithDpRank::from_worker_id(worker_1))
);
assert!(
block_1
.borrow()
.workers
.contains(&WorkerWithDpRank::from_worker_id(worker_2))
);
// Remove worker_0
trie.remove_worker(worker_0);
// Verify worker_0 is completely removed from lookup table
assert!(
!trie
.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert_eq!(trie.lookup.len(), 2);
// Verify that worker_0's hash is removed from the workers set
let block_1 = trie
.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.get(&ExternalSequenceBlockHash(100))
.unwrap();
assert_eq!(block_1.borrow().workers.len(), 2); // worker_1 and worker_2 remain
assert!(
!block_1
.borrow()
.workers
.contains(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert!(
block_1
.borrow()
.workers
.contains(&WorkerWithDpRank::from_worker_id(worker_1))
);
assert!(
block_1
.borrow()
.workers
.contains(&WorkerWithDpRank::from_worker_id(worker_2))
);
// Verify that blocks with no remaining workers have their children cleared
// This tests the optimization where empty blocks clear their children
let block_2 = trie
.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.get(&ExternalSequenceBlockHash(200))
.unwrap();
assert_eq!(block_2.borrow().workers.len(), 1); // only worker_1
assert!(
block_2
.borrow()
.workers
.contains(&WorkerWithDpRank::from_worker_id(worker_1))
);
// Verify match results no longer include worker_0
let result = trie
.find_matches(
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
)
.scores;
assert_eq!(result.len(), 2);
assert!(!result.contains_key(&WorkerWithDpRank::from_worker_id(worker_0)));
assert!(result.contains_key(&WorkerWithDpRank::from_worker_id(worker_1)));
assert!(result.contains_key(&WorkerWithDpRank::from_worker_id(worker_2)));
}
// LocalKvIndexer tests // LocalKvIndexer tests
fn make_indexer_with_events(ids: &[u64]) -> LocalKvIndexer { fn make_indexer_with_events(ids: &[u64]) -> LocalKvIndexer {
let indexer = LocalKvIndexer::new( let indexer = LocalKvIndexer::new(
...@@ -3575,3 +2382,302 @@ mod tests { ...@@ -3575,3 +2382,302 @@ mod tests {
} }
} }
} }
/// Tests for KvIndex enum (parametrized over RadixTree and FlatHashMap variants).
#[cfg(test)]
mod kv_index_tests {
use super::*;
use crate::protocols::{ExternalSequenceBlockHash, LocalBlockHash, compute_seq_hash_for_block};
use rstest::rstest;
use rstest_reuse::{self, *};
/// Create a store event with proper sequence hashes computed from local hashes.
fn make_store_event(worker_id: u64, local_hashes: &[u64]) -> RouterEvent {
let local_block_hashes: Vec<LocalBlockHash> =
local_hashes.iter().map(|&h| LocalBlockHash(h)).collect();
let seq_hashes = compute_seq_hash_for_block(&local_block_hashes);
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: local_block_hashes
.iter()
.zip(seq_hashes.iter())
.map(|(&local, &seq)| KvCacheStoredBlockData {
tokens_hash: local,
block_hash: ExternalSequenceBlockHash(seq),
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
}
}
/// Create a remove event for blocks with given local hashes.
fn make_remove_event(worker_id: u64, local_hashes: &[u64]) -> RouterEvent {
let local_block_hashes: Vec<LocalBlockHash> =
local_hashes.iter().map(|&h| LocalBlockHash(h)).collect();
let seq_hashes = compute_seq_hash_for_block(&local_block_hashes);
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: seq_hashes
.iter()
.map(|&h| ExternalSequenceBlockHash(h))
.collect(),
}),
dp_rank: 0,
},
}
}
#[template]
#[rstest]
fn kv_index_template(#[values("tree", "flat")] variant: &str) {}
fn make_kv_index(variant: &str) -> KvIndex {
match variant {
"tree" => KvIndex::new_tree(),
"flat" => KvIndex::new_flat(),
_ => panic!("Unknown variant: {}", variant),
}
}
#[apply(kv_index_template)]
fn test_store_and_find(variant: &str) {
let mut index = make_kv_index(variant);
// Store a sequence for worker 0
index.apply_event(make_store_event(0, &[1, 2, 3])).unwrap();
assert_eq!(index.current_size(), 3);
// Find matches using local hashes
let scores = index.find_matches(
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
);
assert_eq!(scores.scores.len(), 1);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
}
#[apply(kv_index_template)]
fn test_partial_match(variant: &str) {
let mut index = make_kv_index(variant);
// Store [1, 2, 3] for worker 0
index.apply_event(make_store_event(0, &[1, 2, 3])).unwrap();
// Find matches for [1, 2, 999] - should match first 2 then stop
let scores = index.find_matches(
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(999)],
false,
);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
}
#[apply(kv_index_template)]
fn test_remove(variant: &str) {
let mut index = make_kv_index(variant);
// Store sequence for worker 0
index.apply_event(make_store_event(0, &[1, 2, 3])).unwrap();
assert_eq!(index.current_size(), 3);
// Remove all blocks
index.apply_event(make_remove_event(0, &[1, 2, 3])).unwrap();
assert_eq!(index.current_size(), 0);
// Find should return nothing
let scores = index.find_matches(
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
);
assert!(scores.scores.is_empty());
}
#[apply(kv_index_template)]
fn test_multiple_workers_shared_prefix(variant: &str) {
let mut index = make_kv_index(variant);
// Worker 0 has [1, 2], Worker 1 has [1, 3]
// Since sequence hashes are cumulative, [1] has same hash for both,
// but [1, 2] and [1, 3] have different hashes.
index.apply_event(make_store_event(0, &[1, 2])).unwrap();
index.apply_event(make_store_event(1, &[1, 3])).unwrap();
// Query [1] - both workers should match
let scores = index.find_matches(vec![LocalBlockHash(1)], false);
assert_eq!(scores.scores.len(), 2);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 1);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(), 1);
// Query [1, 2] - worker 0 matches both, worker 1 matches only first block
let scores = index.find_matches(vec![LocalBlockHash(1), LocalBlockHash(2)], false);
assert_eq!(scores.scores.len(), 2);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(), 1);
}
#[apply(kv_index_template)]
fn test_remove_worker(variant: &str) {
let mut index = make_kv_index(variant);
index.apply_event(make_store_event(0, &[1, 2, 3])).unwrap();
index.apply_event(make_store_event(1, &[1, 2, 3])).unwrap();
assert_eq!(index.current_size(), 6);
index.remove_worker(0);
assert_eq!(index.current_size(), 3);
let scores = index.find_matches(
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
);
assert_eq!(scores.scores.len(), 1);
assert!(scores.scores.contains_key(&WorkerWithDpRank::new(1, 0)));
}
#[apply(kv_index_template)]
fn test_get_workers(variant: &str) {
let mut index = make_kv_index(variant);
index.apply_event(make_store_event(0, &[1])).unwrap();
index.apply_event(make_store_event(2, &[1])).unwrap();
index.apply_event(make_store_event(1, &[1])).unwrap();
let workers = index.get_workers();
assert_eq!(workers, vec![0, 1, 2]);
}
#[apply(kv_index_template)]
fn test_early_exit(variant: &str) {
let mut index = make_kv_index(variant);
// Worker 0 has [0, 1, 2], Worker 1 has [0] only
index.apply_event(make_store_event(0, &[0, 1, 2])).unwrap();
index.apply_event(make_store_event(1, &[0])).unwrap();
// Query [0, 1, 2] with early_exit=true
// Should stop after [0, 1] since only worker 0 has block 1
let scores = index.find_matches(
vec![LocalBlockHash(0), LocalBlockHash(1), LocalBlockHash(2)],
true,
);
// Both workers should appear in results
assert_eq!(scores.scores.len(), 2);
// Worker 0 got 2 points (blocks 0 and 1, stopped early)
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
// Worker 1 got 1 point (block 0 only)
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(), 1);
// Without early_exit, worker 0 should get all 3 blocks
let scores = index.find_matches(
vec![LocalBlockHash(0), LocalBlockHash(1), LocalBlockHash(2)],
false,
);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
}
#[apply(kv_index_template)]
fn test_large_stores(variant: &str) {
let mut index = make_kv_index(variant);
// Test sequences of increasing sizes
for i in 0..10 {
let len = 1 << i; // 1, 2, 4, 8, ..., 512
let worker_id = i;
let sequence: Vec<u64> = (1..=len).map(|x| x + (i as u64 * 10000)).collect();
index
.apply_event(make_store_event(worker_id, &sequence))
.unwrap();
assert!(index.current_size() > 0);
}
}
#[apply(kv_index_template)]
fn test_dump_and_restore(variant: &str) {
let mut index = make_kv_index(variant);
// Store some data
index.apply_event(make_store_event(0, &[1, 2, 3])).unwrap();
index.apply_event(make_store_event(1, &[1, 2, 4])).unwrap();
let original_size = index.current_size();
let workers_before = index.get_workers();
// Dump the tree as events
let events = index.dump_tree_as_events();
assert!(!events.is_empty());
// Create a new index and replay events
let mut restored = make_kv_index(variant);
for event in events {
let _ = restored.apply_event(event);
}
// Verify the restored index has same size and workers
assert_eq!(restored.current_size(), original_size);
assert_eq!(restored.get_workers(), workers_before);
// Verify find_matches produces same results
let original_scores = index.find_matches(vec![LocalBlockHash(1), LocalBlockHash(2)], false);
let restored_scores =
restored.find_matches(vec![LocalBlockHash(1), LocalBlockHash(2)], false);
assert_eq!(original_scores.scores, restored_scores.scores);
}
#[apply(kv_index_template)]
fn test_clear_all_blocks(variant: &str) {
let mut index = make_kv_index(variant);
// Store some data for two workers
index.apply_event(make_store_event(0, &[1, 2, 3])).unwrap();
index.apply_event(make_store_event(1, &[1, 2, 3])).unwrap();
assert_eq!(index.current_size(), 6);
// Clear worker 0's blocks
index.clear_all_blocks(0);
// Worker 0's blocks should be gone, worker 1's remain
assert_eq!(index.current_size(), 3);
let scores = index.find_matches(
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
);
assert_eq!(scores.scores.len(), 1);
assert!(scores.scores.contains_key(&WorkerWithDpRank::new(1, 0)));
}
#[apply(kv_index_template)]
fn test_empty_query(variant: &str) {
let mut index = make_kv_index(variant);
index.apply_event(make_store_event(0, &[1, 2, 3])).unwrap();
// Empty query should return empty scores
let scores = index.find_matches(vec![], false);
assert!(scores.scores.is_empty());
}
#[apply(kv_index_template)]
fn test_miss_query(variant: &str) {
let mut index = make_kv_index(variant);
index.apply_event(make_store_event(0, &[1, 2, 3])).unwrap();
// Query for non-existent blocks
let scores = index.find_matches(vec![LocalBlockHash(999), LocalBlockHash(998)], false);
assert!(scores.scores.is_empty());
}
}
...@@ -7,9 +7,16 @@ ...@@ -7,9 +7,16 @@
//! efficient KV cache lookup and routing in distributed LLM inference systems. //! efficient KV cache lookup and routing in distributed LLM inference systems.
pub mod approx; pub mod approx;
pub mod flat_hashmap;
pub mod indexer; pub mod indexer;
pub mod protocols; pub mod protocols;
pub mod radix_tree;
// Re-export key types for convenience // Re-export key types for convenience
pub use indexer::{MaybeError, RadixTree, RouterEvent}; pub use flat_hashmap::FlatHashMap;
pub use protocols::{LocalBlockHash, WorkerId, compute_block_hash_for_seq}; pub use indexer::MaybeError;
pub use protocols::{
KvCacheEventError, LocalBlockHash, OverlapScores, RouterEvent, WorkerId,
compute_block_hash_for_seq,
};
pub use radix_tree::RadixTree;
...@@ -453,6 +453,105 @@ impl<'de> Deserialize<'de> for ExternalSequenceBlockHash { ...@@ -453,6 +453,105 @@ impl<'de> Deserialize<'de> for ExternalSequenceBlockHash {
} }
} }
// ------
// Router Event Types
// ------
/// Errors that can occur during KV Cache Event processing.
#[derive(Debug, thiserror::Error)]
pub enum KvCacheEventError {
#[error("Failed to find parent block")]
ParentBlockNotFound,
#[error("Failed to find block")]
BlockNotFound,
#[error("Invalid block sequence")]
InvalidBlockSequence,
}
/// A [`KvCacheEvent`] on a specific LLM worker denoted by [`WorkerId`].
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RouterEvent {
/// The ID of the worker emitting the event.
pub worker_id: WorkerId,
/// The cache event associated with the worker.
pub event: KvCacheEvent,
}
impl RouterEvent {
/// Create a new `RouterEvent`.
///
/// ### Arguments
///
/// * `worker_id` - The ID of the worker emitting the event.
/// * `event` - The cache event.
///
/// ### Returns
///
/// A new `RouterEvent`.
pub fn new(worker_id: WorkerId, event: KvCacheEvent) -> Self {
Self { worker_id, event }
}
}
/// Scores representing the overlap of workers (with their dp_rank).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OverlapScores {
/// Map of worker (with dp_rank) to score.
pub scores: std::collections::HashMap<WorkerWithDpRank, u32>,
/// List of frequencies that the blocks have been accessed. Entries with value 0 are omitted.
pub frequencies: Vec<usize>,
/// Map of worker to their tree size (number of blocks in the tree for that worker).
pub tree_sizes: std::collections::HashMap<WorkerWithDpRank, usize>,
}
impl Default for OverlapScores {
fn default() -> Self {
Self::new()
}
}
impl OverlapScores {
/// Create a new `OverlapScores`.
///
/// ### Returns
///
/// A new `OverlapScores`.
pub fn new() -> Self {
Self {
scores: std::collections::HashMap::new(),
frequencies: Vec::with_capacity(32),
tree_sizes: std::collections::HashMap::new(),
}
}
/// Update the scores with a set of workers.
///
/// ### Arguments
///
/// * `workers` - An iterator over `WorkerWithDpRank` references.
pub fn update_scores<'a, I>(&mut self, workers: I)
where
I: IntoIterator<Item = &'a WorkerWithDpRank>,
{
for worker in workers {
let score = self.scores.entry(*worker).or_insert(0);
*score += 1;
}
}
/// Add an entry in the frequency list.
pub fn add_frequency(&mut self, frequency: usize) {
if frequency != 0 {
self.frequencies
.last()
.inspect(|elem| debug_assert!(**elem >= frequency));
self.frequencies.push(frequency);
}
}
}
// ------ // ------
// TokensWithHashes // TokensWithHashes
// ------ // ------
...@@ -556,8 +655,67 @@ impl TokensWithHashes { ...@@ -556,8 +655,67 @@ impl TokensWithHashes {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use rstest::rstest;
use serde_json; use serde_json;
#[test]
fn test_router_event_new() {
let worker_id = 0;
let kv_cache_event = KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(0),
mm_extra_info: None,
tokens_hash: LocalBlockHash(13226331709069118873),
}],
}),
dp_rank: 0,
};
let router_event = RouterEvent::new(worker_id, kv_cache_event);
assert_eq!(router_event.worker_id, worker_id);
assert_eq!(router_event.event.event_id, 1);
if let KvCacheEventData::Stored(store_op) = &router_event.event.data {
assert_eq!(store_op.blocks.len(), 1);
assert_eq!(
store_op.blocks[0].tokens_hash,
compute_block_hash(b"test data")
);
assert_eq!(store_op.blocks[0].block_hash, ExternalSequenceBlockHash(0));
} else {
panic!("Expected KvCacheEventData::Stored");
}
}
#[test]
fn test_overlap_scores_default() {
let overlap_scores: OverlapScores = Default::default();
assert!(overlap_scores.scores.is_empty());
}
#[rstest]
#[case(11)]
#[case(32)]
#[case(64)]
fn test_compute_block_hash_for_seq(#[case] kv_block_size: u32) {
// create a sequence of kv_block_size elements
let sequence = (0..kv_block_size).collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None);
assert_eq!(hashes.len(), 1);
// create a sequence of kv_block_size + 1 elements
let sequence = (0..(kv_block_size + 1)).collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None);
assert_eq!(hashes.len(), 1);
// create a sequence of 2 * kv_block_size + 1 elements
let sequence = (0..(2 * kv_block_size + 1)).collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None);
assert_eq!(hashes.len(), 2);
}
#[test] #[test]
fn test_local_block_hash_serialization() { fn test_local_block_hash_serialization() {
let hash = LocalBlockHash(12345); let hash = LocalBlockHash(12345);
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Radix Tree implementation for KV cache routing.
//!
//! This module provides a radix tree (prefix tree) data structure optimized for
//! efficient KV cache block lookup and management in distributed LLM inference.
//!
//! # Overview
//!
//! The main components include:
//!
//! - **RadixTree**: The main data structure with nodes (`RadixBlock`) containing
//! children and associated worker IDs. Allows efficient storage and retrieval
//! of data blocks based on their hashes.
use std::{
cell::RefCell,
collections::{HashMap, HashSet, VecDeque},
rc::Rc,
time::{Duration, Instant},
};
use crate::protocols::*;
/// A shared reference to a [`RadixBlock`].
pub(crate) type SharedRadixBlock = Rc<RefCell<RadixBlock>>;
/// A block in the Radix Tree.
#[derive(Debug)]
pub(crate) struct RadixBlock {
/// A map of child blocks, keyed by their local block hash.
pub(crate) children: HashMap<LocalBlockHash, SharedRadixBlock>,
/// The set of workers that have this block cached.
pub(crate) workers: HashSet<WorkerWithDpRank>,
/// The external sequence block hash for this block (None for root).
/// This is the same for all workers under the simplifying assumption.
pub(crate) block_hash: Option<ExternalSequenceBlockHash>,
/// A buffer of times that this block was last traversed
pub(crate) recent_uses: VecDeque<Instant>,
}
impl RadixBlock {
/// Create a new `RadixBlock` (used for root node).
///
/// ### Returns
///
/// A new `RadixBlock` with no block_hash.
pub fn new() -> Self {
Self {
children: HashMap::new(),
workers: HashSet::new(),
block_hash: None,
recent_uses: VecDeque::new(),
}
}
/// Create a new `RadixBlock` with a specific block hash.
///
/// ### Returns
///
/// A new `RadixBlock` with the given block_hash.
pub fn with_hash(block_hash: ExternalSequenceBlockHash) -> Self {
Self {
children: HashMap::new(),
workers: HashSet::new(),
block_hash: Some(block_hash),
recent_uses: VecDeque::new(),
}
}
}
pub struct RadixTree {
/// This is the root of the radix/prefix tree
/// This will only contain root blocks
pub(crate) root: SharedRadixBlock,
/// Per-worker lookup table for O(1) block access.
/// Maps worker -> (block_hash -> block).
pub(crate) lookup:
HashMap<WorkerWithDpRank, HashMap<ExternalSequenceBlockHash, SharedRadixBlock>>,
/// The time buffer the radix tree should check when considering frequence of block accesses
pub(crate) expiration_duration: Option<Duration>,
}
impl Default for RadixTree {
fn default() -> Self {
Self::new()
}
}
// Dropping Radix blocks can cause a cascade of drops that can overflow the stack.
// This custom drop implementation avoids this using an iterative approach.
impl Drop for RadixTree {
fn drop(&mut self) {
let mut stack: Vec<SharedRadixBlock> = Vec::new();
// Break root -> children edge up front
{
let mut root = self.root.borrow_mut();
stack.extend(root.children.drain().map(|(_, v)| v));
}
// Remove all lookup references (they may include blocks not reachable from root)
for (_, worker_blocks) in self.lookup.drain() {
stack.extend(worker_blocks.into_values());
}
// Iteratively free any uniquely-owned blocks without recursion
while let Some(block) = stack.pop() {
match Rc::try_unwrap(block) {
Ok(cell) => {
// We own the cell, so we can take inner and it will drop after this block.
let mut inner: RadixBlock = cell.into_inner();
stack.extend(inner.children.drain().map(|(_, v)| v));
}
Err(rc) => {
// We don't own the cell, just call drop on it.
drop(rc);
}
}
}
}
}
impl RadixTree {
/// Create a new `RadixTree`.
///
/// ### Returns
///
/// A new `RadixTree`.
pub fn new_with_frequency(expiration_duration: Option<Duration>) -> Self {
Self {
root: Rc::new(RefCell::new(RadixBlock::new())),
lookup: HashMap::new(),
expiration_duration,
}
}
pub fn new() -> Self {
Self::new_with_frequency(None)
}
/// Traverse the radix tree to find the best match for a given sequence of [`LocalBlockHash`]es.
///
/// ### Arguments
///
/// * `sequence` - A vector of `LocalBlockHash` representing the sequence to match.
/// * `early_exit` - A boolean indicating whether to exit early if a single match is found.
///
/// ### Returns
///
/// An `OverlapScores` representing the match scores.
pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores {
let mut scores = OverlapScores::new();
let mut current = self.root.clone();
let now = Instant::now();
tracing::trace!(
"RadixTree::find_matches: looking for sequence={:?}",
sequence.iter().map(|h| h.0).collect::<Vec<_>>()
);
for (idx, block_hash) in sequence.iter().enumerate() {
let next_block = {
let current_borrow = current.borrow();
current_borrow.children.get(block_hash).cloned()
};
if let Some(block) = next_block {
scores.update_scores(block.borrow().workers.iter());
if let Some(expiration_duration) = self.expiration_duration {
let mut block_mut = block.borrow_mut();
while let Some(access_time) = block_mut.recent_uses.front() {
if now.duration_since(*access_time) > expiration_duration {
block_mut.recent_uses.pop_front();
} else {
break;
}
}
scores.add_frequency(block_mut.recent_uses.len());
block_mut.recent_uses.push_back(now);
}
if early_exit && block.borrow().workers.len() == 1 {
break;
}
current = block;
} else {
tracing::trace!(
"RadixTree::find_matches: block not found at index {} for hash {}",
idx,
block_hash.0
);
break;
}
}
tracing::trace!("RadixTree::find_matches: final scores={:?}", scores.scores);
// Populate tree sizes for all workers that have scores
for worker in scores.scores.keys() {
let tree_size = self
.lookup
.get(worker)
.expect("worker in scores must exist in lookup table")
.len();
scores.tree_sizes.insert(*worker, tree_size);
}
scores
}
/// Apply a [`RouterEvent`] to the radix tree.
///
/// ### Arguments
///
/// * `event` - The `RouterEvent` to apply.
pub fn apply_event(&mut self, event: RouterEvent) -> Result<(), KvCacheEventError> {
let (worker_id, kv_event) = (event.worker_id, event.event);
let (id, op) = (kv_event.event_id, kv_event.data);
// Construct WorkerWithDpRank from worker_id and dp_rank from the event
let worker = WorkerWithDpRank::new(worker_id, kv_event.dp_rank);
tracing::trace!(id, "RadixTree::apply_event: Store operation: {:?}", op);
let worker_lookup = self.lookup.entry(worker).or_default();
match op {
KvCacheEventData::Stored(op) => {
// find the parent block from this worker's lookup
let mut current = match op.parent_hash {
Some(parent) => match worker_lookup.get(&parent) {
Some(current) => current.clone(),
None => {
tracing::warn!(
worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank,
id,
parent_hash = ?op.parent_hash,
num_blocks = op.blocks.len(),
"Failed to find parent block; skipping store operation"
);
return Err(KvCacheEventError::ParentBlockNotFound);
}
},
None => self.root.clone(),
};
for block_data in op.blocks {
let mut parent_mut = current.borrow_mut();
let child = match parent_mut.children.get(&block_data.tokens_hash) {
Some(block) => {
// Verify our simplifying assumption: block_hash is uniform across workers
if block.borrow().block_hash != Some(block_data.block_hash) {
tracing::warn!(
expected = ?block_data.block_hash,
actual = ?block.borrow().block_hash,
"block_hash mismatch: sequence hashes should be uniform across workers"
);
}
block.clone()
}
None => {
// create new block or reuse existing from worker's lookup
let new_block = worker_lookup
.get(&block_data.block_hash)
.cloned()
.unwrap_or_else(|| {
Rc::new(RefCell::new(RadixBlock::with_hash(
block_data.block_hash,
)))
});
// insert into radix tree
parent_mut
.children
.insert(block_data.tokens_hash, new_block.clone());
new_block
}
};
// Update child and check for self referential blocks
{
// Try to borrow the child mutably - if it fails, it's already borrowed
// which means a self referencing block.
let mut child_mut = match child.try_borrow_mut() {
Ok(b) => b,
Err(_) => {
tracing::warn!(
worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank,
id,
block_hash = ?block_data.block_hash,
"Detected self referencing block in store event; rejecting sequence"
);
return Err(KvCacheEventError::InvalidBlockSequence);
}
};
// add our worker to the block
child_mut.workers.insert(worker);
}
// add the block to the worker's lookup table
worker_lookup.insert(block_data.block_hash, child.clone());
// drop child so we can shift current to this block
drop(parent_mut);
current = child;
}
Ok(())
}
KvCacheEventData::Removed(remove) => {
let mut kv_cache_err: Option<KvCacheEventError> = None;
for block in remove.block_hashes {
// lookup block in worker's table
let entry = match worker_lookup.get(&block) {
Some(entry) => entry.clone(),
None => {
tracing::warn!(
worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank,
id,
block_hash = ?block,
"Failed to find block to remove; skipping remove operation"
);
// Kv cache removed events may be batched; we should try to apply all
// operations in the batch before returning an error. Return the first
// error.
if kv_cache_err.is_none() {
kv_cache_err = Some(KvCacheEventError::BlockNotFound);
}
continue;
}
};
let mut guard = entry.borrow_mut();
guard.workers.remove(&worker);
if guard.workers.is_empty() {
// if no workers are using this block, that is true for all children
guard.children.clear();
}
// remove the block from the worker's lookup table
worker_lookup.remove(&block);
}
kv_cache_err.map_or(Ok(()), Err)
}
KvCacheEventData::Cleared => {
self.clear_all_blocks(worker.worker_id);
Ok(())
}
}
}
/// Helper function to remove or clear blocks for a worker.
/// If `keep_worker` is true, the worker remains in lookup with empty blocks.
/// If `keep_worker` is false, the worker is completely removed from lookup.
fn remove_or_clear_worker_blocks(&mut self, worker_id: WorkerId, keep_worker: bool) {
// Collect all WorkerWithDpRank keys that match this worker_id
let workers: Vec<WorkerWithDpRank> = self
.lookup
.keys()
.filter(|w| w.worker_id == worker_id)
.copied()
.collect();
for worker in workers {
if let Some((worker_key, blocks)) = self.lookup.remove_entry(&worker) {
for (_, block) in blocks {
block.borrow_mut().workers.remove(&worker);
// If no workers are using this block, that is true for all children
if block.borrow().workers.is_empty() {
block.borrow_mut().children.clear();
}
}
if keep_worker {
// Re-insert worker with empty blocks map to keep it tracked
self.lookup.insert(worker_key, HashMap::new());
}
}
}
}
pub fn remove_worker(&mut self, worker_id: WorkerId) {
self.remove_or_clear_worker_blocks(worker_id, false);
}
pub fn clear_all_blocks(&mut self, worker_id: WorkerId) {
self.remove_or_clear_worker_blocks(worker_id, true);
}
/// Get all worker IDs currently tracked in the radix tree.
/// Returns unique worker_ids (ignoring dp_rank differences).
pub fn get_workers(&self) -> Vec<WorkerId> {
let mut worker_ids: Vec<WorkerId> = self.lookup.keys().map(|w| w.worker_id).collect();
worker_ids.sort_unstable();
worker_ids.dedup();
worker_ids
}
/// Dump the radix tree as a series of RouterEvents that can reconstruct the tree.
/// Uses BFS traversal to ensure that the tree reconstruction is unique,
/// though the exact event ordering will be lost.
pub fn dump_tree_as_events(&self) -> Vec<RouterEvent> {
tracing::debug!(
"Dumping radix tree as events (contains information about {:?} workers)",
self.lookup.len()
);
let mut events = Vec::new();
let mut event_id = 0u64;
// Queue entries: (current_block, parent_hash, tokens_hash)
let mut queue = VecDeque::new();
// Process root's children first
let root_borrow = self.root.borrow();
for (tokens_hash, child_block) in &root_borrow.children {
queue.push_back((child_block.clone(), None, *tokens_hash));
}
drop(root_borrow);
while let Some((current_block, parent_hash, tokens_hash)) = queue.pop_front() {
let current_borrow = current_block.borrow();
// Get this block's hash (same for all workers)
let block_hash = current_borrow
.block_hash
.expect("non-root block must have block_hash");
// For each worker that has this block
for worker in &current_borrow.workers {
// Create a store event for this worker
let event = RouterEvent {
worker_id: worker.worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: vec![KvCacheStoredBlockData {
block_hash,
mm_extra_info: None,
tokens_hash,
}],
}),
dp_rank: worker.dp_rank,
},
};
events.push(event);
event_id += 1;
}
// Enqueue children with this block's hash as their parent
for (child_tokens_hash, child_block) in &current_borrow.children {
queue.push_back((child_block.clone(), Some(block_hash), *child_tokens_hash));
}
}
events
}
pub fn current_size(&self) -> usize {
self.lookup.values().map(|m| m.len()).sum()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData,
KvCacheStoreData, KvCacheStoredBlockData, LocalBlockHash, WorkerId,
};
/// Creates blocks with artificial hash mapping (hash * 100) for testing RadixTree internals.
fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> {
hashes
.iter()
.map(|i| KvCacheStoredBlockData {
tokens_hash: LocalBlockHash(*i),
block_hash: ExternalSequenceBlockHash(*i * 100),
mm_extra_info: None,
})
.collect()
}
fn add_blocks(
hashes: Vec<u64>,
parent_hash: Option<ExternalSequenceBlockHash>,
) -> KvCacheEventData {
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: make_blocks(hashes),
})
}
fn create_store_event(
worker_id: WorkerId,
event_id: u64,
hashes: Vec<u64>,
parent: Option<ExternalSequenceBlockHash>,
) -> RouterEvent {
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id,
data: add_blocks(hashes, parent),
dp_rank: 0,
},
}
}
fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec<u64>) -> RouterEvent {
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: hashes
.iter()
.map(|i| ExternalSequenceBlockHash(*i * 100))
.collect(),
}),
dp_rank: 0,
},
}
}
#[test]
fn test_radix_tree() {
let mut trie = RadixTree::new();
let worker_1 = 0;
let worker_2 = 1;
trie.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None))
.unwrap();
let scores = trie.find_matches(
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap(),
&3
);
assert_eq!(trie.lookup.len(), 1);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.len(),
3
);
assert_eq!(trie.root.borrow().workers.len(), 0);
assert_eq!(trie.root.borrow().children.len(), 1);
assert_eq!(
trie.root
.borrow()
.children
.get(&LocalBlockHash(1))
.unwrap()
.borrow()
.workers
.len(),
1
);
assert_eq!(
trie.root
.borrow()
.children
.get(&LocalBlockHash(1))
.unwrap()
.borrow()
.children
.len(),
1
);
trie.apply_event(create_store_event(worker_2, 1, vec![1, 4, 5], None))
.unwrap();
let scores = trie.find_matches(
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap(),
&3
);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap(),
&1
);
assert_eq!(trie.lookup.len(), 2);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.len(),
3
);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap()
.len(),
3
);
assert_eq!(trie.root.borrow().workers.len(), 0);
assert_eq!(trie.root.borrow().children.len(), 1);
assert_eq!(
trie.root
.borrow()
.children
.get(&LocalBlockHash(1))
.unwrap()
.borrow()
.workers
.len(),
2
);
assert_eq!(
trie.root
.borrow()
.children
.get(&LocalBlockHash(1))
.unwrap()
.borrow()
.children
.len(),
2
);
trie.apply_event(create_remove_event(worker_2, 2, vec![5]))
.unwrap();
assert_eq!(trie.lookup.len(), 2);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.len(),
3
);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap()
.len(),
2
);
assert_eq!(trie.root.borrow().workers.len(), 0);
assert_eq!(trie.root.borrow().children.len(), 1);
assert_eq!(
trie.root
.borrow()
.children
.get(&LocalBlockHash(1))
.unwrap()
.borrow()
.workers
.len(),
2
);
assert_eq!(
trie.root
.borrow()
.children
.get(&LocalBlockHash(1))
.unwrap()
.borrow()
.children
.len(),
2
);
trie.apply_event(create_remove_event(worker_2, 3, vec![4]))
.unwrap();
assert_eq!(trie.lookup.len(), 2);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.len(),
3
);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap()
.len(),
1
);
assert_eq!(trie.root.borrow().workers.len(), 0);
assert_eq!(trie.root.borrow().children.len(), 1);
assert_eq!(
trie.root
.borrow()
.children
.get(&LocalBlockHash(1))
.unwrap()
.borrow()
.workers
.len(),
2
);
assert_eq!(
trie.root
.borrow()
.children
.get(&LocalBlockHash(1))
.unwrap()
.borrow()
.children
.len(),
2
);
trie.apply_event(create_store_event(
worker_2,
4,
vec![2, 6, 7],
Some(ExternalSequenceBlockHash(100)),
))
.unwrap();
let scores = trie.find_matches(
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap(),
&3
);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap(),
&2
);
assert_eq!(trie.lookup.len(), 2);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.len(),
3
);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap()
.len(),
4
);
assert_eq!(trie.root.borrow().workers.len(), 0);
assert_eq!(trie.root.borrow().children.len(), 1);
assert_eq!(
trie.root
.borrow()
.children
.get(&LocalBlockHash(1))
.unwrap()
.borrow()
.workers
.len(),
2
);
assert_eq!(
trie.root
.borrow()
.children
.get(&LocalBlockHash(1))
.unwrap()
.borrow()
.children
.len(),
2
);
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.get(&ExternalSequenceBlockHash(200))
.unwrap()
.borrow()
.workers
.len(),
2
);
}
#[test]
fn test_radix_tree_apply_event_errors() {
let mut trie = RadixTree::new();
let worker_0 = 0;
// Parent block not found
let result = trie.apply_event(create_store_event(
worker_0,
0,
vec![1, 2, 3],
Some(ExternalSequenceBlockHash(12345)),
));
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
KvCacheEventError::ParentBlockNotFound
));
// Block not found for remove event.
let result = trie.apply_event(create_remove_event(worker_0, 0, vec![1, 2, 3]));
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
KvCacheEventError::BlockNotFound
));
// Parent appears in blocks: parent=1, blocks=[1, 2, 3]
// This should be rejected as block 1 (hash 100) is the parent - this is
// a self referencing block.
trie.apply_event(create_store_event(worker_0, 4, vec![1], None))
.unwrap();
let result = trie.apply_event(create_store_event(
worker_0,
5,
vec![1, 2, 3],
Some(ExternalSequenceBlockHash(100)),
));
assert!(matches!(
result.unwrap_err(),
KvCacheEventError::InvalidBlockSequence
));
}
#[test]
fn test_clear_all_blocks() {
let mut trie = RadixTree::new();
let worker_0 = 0;
let worker_1 = 1;
assert!(
trie.find_matches(vec![LocalBlockHash(0)], false)
.scores
.is_empty()
);
// Test clearing an empty worker
trie.clear_all_blocks(worker_0);
assert!(
!trie
.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
// Test clearing a worker with shared blocks
trie.apply_event(create_store_event(worker_0, 0, vec![0, 1, 3], None))
.unwrap();
trie.apply_event(create_store_event(worker_1, 0, vec![0, 2, 3], None))
.unwrap();
let result = trie.find_matches(vec![LocalBlockHash(0)], false).scores;
assert!(
result.len() == 2
&& result[&WorkerWithDpRank::from_worker_id(worker_0)] == 1
&& result[&WorkerWithDpRank::from_worker_id(worker_1)] == 1
);
trie.clear_all_blocks(worker_0);
assert!(
trie.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.unwrap()
.is_empty()
);
let result = trie
.find_matches(vec![LocalBlockHash(0), LocalBlockHash(2)], false)
.scores;
assert_eq!(result.len(), 1);
assert_eq!(result[&WorkerWithDpRank::from_worker_id(worker_1)], 2);
let result = trie
.find_matches(
vec![LocalBlockHash(0), LocalBlockHash(1), LocalBlockHash(3)],
false,
)
.scores;
assert_eq!(result.len(), 1);
assert_eq!(result[&WorkerWithDpRank::from_worker_id(worker_1)], 1);
// Test re-adding blocks after clearing worker
trie.apply_event(create_store_event(worker_0, 0, vec![4, 5], None))
.unwrap();
let result = trie
.find_matches(vec![LocalBlockHash(4), LocalBlockHash(5)], false)
.scores;
assert_eq!(result.len(), 1);
assert_eq!(result[&WorkerWithDpRank::from_worker_id(worker_0)], 2);
// Test multiple clears
trie.clear_all_blocks(worker_0);
trie.clear_all_blocks(worker_0);
assert!(
trie.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
// Test clearing all workers
trie.clear_all_blocks(worker_0);
trie.clear_all_blocks(worker_1);
assert!(!trie.lookup.is_empty());
assert!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.unwrap()
.is_empty()
);
assert!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.is_empty()
);
// Test clearing a worker that has been removed
trie.apply_event(create_store_event(worker_0, 0, vec![6], None))
.unwrap();
trie.apply_event(create_store_event(worker_1, 0, vec![6], None))
.unwrap();
trie.remove_worker(worker_0);
trie.clear_all_blocks(worker_0);
assert!(
!trie
.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
let result = trie.find_matches(vec![LocalBlockHash(6)], false).scores;
assert_eq!(result.len(), 1);
assert_eq!(result[&WorkerWithDpRank::from_worker_id(worker_1)], 1);
// Test clearing a worker that doesn't exist
let worker_fake = 2;
assert!(
!trie
.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_fake))
);
trie.clear_all_blocks(worker_fake);
assert!(
!trie
.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_fake))
);
assert!(
trie.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_1))
);
let result = trie.find_matches(vec![LocalBlockHash(6)], false).scores;
assert_eq!(result.len(), 1);
assert_eq!(result[&WorkerWithDpRank::from_worker_id(worker_1)], 1);
}
#[test]
fn test_radix_tree_default() {
let radix_tree: RadixTree = Default::default();
assert!(radix_tree.root.borrow().children.is_empty());
assert!(radix_tree.root.borrow().workers.is_empty());
assert!(radix_tree.lookup.is_empty());
}
#[test]
fn test_remove_worker_verifies_hash_removal() {
let mut trie = RadixTree::new();
let worker_0 = 0;
let worker_1 = 1;
let worker_2 = 2;
// Add blocks for multiple workers
trie.apply_event(create_store_event(worker_0, 0, vec![1, 2, 3], None))
.unwrap();
trie.apply_event(create_store_event(worker_1, 0, vec![1, 2, 3], None))
.unwrap();
trie.apply_event(create_store_event(worker_2, 0, vec![1, 4, 5], None))
.unwrap();
// Verify worker_0 has 3 blocks in lookup
assert_eq!(
trie.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.unwrap()
.len(),
3
);
// Verify that blocks have the correct workers
let block_1 = trie
.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.unwrap()
.get(&ExternalSequenceBlockHash(100))
.unwrap();
assert_eq!(block_1.borrow().workers.len(), 3); // worker_0, worker_1, and worker_2 (all have hash 1)
assert!(
block_1
.borrow()
.workers
.contains(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert!(
block_1
.borrow()
.workers
.contains(&WorkerWithDpRank::from_worker_id(worker_1))
);
assert!(
block_1
.borrow()
.workers
.contains(&WorkerWithDpRank::from_worker_id(worker_2))
);
// Remove worker_0
trie.remove_worker(worker_0);
// Verify worker_0 is completely removed from lookup table
assert!(
!trie
.lookup
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert_eq!(trie.lookup.len(), 2);
// Verify that worker_0's hash is removed from the workers set
let block_1 = trie
.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.get(&ExternalSequenceBlockHash(100))
.unwrap();
assert_eq!(block_1.borrow().workers.len(), 2); // worker_1 and worker_2 remain
assert!(
!block_1
.borrow()
.workers
.contains(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert!(
block_1
.borrow()
.workers
.contains(&WorkerWithDpRank::from_worker_id(worker_1))
);
assert!(
block_1
.borrow()
.workers
.contains(&WorkerWithDpRank::from_worker_id(worker_2))
);
// Verify that blocks with no remaining workers have their children cleared
// This tests the optimization where empty blocks clear their children
let block_2 = trie
.lookup
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.get(&ExternalSequenceBlockHash(200))
.unwrap();
assert_eq!(block_2.borrow().workers.len(), 1); // only worker_1
assert!(
block_2
.borrow()
.workers
.contains(&WorkerWithDpRank::from_worker_id(worker_1))
);
// Verify match results no longer include worker_0
let result = trie
.find_matches(
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
)
.scores;
assert_eq!(result.len(), 2);
assert!(!result.contains_key(&WorkerWithDpRank::from_worker_id(worker_0)));
assert!(result.contains_key(&WorkerWithDpRank::from_worker_id(worker_1)));
assert!(result.contains_key(&WorkerWithDpRank::from_worker_id(worker_2)));
}
}
...@@ -44,11 +44,11 @@ use crate::{ ...@@ -44,11 +44,11 @@ use crate::{
discovery::RuntimeConfigsWithNotify, discovery::RuntimeConfigsWithNotify,
kv_router::{ kv_router::{
approx::PruneConfig, approx::PruneConfig,
indexer::{KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent}, indexer::{KvIndexer, KvIndexerInterface, KvRouterError},
protocols::{ protocols::{
LocalBlockHash, RouterRequest, RouterResponse, TokensWithHashes, WorkerId, LocalBlockHash, OverlapScores, RouterEvent, RouterRequest, RouterResponse,
WorkerSelectionResult, WorkerWithDpRank, compute_block_hash_for_seq, TokensWithHashes, WorkerId, WorkerSelectionResult, WorkerWithDpRank,
compute_seq_hash_for_block, compute_block_hash_for_seq, compute_seq_hash_for_block,
}, },
scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest}, scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
sequence::SequenceError, sequence::SequenceError,
......
...@@ -42,7 +42,7 @@ fn create_kv_stream_name(component: &Component, subject: &str) -> String { ...@@ -42,7 +42,7 @@ fn create_kv_stream_name(component: &Component, subject: &str) -> String {
use crate::kv_router::{ use crate::kv_router::{
KV_EVENT_SUBJECT, KV_METRICS_SUBJECT, WORKER_KV_INDEXER_BUFFER_SIZE, KV_EVENT_SUBJECT, KV_METRICS_SUBJECT, WORKER_KV_INDEXER_BUFFER_SIZE,
indexer::{KvIndexerMetrics, LocalKvIndexer, RouterEvent}, indexer::{KvIndexerMetrics, LocalKvIndexer},
protocols::*, protocols::*,
worker_query::start_worker_kv_query_endpoint, worker_query::start_worker_kv_query_endpoint,
}; };
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use crate::kv_router::indexer::RouterEvent; use crate::kv_router::protocols::RouterEvent;
use crate::recorder::Recorder; use crate::recorder::Recorder;
// Type alias for backward compatibility // Type alias for backward compatibility
......
...@@ -17,8 +17,7 @@ use super::KV_HIT_RATE_SUBJECT; ...@@ -17,8 +17,7 @@ use super::KV_HIT_RATE_SUBJECT;
use super::KvRouterConfig; use super::KvRouterConfig;
use super::RouterConfigOverride; use super::RouterConfigOverride;
use super::WorkerSelector; use super::WorkerSelector;
use super::indexer::OverlapScores; use super::protocols::{DpRank, OverlapScores, WorkerId, WorkerSelectionResult, WorkerWithDpRank};
use super::protocols::{DpRank, WorkerId, WorkerSelectionResult, WorkerWithDpRank};
use super::sequence::{ActiveSequencesMultiWorker, SequenceError}; use super::sequence::{ActiveSequencesMultiWorker, SequenceError};
use dynamo_tokens::SequenceHash; use dynamo_tokens::SequenceHash;
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
//! Each block is identified by a hash of its contents, allowing for deduplication when multiple //! Each block is identified by a hash of its contents, allowing for deduplication when multiple
//! requests share common prefixes (e.g., system prompts, few-shot examples). //! requests share common prefixes (e.g., system prompts, few-shot examples).
use crate::kv_router::indexer::OverlapScores; use crate::kv_router::protocols::OverlapScores;
use anyhow::Result; use anyhow::Result;
use dashmap::DashMap; use dashmap::DashMap;
use derive_getters::Getters; use derive_getters::Getters;
......
...@@ -19,8 +19,8 @@ use tokio_util::sync::CancellationToken; ...@@ -19,8 +19,8 @@ use tokio_util::sync::CancellationToken;
use crate::kv_router::{ use crate::kv_router::{
KV_EVENT_SUBJECT, RADIX_STATE_BUCKET, RADIX_STATE_FILE, KV_EVENT_SUBJECT, RADIX_STATE_BUCKET, RADIX_STATE_FILE,
indexer::{DumpRequest, GetWorkersRequest, RouterEvent, WorkerKvQueryResponse}, indexer::{DumpRequest, GetWorkersRequest, WorkerKvQueryResponse},
protocols::WorkerId, protocols::{RouterEvent, WorkerId},
router_discovery_query, router_discovery_query,
worker_query::WorkerQueryClient, worker_query::WorkerQueryClient,
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment