Unverified Commit 9e5014da authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

perf: Concurrent router perf improvements (#6536)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
parent fd035b19
......@@ -107,7 +107,8 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
# Native deps for Python/Rust wheels
patchelf \
clang \
libclang-dev && \
libclang-dev \
libfontconfig-dev && \
rm -rf /var/lib/apt/lists/* && \
# Initialize Git LFS for the dynamo user (required for requirements with lfs=true)
git lfs install
......
......@@ -12,12 +12,12 @@ repository.workspace = true
[features]
default = []
metrics = ["dep:dynamo-runtime"]
metrics = []
bench = ["dep:clap", "dep:indicatif", "dep:serde_json", "dynamo-runtime/integration", "dep:plotters"]
[dependencies]
# repo
dynamo-runtime = { workspace = true, optional = true }
dynamo-runtime = { workspace = true }
dynamo-tokens = { workspace = true }
# workspace
......@@ -58,12 +58,6 @@ dynamo-mocker = { workspace = true }
dynamo-tokens = { workspace = true }
minstant = "0.1.7"
[[bench]]
name = "radix_tree_microbench"
harness = false
required-features = ["bench"]
[[bench]]
name = "kv_indexer_bench"
harness = false
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Microbenchmark for radix tree operations with configurable size and depth.
//!
//! Measures latency and throughput of:
//! - store_block: Adding blocks to the tree
//! - remove_block: Removing blocks from the tree
//! - find_matches: Finding prefix matches in the tree
//!
//! Size is defined as total (worker, block) pairs in the tree.
//! Depth is the number of blocks per sequence (depth = (isl + osl) / block_size).
//!
//! Run with: cargo bench --package dynamo-kv-router --bench radix_tree_microbench --features bench -- --help
#[path = "common/mod.rs"]
mod common;
use common::{SequenceData, generate_sequences};
use clap::{Parser, ValueEnum};
use dynamo_bench::common::LatencyStats;
use dynamo_kv_router::{
ConcurrentRadixTree, OverlapScores, PositionalIndexer, RadixTree, RouterEvent, SyncIndexer,
compute_block_hash_for_seq, protocols::LocalBlockHash,
};
use std::time::{Duration, Instant};
/// Unified interface for RadixTree, ConcurrentRadixTree, and PositionalIndexer benchmarking.
///
/// All structures have feature parity for store, remove, find_matches, and current_size.
/// The key difference is find_matches input:
/// - RadixTree/ConcurrentRadixTree: uses LocalBlockHash (tokens_hash)
/// - PositionalIndexer: uses LocalBlockHash (same as tree; internal mapping uses sequence hashes)
enum KvIndex {
Tree(RadixTree),
Concurrent(ConcurrentRadixTree),
Nested(PositionalIndexer),
}
impl KvIndex {
fn name(&self) -> &'static str {
match self {
KvIndex::Tree(_) => "RadixTree",
KvIndex::Concurrent(_) => "ConcurrentRadixTree",
KvIndex::Nested(_) => "PositionalIndexer",
}
}
fn apply_event(&mut self, event: RouterEvent) {
match self {
KvIndex::Tree(tree) => {
let _ = tree.apply_event(event);
}
KvIndex::Concurrent(tree) => {
let _ = tree.apply_event(event);
}
KvIndex::Nested(map) => {
let _ = map.apply_event(event).ok();
}
}
}
fn find_matches_timed(&self, seq: &SequenceData, early_exit: bool) -> Duration {
let local_hashes = seq.local_hashes.clone();
let start = Instant::now();
let _ = match self {
KvIndex::Tree(tree) => tree.find_matches(local_hashes, early_exit),
KvIndex::Concurrent(tree) => tree.find_matches_impl(&local_hashes, early_exit),
KvIndex::Nested(map) => map.find_matches(&local_hashes, early_exit),
};
start.elapsed()
}
fn find_matches_miss_timed(&self, depth: usize, i: usize, early_exit: bool) -> Duration {
let miss_hashes: Vec<LocalBlockHash> = (0..depth)
.map(|j| LocalBlockHash(0xBAD_C0DE_0000_0000 | ((i as u64) << 16) | (j as u64)))
.collect();
let start = Instant::now();
let _ = match self {
KvIndex::Tree(tree) => tree.find_matches(miss_hashes, early_exit),
KvIndex::Concurrent(tree) => tree.find_matches_impl(&miss_hashes, early_exit),
KvIndex::Nested(map) => map.find_matches(&miss_hashes, early_exit),
};
start.elapsed()
}
fn find_matches_partial_timed(
&self,
seq: &SequenceData,
half: usize,
i: usize,
early_exit: bool,
) -> Duration {
let mut partial = seq.local_hashes[..half].to_vec();
partial.extend(
(0..half).map(|j| LocalBlockHash(0xDEAD_0000 | ((i as u64) << 16) | (j as u64))),
);
let start = Instant::now();
let _ = match self {
KvIndex::Tree(tree) => tree.find_matches(partial, early_exit),
KvIndex::Concurrent(tree) => tree.find_matches_impl(&partial, early_exit),
KvIndex::Nested(map) => map.find_matches(&partial, early_exit),
};
start.elapsed()
}
fn current_size(&self) -> usize {
match self {
KvIndex::Tree(tree) => tree.current_size(),
KvIndex::Concurrent(tree) => tree.current_size(),
KvIndex::Nested(map) => map.current_size(),
}
}
fn find_matches(&self, local_hashes: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores {
match self {
KvIndex::Tree(tree) => tree.find_matches(local_hashes, early_exit),
KvIndex::Concurrent(tree) => tree.find_matches_impl(&local_hashes, early_exit),
KvIndex::Nested(map) => map.find_matches(&local_hashes, early_exit),
}
}
fn dump_tree_as_events(&self) -> Vec<RouterEvent> {
match self {
KvIndex::Tree(tree) => tree.dump_tree_as_events(),
KvIndex::Concurrent(tree) => tree.dump_tree_as_events(),
KvIndex::Nested(_) => {
// NestedMap does not support dump_tree_as_events
vec![]
}
}
}
}
/// Sweep benchmark mode
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
enum SweepMode {
/// Vary sequence/query length (query has exactly `depth` blocks, all matching)
Depth,
/// Vary match length (query has `max_depth` blocks, first `depth` match, rest garbage)
MatchLength,
/// Vary number of prefix prompt groups (width of shared prefixes)
Width,
}
#[derive(Parser, Debug)]
#[command(name = "radix_tree_microbench")]
#[command(about = "Microbenchmark for radix tree operations")]
struct Args {
/// Ignored: passed by cargo bench harness
#[arg(long, hide = true)]
bench: bool,
/// Target tree size in total (worker, block) pairs
#[arg(long, default_value = "10000")]
size: usize,
/// Sequence depth in blocks (depth = (isl + osl) / block_size, where block_size = 16)
#[arg(long, default_value = "32")]
depth: usize,
/// Number of workers to distribute blocks across
#[arg(long, default_value = "4")]
num_workers: usize,
/// Number of iterations per operation for timing
#[arg(long, default_value = "1000")]
iterations: usize,
/// Warmup ratio (0.0 to 1.0) - fraction of iterations to discard for warmup
#[arg(long, default_value = "0.1")]
warmup_ratio: f64,
/// Prefix prompt ratio (0.0 to 1.0) - portion of sequence from the beginning that is a shared prefix
#[arg(long, default_value = "0.25")]
prefix_prompt_ratio: f64,
/// Number of unique prefix prompt groups to randomly sample from
#[arg(long, default_value = "4")]
num_prefix_prompts: usize,
/// Run only specific benchmark (hash, store, remove, find_matches, dump, sweep, or all)
#[arg(long, default_value = "all")]
benchmark_type: String,
/// KV block size in tokens (for hash computation)
#[arg(long, default_value = "16")]
block_size: u32,
/// Verbose output with per-iteration timings
#[arg(short, long)]
verbose: bool,
/// Minimum depth for sweep mode
#[arg(long, default_value = "1")]
min_depth: usize,
/// Maximum depth for sweep mode
#[arg(long, default_value = "8000")]
max_depth: usize,
/// Number of depth points to sample in sweep mode (logarithmically spaced)
#[arg(long, default_value = "20")]
sweep_points: usize,
/// Iterations per depth point in sweep mode
#[arg(long, default_value = "100")]
sweep_iterations: usize,
/// Output format for sweep mode: "table" or "csv"
#[arg(long, default_value = "table")]
sweep_format: String,
/// Sweep mode: what to vary during the sweep
#[arg(long, value_enum, default_value = "depth")]
sweep_mode: SweepMode,
/// Minimum width (num_prefix_prompts) for width sweep mode
#[arg(long, default_value = "1")]
min_width: usize,
/// Maximum width (num_prefix_prompts) for width sweep mode
#[arg(long, default_value = "64")]
max_width: usize,
/// Random seed for reproducibility
#[arg(long, default_value = "42")]
seed: u64,
/// Use nested map instead of radix tree (for comparison)
#[arg(long)]
nested_map: bool,
/// Use concurrent radix tree instead of single-threaded radix tree
#[arg(long)]
concurrent: bool,
}
/// Build a pre-populated KvIndex (prints timing info)
fn build_index(sequences: &[SequenceData], use_nested_map: bool, use_concurrent: bool) -> KvIndex {
let num_blocks: usize = sequences.iter().map(|s| s.local_hashes.len()).sum();
let name = if use_nested_map {
"NestedMap"
} else if use_concurrent {
"ConcurrentRadixTree"
} else {
"RadixTree"
};
print!(
" Building {} with {} sequences ({} blocks)... ",
name,
sequences.len(),
num_blocks
);
std::io::Write::flush(&mut std::io::stdout()).unwrap();
let start = Instant::now();
let mut index = if use_nested_map {
KvIndex::Nested(PositionalIndexer::new(32))
} else if use_concurrent {
KvIndex::Concurrent(ConcurrentRadixTree::new())
} else {
KvIndex::Tree(RadixTree::new())
};
for (event_id, seq) in sequences.iter().enumerate() {
let event = seq.to_store_event(event_id as u64);
index.apply_event(event);
}
let elapsed = start.elapsed();
println!(
"done in {:.2?} ({:.2} sequences/sec, {:.2} blocks/sec)",
elapsed,
sequences.len() as f64 / elapsed.as_secs_f64(),
num_blocks as f64 / elapsed.as_secs_f64()
);
index
}
/// Benchmark compute_block_hash_for_seq operation
fn bench_hash(args: &Args) {
println!("\n=== Benchmarking COMPUTE_BLOCK_HASH (per-request hot path) ===");
let num_tokens = args.depth * args.block_size as usize;
println!(
" Token sequence length: {} tokens ({} blocks)",
num_tokens, args.depth
);
// Generate token sequences to hash
let token_sequences: Vec<Vec<u32>> = (0..args.iterations)
.map(|i| {
(0..num_tokens)
.map(|j| ((i * num_tokens + j) % 50000) as u32)
.collect()
})
.collect();
let warmup_iters = (args.iterations as f64 * args.warmup_ratio) as usize;
let measured_iters = args.iterations - warmup_iters;
let mut durations = Vec::with_capacity(measured_iters);
for (i, tokens) in token_sequences.iter().enumerate() {
let start = Instant::now();
let _ = compute_block_hash_for_seq(tokens, args.block_size, None, None);
let elapsed = start.elapsed();
if i >= warmup_iters {
durations.push(elapsed);
}
if args.verbose && (i + 1) % 100 == 0 {
println!(" Completed {}/{} iterations", i + 1, args.iterations);
}
}
let stats = LatencyStats::from_durations(&durations).unwrap();
stats.print("COMPUTE_BLOCK_HASH", args.depth);
}
/// Benchmark store or remove operation on a steady-state index.
///
/// Uses a remove/store cycle to maintain size. If `time_store` is true,
/// the store operation is timed; otherwise the remove operation is timed.
fn bench_store_remove_cycle(args: &Args, time_store: bool) {
let op_name = if time_store {
"STORE_BLOCK"
} else {
"REMOVE_BLOCK"
};
let num_sequences = args.size / args.depth;
let sequences = generate_sequences(
num_sequences,
args.depth,
args.num_workers,
args.prefix_prompt_ratio,
args.num_prefix_prompts,
args.seed,
true,
);
let mut index = build_index(&sequences, args.nested_map, args.concurrent);
println!("\n=== Benchmarking {} ({}) ===", op_name, index.name());
println!(" Size: {} blocks", index.current_size());
let warmup_iters = (args.iterations as f64 * args.warmup_ratio) as usize;
let measured_iters = args.iterations - warmup_iters;
let mut durations = Vec::with_capacity(measured_iters);
for i in 0..args.iterations {
let seq = &sequences[i % sequences.len()];
let remove_event = seq.to_remove_event(i as u64);
let store_event = seq.to_store_event(i as u64 + args.iterations as u64);
let elapsed = if time_store {
index.apply_event(remove_event);
let start = Instant::now();
index.apply_event(store_event);
start.elapsed()
} else {
let start = Instant::now();
index.apply_event(remove_event);
let elapsed = start.elapsed();
index.apply_event(store_event);
elapsed
};
if i >= warmup_iters {
durations.push(elapsed);
}
if args.verbose && (i + 1) % 100 == 0 {
println!(" Completed {}/{} iterations", i + 1, args.iterations);
}
}
let stats = LatencyStats::from_durations(&durations).unwrap();
stats.print(op_name, args.depth);
}
/// Benchmark store_block operation
fn bench_store(args: &Args) {
bench_store_remove_cycle(args, true);
}
/// Benchmark remove_block operation
fn bench_remove(args: &Args) {
bench_store_remove_cycle(args, false);
}
/// Benchmark find_matches operation
fn bench_find_matches(args: &Args) {
let num_sequences = args.size / args.depth;
let sequences = generate_sequences(
num_sequences,
args.depth,
args.num_workers,
args.prefix_prompt_ratio,
args.num_prefix_prompts,
args.seed,
true,
);
let index = build_index(&sequences, args.nested_map, args.concurrent);
println!("\n=== Benchmarking FIND_MATCHES ({}) ===", index.name());
println!(
" Built with {} sequences, {} total blocks",
sequences.len(),
index.current_size()
);
let warmup_iters = (args.iterations as f64 * args.warmup_ratio) as usize;
let measured_iters = args.iterations - warmup_iters;
let half = args.depth / 2;
// HIT case
println!("\n --- HIT case (existing sequences) ---");
let mut hit_durations = Vec::with_capacity(measured_iters);
for i in 0..args.iterations {
let seq = &sequences[i % sequences.len()];
let elapsed = index.find_matches_timed(seq, false);
if i >= warmup_iters {
hit_durations.push(elapsed);
}
if args.verbose && (i + 1) % 100 == 0 {
println!(" Completed {}/{} iterations", i + 1, args.iterations);
}
}
LatencyStats::from_durations(&hit_durations)
.unwrap()
.print("FIND_MATCHES (HIT)", args.depth);
// MISS case
println!("\n --- MISS case (non-existing sequences) ---");
let mut miss_durations = Vec::with_capacity(measured_iters);
for i in 0..args.iterations {
let elapsed = index.find_matches_miss_timed(args.depth, i, false);
if i >= warmup_iters {
miss_durations.push(elapsed);
}
if args.verbose && (i + 1) % 100 == 0 {
println!(" Completed {}/{} iterations", i + 1, args.iterations);
}
}
LatencyStats::from_durations(&miss_durations)
.unwrap()
.print("FIND_MATCHES (MISS)", args.depth);
// PARTIAL case
println!("\n --- PARTIAL case (prefix match only) ---");
let mut partial_durations = Vec::with_capacity(measured_iters);
for i in 0..args.iterations {
let seq = &sequences[i % sequences.len()];
let elapsed = index.find_matches_partial_timed(seq, half, i, false);
if i >= warmup_iters {
partial_durations.push(elapsed);
}
if args.verbose && (i + 1) % 100 == 0 {
println!(" Completed {}/{} iterations", i + 1, args.iterations);
}
}
LatencyStats::from_durations(&partial_durations)
.unwrap()
.print("FIND_MATCHES (PARTIAL)", args.depth);
// EARLY_EXIT case
println!("\n --- EARLY_EXIT case ---");
let mut early_exit_durations = Vec::with_capacity(measured_iters);
for i in 0..args.iterations {
let seq = &sequences[i % sequences.len()];
let elapsed = index.find_matches_timed(seq, true);
if i >= warmup_iters {
early_exit_durations.push(elapsed);
}
}
LatencyStats::from_durations(&early_exit_durations)
.unwrap()
.print("FIND_MATCHES (EARLY_EXIT)", args.depth);
}
/// Generate logarithmically spaced values between min and max
fn generate_log_spaced_points(min_val: usize, max_val: usize, num_points: usize) -> Vec<usize> {
if num_points <= 1 {
return vec![max_val];
}
let log_min = (min_val as f64).ln();
let log_max = (max_val as f64).ln();
let step = (log_max - log_min) / (num_points - 1) as f64;
let mut points: Vec<usize> = (0..num_points)
.map(|i| (log_min + step * i as f64).exp().round() as usize)
.map(|v| v.max(1)) // Ensure minimum value of 1
.collect();
// Deduplicate (logarithmic spacing can produce duplicates at low values)
points.dedup();
points
}
/// Latency statistics (avg, p50, p99) in nanoseconds
#[derive(Debug)]
struct DurationStats {
avg_ns: u64,
p50_ns: u64,
p99_ns: u64,
}
impl DurationStats {
/// Compute stats from durations. Sorts the input vector in place.
fn from_durations(durations: &mut [Duration]) -> Self {
durations.sort();
let n = durations.len();
let avg = durations.iter().sum::<Duration>() / n as u32;
Self {
avg_ns: avg.as_nanos() as u64,
p50_ns: durations[n / 2].as_nanos() as u64,
p99_ns: durations[n * 99 / 100].as_nanos() as u64,
}
}
}
/// Results for a single sweep point (depth or width)
#[derive(Debug)]
struct SweepResult {
point: usize,
point_label: &'static str,
store: DurationStats,
remove: DurationStats,
find_matches: DurationStats,
}
impl SweepResult {
fn csv_header(&self) -> String {
format!(
"{},store_avg_ns,store_p50_ns,store_p99_ns,remove_avg_ns,remove_p50_ns,remove_p99_ns,find_matches_avg_ns,find_matches_p50_ns,find_matches_p99_ns",
self.point_label
)
}
fn csv_row(&self) -> String {
format!(
"{},{},{},{},{},{},{},{},{},{}",
self.point,
self.store.avg_ns,
self.store.p50_ns,
self.store.p99_ns,
self.remove.avg_ns,
self.remove.p50_ns,
self.remove.p99_ns,
self.find_matches.avg_ns,
self.find_matches.p50_ns,
self.find_matches.p99_ns
)
}
fn table_header(&self) -> String {
format!(
"{:>8} | store_avg store_p50 store_p99 | remove_avg remove_p50 remove_p99 | fm_avg fm_p50 fm_p99",
self.point_label
)
}
fn table_row(&self) -> String {
format!(
"{:>8} | {:>12} {:>12} {:>12} | {:>12} {:>12} {:>12} | {:>12} {:>12} {:>12}",
self.point,
format_duration_ns(self.store.avg_ns),
format_duration_ns(self.store.p50_ns),
format_duration_ns(self.store.p99_ns),
format_duration_ns(self.remove.avg_ns),
format_duration_ns(self.remove.p50_ns),
format_duration_ns(self.remove.p99_ns),
format_duration_ns(self.find_matches.avg_ns),
format_duration_ns(self.find_matches.p50_ns),
format_duration_ns(self.find_matches.p99_ns)
)
}
}
fn print_sweep_results_dynamic(results: &[SweepResult], format: &str) {
if results.is_empty() {
return;
}
println!();
if format == "csv" {
println!("{}", results[0].csv_header());
for r in results {
println!("{}", r.csv_row());
}
} else {
println!("{}", results[0].table_header());
println!("{}", "-".repeat(130));
for r in results {
println!("{}", r.table_row());
}
}
}
/// Benchmark store/remove/find_matches across a range of depths or widths.
///
/// For each sweep point, the tree is rebuilt.
///
/// With `--sweep_mode match_length`, find_matches queries have `max_depth` blocks
/// where only the first `depth` blocks match (rest are garbage). With `--sweep_mode depth`,
/// queries have exactly `depth` blocks (all matching). With `--sweep_mode width`,
/// the number of prefix prompt groups is varied.
fn bench_sweep(args: &Args) {
let seq_length = args.max_depth;
let num_sequences = args.size / seq_length;
if num_sequences < 2 {
eprintln!(
"Error: size {} / max_depth {} = {} sequences (need at least 2). \
Increase --size or decrease --max-depth.",
args.size, seq_length, num_sequences
);
std::process::exit(1);
}
let (mode_name, point_label, sweep_points) = match args.sweep_mode {
SweepMode::Depth => (
"Depth",
"depth",
generate_log_spaced_points(args.min_depth, args.max_depth, args.sweep_points),
),
SweepMode::MatchLength => (
"Match Length",
"depth",
generate_log_spaced_points(args.min_depth, args.max_depth, args.sweep_points),
),
SweepMode::Width => (
"Width",
"width",
generate_log_spaced_points(args.min_width, args.max_width, args.sweep_points),
),
};
println!("\n=== {} Sweep Benchmark ===", mode_name);
println!(" Sequence length: {} blocks (fixed)", seq_length);
match args.sweep_mode {
SweepMode::Depth | SweepMode::MatchLength => {
println!(
" Sweep range: {} to {} ({} points, log-spaced)",
args.min_depth, args.max_depth, args.sweep_points
);
}
SweepMode::Width => {
println!(
" Width range: {} to {} ({} points, log-spaced)",
args.min_width, args.max_width, args.sweep_points
);
println!(
" Prefix prompt ratio: {:.1}%",
args.prefix_prompt_ratio * 100.0
);
}
}
println!(" Iterations per point: {}", args.sweep_iterations);
println!(
" Tree: {} sequences, {} total blocks",
num_sequences,
num_sequences * seq_length
);
println!(" Workers: {}", args.num_workers);
match args.sweep_mode {
SweepMode::MatchLength => {
println!(" Mode: find_matches queries padded with garbage to max_depth");
}
SweepMode::Depth => {
println!(" Mode: find_matches queries truncated to depth");
}
SweepMode::Width => {
println!(" Mode: varying num_prefix_prompts, full-depth operations");
}
}
println!();
let mut results: Vec<SweepResult> = Vec::with_capacity(sweep_points.len());
for (idx, &point) in sweep_points.iter().enumerate() {
print!(
"[{}/{}] {}={}... ",
idx + 1,
sweep_points.len(),
point_label,
point
);
std::io::Write::flush(&mut std::io::stdout()).unwrap();
// Determine depth and num_prefix_prompts for this sweep point
let (depth, num_prefix_prompts) = match args.sweep_mode {
SweepMode::Depth | SweepMode::MatchLength => (point, args.num_prefix_prompts),
SweepMode::Width => (seq_length, point),
};
// Generate sequences and rebuild tree for this point
let extra_count = args.sweep_iterations;
let all_sequences = generate_sequences(
num_sequences + extra_count,
seq_length,
args.num_workers,
args.prefix_prompt_ratio,
num_prefix_prompts,
args.seed,
true,
);
let tree_sequences = &all_sequences[..num_sequences];
let extra_sequences = &all_sequences[num_sequences..];
let mut index = build_index(tree_sequences, args.nested_map, args.concurrent);
// --- STORE benchmark ---
let mut store_durations = Vec::with_capacity(args.sweep_iterations);
for (i, seq) in extra_sequences
.iter()
.enumerate()
.take(args.sweep_iterations)
{
let truncated = SequenceData {
worker_id: seq.worker_id,
local_hashes: seq.local_hashes[..depth].to_vec(),
external_hashes: seq.external_hashes[..depth].to_vec(),
};
let store_event = truncated.to_store_event(i as u64);
let start = Instant::now();
index.apply_event(store_event);
store_durations.push(start.elapsed());
// Remove to restore index state (untimed)
let remove_event = truncated.to_remove_event(i as u64);
index.apply_event(remove_event);
}
// --- REMOVE benchmark ---
let mut remove_durations = Vec::with_capacity(args.sweep_iterations);
for i in 0..args.sweep_iterations.min(num_sequences) {
let seq = &tree_sequences[i % tree_sequences.len()];
let truncated = SequenceData {
worker_id: seq.worker_id,
local_hashes: seq.local_hashes[..depth].to_vec(),
external_hashes: seq.external_hashes[..depth].to_vec(),
};
let remove_event = truncated.to_remove_event(i as u64);
let start = Instant::now();
index.apply_event(remove_event);
remove_durations.push(start.elapsed());
// Re-add to restore state (untimed)
let store_event = truncated.to_store_event(i as u64 + 1000000);
index.apply_event(store_event);
}
// --- FIND_MATCHES benchmark ---
let mut find_matches_durations = Vec::with_capacity(args.sweep_iterations);
for i in 0..args.sweep_iterations {
let seq = &tree_sequences[i % tree_sequences.len()];
let query = match args.sweep_mode {
SweepMode::MatchLength => {
// Match length mode: first `depth` blocks match, rest are garbage
let mut q = seq.local_hashes[..depth].to_vec();
let garbage_len = seq_length - depth;
q.extend((0..garbage_len).map(|j| {
LocalBlockHash(0xBAD_C0DE_0000_0000 | ((i as u64) << 16) | (j as u64))
}));
q
}
SweepMode::Depth | SweepMode::Width => {
// Depth/width mode: query has exactly `depth` blocks
seq.local_hashes[..depth].to_vec()
}
};
let start = Instant::now();
let _ = index.find_matches(query, false);
find_matches_durations.push(start.elapsed());
}
// Compute stats
let store = DurationStats::from_durations(&mut store_durations);
let remove = DurationStats::from_durations(&mut remove_durations);
let find_matches = DurationStats::from_durations(&mut find_matches_durations);
println!(
"store={:.2}us, remove={:.2}us, find_matches={:.2}us",
store.avg_ns as f64 / 1000.0,
remove.avg_ns as f64 / 1000.0,
find_matches.avg_ns as f64 / 1000.0
);
results.push(SweepResult {
point,
point_label,
store,
remove,
find_matches,
});
}
print_sweep_results_dynamic(&results, &args.sweep_format);
}
/// Benchmark dump_tree_as_events (BFS dump)
fn bench_dump(args: &Args) {
println!("\n=== Benchmarking DUMP_TREE_AS_EVENTS (BFS dump) ===");
let num_sequences = args.size / args.depth;
let sequences = generate_sequences(
num_sequences,
args.depth,
args.num_workers,
args.prefix_prompt_ratio,
args.num_prefix_prompts,
args.seed,
true,
);
let index = build_index(&sequences, args.nested_map, args.concurrent);
println!(
" {} built with {} sequences, {} total blocks",
index.name(),
sequences.len(),
index.current_size()
);
// Single iteration timing
let start = Instant::now();
let events = index.dump_tree_as_events();
let elapsed = start.elapsed();
println!("\nDUMP_TREE_AS_EVENTS Results:");
println!(" Time: {:?}", elapsed);
println!(" Events: {}", events.len());
println!(
" Throughput: {:.2} events/sec",
events.len() as f64 / elapsed.as_secs_f64()
);
}
/// Format nanoseconds as human-readable string
fn format_duration_ns(ns: u64) -> String {
if ns >= 1_000_000_000 {
format!("{:.2}s", ns as f64 / 1_000_000_000.0)
} else if ns >= 1_000_000 {
format!("{:.2}ms", ns as f64 / 1_000_000.0)
} else if ns >= 1_000 {
format!("{:.2}us", ns as f64 / 1_000.0)
} else {
format!("{}ns", ns)
}
}
fn main() {
let args = Args::parse();
// Validate arguments to prevent panics
if args.size == 0
|| args.depth == 0
|| args.num_workers == 0
|| args.iterations == 0
|| args.block_size == 0
|| args.min_depth == 0
|| args.max_depth == 0
|| args.min_width == 0
|| args.max_width == 0
|| args.sweep_iterations == 0
{
eprintln!(
"size, depth, num_workers, iterations, block_size, min_depth, max_depth, min_width, max_width, and sweep_iterations must be > 0"
);
std::process::exit(1);
}
if args.min_depth > args.max_depth {
eprintln!("min_depth must be <= max_depth");
std::process::exit(1);
}
if args.min_width > args.max_width {
eprintln!("min_width must be <= max_width");
std::process::exit(1);
}
if !(0.0..=1.0).contains(&args.prefix_prompt_ratio) {
eprintln!("prefix_prompt_ratio must be between 0.0 and 1.0");
std::process::exit(1);
}
if !(0.0..=1.0).contains(&args.warmup_ratio) {
eprintln!("warmup_ratio must be between 0.0 and 1.0");
std::process::exit(1);
}
let num_sequences = args.size / args.depth;
if matches!(
args.benchmark_type.as_str(),
"store" | "remove" | "lookup" | "sweep" | "all"
) && num_sequences == 0
{
eprintln!(
"size must be >= depth to produce at least one sequence for {}",
args.benchmark_type
);
std::process::exit(1);
}
println!("Radix Tree Microbenchmark");
println!("=========================\n");
println!("Configuration:");
println!(" Target size: {} (worker, block) pairs", args.size);
println!(
" Depth: {} blocks/sequence (= {} tokens with block_size={})",
args.depth,
args.depth * args.block_size as usize,
args.block_size
);
println!(" Block size: {} tokens", args.block_size);
println!(" Workers: {}", args.num_workers);
println!(" Iterations: {}", args.iterations);
println!(
" Warmup: {:.0}% ({} iterations discarded)",
args.warmup_ratio * 100.0,
(args.iterations as f64 * args.warmup_ratio) as usize
);
println!(
" Prefix prompt ratio: {:.1}% ({} blocks at depth {})",
args.prefix_prompt_ratio * 100.0,
(args.depth as f64 * args.prefix_prompt_ratio).round() as usize,
args.depth
);
println!(" Prefix prompt groups: {}", args.num_prefix_prompts);
println!(
"\n Derived: {} sequences to reach target size",
num_sequences
);
match args.benchmark_type.as_str() {
"hash" => bench_hash(&args),
"store" => bench_store(&args),
"remove" => bench_remove(&args),
"find_matches" => bench_find_matches(&args),
"dump" => bench_dump(&args),
"sweep" => bench_sweep(&args),
"all" => {
bench_hash(&args);
bench_store(&args);
bench_remove(&args);
bench_find_matches(&args);
bench_dump(&args);
}
_ => {
eprintln!(
"Unknown benchmark type: {}. Use 'hash', 'store', 'remove', 'find_matches', 'dump', 'sweep', or 'all'",
args.benchmark_type
);
std::process::exit(1);
}
}
println!("\nBenchmark complete.");
}
......@@ -25,12 +25,15 @@
//! per-worker write concurrency.
//! - Deadlock prevention: always lock parent before child, hand-over-hand locking
use std::{collections::VecDeque, sync::Arc};
use std::sync::Arc;
use dashmap::DashMap;
use parking_lot::RwLock;
use rustc_hash::{FxHashMap, FxHashSet};
use rustc_hash::{FxBuildHasher, FxHashMap, FxHashSet};
use std::collections::VecDeque;
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::indexer::SyncIndexer;
use crate::indexer::{SyncIndexer, WorkerTask};
use crate::protocols::*;
/// Thread-safe shared reference to a Block.
......@@ -98,10 +101,7 @@ pub struct ConcurrentRadixTree {
/// This will only contain root blocks.
root: SharedBlock,
/// Per-worker lookup table for O(1) block access.
/// Outer RwLock protects the worker map structure (rarely mutated);
/// inner RwLock per worker protects that worker's block-hash map.
lookup: RwLock<FxHashMap<WorkerWithDpRank, RwLock<WorkerLookup>>>,
tree_sizes: DashMap<WorkerWithDpRank, AtomicUsize, FxBuildHasher>,
}
impl Default for ConcurrentRadixTree {
......@@ -122,14 +122,9 @@ impl Drop for ConcurrentRadixTree {
stack.extend(root.children.drain().map(|(_, v)| v));
}
// Remove all lookup references (they may include blocks not reachable from root).
// We have &mut self so no concurrent access; drain the map.
let lookup = self.lookup.get_mut();
for (_, inner_lock) in lookup.drain() {
stack.extend(inner_lock.into_inner().into_values());
}
// Iteratively free any uniquely-owned blocks without recursion
// Iteratively drop blocks to avoid stack overflow on deep trees.
// Without this loop, dropping `stack` would recursively drop each
// Arc<RwLock<Block>> through its `children` map.
while let Some(block) = stack.pop() {
if let Ok(rwlock) = Arc::try_unwrap(block) {
let mut inner = rwlock.into_inner();
......@@ -144,7 +139,7 @@ impl ConcurrentRadixTree {
pub fn new() -> Self {
Self {
root: Arc::new(RwLock::new(Block::new())),
lookup: RwLock::new(FxHashMap::default()),
tree_sizes: DashMap::with_hasher(FxBuildHasher),
}
}
......@@ -197,10 +192,11 @@ impl ConcurrentRadixTree {
for worker in &active {
scores.scores.insert(*worker, 1);
}
let lk = self.lookup.read();
for worker in scores.scores.keys() {
if let Some(inner_lock) = lk.get(worker) {
scores.tree_sizes.insert(*worker, inner_lock.read().len());
if let Some(worker_tree_size) = self.tree_sizes.get(worker) {
scores
.tree_sizes
.insert(*worker, worker_tree_size.load(Ordering::Relaxed));
}
}
return scores;
......@@ -272,10 +268,11 @@ impl ConcurrentRadixTree {
}
// Get tree sizes from lookup.
let lk = self.lookup.read();
for worker in scores.scores.keys() {
if let Some(inner_lock) = lk.get(worker) {
scores.tree_sizes.insert(*worker, inner_lock.read().len());
if let Some(worker_tree_size) = self.tree_sizes.get(worker) {
scores
.tree_sizes
.insert(*worker, worker_tree_size.load(Ordering::Relaxed));
}
}
......@@ -290,7 +287,11 @@ impl ConcurrentRadixTree {
/// ### Arguments
///
/// * `event` - The `RouterEvent` to apply.
pub fn apply_event(&self, event: RouterEvent) -> Result<(), KvCacheEventError> {
fn apply_event(
&self,
lookup: &mut FxHashMap<WorkerWithDpRank, WorkerLookup>,
event: RouterEvent,
) -> Result<(), KvCacheEventError> {
let (worker_id, kv_event) = (event.worker_id, event.event);
let (id, op) = (kv_event.event_id, kv_event.data);
......@@ -298,10 +299,17 @@ impl ConcurrentRadixTree {
let worker = WorkerWithDpRank::new(worker_id, kv_event.dp_rank);
match op {
KvCacheEventData::Stored(op) => self.apply_stored(worker, op, id),
KvCacheEventData::Removed(op) => self.apply_removed(worker, op, id),
KvCacheEventData::Stored(op) => self.apply_stored(lookup, worker, op, id),
KvCacheEventData::Removed(op) => self.apply_removed(lookup, worker, op, id),
KvCacheEventData::Cleared => {
self.clear_all_blocks(worker.worker_id);
// Ensure the worker is tracked in lookup before clearing,
// matching RadixTree behavior where `lookup.entry(worker).or_default()`
// fires before the match arm.
lookup.entry(worker).or_default();
self.tree_sizes
.entry(worker)
.or_insert_with(|| AtomicUsize::new(0));
self.clear_all_blocks(lookup, worker.worker_id);
Ok(())
}
}
......@@ -310,20 +318,13 @@ impl ConcurrentRadixTree {
/// Apply a store operation.
fn apply_stored(
&self,
lookup: &mut FxHashMap<WorkerWithDpRank, WorkerLookup>,
worker: WorkerWithDpRank,
op: KvCacheStoreData,
id: u64,
) -> Result<(), KvCacheEventError> {
// Ensure this worker has an entry in the outer map.
if !self.lookup.read().contains_key(&worker) {
self.lookup
.write()
.entry(worker)
.or_insert_with(|| RwLock::new(FxHashMap::default()));
}
let lk = self.lookup.read();
let mut worker_lookup = lk.get(&worker).unwrap().write();
let worker_lookup = lookup.entry(worker).or_default();
// Find parent block
let mut current = match op.parent_hash {
......@@ -346,6 +347,8 @@ impl ConcurrentRadixTree {
let mut needs_worker_insert = false;
let num_blocks_added = op.blocks.len();
// In each iteration, we lock the parent block and insert the worker into it from
// the previous iteration. This avoids locking a block twice.
for block_data in op.blocks {
......@@ -399,6 +402,16 @@ impl ConcurrentRadixTree {
current = child;
}
match self.tree_sizes.get(&worker) {
Some(size) => {
size.fetch_add(num_blocks_added, Ordering::Relaxed);
}
None => {
self.tree_sizes
.insert(worker, AtomicUsize::new(num_blocks_added));
}
}
// Insert worker into the last child (not yet handled since there is
// no subsequent iteration to pick it up).
if needs_worker_insert {
......@@ -417,15 +430,16 @@ impl ConcurrentRadixTree {
/// `child_count > active_count`.
fn apply_removed(
&self,
lookup: &mut FxHashMap<WorkerWithDpRank, WorkerLookup>,
worker: WorkerWithDpRank,
op: KvCacheRemoveData,
id: u64,
) -> Result<(), KvCacheEventError> {
let lk = self.lookup.read();
let Some(inner_ref) = lk.get(&worker) else {
let Some(worker_lookup) = lookup.get_mut(&worker) else {
return Err(KvCacheEventError::BlockNotFound);
};
let mut worker_lookup = inner_ref.write();
let mut num_removed = 0;
for block_hash in op.block_hashes {
let Some(block) = worker_lookup.remove(&block_hash) else {
......@@ -445,6 +459,18 @@ impl ConcurrentRadixTree {
if guard.workers.is_empty() {
guard.children.clear();
}
num_removed += 1;
}
match self.tree_sizes.get(&worker) {
Some(size) => {
size.fetch_sub(num_removed, Ordering::Relaxed);
}
None => {
self.tree_sizes
.insert(worker, AtomicUsize::new(num_removed));
}
}
Ok(())
......@@ -453,20 +479,21 @@ impl ConcurrentRadixTree {
/// Helper function to remove or clear blocks for a worker.
/// If `keep_worker` is true, the worker remains in lookup with empty blocks.
/// If `keep_worker` is false, the worker is completely removed from lookup.
fn remove_or_clear_worker_blocks(&self, worker_id: WorkerId, keep_worker: bool) {
let workers: Vec<WorkerWithDpRank> = self
.lookup
.read()
fn remove_or_clear_worker_blocks(
&self,
lookup: &mut FxHashMap<WorkerWithDpRank, WorkerLookup>,
worker_id: WorkerId,
keep_worker: bool,
) {
let workers: Vec<WorkerWithDpRank> = lookup
.keys()
.filter(|w| w.worker_id == worker_id)
.copied()
.collect();
let mut lk = self.lookup.write();
for worker in workers {
if let Some(inner_lock) = lk.remove(&worker) {
let blocks = inner_lock.into_inner();
for (_, block) in blocks {
if let Some(worker_lookup) = lookup.remove(&worker) {
for (_, block) in worker_lookup.into_iter() {
let mut guard = block.write();
guard.workers.remove(&worker);
if guard.workers.is_empty() {
......@@ -475,45 +502,49 @@ impl ConcurrentRadixTree {
}
if keep_worker {
lk.insert(worker, RwLock::new(FxHashMap::default()));
lookup.insert(worker, FxHashMap::default());
// Reset tree size to 0 but keep the entry so get_workers()
// still returns this worker (matches RadixTree::clear_all_blocks behavior).
if let Some(size) = self.tree_sizes.get(&worker) {
size.store(0, Ordering::Relaxed);
}
} else {
// Fully remove the worker from tree_sizes so get_workers()
// no longer returns it (matches RadixTree::remove_worker behavior).
self.tree_sizes.remove(&worker);
}
}
}
/// Remove a worker and all their blocks from the tree.
pub fn remove_worker(&self, worker_id: WorkerId) {
self.remove_or_clear_worker_blocks(worker_id, false);
}
/// Clear all blocks for a worker but keep the worker tracked.
pub fn clear_all_blocks(&self, worker_id: WorkerId) {
self.remove_or_clear_worker_blocks(worker_id, true);
fn clear_all_blocks(
&self,
lookup: &mut FxHashMap<WorkerWithDpRank, WorkerLookup>,
worker_id: WorkerId,
) {
self.remove_or_clear_worker_blocks(lookup, worker_id, true);
}
/// Get all worker IDs currently tracked in the radix tree.
/// Returns unique worker_ids (ignoring dp_rank differences).
pub fn get_workers(&self) -> Vec<WorkerId> {
let mut worker_ids: Vec<WorkerId> = self
.lookup
.read()
.keys()
.map(|w| w.worker_id)
.collect::<FxHashSet<_>>()
.into_iter()
.tree_sizes
.iter()
.map(|entry| entry.key().worker_id)
.collect();
worker_ids.sort_unstable();
worker_ids.dedup();
worker_ids
}
/// Dump the radix tree as a series of RouterEvents that can reconstruct the tree.
/// Uses BFS traversal to ensure that the tree reconstruction is unique,
/// though the exact event ordering will be lost.
pub fn dump_tree_as_events(&self) -> Vec<RouterEvent> {
tracing::debug!(
"Dumping concurrent radix tree as events (contains information about {:?} workers)",
self.lookup.read().len()
);
/// Uses BFS traversal over the shared tree. Since all worker/block membership is
/// stored in the tree nodes themselves, this can be called from any thread without
/// needing per-thread lookup state.
fn dump_tree_as_events(&self) -> Vec<RouterEvent> {
tracing::debug!("Dumping concurrent radix tree as events");
let mut events = Vec::new();
let mut event_id = 0u64;
......@@ -567,15 +598,6 @@ impl ConcurrentRadixTree {
events
}
/// Get total number of blocks across all workers.
pub fn current_size(&self) -> usize {
self.lookup
.read()
.values()
.map(|inner| inner.read().len())
.sum()
}
}
// ============================================================================
......@@ -583,646 +605,39 @@ impl ConcurrentRadixTree {
// ============================================================================
impl SyncIndexer for ConcurrentRadixTree {
fn find_matches(&self, sequence: &[LocalBlockHash], early_exit: bool) -> OverlapScores {
// Delegate to the existing find_matches method
self.find_matches_impl(sequence, early_exit)
}
fn apply_event(&self, event: RouterEvent) -> Result<(), KvCacheEventError> {
self.apply_event(event)
}
fn remove_worker(&self, worker_id: WorkerId) {
self.remove_worker(worker_id);
}
fn dump_events(&self) -> Vec<RouterEvent> {
self.dump_tree_as_events()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::{create_remove_event, create_store_event};
use std::sync::Arc;
use std::thread;
#[test]
fn test_concurrent_radix_tree_basic() {
let trie = ConcurrentRadixTree::new();
let worker_1 = 0;
let worker_2 = 1;
trie.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None))
.unwrap();
let scores = trie.find_matches_impl(
&[LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap(),
&3
);
assert_eq!(trie.lookup.read().len(), 1);
assert_eq!(
trie.lookup
.read()
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.read()
.len(),
3
);
trie.apply_event(create_store_event(worker_2, 1, vec![1, 4, 5], None))
.unwrap();
let scores = trie.find_matches_impl(
&[LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap(),
&3
);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap(),
&1
);
assert_eq!(trie.lookup.read().len(), 2);
}
#[test]
fn test_concurrent_radix_tree_remove() {
let trie = ConcurrentRadixTree::new();
let worker_1 = 0;
let worker_2 = 1;
trie.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None))
.unwrap();
trie.apply_event(create_store_event(worker_2, 1, vec![1, 4, 5], None))
.unwrap();
trie.apply_event(create_remove_event(worker_2, 2, vec![5]))
.unwrap();
assert_eq!(
trie.lookup
.read()
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap()
.read()
.len(),
2
);
trie.apply_event(create_remove_event(worker_2, 3, vec![4]))
.unwrap();
assert_eq!(
trie.lookup
.read()
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap()
.read()
.len(),
1
);
}
#[test]
fn test_concurrent_radix_tree_apply_event_errors() {
let trie = ConcurrentRadixTree::new();
let worker_0 = 0;
// Parent block not found
let result = trie.apply_event(create_store_event(
worker_0,
0,
vec![1, 2, 3],
Some(ExternalSequenceBlockHash(12345)),
));
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
KvCacheEventError::ParentBlockNotFound
));
}
#[test]
fn test_clear_all_blocks() {
let trie = ConcurrentRadixTree::new();
let worker_0 = 0;
let worker_1 = 1;
trie.apply_event(create_store_event(worker_0, 0, vec![0, 1, 3], None))
.unwrap();
trie.apply_event(create_store_event(worker_1, 0, vec![0, 2, 3], None))
.unwrap();
let result = trie.find_matches_impl(&[LocalBlockHash(0)], false).scores;
assert_eq!(result.len(), 2);
trie.clear_all_blocks(worker_0);
assert!(
trie.lookup
.read()
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert!(
trie.lookup
.read()
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.unwrap()
.read()
.is_empty()
);
let result = trie
.find_matches_impl(&[LocalBlockHash(0), LocalBlockHash(2)], false)
.scores;
assert_eq!(result.len(), 1);
assert_eq!(result[&WorkerWithDpRank::from_worker_id(worker_1)], 2);
}
#[test]
fn test_remove_worker() {
let trie = ConcurrentRadixTree::new();
let worker_0 = 0;
let worker_1 = 1;
trie.apply_event(create_store_event(worker_0, 0, vec![1, 2, 3], None))
.unwrap();
trie.apply_event(create_store_event(worker_1, 0, vec![1, 2, 3], None))
.unwrap();
assert_eq!(trie.lookup.read().len(), 2);
trie.remove_worker(worker_0);
assert!(
!trie
.lookup
.read()
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert_eq!(trie.lookup.read().len(), 1);
let result = trie
.find_matches_impl(
&[LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
)
.scores;
assert_eq!(result.len(), 1);
assert!(!result.contains_key(&WorkerWithDpRank::from_worker_id(worker_0)));
assert!(result.contains_key(&WorkerWithDpRank::from_worker_id(worker_1)));
}
#[test]
fn test_concurrent_radix_tree_default() {
let trie: ConcurrentRadixTree = Default::default();
assert!(trie.root.read().children.is_empty());
assert!(trie.root.read().workers.is_empty());
assert!(trie.lookup.read().is_empty());
}
#[test]
fn test_concurrent_find_matches() {
let trie = Arc::new(ConcurrentRadixTree::new());
// Populate tree
trie.apply_event(create_store_event(0, 0, vec![1, 2, 3, 4, 5], None))
.unwrap();
trie.apply_event(create_store_event(1, 0, vec![1, 2, 6, 7, 8], None))
.unwrap();
let sequence = vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
LocalBlockHash(4),
LocalBlockHash(5),
];
// Spawn multiple threads doing concurrent find_matches
let handles: Vec<_> = (0..10)
.map(|_| {
let tree = trie.clone();
let seq = sequence.clone();
thread::spawn(move || tree.find_matches_impl(&seq, false))
})
.collect();
// All should return the same result
let expected_worker_0_score = 5;
let expected_worker_1_score = 2;
for h in handles {
let result = h.join().unwrap();
assert_eq!(
result
.scores
.get(&WorkerWithDpRank::from_worker_id(0))
.unwrap(),
&expected_worker_0_score
);
assert_eq!(
result
.scores
.get(&WorkerWithDpRank::from_worker_id(1))
.unwrap(),
&expected_worker_1_score
);
}
}
#[test]
fn test_concurrent_read_write() {
let trie = Arc::new(ConcurrentRadixTree::new());
// Pre-populate
for i in 0..5 {
trie.apply_event(create_store_event(i, 0, vec![1, 2, 3], None))
.unwrap();
}
fn worker(&self, event_receiver: flume::Receiver<WorkerTask>) -> anyhow::Result<()> {
let mut lookup = FxHashMap::default();
let sequence = vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)];
// Spawn readers
let reader_handles: Vec<_> = (0..5)
.map(|_| {
let tree = trie.clone();
let seq = sequence.clone();
thread::spawn(move || {
for _ in 0..100 {
let _ = tree.find_matches_impl(&seq, false);
}
})
})
.collect();
// Spawn writers (adding more workers)
let writer_handles: Vec<_> = (5..10)
.map(|i| {
let tree = trie.clone();
thread::spawn(move || {
for j in 0..10 {
let _ =
tree.apply_event(create_store_event(i, j, vec![1, 2, 3, 4 + j], None));
}
})
})
.collect();
// Wait for all threads
for h in reader_handles {
h.join().unwrap();
}
for h in writer_handles {
h.join().unwrap();
while let Ok(task) = event_receiver.recv() {
match task {
WorkerTask::Event(event) => {
if let Err(e) = self.apply_event(&mut lookup, event) {
tracing::warn!("Failed to apply event: {:?}", e);
}
// Tree should have 10 workers now
assert_eq!(trie.get_workers().len(), 10);
}
#[test]
fn test_remove_parent_does_not_cascade() {
let trie = ConcurrentRadixTree::new();
let worker_1 = 0;
// Create a chain: root -> block1 -> block2 -> block3
trie.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None))
.unwrap();
let worker_key = WorkerWithDpRank::from_worker_id(worker_1);
assert_eq!(trie.lookup.read().get(&worker_key).unwrap().read().len(), 3);
// Remove ONLY block1 -- descendants should NOT be cascade-removed
trie.apply_event(create_remove_event(worker_1, 2, vec![1]))
.unwrap();
let lk = trie.lookup.read();
let worker_lookup = lk.get(&worker_key).unwrap().read();
assert!(
!worker_lookup.contains_key(&ExternalSequenceBlockHash(100)),
"block1 should be removed"
);
assert!(
worker_lookup.contains_key(&ExternalSequenceBlockHash(200)),
"block2 should remain (no cascade)"
);
assert!(
worker_lookup.contains_key(&ExternalSequenceBlockHash(300)),
"block3 should remain (no cascade)"
);
assert_eq!(worker_lookup.len(), 2);
}
#[test]
fn test_remove_all_blocks_individually() {
// Verifies that explicitly removing all blocks (as the engine would)
// cleans up fully, even without cascade.
let trie = ConcurrentRadixTree::new();
let worker_1 = 0;
trie.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None))
.unwrap();
let worker_key = WorkerWithDpRank::from_worker_id(worker_1);
// Remove all three blocks explicitly in one event
trie.apply_event(create_remove_event(worker_1, 2, vec![1, 2, 3]))
.unwrap();
let lk = trie.lookup.read();
let worker_lookup = lk.get(&worker_key).unwrap().read();
assert_eq!(worker_lookup.len(), 0, "all blocks should be removed");
}
#[test]
fn test_find_matches_with_stale_entries() {
// Two workers share a full path. Remove worker_1 from the root block
// only (simulating a partial remove). find_matches should still
// produce correct scores for worker_2, and worker_1 should score at
// the stale descendant depth (transiently inflated but not a crash).
let trie = ConcurrentRadixTree::new();
let worker_1 = 0;
let worker_2 = 1;
// Both workers have blocks 1 -> 2 -> 3
trie.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None))
.unwrap();
trie.apply_event(create_store_event(worker_2, 2, vec![1, 2, 3], None))
.unwrap();
// Remove worker_1 from block 1 only (no cascade to 2,3)
trie.apply_event(create_remove_event(worker_1, 3, vec![1]))
.unwrap();
let scores = trie.find_matches_impl(
&[LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
);
// worker_2 was never removed, should have full depth
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_2)),
Some(&3),
"worker_2 should score 3 (fully present)"
);
// worker_1 was removed from block 1 so it drops out at depth 1.
// But because blocks 2 and 3 still have worker_1 (stale), the
// child_count > active_count path fires and detects the dropout.
// The exact score depends on the detection logic: worker_1 is absent
// from block 1's workers, so it should be scored at depth 0 from the
// first child initialization (it won't appear in `active` at all).
// So worker_1 should NOT appear in scores (it was never in active).
assert!(
!scores
.scores
.contains_key(&WorkerWithDpRank::from_worker_id(worker_1)),
"worker_1 should not appear in scores (removed from root-level block)"
);
}
// ========================================================================
// ThreadPoolIndexer<ConcurrentRadixTree> Tests
// ========================================================================
mod thread_pool_indexer_tests {
use tokio::time::Duration;
use super::*;
use crate::indexer::{KvIndexerInterface, ThreadPoolIndexer};
fn make_indexer(
num_workers: usize,
kv_block_size: u32,
) -> ThreadPoolIndexer<ConcurrentRadixTree> {
ThreadPoolIndexer::new(ConcurrentRadixTree::new(), num_workers, kv_block_size)
WorkerTask::RemoveWorker(worker_id) => {
self.remove_or_clear_worker_blocks(&mut lookup, worker_id, false);
}
#[tokio::test]
async fn test_thread_pool_indexer_basic() {
let indexer = make_indexer(4, 16);
let worker_1 = 0;
let worker_2 = 1;
indexer
.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None))
.await;
indexer
.apply_event(create_store_event(worker_2, 1, vec![1, 4, 5], None))
.await;
tokio::time::sleep(Duration::from_millis(100)).await;
let scores = indexer
.find_matches(vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
])
.await
.unwrap();
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap(),
&3
);
assert_eq!(
scores
.scores
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap(),
&1
);
indexer.shutdown();
WorkerTask::DumpEvents(_sender) => {
// Handled directly via dump_events() on the shared tree.
// Should not be reached, but respond with empty to avoid blocking.
let _ = _sender.send(Ok(Vec::new()));
}
#[tokio::test]
async fn test_thread_pool_indexer_remove_worker() {
let indexer = make_indexer(2, 16);
let worker_0 = 0;
let worker_1 = 1;
indexer
.apply_event(create_store_event(worker_0, 1, vec![1, 2, 3], None))
.await;
indexer
.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None))
.await;
tokio::time::sleep(Duration::from_millis(100)).await;
assert_eq!(indexer.backend().get_workers().len(), 2);
indexer.remove_worker(worker_0).await;
let workers = indexer.backend().get_workers();
assert_eq!(workers.len(), 1);
assert!(!workers.contains(&worker_0));
assert!(workers.contains(&worker_1));
indexer.shutdown();
}
#[tokio::test]
async fn test_thread_pool_indexer_dump_events() {
let indexer = make_indexer(2, 16);
indexer
.apply_event(create_store_event(0, 1, vec![1, 2, 3], None))
.await;
tokio::time::sleep(Duration::from_millis(100)).await;
let events = indexer.dump_events().await.unwrap();
assert_eq!(events.len(), 3);
indexer.shutdown();
}
#[tokio::test]
async fn test_thread_pool_indexer_find_matches_for_request() {
let indexer = make_indexer(2, 1);
indexer
.apply_event(create_store_event(0, 1, vec![100, 200, 300], None))
.await;
tokio::time::sleep(Duration::from_millis(100)).await;
let scores = indexer
.find_matches_for_request(&[100, 200, 300], None)
.await;
assert!(scores.is_ok());
indexer.shutdown();
}
#[tokio::test]
async fn test_thread_pool_indexer_sticky_routing() {
let indexer = make_indexer(4, 16);
for i in 0..10 {
indexer
.apply_event(create_store_event(0, i, vec![i], None))
.await;
}
tokio::time::sleep(Duration::from_millis(100)).await;
assert_eq!(indexer.backend().current_size(), 10);
indexer.shutdown();
}
#[tokio::test]
async fn test_thread_pool_indexer_multiple_workers() {
let indexer = make_indexer(4, 16);
for worker_id in 0..8 {
indexer
.apply_event(create_store_event(
worker_id,
1,
vec![1, 2, worker_id + 10],
None,
))
.await;
}
tokio::time::sleep(Duration::from_millis(100)).await;
assert_eq!(indexer.backend().get_workers().len(), 8);
let scores = indexer
.find_matches(vec![LocalBlockHash(1), LocalBlockHash(2)])
.await
.unwrap();
assert_eq!(scores.scores.len(), 8);
for (_, score) in scores.scores.iter() {
assert_eq!(*score, 2);
}
indexer.shutdown();
WorkerTask::Terminate => {
break;
}
#[tokio::test]
async fn test_thread_pool_indexer_shutdown_idempotent() {
let indexer = make_indexer(2, 16);
indexer
.apply_event(create_store_event(0, 1, vec![1, 2, 3], None))
.await;
tokio::time::sleep(Duration::from_millis(100)).await;
indexer.shutdown();
indexer.shutdown();
}
#[tokio::test]
async fn test_thread_pool_indexer_concurrent_operations() {
use std::sync::Arc;
let indexer = Arc::new(make_indexer(4, 16));
for worker_id in 0..4 {
indexer
.apply_event(create_store_event(worker_id, 1, vec![1, 2, 3, 4, 5], None))
.await;
}
tokio::time::sleep(Duration::from_millis(100)).await;
let sequence = vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)];
let mut handles = Vec::new();
for _ in 0..10 {
let idx = indexer.clone();
let seq = sequence.clone();
handles.push(tokio::spawn(
async move { idx.find_matches(seq).await.unwrap() },
));
tracing::debug!("ConcurrentRadixTree worker thread shutting down");
Ok(())
}
for handle in handles {
let scores = handle.await.unwrap();
assert_eq!(scores.scores.len(), 4);
fn find_matches(&self, sequence: &[LocalBlockHash], early_exit: bool) -> OverlapScores {
self.find_matches_impl(sequence, early_exit)
}
indexer.shutdown();
}
fn dump_events(&self) -> Option<Vec<RouterEvent>> {
Some(self.dump_tree_as_events())
}
}
......@@ -359,6 +359,14 @@ pub trait KvIndexerInterface {
async fn flush(&self) -> usize;
}
pub enum WorkerTask {
Event(RouterEvent),
/// Permanently remove a worker from tracking (keep_worker: false).
RemoveWorker(WorkerId),
DumpEvents(oneshot::Sender<anyhow::Result<Vec<RouterEvent>>>),
Terminate,
}
// ============================================================================
// SyncIndexer trait and ThreadPoolIndexer generic wrapper
// ============================================================================
......@@ -373,17 +381,18 @@ pub trait KvIndexerInterface {
/// - Sticky event routing to N worker threads
/// - Inline reads on the caller's thread (no channel dispatch for find_matches)
pub trait SyncIndexer: Send + Sync + 'static {
fn worker(&self, event_receiver: flume::Receiver<WorkerTask>) -> anyhow::Result<()>;
/// Find matches for a sequence of block hashes.
fn find_matches(&self, sequence: &[LocalBlockHash], early_exit: bool) -> OverlapScores;
/// Apply a router event to the data structure.
fn apply_event(&self, event: RouterEvent) -> Result<(), KvCacheEventError>;
/// Remove all entries for a worker.
fn remove_worker(&self, worker_id: WorkerId);
/// Dump the data structure as router events for reconstruction.
fn dump_events(&self) -> Vec<RouterEvent>;
/// Dump events directly from the shared structure, bypassing worker channels.
/// Returns `Some(events)` for backends whose tree state is fully shared (e.g.
/// ConcurrentRadixTree). Returns `None` for backends that keep per-thread
/// state and must dump via the worker channel.
fn dump_events(&self) -> Option<Vec<RouterEvent>> {
None
}
}
/// Generic wrapper that provides [`KvIndexerInterface`] for any [`SyncIndexer`] backend.
......@@ -415,9 +424,9 @@ pub struct ThreadPoolIndexer<T: SyncIndexer> {
/// Counter for round-robin assignment of new WorkerIds.
worker_assignment_count: AtomicUsize,
/// Channels to send events to worker threads (one per thread).
/// Sending `None` signals the thread to shut down.
worker_event_channels: Vec<flume::Sender<Option<RouterEvent>>>,
/// Channels to send tasks to worker threads (one per thread).
/// Sending `WorkerTask::Terminate` signals the thread to shut down.
worker_event_channels: Vec<flume::Sender<WorkerTask>>,
/// Number of worker threads.
num_workers: usize,
......@@ -450,18 +459,13 @@ impl<T: SyncIndexer> ThreadPoolIndexer<T> {
let mut worker_event_senders = Vec::new();
let mut thread_handles = Vec::new();
for _ in 0..num_workers {
let (event_sender, event_receiver) = flume::unbounded::<Option<RouterEvent>>();
let (event_sender, event_receiver) = flume::unbounded::<WorkerTask>();
worker_event_senders.push(event_sender);
let backend = Arc::clone(&backend);
let handle = std::thread::spawn(move || {
while let Ok(Some(event)) = event_receiver.recv() {
if let Err(e) = backend.apply_event(event) {
tracing::warn!("Failed to apply event: {:?}", e);
}
}
tracing::debug!("Worker thread shutting down");
backend.worker(event_receiver).unwrap();
});
thread_handles.push(handle);
}
......@@ -530,7 +534,7 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
});
// Send event to the assigned worker thread
if let Err(e) = self.worker_event_channels[thread_idx].send(Some(event)) {
if let Err(e) = self.worker_event_channels[thread_idx].send(WorkerTask::Event(event)) {
tracing::error!(
"Failed to send event to worker thread {}: {:?}",
thread_idx,
......@@ -540,14 +544,34 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
}
async fn remove_worker(&self, worker_id: WorkerId) {
// Execute inline - the backend is thread-safe
self.backend.remove_worker(worker_id);
// Route to the worker's assigned thread (if any), otherwise broadcast
// to all threads since dp_ranks may be spread across threads.
let thread_idx = self.worker_assignments.get(&worker_id).map(|v| *v);
match thread_idx {
Some(idx) => {
if let Err(e) =
self.worker_event_channels[idx].send(WorkerTask::RemoveWorker(worker_id))
{
tracing::error!(
"Failed to send RemoveWorker to worker thread {}: {:?}",
idx,
e
);
}
}
None => {
// Worker was never assigned a thread - broadcast to all
for channel in &self.worker_event_channels {
let _ = channel.send(WorkerTask::RemoveWorker(worker_id));
}
}
}
}
fn shutdown(&self) {
// Send shutdown signal (None) to all worker threads
// Send shutdown signal to all worker threads
for channel in self.worker_event_channels.iter() {
let _ = channel.send(None);
let _ = channel.send(WorkerTask::Terminate);
}
// Take ownership of thread handles and join them
......@@ -565,8 +589,41 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
}
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
// Execute inline - the backend is thread-safe
Ok(self.backend.dump_events())
// Fast path: backend can dump directly from shared state (e.g. ConcurrentRadixTree).
if let Some(events) = self.backend.dump_events() {
return Ok(events);
}
// Slow path: collect from each worker thread via channel (e.g. PositionalIndexer).
let mut receivers = Vec::new();
for channel in &self.worker_event_channels {
let (resp_tx, resp_rx) = oneshot::channel::<anyhow::Result<Vec<RouterEvent>>>();
let dump_req = WorkerTask::DumpEvents(resp_tx);
channel
.send(dump_req)
.map_err(|_| KvRouterError::IndexerOffline)?;
receivers.push(resp_rx);
}
let mut event_id_counter = 0;
let mut all_events = Vec::new();
for resp_rx in receivers {
let mut events = resp_rx
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)?
.map_err(|_| KvRouterError::IndexerOffline)?;
for event in &mut events {
event.event.event_id = event_id_counter;
event_id_counter += 1;
}
all_events.extend(events);
}
Ok(all_events)
}
async fn process_routing_decision_for_request(
......@@ -2354,6 +2411,16 @@ mod tests {
index.shutdown();
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_shutdown_idempotent(variant: &str) {
let index = make_indexer(variant);
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
tokio::time::sleep(Duration::from_millis(100)).await;
index.shutdown();
index.shutdown();
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_find_matches_for_request(variant: &str) {
......
......@@ -21,10 +21,10 @@
//! `KvIndexerInterface` with sticky event routing and worker threads, wrap it
//! in a `ThreadPoolIndexer`.
use dashmap::DashMap;
use parking_lot::RwLock;
use rustc_hash::{FxBuildHasher, FxHashMap, FxHashSet};
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::indexer::SyncIndexer;
use crate::indexer::{SyncIndexer, WorkerTask};
use crate::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheEventError, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash, OverlapScores, RouterEvent, WorkerId, WorkerWithDpRank,
......@@ -100,7 +100,7 @@ impl SeqEntry {
}
}
type LevelIndex = RwLock<FxHashMap<ExternalSequenceBlockHash, (usize, LocalBlockHash)>>;
pub type LevelIndex = FxHashMap<ExternalSequenceBlockHash, (usize, LocalBlockHash)>;
/// Positional HashMap-based KV cache index.
///
......@@ -108,11 +108,8 @@ type LevelIndex = RwLock<FxHashMap<ExternalSequenceBlockHash, (usize, LocalBlock
/// All methods are synchronous and thread-safe.
pub struct PositionalIndexer {
index: DashMap<(usize, LocalBlockHash), SeqEntry, FxBuildHasher>,
/// Per-worker reverse lookup: worker -> seq_hash -> (position, local_hash)
/// Enables efficient remove operations without global flat reverse map.
/// Uses a single RwLock rather than DashMap because structural mutations
/// (adding/removing workers) are rare; the hot path is read-only.
worker_blocks: RwLock<FxHashMap<WorkerWithDpRank, LevelIndex>>,
tree_sizes: DashMap<WorkerWithDpRank, AtomicUsize, FxBuildHasher>,
jump_size: usize,
}
......@@ -129,7 +126,7 @@ impl PositionalIndexer {
Self {
index: DashMap::with_hasher(FxBuildHasher),
worker_blocks: RwLock::new(FxHashMap::default()),
tree_sizes: DashMap::with_hasher(FxBuildHasher),
jump_size,
}
}
......@@ -140,83 +137,37 @@ impl PositionalIndexer {
// ============================================================================
impl SyncIndexer for PositionalIndexer {
fn find_matches(&self, sequence: &[LocalBlockHash], early_exit: bool) -> OverlapScores {
self.jump_search_matches(sequence, early_exit)
}
fn worker(&self, event_receiver: flume::Receiver<WorkerTask>) -> anyhow::Result<()> {
let mut worker_blocks = FxHashMap::default();
fn apply_event(&self, event: RouterEvent) -> Result<(), KvCacheEventError> {
Self::apply_event_impl(&self.index, &self.worker_blocks, event)
while let Ok(task) = event_receiver.recv() {
match task {
WorkerTask::Event(event) => {
if let Err(e) = self.apply_event(&mut worker_blocks, event) {
tracing::warn!("Failed to apply event: {:?}", e);
}
fn remove_worker(&self, worker_id: WorkerId) {
Self::remove_or_clear_worker_blocks_impl(
&self.index,
&self.worker_blocks,
worker_id,
false,
);
}
fn dump_events(&self) -> Vec<RouterEvent> {
let mut events = Vec::new();
let mut event_id = 0u64;
let wb = self.worker_blocks.read();
for (worker, level_index) in wb.iter() {
let worker = *worker;
let worker_map = level_index.read();
// Collect (position, local_hash, seq_hash) and sort by position
// so parents are emitted before children during replay.
let mut blocks: Vec<_> = worker_map
.iter()
.map(|(seq_hash, (pos, local_hash))| (*pos, *local_hash, *seq_hash))
.collect();
blocks.sort_unstable_by_key(|(pos, _, _)| *pos);
// Track one valid seq_hash per position for parent_hash synthesis.
let mut last_at_position: FxHashMap<usize, ExternalSequenceBlockHash> =
FxHashMap::default();
for (pos, local_hash, seq_hash) in blocks {
let parent_hash = if pos == 0 {
None
} else {
match last_at_position.get(&(pos - 1)) {
Some(&parent) => Some(parent),
None => {
tracing::warn!(
worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank,
position = pos,
"Orphaned block at position with no parent; skipping in dump"
);
continue;
WorkerTask::RemoveWorker(worker_id) => {
self.remove_or_clear_worker_blocks_impl(&mut worker_blocks, worker_id, false);
}
WorkerTask::DumpEvents(sender) => {
let events = self.dump_events(&worker_blocks);
if let Err(e) = sender.send(Ok(events)) {
tracing::warn!("Failed to send events: {:?}", e);
}
}
WorkerTask::Terminate => {
break;
}
};
events.push(RouterEvent {
worker_id: worker.worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: vec![KvCacheStoredBlockData {
block_hash: seq_hash,
tokens_hash: local_hash,
mm_extra_info: None,
}],
}),
dp_rank: worker.dp_rank,
},
});
event_id += 1;
last_at_position.insert(pos, seq_hash);
}
}
events
tracing::debug!("PositionalIndexer worker thread shutting down");
Ok(())
}
fn find_matches(&self, sequence: &[LocalBlockHash], early_exit: bool) -> OverlapScores {
self.jump_search_matches(sequence, early_exit)
}
}
......@@ -227,9 +178,9 @@ impl SyncIndexer for PositionalIndexer {
impl PositionalIndexer {
/// Process an event using the provided index and worker_blocks.
/// This is called from worker threads.
fn apply_event_impl(
index: &DashMap<(usize, LocalBlockHash), SeqEntry, FxBuildHasher>,
worker_blocks: &RwLock<FxHashMap<WorkerWithDpRank, LevelIndex>>,
pub fn apply_event(
&self,
worker_blocks: &mut FxHashMap<WorkerWithDpRank, LevelIndex>,
event: RouterEvent,
) -> Result<(), KvCacheEventError> {
let (worker_id, kv_event) = (event.worker_id, event.event);
......@@ -245,50 +196,32 @@ impl PositionalIndexer {
match op {
KvCacheEventData::Stored(store_data) => {
Self::store_blocks_impl(index, worker_blocks, worker, store_data, id)?;
self.store_blocks_impl(worker_blocks, worker, store_data, id)?;
Ok(())
}
KvCacheEventData::Removed(remove_data) => {
Self::remove_blocks_impl(
index,
worker_blocks,
worker,
&remove_data.block_hashes,
id,
)?;
self.remove_blocks_impl(worker_blocks, worker, &remove_data.block_hashes, id)?;
Ok(())
}
KvCacheEventData::Cleared => {
Self::clear_worker_blocks_impl(index, worker_blocks, worker_id);
self.clear_worker_blocks_impl(worker_blocks, worker_id);
Ok(())
}
}
}
fn store_blocks_impl(
index: &DashMap<(usize, LocalBlockHash), SeqEntry, FxBuildHasher>,
worker_blocks: &RwLock<FxHashMap<WorkerWithDpRank, LevelIndex>>,
&self,
worker_blocks: &mut FxHashMap<WorkerWithDpRank, LevelIndex>,
worker: WorkerWithDpRank,
store_data: KvCacheStoreData,
event_id: u64,
) -> Result<(), KvCacheEventError> {
let worker_map = worker_blocks.entry(worker).or_default();
// Determine starting position based on parent_hash
let start_pos = match store_data.parent_hash {
Some(parent_hash) => {
let wb = worker_blocks.read();
let Some(level_index) = wb.get(&worker) else {
tracing::warn!(
worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank,
event_id,
parent_hash = ?parent_hash,
);
return Err(KvCacheEventError::ParentBlockNotFound);
};
let worker_map = level_index.read();
let Some(entry) = worker_map.get(&parent_hash) else {
tracing::warn!(
worker_id = worker.worker_id.to_string(),
......@@ -304,42 +237,45 @@ impl PositionalIndexer {
None => 0, // Start from position 0
};
if !worker_blocks.read().contains_key(&worker) {
worker_blocks
.write()
.entry(worker)
.or_insert_with(|| RwLock::new(FxHashMap::default()));
}
let worker_blocks_entry = worker_blocks.entry(worker).or_default();
let wb = worker_blocks.read();
let mut worker_map = wb.get(&worker).unwrap().write();
let num_stored_blocks = store_data.blocks.len();
for (i, block_data) in store_data.blocks.into_iter().enumerate() {
let position = start_pos + i;
let local_hash = block_data.tokens_hash;
let seq_hash = block_data.block_hash;
index
self.index
.entry((position, local_hash))
.and_modify(|entry| entry.insert(seq_hash, worker))
.or_insert_with(|| SeqEntry::new(seq_hash, worker));
// Insert into worker_blocks: worker -> seq_hash -> (position, local_hash)
worker_map.insert(seq_hash, (position, local_hash));
worker_blocks_entry.insert(seq_hash, (position, local_hash));
}
match self.tree_sizes.get(&worker) {
Some(size) => {
size.fetch_add(num_stored_blocks, Ordering::Relaxed);
}
None => {
self.tree_sizes
.insert(worker, AtomicUsize::new(num_stored_blocks));
}
}
Ok(())
}
fn remove_blocks_impl(
index: &DashMap<(usize, LocalBlockHash), SeqEntry, FxBuildHasher>,
worker_blocks: &RwLock<FxHashMap<WorkerWithDpRank, LevelIndex>>,
&self,
worker_blocks: &mut FxHashMap<WorkerWithDpRank, LevelIndex>,
worker: WorkerWithDpRank,
seq_hashes: &Vec<ExternalSequenceBlockHash>,
event_id: u64,
) -> Result<(), KvCacheEventError> {
let wb = worker_blocks.read();
let level_index = wb.get(&worker).ok_or_else(|| {
let worker_map = worker_blocks.get_mut(&worker).ok_or_else(|| {
tracing::warn!(
worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank,
......@@ -350,7 +286,7 @@ impl PositionalIndexer {
KvCacheEventError::BlockNotFound
})?;
let mut worker_map = level_index.write();
let mut num_removed_blocks = 0;
for seq_hash in seq_hashes {
let Some((position, local_hash)) = worker_map.remove(seq_hash) else {
......@@ -361,13 +297,23 @@ impl PositionalIndexer {
block_hash = ?seq_hash,
"Failed to find block to remove; skipping remove operation"
);
if let Some(size) = self.tree_sizes.get(&worker) {
size.fetch_sub(num_removed_blocks, Ordering::Relaxed);
}
return Err(KvCacheEventError::BlockNotFound);
};
// Remove from index
if let Some(mut entry) = index.get_mut(&(position, local_hash)) {
if let Some(mut entry) = self.index.get_mut(&(position, local_hash)) {
let _ = entry.remove(*seq_hash, worker);
}
num_removed_blocks += 1;
}
if let Some(size) = self.tree_sizes.get(&worker) {
size.fetch_sub(num_removed_blocks, Ordering::Relaxed);
}
Ok(())
......@@ -376,63 +322,114 @@ impl PositionalIndexer {
/// Clear all blocks for a specific worker_id (all dp_ranks), but keep worker tracked.
/// Static version for use in worker threads.
fn clear_worker_blocks_impl(
index: &DashMap<(usize, LocalBlockHash), SeqEntry, FxBuildHasher>,
worker_blocks: &RwLock<FxHashMap<WorkerWithDpRank, LevelIndex>>,
&self,
worker_blocks: &mut FxHashMap<WorkerWithDpRank, LevelIndex>,
worker_id: WorkerId,
) {
Self::remove_or_clear_worker_blocks_impl(index, worker_blocks, worker_id, true);
}
/// Get total number of blocks across all workers.
pub fn current_size(&self) -> usize {
self.worker_blocks
.read()
.values()
.map(|level_index| level_index.read().len())
.sum()
}
/// Remove a worker and all their blocks completely from the index.
#[allow(dead_code)]
fn remove_worker_blocks(&self, worker_id: WorkerId) {
Self::remove_or_clear_worker_blocks_impl(
&self.index,
&self.worker_blocks,
worker_id,
false,
);
self.remove_or_clear_worker_blocks_impl(worker_blocks, worker_id, true);
}
/// Helper function to remove or clear blocks for a worker.
/// If `keep_worker` is true, the worker remains tracked with empty blocks.
/// If `keep_worker` is false, the worker is completely removed.
fn remove_or_clear_worker_blocks_impl(
index: &DashMap<(usize, LocalBlockHash), SeqEntry, FxBuildHasher>,
worker_blocks: &RwLock<FxHashMap<WorkerWithDpRank, LevelIndex>>,
&self,
worker_blocks: &mut FxHashMap<WorkerWithDpRank, LevelIndex>,
worker_id: WorkerId,
keep_worker: bool,
) {
let workers: Vec<WorkerWithDpRank> = worker_blocks
.read()
.keys()
.filter(|w| w.worker_id == worker_id)
.copied()
.iter()
.filter(|entry| entry.0.worker_id == worker_id)
.map(|entry| *entry.0)
.collect();
let mut wb = worker_blocks.write();
for worker in workers {
if let Some(worker_map) = wb.remove(&worker) {
for (seq_hash, (position, local_hash)) in worker_map.read().iter() {
if let Some(mut entry) = index.get_mut(&(*position, *local_hash)) {
if let Some(worker_map) = worker_blocks.remove(&worker) {
for (seq_hash, (position, local_hash)) in worker_map.iter() {
if let Some(mut entry) = self.index.get_mut(&(*position, *local_hash)) {
let _ = entry.remove(*seq_hash, worker);
}
}
}
if keep_worker {
wb.insert(worker, RwLock::new(FxHashMap::default()));
// Re-insert worker with empty map to keep it tracked
worker_blocks.insert(worker, FxHashMap::default());
// Reset tree size to 0 but keep the entry so scoring remains consistent.
if let Some(size) = self.tree_sizes.get(&worker) {
size.store(0, Ordering::Relaxed);
}
} else {
// Fully remove the worker from tree_sizes.
self.tree_sizes.remove(&worker);
}
}
}
fn dump_events(
&self,
worker_blocks: &FxHashMap<WorkerWithDpRank, LevelIndex>,
) -> Vec<RouterEvent> {
let mut events = Vec::new();
let mut event_id = 0u64;
for (worker, worker_map) in worker_blocks.iter() {
// Collect (position, local_hash, seq_hash) and sort by position
// so parents are emitted before children during replay.
let mut blocks: Vec<_> = worker_map
.iter()
.map(|(seq_hash, (pos, local_hash))| (*pos, *local_hash, *seq_hash))
.collect();
blocks.sort_unstable_by_key(|(pos, _, _)| *pos);
// Track one valid seq_hash per position for parent_hash synthesis.
// Note: The synthesized parent_hash doesn't need to be the true logical
// parent — during replay it's only used to derive `start_pos = parent.position + 1`,
// so any seq_hash at the previous position is sufficient. The PositionalIndexer
// is position-based, not tree-topology-based.
let mut last_at_position: FxHashMap<usize, ExternalSequenceBlockHash> =
FxHashMap::default();
for (pos, local_hash, seq_hash) in blocks {
let parent_hash = if pos == 0 {
None
} else {
match last_at_position.get(&(pos - 1)) {
Some(&parent) => Some(parent),
None => {
tracing::warn!(
worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank,
position = pos,
"Orphaned block at position with no parent; skipping in dump"
);
continue;
}
}
};
events.push(RouterEvent {
worker_id: worker.worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: vec![KvCacheStoredBlockData {
block_hash: seq_hash,
tokens_hash: local_hash,
mm_extra_info: None,
}],
}),
dp_rank: worker.dp_rank,
},
});
event_id += 1;
last_at_position.insert(pos, seq_hash);
}
}
events
}
}
......@@ -533,11 +530,10 @@ impl PositionalIndexer {
hi: usize,
early_exit: bool,
) {
for pos in lo..hi {
if active.is_empty() {
break;
return;
}
for pos in lo..hi {
let Some(entry) = self.index.get(&(pos, sequence[pos])) else {
for worker in active.iter() {
scores.scores.insert(*worker, pos as u32);
......@@ -568,6 +564,7 @@ impl PositionalIndexer {
scores.scores.insert(*worker, pos as u32);
}
active.clear();
break;
}
}
}
......@@ -626,10 +623,12 @@ impl PositionalIndexer {
scores.scores.insert(*worker, 1);
}
// Populate tree_sizes
let wb = self.worker_blocks.read();
for worker in scores.scores.keys() {
if let Some(level_index) = wb.get(worker) {
scores.tree_sizes.insert(*worker, level_index.read().len());
if let Some(worker_tree_size) = self.tree_sizes.get(worker) {
scores
.tree_sizes
.insert(*worker, worker_tree_size.load(Ordering::Relaxed));
}
}
return scores;
......@@ -677,11 +676,11 @@ impl PositionalIndexer {
scores.scores.insert(worker, final_score);
}
// Populate tree_sizes from worker_blocks
let wb = self.worker_blocks.read();
for worker in scores.scores.keys() {
if let Some(level_index) = wb.get(worker) {
scores.tree_sizes.insert(*worker, level_index.read().len());
if let Some(worker_tree_size) = self.tree_sizes.get(worker) {
scores
.tree_sizes
.insert(*worker, worker_tree_size.load(Ordering::Relaxed));
}
}
......
......@@ -182,6 +182,7 @@ RUN apt-get update -y && \
pybind11-dev \
clang \
libclang-dev \
libfontconfig-dev \
protobuf-compiler && \
rm -rf /var/lib/apt/lists/*
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment