Unverified Commit c5c6a551 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: more flash indexer optimizations (#6305)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarCursor <cursoragent@cursor.com>
parent a8226eb0
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -42,6 +42,7 @@ parking_lot = { workspace = true }
clap = { version = "4.5", features = ["derive"], optional = true }
indicatif = { version = "0.18.0", optional = true }
uuid = { workspace = true, optional = true }
rustc-hash = "2.1.1"
[dev-dependencies]
dynamo-bench = { path = "../bench" }
......@@ -53,6 +54,7 @@ dynamo-mocker = { workspace = true }
dynamo-tokens = { workspace = true }
minstant = "0.1.7"
futures = "0.3"
plotters = "0.3"
[[bench]]
......
......@@ -7,7 +7,9 @@ use dynamo_kv_router::indexer::{
KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvIndexerSharded,
};
use dynamo_kv_router::protocols::{RouterEvent, XXH3_SEED};
use dynamo_kv_router::{ConcurrentRadixTree, PositionalIndexer, ThreadPoolIndexer};
use dynamo_kv_router::{
ConcurrentRadixTree, InvertedIndex, NaiveNestedMap, PositionalIndexer, ThreadPoolIndexer,
};
use dynamo_tokens::compute_hash_v2;
use rand::prelude::*;
use std::fs::File;
......@@ -24,6 +26,7 @@ use tokio::task::JoinHandle;
use tokio::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
use plotters::prelude::*;
use serde::{Deserialize, Serialize};
/// Indexer backend selection and its backend-specific parameters.
......@@ -56,6 +59,17 @@ enum IndexerArgs {
#[clap(long, default_value = "16")]
num_event_workers: usize,
},
/// Naive per-worker nested HashMap indexer behind a single-threaded actor
/// (blog section 2).
NaiveNestedMap {},
/// Inverted index keyed by local_hash (blog section 3).
InvertedIndex {
/// Number of OS threads that consume and apply KV cache events.
#[clap(long, default_value = "16")]
num_event_workers: usize,
},
}
impl IndexerArgs {
......@@ -88,8 +102,39 @@ impl IndexerArgs {
args.block_size,
))
}
IndexerArgs::NaiveNestedMap {} => Arc::new(NaiveNestedMap::new()),
IndexerArgs::InvertedIndex { .. } => Arc::new(InvertedIndex::new()),
}
}
/// Construct an indexer from a short name string, using `args.num_event_workers`.
fn from_name(
name: &str,
args: &Args,
) -> anyhow::Result<Arc<dyn KvIndexerInterface + Send + Sync>> {
let nw = args.num_event_workers;
let indexer_args = match name {
"radix-tree" => IndexerArgs::RadixTree {},
"radix-tree-sharded" => IndexerArgs::RadixTreeSharded { num_shards: 4 },
"nested-map" => IndexerArgs::NestedMap {
jump_size: 8,
num_event_workers: nw,
},
"concurrent-radix-tree" => IndexerArgs::ConcurrentRadixTree {
num_event_workers: nw,
},
"naive-nested-map" => IndexerArgs::NaiveNestedMap {},
"inverted-index" => IndexerArgs::InvertedIndex {
num_event_workers: 0,
},
_ => anyhow::bail!(
"Unknown indexer '{}'. Valid names: radix-tree, radix-tree-sharded, \
nested-map, concurrent-radix-tree, naive-nested-map, inverted-index",
name
),
};
Ok(indexer_args.build(args))
}
}
#[derive(Parser, Debug)]
......@@ -106,7 +151,7 @@ struct Args {
/// Number of GPU blocks available in the mock engine's KV cache.
/// Smaller values force more evictions and produce more remove events.
#[clap(long, default_value = "2048")]
#[clap(long, default_value = "1048576")]
num_gpu_blocks: usize,
/// Number of tokens per KV cache block.
......@@ -126,7 +171,7 @@ struct Args {
/// Number of unique simulated inference workers. Each gets a random
/// partition of the trace and its own mock engine for event generation.
#[clap(short, long, default_value = "64")]
#[clap(short, long, default_value = "256")]
num_unique_inference_workers: usize,
/// How many times to duplicate the set of unique workers during the
......@@ -152,6 +197,41 @@ struct Args {
#[clap(long, default_value = "42")]
seed: u64,
/// Enable throughput vs p99 latency sweep mode. Runs the benchmark at
/// multiple benchmark_duration_ms values and plots the results.
#[clap(long)]
sweep: bool,
/// Minimum benchmark duration (ms) for sweep mode.
#[clap(long, default_value = "1000")]
sweep_min_ms: u64,
/// Maximum benchmark duration (ms) for sweep mode.
#[clap(long, default_value = "50000")]
sweep_max_ms: u64,
/// Number of logarithmically spaced sweep steps between min and max.
#[clap(long, default_value = "10")]
sweep_steps: usize,
/// Output path for the sweep plot PNG.
#[clap(long, default_value = "sweep_plot.png")]
sweep_output: String,
/// Comma-separated list of indexer names to benchmark and compare on the
/// same plot. Overrides the subcommand indexer when present. Valid names:
/// radix-tree, radix-tree-sharded, nested-map, concurrent-radix-tree,
/// naive-nested-map, inverted-index.
#[clap(long, value_delimiter = ',')]
compare: Vec<String>,
/// Number of OS threads for event processing in compare mode. Applies to
/// indexers that use a thread pool (nested-map, concurrent-radix-tree,
/// inverted-index). Ignored by radix-tree, radix-tree-sharded, and
/// naive-nested-map.
#[clap(long, default_value = "16")]
num_event_workers: usize,
/// Indexer backend to benchmark (defaults to radix-tree if not specified).
#[clap(subcommand)]
indexer: Option<IndexerArgs>,
......@@ -340,6 +420,7 @@ fn duplicate_traces(requests: Vec<MooncakeRequest>, factor: usize) -> Vec<Moonca
for d in 0..factor {
let offset = offset_base * d as u64;
out.push(MooncakeRequest {
uuid: Uuid::new_v4(),
hash_ids: r.hash_ids.iter().map(|&h| h + offset).collect(),
..r.clone()
});
......@@ -513,7 +594,9 @@ async fn generate_events(
fn prepare_worker_traces(
traces: Vec<Vec<MooncakeRequest>>,
events: Vec<Vec<(KvCacheEvent, Instant)>>,
args: &Args,
block_size: u32,
benchmark_duration_ms: u64,
trace_simulation_duration_ms: u64,
) -> Vec<Vec<WorkerTrace>> {
assert!(traces.len() == events.len());
......@@ -525,13 +608,13 @@ fn prepare_worker_traces(
trace
.into_iter()
.map(|request| WorkerTrace {
timestamp_us: request.timestamp * 1000 * args.benchmark_duration_ms
timestamp_us: request.timestamp * 1000 * benchmark_duration_ms
/ trace_duration_ms,
entry: WorkerTraceEntry::Request(
request
.hash_ids
.iter()
.map(|id| local_block_hash_from_id(*id, args.block_size))
.map(|id| local_block_hash_from_id(*id, block_size))
.collect(),
),
})
......@@ -547,8 +630,8 @@ fn prepare_worker_traces(
.into_iter()
.map(|(event, timestamp)| WorkerTrace {
timestamp_us: (timestamp - start_instant).as_micros() as u64
* args.benchmark_duration_ms
/ args.trace_simulation_duration_ms,
* benchmark_duration_ms
/ trace_simulation_duration_ms,
entry: WorkerTraceEntry::Event(event),
})
.collect::<Vec<_>>()
......@@ -569,6 +652,15 @@ fn prepare_worker_traces(
.collect()
}
/// Results from a single benchmark run.
struct BenchmarkResults {
offered_ops_throughput: f32,
ops_throughput: f32,
offered_block_throughput: f32,
block_throughput: f32,
latency_p99_us: f32,
}
/// Run the benchmark: replay each worker's merged trace against the indexer,
/// measuring find_matches latency and event processing throughput.
///
......@@ -580,8 +672,15 @@ async fn run_benchmark(
traces: Vec<Vec<MooncakeRequest>>,
events: Vec<Vec<(KvCacheEvent, Instant)>>,
args: &Args,
) -> anyhow::Result<()> {
let worker_traces = prepare_worker_traces(traces, events, args);
benchmark_duration_ms: u64,
) -> anyhow::Result<BenchmarkResults> {
let worker_traces = prepare_worker_traces(
traces,
events,
args.block_size,
benchmark_duration_ms,
args.trace_simulation_duration_ms,
);
let worker_traces = worker_traces
.into_iter()
.map(|trace| Arc::new(trace))
......@@ -680,21 +779,13 @@ async fn run_benchmark(
latencies.extend(task.await??);
}
if progress.elapsed() > Duration::from_millis(args.benchmark_duration_ms * 11 / 10) {
if progress.elapsed() > Duration::from_millis(benchmark_duration_ms * 11 / 10) {
eprintln!(
"WARNING: The benchmarker is unable to keep up with the request/event generation rate. Rerun with a larger --benchmark-duration-ms."
)
}
println!("Flushing event queue...");
let request_duration = progress.elapsed();
let flush_start = Instant::now();
let flush_size = indexer.flush().await;
let flush_duration = flush_start.elapsed();
let event_duration = progress.elapsed();
let total_duration = progress.elapsed();
let total_events = worker_traces
.iter()
......@@ -711,48 +802,154 @@ async fn run_benchmark(
* args.inference_worker_duplication_factor
- total_events;
let event_queue_flush_percentage = flush_size as f32 / total_events as f32 * 100.0;
let total_request_blocks: usize = worker_traces
.iter()
.flat_map(|t| t.iter())
.filter_map(|entry| match &entry.entry {
WorkerTraceEntry::Request(hashes) => Some(hashes.len()),
_ => None,
})
.sum::<usize>()
* args.inference_worker_duplication_factor;
println!("Event queue flush duration: {:?}", flush_duration);
println!(
"Event queue flush size: {} ({}% of total events)",
flush_size, event_queue_flush_percentage
);
let total_event_blocks: usize = worker_traces
.iter()
.flat_map(|t| t.iter())
.filter_map(|entry| match &entry.entry {
WorkerTraceEntry::Event(ev) => match &ev.data {
KvCacheEventData::Stored(s) => Some(s.blocks.len()),
_ => Some(0),
},
_ => None,
})
.sum::<usize>()
* args.inference_worker_duplication_factor;
if event_queue_flush_percentage > 5.0 {
eprintln!(
"ERROR: Over 5% of events were unable to be completed within the benchmark duration.
Results are invalid. Rerun with a smaller trace or less worker duplication."
);
}
let total_blocks = total_request_blocks + total_event_blocks;
println!(
"Request Throughput: {} req/s",
total_requests as f32 / request_duration.as_millis() as f32 * 1000.0
);
println!(
"Event Throughput: {} events/s",
total_events as f32 / event_duration.as_millis() as f32 * 1000.0
);
let total_ops = total_requests + total_events;
let offered_ops_throughput = total_ops as f32 / benchmark_duration_ms as f32 * 1000.0;
let ops_throughput = total_ops as f32 / total_duration.as_millis() as f32 * 1000.0;
let offered_block_throughput = total_blocks as f32 / benchmark_duration_ms as f32 * 1000.0;
let block_throughput = total_blocks as f32 / total_duration.as_millis() as f32 * 1000.0;
latencies.sort_unstable();
let latency_p99_us = latencies[latencies.len() * 99 / 100] as f32 / 1000.0;
println!(
"Latency p50: {}us",
latencies[latencies.len() / 2] as f32 / 1000.0
);
println!(
"Latency p95: {}us",
latencies[latencies.len() * 95 / 100] as f32 / 1000.0
);
println!(
"Latency p99: {}us",
latencies[latencies.len() * 99 / 100] as f32 / 1000.0
);
println!(
"Latency max: {}us",
*latencies.last().unwrap() as f32 / 1000.0
"Ops Throughput: {} ops/s (requests + events)",
ops_throughput
);
println!("Block Throughput: {} block ops/s", block_throughput);
println!("Latency p99: {}us", latency_p99_us);
Ok(BenchmarkResults {
offered_ops_throughput,
ops_throughput,
offered_block_throughput,
block_throughput,
latency_p99_us,
})
}
fn plot_sweep(
all_results: &[(&str, Vec<(u64, BenchmarkResults)>)],
output_path: &str,
) -> anyhow::Result<()> {
use plotters::coord::combinators::IntoLogRange;
use plotters::element::DashedPathElement;
use plotters::style::ShapeStyle;
let colors = [
RGBColor(31, 119, 180),
RGBColor(255, 127, 14),
RGBColor(44, 160, 44),
RGBColor(214, 39, 40),
RGBColor(148, 103, 189),
RGBColor(140, 86, 75),
];
let mut global_min = f64::MAX;
let mut global_max = f64::MIN;
for (_, results) in all_results {
for (_, r) in results {
let offered = r.offered_block_throughput as f64;
let achieved = r.block_throughput as f64;
global_min = global_min.min(offered).min(achieved);
global_max = global_max.max(offered).max(achieved);
}
}
let axis_min = global_min * 0.9;
let axis_max = global_max * 1.1;
let root = BitMapBackend::new(output_path, (800, 600)).into_drawing_area();
root.fill(&WHITE)?;
let mut chart = ChartBuilder::on(&root)
.caption(
"Achieved vs Offered Throughput",
("sans-serif", 22).into_font(),
)
.margin(20)
.x_label_area_size(40)
.y_label_area_size(80)
.build_cartesian_2d(
(axis_min..axis_max).log_scale(),
(axis_min..axis_max).log_scale(),
)?;
chart
.configure_mesh()
.x_desc("Offered Throughput (block ops/s)")
.y_desc("Achieved Throughput (block ops/s)")
.draw()?;
let identity_style = ShapeStyle::from(&BLACK.mix(0.4)).stroke_width(1);
chart.draw_series(std::iter::once(DashedPathElement::new(
vec![(axis_min, axis_min), (axis_max, axis_max)],
5,
3,
identity_style,
)))?;
for (i, (name, results)) in all_results.iter().enumerate() {
let color = &colors[i % colors.len()];
let points: Vec<(f64, f64)> = results
.iter()
.map(|(_, r)| (r.offered_block_throughput as f64, r.block_throughput as f64))
.collect();
let series_color = *color;
chart
.draw_series(LineSeries::new(
points.iter().map(|&(x, y)| (x, y)),
&series_color,
))?
.label(*name)
.legend(move |(x, y)| {
plotters::element::PathElement::new(
vec![(x, y), (x + 20, y)],
series_color.stroke_width(2),
)
});
chart.draw_series(
points
.iter()
.map(|&(x, y)| Circle::new((x, y), 4, series_color.filled())),
)?;
}
chart
.configure_series_labels()
.position(SeriesLabelPosition::LowerRight)
.background_style(WHITE.mix(0.8))
.border_style(BLACK)
.draw()?;
root.present()?;
println!("Sweep plot saved to {}", output_path);
Ok(())
}
......@@ -841,12 +1038,93 @@ async fn main() -> anyhow::Result<()> {
}
let traces = process_mooncake_trace(&args)?;
let events = generate_events(&traces, &args).await?;
let indexer = args.get_indexer().build(&args);
let indexer_names: Vec<String> = if args.compare.is_empty() {
let name = match args.get_indexer() {
IndexerArgs::RadixTree {} => "radix-tree",
IndexerArgs::RadixTreeSharded { .. } => "radix-tree-sharded",
IndexerArgs::NestedMap { .. } => "nested-map",
IndexerArgs::ConcurrentRadixTree { .. } => "concurrent-radix-tree",
IndexerArgs::NaiveNestedMap {} => "naive-nested-map",
IndexerArgs::InvertedIndex { .. } => "inverted-index",
};
vec![name.to_string()]
} else {
args.compare.clone()
};
if args.sweep {
let log_min = (args.sweep_min_ms as f64).ln();
let log_max = (args.sweep_max_ms as f64).ln();
let n = args.sweep_steps;
let durations: Vec<u64> = (0..n)
.map(|i| {
let t = i as f64 / (n - 1) as f64;
(log_max * (1.0 - t) + log_min * t).exp().round() as u64
})
.collect();
let mut all_results: Vec<(&str, Vec<(u64, BenchmarkResults)>)> = Vec::new();
for name in &indexer_names {
println!("\n{}", "=".repeat(60));
println!("Benchmarking indexer: {}", name);
println!("{}", "=".repeat(60));
let mut results: Vec<(u64, BenchmarkResults)> = Vec::new();
for &dur_ms in &durations {
println!("\n=== Sweep: benchmark_duration_ms = {} ===", dur_ms);
let indexer = if args.compare.is_empty() {
args.get_indexer().build(&args)
} else {
IndexerArgs::from_name(name, &args)?
};
let result =
run_benchmark(indexer, traces.clone(), events.clone(), &args, dur_ms).await?;
results.push((dur_ms, result));
}
println!("\n=== Sweep Summary: {} ===", name);
println!(
"{:>12} {:>14} {:>14} {:>14} {:>14} {:>10}",
"duration_ms", "ops/s_off", "ops/s", "blk_ops/s_off", "blk_ops/s", "p99(us)"
);
for (dur, r) in &results {
println!(
"{:>12} {:>14.1} {:>14.1} {:>14.1} {:>14.1} {:>10.1}",
dur,
r.offered_ops_throughput,
r.ops_throughput,
r.offered_block_throughput,
r.block_throughput,
r.latency_p99_us,
);
}
run_benchmark(indexer, traces, events, &args).await?;
all_results.push((name, results));
}
plot_sweep(&all_results, &args.sweep_output)?;
} else {
for name in &indexer_names {
println!("\nBenchmarking indexer: {}", name);
let indexer = if args.compare.is_empty() {
args.get_indexer().build(&args)
} else {
IndexerArgs::from_name(name, &args)?
};
run_benchmark(
indexer,
traces.clone(),
events.clone(),
&args,
args.benchmark_duration_ms,
)
.await?;
}
}
Ok(())
}
......@@ -8,7 +8,7 @@
//!
//! Unlike `RadixTree` which uses `Rc<RefCell<>>` and requires single-threaded access,
//! `ConcurrentRadixTree` uses `Arc<RwLock<>>` per node and a
//! `DashMap<..., RwLock<HashMap<...>>>` for the lookup table.
//! `RwLock<FxHashMap<..., RwLock<FxHashMap<...>>>>` for the lookup table.
//!
//! # Limitations vs RadixTree
//!
......@@ -20,17 +20,15 @@
//!
//! - Multiple `find_matches` can run in parallel (read locks only)
//! - Write operations (`apply_event`, `remove_worker`) acquire write locks
//! - The outer `DashMap` distributes contention across shards; inner `RwLock`
//! per worker allows per-worker write concurrency.
//! - Outer `RwLock` is read-locked on the hot path; structural mutations
//! (adding/removing workers) are rare. Inner `RwLock` per worker allows
//! per-worker write concurrency.
//! - Deadlock prevention: always lock parent before child, hand-over-hand locking
use std::{
collections::{HashMap, HashSet, VecDeque},
sync::Arc,
};
use std::{collections::VecDeque, sync::Arc};
use dashmap::DashMap;
use parking_lot::RwLock;
use rustc_hash::{FxHashMap, FxHashSet};
use crate::indexer::SyncIndexer;
use crate::protocols::*;
......@@ -38,13 +36,16 @@ use crate::protocols::*;
/// Thread-safe shared reference to a Block.
type SharedBlock = Arc<RwLock<Block>>;
/// Per-worker block-hash map. Inner RwLock allows concurrent reads of different workers.
type WorkerLookup = FxHashMap<ExternalSequenceBlockHash, SharedBlock>;
/// A block in the concurrent radix tree.
#[derive(Debug)]
struct Block {
/// A map of child blocks, keyed by their local block hash.
children: HashMap<LocalBlockHash, SharedBlock>,
children: FxHashMap<LocalBlockHash, SharedBlock>,
/// The set of workers that have this block cached.
workers: HashSet<WorkerWithDpRank>,
workers: FxHashSet<WorkerWithDpRank>,
/// The external sequence block hash for this block (None for root).
block_hash: Option<ExternalSequenceBlockHash>,
// NOTE: No recent_uses field.
......@@ -55,8 +56,8 @@ impl Block {
/// Create a new `Block` (used for root node).
fn new() -> Self {
Self {
children: HashMap::new(),
workers: HashSet::new(),
children: FxHashMap::default(),
workers: FxHashSet::default(),
block_hash: None,
}
}
......@@ -64,8 +65,8 @@ impl Block {
/// Create a new `Block` with a specific block hash.
fn with_hash(block_hash: ExternalSequenceBlockHash) -> Self {
Self {
children: HashMap::new(),
workers: HashSet::new(),
children: FxHashMap::default(),
workers: FxHashSet::default(),
block_hash: Some(block_hash),
}
}
......@@ -75,7 +76,7 @@ impl Block {
///
/// Unlike `RadixTree` which uses `Rc<RefCell<>>` and requires single-threaded access,
/// `ConcurrentRadixTree` uses `Arc<RwLock<>>` per node and a
/// `DashMap<..., RwLock<HashMap<...>>>` for the lookup table,
/// `RwLock<FxHashMap<..., RwLock<FxHashMap<...>>>>` for the lookup table,
/// enabling concurrent `find_matches` operations.
///
/// # Limitations vs RadixTree
......@@ -88,8 +89,9 @@ impl Block {
///
/// - Multiple `find_matches` can run in parallel (read locks only)
/// - Write operations (`apply_event`, `remove_worker`) acquire write locks
/// - The outer `DashMap` distributes contention across shards; inner `RwLock`
/// per worker allows per-worker write concurrency.
/// - Outer RwLock is read-locked on the hot path; structural mutations
/// (adding/removing workers) are rare and take a write lock.
/// - Inner `RwLock` per worker allows per-worker write concurrency.
/// - Deadlock prevention: always lock parent before child, hand-over-hand locking
pub struct ConcurrentRadixTree {
/// This is the root of the radix/prefix tree.
......@@ -97,9 +99,9 @@ pub struct ConcurrentRadixTree {
root: SharedBlock,
/// Per-worker lookup table for O(1) block access.
/// Outer `DashMap` distributes lock contention across shards; inner `RwLock`
/// per worker protects that worker's block-hash map.
lookup: DashMap<WorkerWithDpRank, RwLock<HashMap<ExternalSequenceBlockHash, SharedBlock>>>,
/// Outer RwLock protects the worker map structure (rarely mutated);
/// inner RwLock per worker protects that worker's block-hash map.
lookup: RwLock<FxHashMap<WorkerWithDpRank, RwLock<WorkerLookup>>>,
}
impl Default for ConcurrentRadixTree {
......@@ -121,15 +123,11 @@ impl Drop for ConcurrentRadixTree {
}
// Remove all lookup references (they may include blocks not reachable from root).
// We have &mut self so no concurrent access; drain the DashMap by clearing it
// after collecting all inner values.
let entries: Vec<_> = self
.lookup
.iter()
.flat_map(|entry| entry.value().read().values().cloned().collect::<Vec<_>>())
.collect();
stack.extend(entries);
self.lookup.clear();
// We have &mut self so no concurrent access; drain the map.
let lookup = self.lookup.get_mut();
for (_, inner_lock) in lookup.drain() {
stack.extend(inner_lock.into_inner().into_values());
}
// Iteratively free any uniquely-owned blocks without recursion
while let Some(block) = stack.pop() {
......@@ -146,7 +144,7 @@ impl ConcurrentRadixTree {
pub fn new() -> Self {
Self {
root: Arc::new(RwLock::new(Block::new())),
lookup: DashMap::new(),
lookup: RwLock::new(FxHashMap::default()),
}
}
......@@ -199,8 +197,9 @@ impl ConcurrentRadixTree {
for worker in &active {
scores.scores.insert(*worker, 1);
}
let lk = self.lookup.read();
for worker in scores.scores.keys() {
if let Some(inner_lock) = self.lookup.get(worker) {
if let Some(inner_lock) = lk.get(worker) {
scores.tree_sizes.insert(*worker, inner_lock.read().len());
}
}
......@@ -234,25 +233,10 @@ impl ConcurrentRadixTree {
let guard = block.read();
let child_count = guard.workers.len();
if child_count < active_count {
// Workers dropped out. Record scores for those that left.
// Score = matched_depth (number of nodes they were present at).
for worker in &active {
if !guard.workers.contains(worker) {
scores.scores.insert(*worker, matched_depth);
}
}
active.clone_from(&guard.workers);
active_count = child_count;
if active_count == 0 {
break;
}
} else if child_count > active_count {
// child_count > active_count means stale entries exist
// (child retains workers already removed from an ancestor).
// Fall back to full membership check: keep only workers
// present in both active and this child, scoring dropouts.
if child_count != active_count {
// Workers changed: either dropped out (child < active) or
// stale entries exist (child > active). In both cases,
// retain only workers present in the child, scoring dropouts.
active.retain(|w| {
if guard.workers.contains(w) {
true
......@@ -288,8 +272,9 @@ impl ConcurrentRadixTree {
}
// Get tree sizes from lookup.
let lk = self.lookup.read();
for worker in scores.scores.keys() {
if let Some(inner_lock) = self.lookup.get(worker) {
if let Some(inner_lock) = lk.get(worker) {
scores.tree_sizes.insert(*worker, inner_lock.read().len());
}
}
......@@ -330,14 +315,15 @@ impl ConcurrentRadixTree {
id: u64,
) -> Result<(), KvCacheEventError> {
// Ensure this worker has an entry in the outer map.
if !self.lookup.contains_key(&worker) {
if !self.lookup.read().contains_key(&worker) {
self.lookup
.write()
.entry(worker)
.or_insert_with(|| RwLock::new(HashMap::new()));
.or_insert_with(|| RwLock::new(FxHashMap::default()));
}
let inner_ref = self.lookup.get(&worker).unwrap();
let mut worker_lookup = inner_ref.write();
let lk = self.lookup.read();
let mut worker_lookup = lk.get(&worker).unwrap().write();
// Find parent block
let mut current = match op.parent_hash {
......@@ -435,7 +421,8 @@ impl ConcurrentRadixTree {
op: KvCacheRemoveData,
id: u64,
) -> Result<(), KvCacheEventError> {
let Some(inner_ref) = self.lookup.get(&worker) else {
let lk = self.lookup.read();
let Some(inner_ref) = lk.get(&worker) else {
return Err(KvCacheEventError::BlockNotFound);
};
let mut worker_lookup = inner_ref.write();
......@@ -467,17 +454,17 @@ impl ConcurrentRadixTree {
/// If `keep_worker` is true, the worker remains in lookup with empty blocks.
/// If `keep_worker` is false, the worker is completely removed from lookup.
fn remove_or_clear_worker_blocks(&self, worker_id: WorkerId, keep_worker: bool) {
// Collect all WorkerWithDpRank keys that match this worker_id
let workers: Vec<WorkerWithDpRank> = self
.lookup
.iter()
.filter(|entry| entry.key().worker_id == worker_id)
.map(|entry| *entry.key())
.read()
.keys()
.filter(|w| w.worker_id == worker_id)
.copied()
.collect();
let mut lk = self.lookup.write();
for worker in workers {
if let Some((_, inner_lock)) = self.lookup.remove(&worker) {
// We now own the inner RwLock; extract the HashMap.
if let Some(inner_lock) = lk.remove(&worker) {
let blocks = inner_lock.into_inner();
for (_, block) in blocks {
let mut guard = block.write();
......@@ -488,7 +475,7 @@ impl ConcurrentRadixTree {
}
if keep_worker {
self.lookup.insert(worker, RwLock::new(HashMap::new()));
lk.insert(worker, RwLock::new(FxHashMap::default()));
}
}
}
......@@ -509,9 +496,10 @@ impl ConcurrentRadixTree {
pub fn get_workers(&self) -> Vec<WorkerId> {
let mut worker_ids: Vec<WorkerId> = self
.lookup
.iter()
.map(|entry| entry.key().worker_id)
.collect::<HashSet<_>>()
.read()
.keys()
.map(|w| w.worker_id)
.collect::<FxHashSet<_>>()
.into_iter()
.collect();
worker_ids.sort_unstable();
......@@ -524,7 +512,7 @@ impl ConcurrentRadixTree {
pub fn dump_tree_as_events(&self) -> Vec<RouterEvent> {
tracing::debug!(
"Dumping concurrent radix tree as events (contains information about {:?} workers)",
self.lookup.len()
self.lookup.read().len()
);
let mut events = Vec::new();
......@@ -583,8 +571,9 @@ impl ConcurrentRadixTree {
/// Get total number of blocks across all workers.
pub fn current_size(&self) -> usize {
self.lookup
.iter()
.map(|entry| entry.value().read().len())
.read()
.values()
.map(|inner| inner.read().len())
.sum()
}
}
......@@ -641,9 +630,10 @@ mod tests {
&3
);
assert_eq!(trie.lookup.len(), 1);
assert_eq!(trie.lookup.read().len(), 1);
assert_eq!(
trie.lookup
.read()
.get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap()
.read()
......@@ -673,7 +663,7 @@ mod tests {
&1
);
assert_eq!(trie.lookup.len(), 2);
assert_eq!(trie.lookup.read().len(), 2);
}
#[test]
......@@ -693,6 +683,7 @@ mod tests {
assert_eq!(
trie.lookup
.read()
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap()
.read()
......@@ -705,6 +696,7 @@ mod tests {
assert_eq!(
trie.lookup
.read()
.get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap()
.read()
......@@ -751,10 +743,12 @@ mod tests {
assert!(
trie.lookup
.read()
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert!(
trie.lookup
.read()
.get(&WorkerWithDpRank::from_worker_id(worker_0))
.unwrap()
.read()
......@@ -780,16 +774,17 @@ mod tests {
trie.apply_event(create_store_event(worker_1, 0, vec![1, 2, 3], None))
.unwrap();
assert_eq!(trie.lookup.len(), 2);
assert_eq!(trie.lookup.read().len(), 2);
trie.remove_worker(worker_0);
assert!(
!trie
.lookup
.read()
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
);
assert_eq!(trie.lookup.len(), 1);
assert_eq!(trie.lookup.read().len(), 1);
let result = trie
.find_matches_impl(
......@@ -807,7 +802,7 @@ mod tests {
let trie: ConcurrentRadixTree = Default::default();
assert!(trie.root.read().children.is_empty());
assert!(trie.root.read().workers.is_empty());
assert!(trie.lookup.is_empty());
assert!(trie.lookup.read().is_empty());
}
#[test]
......@@ -920,14 +915,14 @@ mod tests {
.unwrap();
let worker_key = WorkerWithDpRank::from_worker_id(worker_1);
assert_eq!(trie.lookup.get(&worker_key).unwrap().read().len(), 3);
assert_eq!(trie.lookup.read().get(&worker_key).unwrap().read().len(), 3);
// Remove ONLY block1 -- descendants should NOT be cascade-removed
trie.apply_event(create_remove_event(worker_1, 2, vec![1]))
.unwrap();
let inner_ref = trie.lookup.get(&worker_key).unwrap();
let worker_lookup = inner_ref.read();
let lk = trie.lookup.read();
let worker_lookup = lk.get(&worker_key).unwrap().read();
assert!(
!worker_lookup.contains_key(&ExternalSequenceBlockHash(100)),
"block1 should be removed"
......@@ -959,8 +954,8 @@ mod tests {
trie.apply_event(create_remove_event(worker_1, 2, vec![1, 2, 3]))
.unwrap();
let inner_ref = trie.lookup.get(&worker_key).unwrap();
let worker_lookup = inner_ref.read();
let lk = trie.lookup.read();
let worker_lookup = lk.get(&worker_key).unwrap().read();
assert_eq!(worker_lookup.len(), 0, "all blocks should be removed");
}
......
......@@ -44,6 +44,7 @@ use dynamo_runtime::{
metrics::{MetricsHierarchy, prometheus_names::kvrouter},
};
use prometheus::{IntCounterVec, Opts};
use rustc_hash::FxBuildHasher;
/// Trait for types that may represent an error response.
/// Used for RPC-style responses that can indicate success or failure.
......@@ -406,7 +407,7 @@ pub struct ThreadPoolIndexer<T: SyncIndexer> {
backend: Arc<T>,
/// Maps WorkerId to worker thread index for sticky routing.
worker_assignments: DashMap<WorkerId, usize>,
worker_assignments: DashMap<WorkerId, usize, FxBuildHasher>,
/// Counter for round-robin assignment of new WorkerIds.
worker_assignment_count: AtomicUsize,
......@@ -463,7 +464,7 @@ impl<T: SyncIndexer> ThreadPoolIndexer<T> {
Self {
backend,
worker_assignments: DashMap::new(),
worker_assignments: DashMap::with_hasher(FxBuildHasher),
worker_assignment_count: AtomicUsize::new(0),
worker_event_channels: worker_event_senders,
num_workers,
......@@ -1388,7 +1389,7 @@ pub struct KvIndexerSharded {
cancel: CancellationToken,
/// The size of the KV block this indexer can handle.
kv_block_size: u32,
worker_assignments: DashMap<WorkerId, usize>,
worker_assignments: DashMap<WorkerId, usize, FxBuildHasher>,
worker_counts: Arc<Mutex<Vec<usize>>>,
event_tx: Vec<mpsc::Sender<RouterEvent>>,
......@@ -1421,7 +1422,7 @@ impl KvIndexerSharded {
metrics: Arc<KvIndexerMetrics>,
prune_config: Option<PruneConfig>,
) -> Self {
let worker_assignments = DashMap::new();
let worker_assignments = DashMap::with_hasher(FxBuildHasher);
let worker_counts = Arc::new(Mutex::new(vec![0; num_shards]));
let mut event_tx = Vec::new();
......
......@@ -11,6 +11,8 @@ pub mod approx;
pub mod bench_utils;
pub mod concurrent_radix_tree;
pub mod indexer;
#[cfg(feature = "bench")]
pub mod naive_indexers;
pub mod nested_map;
pub mod protocols;
pub mod radix_tree;
......@@ -21,6 +23,8 @@ pub(crate) mod test_utils;
// Re-export key types for convenience
pub use concurrent_radix_tree::ConcurrentRadixTree;
pub use indexer::{MaybeError, SyncIndexer, ThreadPoolIndexer};
#[cfg(feature = "bench")]
pub use naive_indexers::{InvertedIndex, NaiveNestedMap};
pub use nested_map::PositionalIndexer;
pub use protocols::{
KvCacheEventError, LocalBlockHash, OverlapScores, RouterEvent, WorkerId,
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! **DO NOT USE IN PRODUCTION.** These are intentionally simplified indexer
//! implementations for benchmarking and blog illustrations only. They cut
//! corners (no reverse lookup, Remove events are unimplemented) that make
//! them incorrect under real workloads with eviction pressure.
//!
//! They correspond to blog sections 2 and 3 and exist to show the performance
//! progression from naive approaches to the production indexers.
//!
//! - [`NaiveNestedMap`]: `worker -> set<local_hash>`. O(W × D) per
//! `find_matches` call, behind a single-threaded actor. Blog section 2.
//! - [`InvertedIndex`]: `local_hash -> set<worker>`. O(D + W) per
//! `find_matches` call, single-threaded actor. Blog section 3.
use async_trait::async_trait;
use std::collections::{HashMap, HashSet};
use tokio::sync::{mpsc, oneshot};
use crate::indexer::{KvIndexerInterface, KvRouterError};
use crate::protocols::{
KvCacheEventData, LocalBlockHash, OverlapScores, RouterEvent, TokensWithHashes, WorkerId,
WorkerWithDpRank,
};
// ============================================================================
// Section 2 — Naive Nested Map + Actor
// ============================================================================
/// Plain nested `HashMap` index — no locks, owned exclusively by the actor thread.
///
/// Structure: `worker -> set<local_hash>`.
/// No reverse lookup — Remove is unimplemented (relies on large GPU block
/// budget to avoid evictions).
struct NaiveNestedMapInner {
index: HashMap<WorkerWithDpRank, HashSet<LocalBlockHash>>,
}
impl NaiveNestedMapInner {
fn new() -> Self {
Self {
index: HashMap::new(),
}
}
fn find_matches(&self, sequence: &[LocalBlockHash]) -> OverlapScores {
let mut scores = OverlapScores::new();
if sequence.is_empty() {
return scores;
}
for (worker, blocks) in &self.index {
let mut depth = 0u32;
for local_hash in sequence {
if !blocks.contains(local_hash) {
break;
}
depth += 1;
}
if depth > 0 {
scores.scores.insert(*worker, depth);
}
}
scores
}
fn apply_event(&mut self, event: RouterEvent) {
let worker = WorkerWithDpRank::new(event.worker_id, event.event.dp_rank);
match event.event.data {
KvCacheEventData::Stored(store_data) => {
let worker_set = self.index.entry(worker).or_default();
for block in store_data.blocks {
worker_set.insert(block.tokens_hash);
}
}
KvCacheEventData::Removed(_) => {
unimplemented!(
"NaiveNestedMap does not support Remove events; increase --num-gpu-blocks to avoid evictions"
);
}
KvCacheEventData::Cleared => {
self.index.remove(&worker);
}
}
}
fn remove_worker(&mut self, worker_id: WorkerId) {
self.index.retain(|w, _| w.worker_id != worker_id);
}
}
struct MatchRequest {
sequence: Vec<LocalBlockHash>,
reply: oneshot::Sender<OverlapScores>,
}
enum ActorMessage {
Event(RouterEvent),
Match(MatchRequest),
RemoveWorker(WorkerId),
}
/// Single-threaded actor wrapping [`NaiveNestedMapInner`] (blog section 2).
///
/// All reads and writes are serialized through a single OS thread via channels.
/// This is the pure actor pattern described in the blog — no concurrent access
/// to the data structure at all.
pub struct NaiveNestedMap {
tx: mpsc::Sender<ActorMessage>,
}
impl Default for NaiveNestedMap {
fn default() -> Self {
Self::new()
}
}
impl NaiveNestedMap {
pub fn new() -> Self {
let (tx, mut rx) = mpsc::channel::<ActorMessage>(2048);
std::thread::spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async move {
let mut inner = NaiveNestedMapInner::new();
while let Some(msg) = rx.recv().await {
match msg {
ActorMessage::Event(event) => {
inner.apply_event(event);
}
ActorMessage::Match(req) => {
let scores = inner.find_matches(&req.sequence);
let _ = req.reply.send(scores);
}
ActorMessage::RemoveWorker(worker_id) => {
inner.remove_worker(worker_id);
}
}
}
});
});
Self { tx }
}
}
#[async_trait]
impl KvIndexerInterface for NaiveNestedMap {
async fn find_matches(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<OverlapScores, KvRouterError> {
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(ActorMessage::Match(MatchRequest {
sequence,
reply: reply_tx,
}))
.await
.map_err(|_| KvRouterError::IndexerOffline)?;
reply_rx
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)
}
async fn find_matches_for_request(
&self,
_tokens: &[u32],
) -> Result<OverlapScores, KvRouterError> {
unimplemented!("not used in bench")
}
async fn apply_event(&self, event: RouterEvent) {
let _ = self.tx.send(ActorMessage::Event(event)).await;
}
async fn remove_worker(&self, worker: WorkerId) {
let _ = self.tx.send(ActorMessage::RemoveWorker(worker)).await;
}
fn shutdown(&self) {}
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
Ok(Vec::new())
}
async fn process_routing_decision_for_request(
&self,
_tokens_with_hashes: &mut TokensWithHashes,
_worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
unimplemented!("not used in bench")
}
async fn flush(&self) -> usize {
let curr_size = self.tx.max_capacity() - self.tx.capacity();
loop {
if self.tx.capacity() == self.tx.max_capacity() {
break;
}
tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
}
curr_size
}
}
// ============================================================================
// Section 3 — Inverted Index
// ============================================================================
/// Plain inverted index — no locks, owned exclusively by the actor thread.
///
/// Flat forward index: `local_hash -> set<worker>`.
/// No reverse lookup — Remove is a no-op (relies on large GPU block budget
/// to avoid evictions), Clear/remove_worker scan the forward index.
struct InvertedIndexInner {
index: HashMap<LocalBlockHash, HashSet<WorkerWithDpRank>>,
}
impl InvertedIndexInner {
fn new() -> Self {
Self {
index: HashMap::new(),
}
}
fn find_matches(&self, sequence: &[LocalBlockHash]) -> OverlapScores {
let mut scores = OverlapScores::new();
if sequence.is_empty() {
return scores;
}
let Some(workers) = self.index.get(&sequence[0]) else {
return scores;
};
let mut active: HashSet<WorkerWithDpRank> = workers.clone();
if active.is_empty() {
return scores;
}
for (depth, local_hash) in sequence.iter().enumerate() {
let empty = HashSet::new();
let workers_here = self.index.get(local_hash).unwrap_or(&empty);
active.retain(|w| {
if workers_here.contains(w) {
true
} else {
scores.scores.insert(*w, depth as u32);
false
}
});
if active.is_empty() {
break;
}
}
for w in active {
scores.scores.insert(w, sequence.len() as u32);
}
scores
}
fn apply_event(&mut self, event: RouterEvent) {
let worker = WorkerWithDpRank::new(event.worker_id, event.event.dp_rank);
match event.event.data {
KvCacheEventData::Stored(store_data) => {
for block in store_data.blocks {
self.index
.entry(block.tokens_hash)
.or_default()
.insert(worker);
}
}
KvCacheEventData::Removed(_) => {
unimplemented!(
"InvertedIndex does not support Remove events; increase --num-gpu-blocks to avoid evictions"
);
}
KvCacheEventData::Cleared => {
self.clear_worker(worker);
}
}
}
fn remove_worker(&mut self, worker_id: WorkerId) {
for workers in self.index.values_mut() {
workers.retain(|w| w.worker_id != worker_id);
}
}
fn clear_worker(&mut self, worker: WorkerWithDpRank) {
for workers in self.index.values_mut() {
workers.remove(&worker);
}
}
}
enum InvertedIndexMessage {
Event(RouterEvent),
Match(MatchRequest),
RemoveWorker(WorkerId),
}
/// Single-threaded actor wrapping [`InvertedIndexInner`] (blog section 3).
///
/// Same actor pattern as [`NaiveNestedMap`] — all reads and writes are
/// serialized through a single OS thread via channels.
pub struct InvertedIndex {
tx: mpsc::Sender<InvertedIndexMessage>,
}
impl Default for InvertedIndex {
fn default() -> Self {
Self::new()
}
}
impl InvertedIndex {
pub fn new() -> Self {
let (tx, mut rx) = mpsc::channel::<InvertedIndexMessage>(2048);
std::thread::spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async move {
let mut inner = InvertedIndexInner::new();
while let Some(msg) = rx.recv().await {
match msg {
InvertedIndexMessage::Event(event) => {
inner.apply_event(event);
}
InvertedIndexMessage::Match(req) => {
let scores = inner.find_matches(&req.sequence);
let _ = req.reply.send(scores);
}
InvertedIndexMessage::RemoveWorker(worker_id) => {
inner.remove_worker(worker_id);
}
}
}
});
});
Self { tx }
}
}
#[async_trait]
impl KvIndexerInterface for InvertedIndex {
async fn find_matches(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<OverlapScores, KvRouterError> {
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(InvertedIndexMessage::Match(MatchRequest {
sequence,
reply: reply_tx,
}))
.await
.map_err(|_| KvRouterError::IndexerOffline)?;
reply_rx
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)
}
async fn find_matches_for_request(
&self,
_tokens: &[u32],
) -> Result<OverlapScores, KvRouterError> {
unimplemented!("not used in bench")
}
async fn apply_event(&self, event: RouterEvent) {
let _ = self.tx.send(InvertedIndexMessage::Event(event)).await;
}
async fn remove_worker(&self, worker: WorkerId) {
let _ = self
.tx
.send(InvertedIndexMessage::RemoveWorker(worker))
.await;
}
fn shutdown(&self) {}
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
Ok(Vec::new())
}
async fn process_routing_decision_for_request(
&self,
_tokens_with_hashes: &mut TokensWithHashes,
_worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
unimplemented!("not used in bench")
}
async fn flush(&self) -> usize {
let curr_size = self.tx.max_capacity() - self.tx.capacity();
loop {
if self.tx.capacity() == self.tx.max_capacity() {
break;
}
tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
}
curr_size
}
}
......@@ -21,8 +21,8 @@
//! `KvIndexerInterface` with sticky event routing and worker threads, wrap it
//! in a `ThreadPoolIndexer`.
use dashmap::DashMap;
use std::collections::{HashMap, HashSet};
use std::sync::RwLock;
use parking_lot::RwLock;
use rustc_hash::{FxBuildHasher, FxHashMap, FxHashSet};
use crate::indexer::SyncIndexer;
use crate::protocols::{
......@@ -37,15 +37,15 @@ use crate::protocols::{
#[derive(Debug, Clone)]
enum SeqEntry {
/// Single seq_hash -> workers mapping (common case, no HashMap allocation)
Single(ExternalSequenceBlockHash, HashSet<WorkerWithDpRank>),
Single(ExternalSequenceBlockHash, FxHashSet<WorkerWithDpRank>),
/// Multiple seq_hash -> workers mappings (rare case, different prefixes)
Multi(HashMap<ExternalSequenceBlockHash, HashSet<WorkerWithDpRank>>),
Multi(FxHashMap<ExternalSequenceBlockHash, FxHashSet<WorkerWithDpRank>>),
}
impl SeqEntry {
/// Create a new entry with a single worker.
fn new(seq_hash: ExternalSequenceBlockHash, worker: WorkerWithDpRank) -> Self {
let mut workers = HashSet::new();
let mut workers = FxHashSet::default();
workers.insert(worker);
Self::Single(seq_hash, workers)
}
......@@ -58,7 +58,7 @@ impl SeqEntry {
}
Self::Single(existing_hash, existing_workers) => {
// Upgrade to Multi
let mut map = HashMap::with_capacity(2);
let mut map = FxHashMap::with_capacity_and_hasher(2, FxBuildHasher);
map.insert(*existing_hash, std::mem::take(existing_workers));
map.entry(seq_hash).or_default().insert(worker);
*self = Self::Multi(map);
......@@ -91,7 +91,7 @@ impl SeqEntry {
}
/// Get workers for a specific seq_hash.
fn get(&self, seq_hash: ExternalSequenceBlockHash) -> Option<&HashSet<WorkerWithDpRank>> {
fn get(&self, seq_hash: ExternalSequenceBlockHash) -> Option<&FxHashSet<WorkerWithDpRank>> {
match self {
Self::Single(existing_hash, workers) if *existing_hash == seq_hash => Some(workers),
Self::Single(_, _) => None,
......@@ -100,17 +100,19 @@ impl SeqEntry {
}
}
type LevelIndex = RwLock<HashMap<ExternalSequenceBlockHash, (usize, LocalBlockHash)>>;
type LevelIndex = RwLock<FxHashMap<ExternalSequenceBlockHash, (usize, LocalBlockHash)>>;
/// Positional HashMap-based KV cache index.
///
/// Implements [`SyncIndexer`] for use with [`ThreadPoolIndexer`](crate::indexer::ThreadPoolIndexer).
/// All methods are synchronous and thread-safe.
pub struct PositionalIndexer {
index: DashMap<(usize, LocalBlockHash), SeqEntry>,
index: DashMap<(usize, LocalBlockHash), SeqEntry, FxBuildHasher>,
/// Per-worker reverse lookup: worker -> seq_hash -> (position, local_hash)
/// Enables efficient remove operations without global flat reverse map.
worker_blocks: DashMap<WorkerWithDpRank, LevelIndex>,
/// Uses a single RwLock rather than DashMap because structural mutations
/// (adding/removing workers) are rare; the hot path is read-only.
worker_blocks: RwLock<FxHashMap<WorkerWithDpRank, LevelIndex>>,
jump_size: usize,
}
......@@ -126,8 +128,8 @@ impl PositionalIndexer {
assert!(jump_size > 0, "jump_size must be greater than 0");
Self {
index: DashMap::new(),
worker_blocks: DashMap::new(),
index: DashMap::with_hasher(FxBuildHasher),
worker_blocks: RwLock::new(FxHashMap::default()),
jump_size,
}
}
......@@ -159,9 +161,10 @@ impl SyncIndexer for PositionalIndexer {
let mut events = Vec::new();
let mut event_id = 0u64;
for entry in self.worker_blocks.iter() {
let worker = *entry.key();
let worker_map = entry.value().read().unwrap();
let wb = self.worker_blocks.read();
for (worker, level_index) in wb.iter() {
let worker = *worker;
let worker_map = level_index.read();
// Collect (position, local_hash, seq_hash) and sort by position
// so parents are emitted before children during replay.
......@@ -172,7 +175,8 @@ impl SyncIndexer for PositionalIndexer {
blocks.sort_unstable_by_key(|(pos, _, _)| *pos);
// Track one valid seq_hash per position for parent_hash synthesis.
let mut last_at_position: HashMap<usize, ExternalSequenceBlockHash> = HashMap::new();
let mut last_at_position: FxHashMap<usize, ExternalSequenceBlockHash> =
FxHashMap::default();
for (pos, local_hash, seq_hash) in blocks {
let parent_hash = if pos == 0 {
......@@ -224,8 +228,8 @@ impl PositionalIndexer {
/// Process an event using the provided index and worker_blocks.
/// This is called from worker threads.
fn apply_event_impl(
index: &DashMap<(usize, LocalBlockHash), SeqEntry>,
worker_blocks: &DashMap<WorkerWithDpRank, LevelIndex>,
index: &DashMap<(usize, LocalBlockHash), SeqEntry, FxBuildHasher>,
worker_blocks: &RwLock<FxHashMap<WorkerWithDpRank, LevelIndex>>,
event: RouterEvent,
) -> Result<(), KvCacheEventError> {
let (worker_id, kv_event) = (event.worker_id, event.event);
......@@ -263,8 +267,8 @@ impl PositionalIndexer {
}
fn store_blocks_impl(
index: &DashMap<(usize, LocalBlockHash), SeqEntry>,
worker_blocks: &DashMap<WorkerWithDpRank, LevelIndex>,
index: &DashMap<(usize, LocalBlockHash), SeqEntry, FxBuildHasher>,
worker_blocks: &RwLock<FxHashMap<WorkerWithDpRank, LevelIndex>>,
worker: WorkerWithDpRank,
store_data: KvCacheStoreData,
event_id: u64,
......@@ -272,9 +276,8 @@ impl PositionalIndexer {
// Determine starting position based on parent_hash
let start_pos = match store_data.parent_hash {
Some(parent_hash) => {
// Find parent position from worker_blocks
let Some(worker_map) = worker_blocks.get(&worker) else {
let wb = worker_blocks.read();
let Some(level_index) = wb.get(&worker) else {
tracing::warn!(
worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank,
......@@ -284,7 +287,7 @@ impl PositionalIndexer {
return Err(KvCacheEventError::ParentBlockNotFound);
};
let worker_map = worker_map.read().unwrap();
let worker_map = level_index.read();
let Some(entry) = worker_map.get(&parent_hash) else {
tracing::warn!(
......@@ -301,12 +304,15 @@ impl PositionalIndexer {
None => 0, // Start from position 0
};
if !worker_blocks.contains_key(&worker) {
worker_blocks.insert(worker, RwLock::new(HashMap::new()));
if !worker_blocks.read().contains_key(&worker) {
worker_blocks
.write()
.entry(worker)
.or_insert_with(|| RwLock::new(FxHashMap::default()));
}
let worker_blocks_entry = worker_blocks.get(&worker).unwrap();
let mut worker_map = worker_blocks_entry.write().unwrap();
let wb = worker_blocks.read();
let mut worker_map = wb.get(&worker).unwrap().write();
for (i, block_data) in store_data.blocks.into_iter().enumerate() {
let position = start_pos + i;
......@@ -326,13 +332,14 @@ impl PositionalIndexer {
}
fn remove_blocks_impl(
index: &DashMap<(usize, LocalBlockHash), SeqEntry>,
worker_blocks: &DashMap<WorkerWithDpRank, LevelIndex>,
index: &DashMap<(usize, LocalBlockHash), SeqEntry, FxBuildHasher>,
worker_blocks: &RwLock<FxHashMap<WorkerWithDpRank, LevelIndex>>,
worker: WorkerWithDpRank,
seq_hashes: &Vec<ExternalSequenceBlockHash>,
event_id: u64,
) -> Result<(), KvCacheEventError> {
let worker_map = worker_blocks.get(&worker).ok_or_else(|| {
let wb = worker_blocks.read();
let level_index = wb.get(&worker).ok_or_else(|| {
tracing::warn!(
worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank,
......@@ -343,7 +350,7 @@ impl PositionalIndexer {
KvCacheEventError::BlockNotFound
})?;
let mut worker_map = worker_map.write().unwrap();
let mut worker_map = level_index.write();
for seq_hash in seq_hashes {
let Some((position, local_hash)) = worker_map.remove(seq_hash) else {
......@@ -369,8 +376,8 @@ impl PositionalIndexer {
/// Clear all blocks for a specific worker_id (all dp_ranks), but keep worker tracked.
/// Static version for use in worker threads.
fn clear_worker_blocks_impl(
index: &DashMap<(usize, LocalBlockHash), SeqEntry>,
worker_blocks: &DashMap<WorkerWithDpRank, LevelIndex>,
index: &DashMap<(usize, LocalBlockHash), SeqEntry, FxBuildHasher>,
worker_blocks: &RwLock<FxHashMap<WorkerWithDpRank, LevelIndex>>,
worker_id: WorkerId,
) {
Self::remove_or_clear_worker_blocks_impl(index, worker_blocks, worker_id, true);
......@@ -379,8 +386,9 @@ impl PositionalIndexer {
/// Get total number of blocks across all workers.
pub fn current_size(&self) -> usize {
self.worker_blocks
.iter()
.map(|entry| entry.value().read().unwrap().len())
.read()
.values()
.map(|level_index| level_index.read().len())
.sum()
}
......@@ -399,34 +407,30 @@ impl PositionalIndexer {
/// If `keep_worker` is true, the worker remains tracked with empty blocks.
/// If `keep_worker` is false, the worker is completely removed.
fn remove_or_clear_worker_blocks_impl(
index: &DashMap<(usize, LocalBlockHash), SeqEntry>,
worker_blocks: &DashMap<WorkerWithDpRank, LevelIndex>,
index: &DashMap<(usize, LocalBlockHash), SeqEntry, FxBuildHasher>,
worker_blocks: &RwLock<FxHashMap<WorkerWithDpRank, LevelIndex>>,
worker_id: WorkerId,
keep_worker: bool,
) {
// Collect all WorkerWithDpRank keys that match this worker_id
let workers: Vec<WorkerWithDpRank> = worker_blocks
.iter()
.filter(|entry| entry.key().worker_id == worker_id)
.map(|entry| *entry.key())
.read()
.keys()
.filter(|w| w.worker_id == worker_id)
.copied()
.collect();
let mut wb = worker_blocks.write();
for worker in workers {
if let Some((_, worker_map)) = worker_blocks.remove(&worker) {
// Remove each block from the index
for entry in worker_map.read().unwrap().iter() {
let seq_hash = *entry.0;
let (position, local_hash) = *entry.1;
if let Some(mut entry) = index.get_mut(&(position, local_hash)) {
let _ = entry.remove(seq_hash, worker);
if let Some(worker_map) = wb.remove(&worker) {
for (seq_hash, (position, local_hash)) in worker_map.read().iter() {
if let Some(mut entry) = index.get_mut(&(*position, *local_hash)) {
let _ = entry.remove(*seq_hash, worker);
}
}
}
if keep_worker {
// Re-insert worker with empty map to keep it tracked
worker_blocks.insert(worker, RwLock::new(HashMap::new()));
wb.insert(worker, RwLock::new(FxHashMap::default()));
}
}
}
......@@ -481,7 +485,7 @@ impl PositionalIndexer {
local_hash: LocalBlockHash,
seq_hashes: &mut Vec<ExternalSequenceBlockHash>,
sequence: &[LocalBlockHash],
) -> Option<HashSet<WorkerWithDpRank>> {
) -> Option<FxHashSet<WorkerWithDpRank>> {
let entry = self.index.get(&(position, local_hash))?;
// Always compute and verify seq_hash to handle divergent queries correctly.
......@@ -517,13 +521,13 @@ impl PositionalIndexer {
/// Scan positions sequentially, updating active set and recording drain scores.
///
/// Inlines the DashMap lookup so the guard lives for each iteration,
/// avoiding a per-position `HashSet` clone.
/// avoiding a per-position `FxHashSet` clone.
#[allow(clippy::too_many_arguments)]
fn linear_scan_drain(
&self,
sequence: &[LocalBlockHash],
seq_hashes: &mut Vec<ExternalSequenceBlockHash>,
active: &mut HashSet<WorkerWithDpRank>,
active: &mut FxHashSet<WorkerWithDpRank>,
scores: &mut OverlapScores,
lo: usize,
hi: usize,
......@@ -622,10 +626,10 @@ impl PositionalIndexer {
scores.scores.insert(*worker, 1);
}
// Populate tree_sizes
let wb = self.worker_blocks.read();
for worker in scores.scores.keys() {
if let Some(worker_map) = self.worker_blocks.get(worker) {
let worker_map = worker_map.read().unwrap();
scores.tree_sizes.insert(*worker, worker_map.len());
if let Some(level_index) = wb.get(worker) {
scores.tree_sizes.insert(*worker, level_index.read().len());
}
}
return scores;
......@@ -674,10 +678,10 @@ impl PositionalIndexer {
}
// Populate tree_sizes from worker_blocks
let wb = self.worker_blocks.read();
for worker in scores.scores.keys() {
if let Some(worker_map) = self.worker_blocks.get(worker) {
let worker_map = worker_map.read().unwrap();
scores.tree_sizes.insert(*worker, worker_map.len());
if let Some(level_index) = wb.get(worker) {
scores.tree_sizes.insert(*worker, level_index.read().len());
}
}
......
......@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use dynamo_tokens::{SequenceHash, Token};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use xxhash_rust::xxh3;
......@@ -506,11 +507,11 @@ impl RouterEvent {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OverlapScores {
/// Map of worker (with dp_rank) to score.
pub scores: std::collections::HashMap<WorkerWithDpRank, u32>,
pub scores: FxHashMap<WorkerWithDpRank, u32>,
/// List of frequencies that the blocks have been accessed. Entries with value 0 are omitted.
pub frequencies: Vec<usize>,
/// Map of worker to their tree size (number of blocks in the tree for that worker).
pub tree_sizes: std::collections::HashMap<WorkerWithDpRank, usize>,
pub tree_sizes: FxHashMap<WorkerWithDpRank, usize>,
}
impl Default for OverlapScores {
......@@ -527,9 +528,9 @@ impl OverlapScores {
/// A new `OverlapScores`.
pub fn new() -> Self {
Self {
scores: std::collections::HashMap::new(),
scores: FxHashMap::default(),
frequencies: Vec::with_capacity(32),
tree_sizes: std::collections::HashMap::new(),
tree_sizes: FxHashMap::default(),
}
}
......
......@@ -16,11 +16,13 @@
use std::{
cell::RefCell,
collections::{HashMap, HashSet, VecDeque},
collections::VecDeque,
rc::Rc,
time::{Duration, Instant},
};
use rustc_hash::{FxHashMap, FxHashSet};
use crate::protocols::*;
/// A shared reference to a [`RadixBlock`].
......@@ -30,9 +32,9 @@ pub(crate) type SharedRadixBlock = Rc<RefCell<RadixBlock>>;
#[derive(Debug)]
pub(crate) struct RadixBlock {
/// A map of child blocks, keyed by their local block hash.
pub(crate) children: HashMap<LocalBlockHash, SharedRadixBlock>,
pub(crate) children: FxHashMap<LocalBlockHash, SharedRadixBlock>,
/// The set of workers that have this block cached.
pub(crate) workers: HashSet<WorkerWithDpRank>,
pub(crate) workers: FxHashSet<WorkerWithDpRank>,
/// The external sequence block hash for this block (None for root).
/// This is the same for all workers under the simplifying assumption.
pub(crate) block_hash: Option<ExternalSequenceBlockHash>,
......@@ -48,8 +50,8 @@ impl RadixBlock {
/// A new `RadixBlock` with no block_hash.
pub fn new() -> Self {
Self {
children: HashMap::new(),
workers: HashSet::new(),
children: FxHashMap::default(),
workers: FxHashSet::default(),
block_hash: None,
recent_uses: VecDeque::new(),
}
......@@ -62,8 +64,8 @@ impl RadixBlock {
/// A new `RadixBlock` with the given block_hash.
pub fn with_hash(block_hash: ExternalSequenceBlockHash) -> Self {
Self {
children: HashMap::new(),
workers: HashSet::new(),
children: FxHashMap::default(),
workers: FxHashSet::default(),
block_hash: Some(block_hash),
recent_uses: VecDeque::new(),
}
......@@ -78,7 +80,7 @@ pub struct RadixTree {
/// Per-worker lookup table for O(1) block access.
/// Maps worker -> (block_hash -> block).
pub(crate) lookup:
HashMap<WorkerWithDpRank, HashMap<ExternalSequenceBlockHash, SharedRadixBlock>>,
FxHashMap<WorkerWithDpRank, FxHashMap<ExternalSequenceBlockHash, SharedRadixBlock>>,
/// The time buffer the radix tree should check when considering frequence of block accesses
pub(crate) expiration_duration: Option<Duration>,
......@@ -132,7 +134,7 @@ impl RadixTree {
pub fn new_with_frequency(expiration_duration: Option<Duration>) -> Self {
Self {
root: Rc::new(RefCell::new(RadixBlock::new())),
lookup: HashMap::new(),
lookup: FxHashMap::default(),
expiration_duration,
}
}
......@@ -485,7 +487,7 @@ impl RadixTree {
if keep_worker {
// Re-insert worker with empty blocks map to keep it tracked
self.lookup.insert(worker_key, HashMap::new());
self.lookup.insert(worker_key, FxHashMap::default());
}
}
}
......
......@@ -184,11 +184,7 @@ impl Indexer {
match self {
Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
Indexer::Concurrent(tpi) => tpi.find_matches(sequence).await,
Indexer::None => Ok(OverlapScores {
scores: HashMap::new(),
frequencies: Vec::new(),
tree_sizes: HashMap::new(),
}),
Indexer::None => Ok(OverlapScores::new()),
}
}
......
......@@ -47,9 +47,9 @@ checksum = "5f0e0fee31ef5ed1ba1316088939cea399010ed7731dba877ed44aeb407a75ea"
[[package]]
name = "arc-swap"
version = "1.8.1"
version = "1.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ded5f9a03ac8f24d1b8a25101ee812cd32cdc8c50a4c50237de2c4915850e73"
checksum = "f9f3647c145568cec02c42054e07bdf9a5a698e15b466fb2341bfc393cd24aa5"
dependencies = [
"rustversion",
]
......@@ -1100,9 +1100,9 @@ checksum = "3dcaa9ae7725d12cdb85b3ad99a434db70b468c09ded17e012d86b5c1010f7a7"
[[package]]
name = "futures"
version = "0.3.31"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d"
dependencies = [
"futures-channel",
"futures-core",
......@@ -1115,9 +1115,9 @@ dependencies = [
[[package]]
name = "futures-channel"
version = "0.3.31"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10"
checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d"
dependencies = [
"futures-core",
"futures-sink",
......@@ -1125,15 +1125,15 @@ dependencies = [
[[package]]
name = "futures-core"
version = "0.3.31"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d"
[[package]]
name = "futures-executor"
version = "0.3.31"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f"
checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d"
dependencies = [
"futures-core",
"futures-task",
......@@ -1142,15 +1142,15 @@ dependencies = [
[[package]]
name = "futures-io"
version = "0.3.31"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6"
checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718"
[[package]]
name = "futures-macro"
version = "0.3.31"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b"
dependencies = [
"proc-macro2",
"quote",
......@@ -1159,21 +1159,21 @@ dependencies = [
[[package]]
name = "futures-sink"
version = "0.3.31"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7"
checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893"
[[package]]
name = "futures-task"
version = "0.3.31"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393"
[[package]]
name = "futures-util"
version = "0.3.31"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6"
dependencies = [
"futures-channel",
"futures-core",
......@@ -1183,7 +1183,6 @@ dependencies = [
"futures-task",
"memchr",
"pin-project-lite",
"pin-utils",
"slab",
]
......@@ -3466,9 +3465,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292"
[[package]]
name = "syn"
version = "2.0.115"
version = "2.0.116"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e614ed320ac28113fa64972c4262d5dbc89deacdfd00c34a3e4cea073243c12"
checksum = "3df424c70518695237746f84cede799c9c58fcb37450d7b23716568cc8bc69cb"
dependencies = [
"proc-macro2",
"quote",
......@@ -4068,9 +4067,9 @@ checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142"
[[package]]
name = "unicode-ident"
version = "1.0.23"
version = "1.0.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "537dd038a89878be9b64dd4bd1b260315c1bb94f4d784956b81e27a088d9a09e"
checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75"
[[package]]
name = "unicode-xid"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment