Unverified Commit c5c6a551 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: more flash indexer optimizations (#6305)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarCursor <cursoragent@cursor.com>
parent a8226eb0
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -42,6 +42,7 @@ parking_lot = { workspace = true } ...@@ -42,6 +42,7 @@ parking_lot = { workspace = true }
clap = { version = "4.5", features = ["derive"], optional = true } clap = { version = "4.5", features = ["derive"], optional = true }
indicatif = { version = "0.18.0", optional = true } indicatif = { version = "0.18.0", optional = true }
uuid = { workspace = true, optional = true } uuid = { workspace = true, optional = true }
rustc-hash = "2.1.1"
[dev-dependencies] [dev-dependencies]
dynamo-bench = { path = "../bench" } dynamo-bench = { path = "../bench" }
...@@ -53,6 +54,7 @@ dynamo-mocker = { workspace = true } ...@@ -53,6 +54,7 @@ dynamo-mocker = { workspace = true }
dynamo-tokens = { workspace = true } dynamo-tokens = { workspace = true }
minstant = "0.1.7" minstant = "0.1.7"
futures = "0.3" futures = "0.3"
plotters = "0.3"
[[bench]] [[bench]]
......
...@@ -7,7 +7,9 @@ use dynamo_kv_router::indexer::{ ...@@ -7,7 +7,9 @@ use dynamo_kv_router::indexer::{
KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvIndexerSharded, KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvIndexerSharded,
}; };
use dynamo_kv_router::protocols::{RouterEvent, XXH3_SEED}; use dynamo_kv_router::protocols::{RouterEvent, XXH3_SEED};
use dynamo_kv_router::{ConcurrentRadixTree, PositionalIndexer, ThreadPoolIndexer}; use dynamo_kv_router::{
ConcurrentRadixTree, InvertedIndex, NaiveNestedMap, PositionalIndexer, ThreadPoolIndexer,
};
use dynamo_tokens::compute_hash_v2; use dynamo_tokens::compute_hash_v2;
use rand::prelude::*; use rand::prelude::*;
use std::fs::File; use std::fs::File;
...@@ -24,6 +26,7 @@ use tokio::task::JoinHandle; ...@@ -24,6 +26,7 @@ use tokio::task::JoinHandle;
use tokio::time::{Duration, Instant}; use tokio::time::{Duration, Instant};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use plotters::prelude::*;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// Indexer backend selection and its backend-specific parameters. /// Indexer backend selection and its backend-specific parameters.
...@@ -56,6 +59,17 @@ enum IndexerArgs { ...@@ -56,6 +59,17 @@ enum IndexerArgs {
#[clap(long, default_value = "16")] #[clap(long, default_value = "16")]
num_event_workers: usize, num_event_workers: usize,
}, },
/// Naive per-worker nested HashMap indexer behind a single-threaded actor
/// (blog section 2).
NaiveNestedMap {},
/// Inverted index keyed by local_hash (blog section 3).
InvertedIndex {
/// Number of OS threads that consume and apply KV cache events.
#[clap(long, default_value = "16")]
num_event_workers: usize,
},
} }
impl IndexerArgs { impl IndexerArgs {
...@@ -88,8 +102,39 @@ impl IndexerArgs { ...@@ -88,8 +102,39 @@ impl IndexerArgs {
args.block_size, args.block_size,
)) ))
} }
IndexerArgs::NaiveNestedMap {} => Arc::new(NaiveNestedMap::new()),
IndexerArgs::InvertedIndex { .. } => Arc::new(InvertedIndex::new()),
} }
} }
/// Construct an indexer from a short name string, using `args.num_event_workers`.
fn from_name(
name: &str,
args: &Args,
) -> anyhow::Result<Arc<dyn KvIndexerInterface + Send + Sync>> {
let nw = args.num_event_workers;
let indexer_args = match name {
"radix-tree" => IndexerArgs::RadixTree {},
"radix-tree-sharded" => IndexerArgs::RadixTreeSharded { num_shards: 4 },
"nested-map" => IndexerArgs::NestedMap {
jump_size: 8,
num_event_workers: nw,
},
"concurrent-radix-tree" => IndexerArgs::ConcurrentRadixTree {
num_event_workers: nw,
},
"naive-nested-map" => IndexerArgs::NaiveNestedMap {},
"inverted-index" => IndexerArgs::InvertedIndex {
num_event_workers: 0,
},
_ => anyhow::bail!(
"Unknown indexer '{}'. Valid names: radix-tree, radix-tree-sharded, \
nested-map, concurrent-radix-tree, naive-nested-map, inverted-index",
name
),
};
Ok(indexer_args.build(args))
}
} }
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
...@@ -106,7 +151,7 @@ struct Args { ...@@ -106,7 +151,7 @@ struct Args {
/// Number of GPU blocks available in the mock engine's KV cache. /// Number of GPU blocks available in the mock engine's KV cache.
/// Smaller values force more evictions and produce more remove events. /// Smaller values force more evictions and produce more remove events.
#[clap(long, default_value = "2048")] #[clap(long, default_value = "1048576")]
num_gpu_blocks: usize, num_gpu_blocks: usize,
/// Number of tokens per KV cache block. /// Number of tokens per KV cache block.
...@@ -126,7 +171,7 @@ struct Args { ...@@ -126,7 +171,7 @@ struct Args {
/// Number of unique simulated inference workers. Each gets a random /// Number of unique simulated inference workers. Each gets a random
/// partition of the trace and its own mock engine for event generation. /// partition of the trace and its own mock engine for event generation.
#[clap(short, long, default_value = "64")] #[clap(short, long, default_value = "256")]
num_unique_inference_workers: usize, num_unique_inference_workers: usize,
/// How many times to duplicate the set of unique workers during the /// How many times to duplicate the set of unique workers during the
...@@ -152,6 +197,41 @@ struct Args { ...@@ -152,6 +197,41 @@ struct Args {
#[clap(long, default_value = "42")] #[clap(long, default_value = "42")]
seed: u64, seed: u64,
/// Enable throughput vs p99 latency sweep mode. Runs the benchmark at
/// multiple benchmark_duration_ms values and plots the results.
#[clap(long)]
sweep: bool,
/// Minimum benchmark duration (ms) for sweep mode.
#[clap(long, default_value = "1000")]
sweep_min_ms: u64,
/// Maximum benchmark duration (ms) for sweep mode.
#[clap(long, default_value = "50000")]
sweep_max_ms: u64,
/// Number of logarithmically spaced sweep steps between min and max.
#[clap(long, default_value = "10")]
sweep_steps: usize,
/// Output path for the sweep plot PNG.
#[clap(long, default_value = "sweep_plot.png")]
sweep_output: String,
/// Comma-separated list of indexer names to benchmark and compare on the
/// same plot. Overrides the subcommand indexer when present. Valid names:
/// radix-tree, radix-tree-sharded, nested-map, concurrent-radix-tree,
/// naive-nested-map, inverted-index.
#[clap(long, value_delimiter = ',')]
compare: Vec<String>,
/// Number of OS threads for event processing in compare mode. Applies to
/// indexers that use a thread pool (nested-map, concurrent-radix-tree,
/// inverted-index). Ignored by radix-tree, radix-tree-sharded, and
/// naive-nested-map.
#[clap(long, default_value = "16")]
num_event_workers: usize,
/// Indexer backend to benchmark (defaults to radix-tree if not specified). /// Indexer backend to benchmark (defaults to radix-tree if not specified).
#[clap(subcommand)] #[clap(subcommand)]
indexer: Option<IndexerArgs>, indexer: Option<IndexerArgs>,
...@@ -340,6 +420,7 @@ fn duplicate_traces(requests: Vec<MooncakeRequest>, factor: usize) -> Vec<Moonca ...@@ -340,6 +420,7 @@ fn duplicate_traces(requests: Vec<MooncakeRequest>, factor: usize) -> Vec<Moonca
for d in 0..factor { for d in 0..factor {
let offset = offset_base * d as u64; let offset = offset_base * d as u64;
out.push(MooncakeRequest { out.push(MooncakeRequest {
uuid: Uuid::new_v4(),
hash_ids: r.hash_ids.iter().map(|&h| h + offset).collect(), hash_ids: r.hash_ids.iter().map(|&h| h + offset).collect(),
..r.clone() ..r.clone()
}); });
...@@ -513,7 +594,9 @@ async fn generate_events( ...@@ -513,7 +594,9 @@ async fn generate_events(
fn prepare_worker_traces( fn prepare_worker_traces(
traces: Vec<Vec<MooncakeRequest>>, traces: Vec<Vec<MooncakeRequest>>,
events: Vec<Vec<(KvCacheEvent, Instant)>>, events: Vec<Vec<(KvCacheEvent, Instant)>>,
args: &Args, block_size: u32,
benchmark_duration_ms: u64,
trace_simulation_duration_ms: u64,
) -> Vec<Vec<WorkerTrace>> { ) -> Vec<Vec<WorkerTrace>> {
assert!(traces.len() == events.len()); assert!(traces.len() == events.len());
...@@ -525,13 +608,13 @@ fn prepare_worker_traces( ...@@ -525,13 +608,13 @@ fn prepare_worker_traces(
trace trace
.into_iter() .into_iter()
.map(|request| WorkerTrace { .map(|request| WorkerTrace {
timestamp_us: request.timestamp * 1000 * args.benchmark_duration_ms timestamp_us: request.timestamp * 1000 * benchmark_duration_ms
/ trace_duration_ms, / trace_duration_ms,
entry: WorkerTraceEntry::Request( entry: WorkerTraceEntry::Request(
request request
.hash_ids .hash_ids
.iter() .iter()
.map(|id| local_block_hash_from_id(*id, args.block_size)) .map(|id| local_block_hash_from_id(*id, block_size))
.collect(), .collect(),
), ),
}) })
...@@ -547,8 +630,8 @@ fn prepare_worker_traces( ...@@ -547,8 +630,8 @@ fn prepare_worker_traces(
.into_iter() .into_iter()
.map(|(event, timestamp)| WorkerTrace { .map(|(event, timestamp)| WorkerTrace {
timestamp_us: (timestamp - start_instant).as_micros() as u64 timestamp_us: (timestamp - start_instant).as_micros() as u64
* args.benchmark_duration_ms * benchmark_duration_ms
/ args.trace_simulation_duration_ms, / trace_simulation_duration_ms,
entry: WorkerTraceEntry::Event(event), entry: WorkerTraceEntry::Event(event),
}) })
.collect::<Vec<_>>() .collect::<Vec<_>>()
...@@ -569,6 +652,15 @@ fn prepare_worker_traces( ...@@ -569,6 +652,15 @@ fn prepare_worker_traces(
.collect() .collect()
} }
/// Results from a single benchmark run.
struct BenchmarkResults {
offered_ops_throughput: f32,
ops_throughput: f32,
offered_block_throughput: f32,
block_throughput: f32,
latency_p99_us: f32,
}
/// Run the benchmark: replay each worker's merged trace against the indexer, /// Run the benchmark: replay each worker's merged trace against the indexer,
/// measuring find_matches latency and event processing throughput. /// measuring find_matches latency and event processing throughput.
/// ///
...@@ -580,8 +672,15 @@ async fn run_benchmark( ...@@ -580,8 +672,15 @@ async fn run_benchmark(
traces: Vec<Vec<MooncakeRequest>>, traces: Vec<Vec<MooncakeRequest>>,
events: Vec<Vec<(KvCacheEvent, Instant)>>, events: Vec<Vec<(KvCacheEvent, Instant)>>,
args: &Args, args: &Args,
) -> anyhow::Result<()> { benchmark_duration_ms: u64,
let worker_traces = prepare_worker_traces(traces, events, args); ) -> anyhow::Result<BenchmarkResults> {
let worker_traces = prepare_worker_traces(
traces,
events,
args.block_size,
benchmark_duration_ms,
args.trace_simulation_duration_ms,
);
let worker_traces = worker_traces let worker_traces = worker_traces
.into_iter() .into_iter()
.map(|trace| Arc::new(trace)) .map(|trace| Arc::new(trace))
...@@ -680,21 +779,13 @@ async fn run_benchmark( ...@@ -680,21 +779,13 @@ async fn run_benchmark(
latencies.extend(task.await??); latencies.extend(task.await??);
} }
if progress.elapsed() > Duration::from_millis(args.benchmark_duration_ms * 11 / 10) { if progress.elapsed() > Duration::from_millis(benchmark_duration_ms * 11 / 10) {
eprintln!( eprintln!(
"WARNING: The benchmarker is unable to keep up with the request/event generation rate. Rerun with a larger --benchmark-duration-ms." "WARNING: The benchmarker is unable to keep up with the request/event generation rate. Rerun with a larger --benchmark-duration-ms."
) )
} }
println!("Flushing event queue..."); let total_duration = progress.elapsed();
let request_duration = progress.elapsed();
let flush_start = Instant::now();
let flush_size = indexer.flush().await;
let flush_duration = flush_start.elapsed();
let event_duration = progress.elapsed();
let total_events = worker_traces let total_events = worker_traces
.iter() .iter()
...@@ -711,48 +802,154 @@ async fn run_benchmark( ...@@ -711,48 +802,154 @@ async fn run_benchmark(
* args.inference_worker_duplication_factor * args.inference_worker_duplication_factor
- total_events; - total_events;
let event_queue_flush_percentage = flush_size as f32 / total_events as f32 * 100.0; let total_request_blocks: usize = worker_traces
.iter()
.flat_map(|t| t.iter())
.filter_map(|entry| match &entry.entry {
WorkerTraceEntry::Request(hashes) => Some(hashes.len()),
_ => None,
})
.sum::<usize>()
* args.inference_worker_duplication_factor;
println!("Event queue flush duration: {:?}", flush_duration); let total_event_blocks: usize = worker_traces
println!( .iter()
"Event queue flush size: {} ({}% of total events)", .flat_map(|t| t.iter())
flush_size, event_queue_flush_percentage .filter_map(|entry| match &entry.entry {
); WorkerTraceEntry::Event(ev) => match &ev.data {
KvCacheEventData::Stored(s) => Some(s.blocks.len()),
_ => Some(0),
},
_ => None,
})
.sum::<usize>()
* args.inference_worker_duplication_factor;
if event_queue_flush_percentage > 5.0 { let total_blocks = total_request_blocks + total_event_blocks;
eprintln!(
"ERROR: Over 5% of events were unable to be completed within the benchmark duration.
Results are invalid. Rerun with a smaller trace or less worker duplication."
);
}
println!( let total_ops = total_requests + total_events;
"Request Throughput: {} req/s", let offered_ops_throughput = total_ops as f32 / benchmark_duration_ms as f32 * 1000.0;
total_requests as f32 / request_duration.as_millis() as f32 * 1000.0 let ops_throughput = total_ops as f32 / total_duration.as_millis() as f32 * 1000.0;
); let offered_block_throughput = total_blocks as f32 / benchmark_duration_ms as f32 * 1000.0;
println!( let block_throughput = total_blocks as f32 / total_duration.as_millis() as f32 * 1000.0;
"Event Throughput: {} events/s",
total_events as f32 / event_duration.as_millis() as f32 * 1000.0
);
latencies.sort_unstable(); latencies.sort_unstable();
let latency_p99_us = latencies[latencies.len() * 99 / 100] as f32 / 1000.0;
println!( println!(
"Latency p50: {}us", "Ops Throughput: {} ops/s (requests + events)",
latencies[latencies.len() / 2] as f32 / 1000.0 ops_throughput
);
println!(
"Latency p95: {}us",
latencies[latencies.len() * 95 / 100] as f32 / 1000.0
);
println!(
"Latency p99: {}us",
latencies[latencies.len() * 99 / 100] as f32 / 1000.0
);
println!(
"Latency max: {}us",
*latencies.last().unwrap() as f32 / 1000.0
); );
println!("Block Throughput: {} block ops/s", block_throughput);
println!("Latency p99: {}us", latency_p99_us);
Ok(BenchmarkResults {
offered_ops_throughput,
ops_throughput,
offered_block_throughput,
block_throughput,
latency_p99_us,
})
}
fn plot_sweep(
all_results: &[(&str, Vec<(u64, BenchmarkResults)>)],
output_path: &str,
) -> anyhow::Result<()> {
use plotters::coord::combinators::IntoLogRange;
use plotters::element::DashedPathElement;
use plotters::style::ShapeStyle;
let colors = [
RGBColor(31, 119, 180),
RGBColor(255, 127, 14),
RGBColor(44, 160, 44),
RGBColor(214, 39, 40),
RGBColor(148, 103, 189),
RGBColor(140, 86, 75),
];
let mut global_min = f64::MAX;
let mut global_max = f64::MIN;
for (_, results) in all_results {
for (_, r) in results {
let offered = r.offered_block_throughput as f64;
let achieved = r.block_throughput as f64;
global_min = global_min.min(offered).min(achieved);
global_max = global_max.max(offered).max(achieved);
}
}
let axis_min = global_min * 0.9;
let axis_max = global_max * 1.1;
let root = BitMapBackend::new(output_path, (800, 600)).into_drawing_area();
root.fill(&WHITE)?;
let mut chart = ChartBuilder::on(&root)
.caption(
"Achieved vs Offered Throughput",
("sans-serif", 22).into_font(),
)
.margin(20)
.x_label_area_size(40)
.y_label_area_size(80)
.build_cartesian_2d(
(axis_min..axis_max).log_scale(),
(axis_min..axis_max).log_scale(),
)?;
chart
.configure_mesh()
.x_desc("Offered Throughput (block ops/s)")
.y_desc("Achieved Throughput (block ops/s)")
.draw()?;
let identity_style = ShapeStyle::from(&BLACK.mix(0.4)).stroke_width(1);
chart.draw_series(std::iter::once(DashedPathElement::new(
vec![(axis_min, axis_min), (axis_max, axis_max)],
5,
3,
identity_style,
)))?;
for (i, (name, results)) in all_results.iter().enumerate() {
let color = &colors[i % colors.len()];
let points: Vec<(f64, f64)> = results
.iter()
.map(|(_, r)| (r.offered_block_throughput as f64, r.block_throughput as f64))
.collect();
let series_color = *color;
chart
.draw_series(LineSeries::new(
points.iter().map(|&(x, y)| (x, y)),
&series_color,
))?
.label(*name)
.legend(move |(x, y)| {
plotters::element::PathElement::new(
vec![(x, y), (x + 20, y)],
series_color.stroke_width(2),
)
});
chart.draw_series(
points
.iter()
.map(|&(x, y)| Circle::new((x, y), 4, series_color.filled())),
)?;
}
chart
.configure_series_labels()
.position(SeriesLabelPosition::LowerRight)
.background_style(WHITE.mix(0.8))
.border_style(BLACK)
.draw()?;
root.present()?;
println!("Sweep plot saved to {}", output_path);
Ok(()) Ok(())
} }
...@@ -841,12 +1038,93 @@ async fn main() -> anyhow::Result<()> { ...@@ -841,12 +1038,93 @@ async fn main() -> anyhow::Result<()> {
} }
let traces = process_mooncake_trace(&args)?; let traces = process_mooncake_trace(&args)?;
let events = generate_events(&traces, &args).await?; let events = generate_events(&traces, &args).await?;
let indexer = args.get_indexer().build(&args); let indexer_names: Vec<String> = if args.compare.is_empty() {
let name = match args.get_indexer() {
IndexerArgs::RadixTree {} => "radix-tree",
IndexerArgs::RadixTreeSharded { .. } => "radix-tree-sharded",
IndexerArgs::NestedMap { .. } => "nested-map",
IndexerArgs::ConcurrentRadixTree { .. } => "concurrent-radix-tree",
IndexerArgs::NaiveNestedMap {} => "naive-nested-map",
IndexerArgs::InvertedIndex { .. } => "inverted-index",
};
vec![name.to_string()]
} else {
args.compare.clone()
};
if args.sweep {
let log_min = (args.sweep_min_ms as f64).ln();
let log_max = (args.sweep_max_ms as f64).ln();
let n = args.sweep_steps;
let durations: Vec<u64> = (0..n)
.map(|i| {
let t = i as f64 / (n - 1) as f64;
(log_max * (1.0 - t) + log_min * t).exp().round() as u64
})
.collect();
let mut all_results: Vec<(&str, Vec<(u64, BenchmarkResults)>)> = Vec::new();
for name in &indexer_names {
println!("\n{}", "=".repeat(60));
println!("Benchmarking indexer: {}", name);
println!("{}", "=".repeat(60));
let mut results: Vec<(u64, BenchmarkResults)> = Vec::new();
for &dur_ms in &durations {
println!("\n=== Sweep: benchmark_duration_ms = {} ===", dur_ms);
let indexer = if args.compare.is_empty() {
args.get_indexer().build(&args)
} else {
IndexerArgs::from_name(name, &args)?
};
let result =
run_benchmark(indexer, traces.clone(), events.clone(), &args, dur_ms).await?;
results.push((dur_ms, result));
}
println!("\n=== Sweep Summary: {} ===", name);
println!(
"{:>12} {:>14} {:>14} {:>14} {:>14} {:>10}",
"duration_ms", "ops/s_off", "ops/s", "blk_ops/s_off", "blk_ops/s", "p99(us)"
);
for (dur, r) in &results {
println!(
"{:>12} {:>14.1} {:>14.1} {:>14.1} {:>14.1} {:>10.1}",
dur,
r.offered_ops_throughput,
r.ops_throughput,
r.offered_block_throughput,
r.block_throughput,
r.latency_p99_us,
);
}
run_benchmark(indexer, traces, events, &args).await?; all_results.push((name, results));
}
plot_sweep(&all_results, &args.sweep_output)?;
} else {
for name in &indexer_names {
println!("\nBenchmarking indexer: {}", name);
let indexer = if args.compare.is_empty() {
args.get_indexer().build(&args)
} else {
IndexerArgs::from_name(name, &args)?
};
run_benchmark(
indexer,
traces.clone(),
events.clone(),
&args,
args.benchmark_duration_ms,
)
.await?;
}
}
Ok(()) Ok(())
} }
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
//! //!
//! Unlike `RadixTree` which uses `Rc<RefCell<>>` and requires single-threaded access, //! Unlike `RadixTree` which uses `Rc<RefCell<>>` and requires single-threaded access,
//! `ConcurrentRadixTree` uses `Arc<RwLock<>>` per node and a //! `ConcurrentRadixTree` uses `Arc<RwLock<>>` per node and a
//! `DashMap<..., RwLock<HashMap<...>>>` for the lookup table. //! `RwLock<FxHashMap<..., RwLock<FxHashMap<...>>>>` for the lookup table.
//! //!
//! # Limitations vs RadixTree //! # Limitations vs RadixTree
//! //!
...@@ -20,17 +20,15 @@ ...@@ -20,17 +20,15 @@
//! //!
//! - Multiple `find_matches` can run in parallel (read locks only) //! - Multiple `find_matches` can run in parallel (read locks only)
//! - Write operations (`apply_event`, `remove_worker`) acquire write locks //! - Write operations (`apply_event`, `remove_worker`) acquire write locks
//! - The outer `DashMap` distributes contention across shards; inner `RwLock` //! - Outer `RwLock` is read-locked on the hot path; structural mutations
//! per worker allows per-worker write concurrency. //! (adding/removing workers) are rare. Inner `RwLock` per worker allows
//! per-worker write concurrency.
//! - Deadlock prevention: always lock parent before child, hand-over-hand locking //! - Deadlock prevention: always lock parent before child, hand-over-hand locking
use std::{ use std::{collections::VecDeque, sync::Arc};
collections::{HashMap, HashSet, VecDeque},
sync::Arc,
};
use dashmap::DashMap;
use parking_lot::RwLock; use parking_lot::RwLock;
use rustc_hash::{FxHashMap, FxHashSet};
use crate::indexer::SyncIndexer; use crate::indexer::SyncIndexer;
use crate::protocols::*; use crate::protocols::*;
...@@ -38,13 +36,16 @@ use crate::protocols::*; ...@@ -38,13 +36,16 @@ use crate::protocols::*;
/// Thread-safe shared reference to a Block. /// Thread-safe shared reference to a Block.
type SharedBlock = Arc<RwLock<Block>>; type SharedBlock = Arc<RwLock<Block>>;
/// Per-worker block-hash map. Inner RwLock allows concurrent reads of different workers.
type WorkerLookup = FxHashMap<ExternalSequenceBlockHash, SharedBlock>;
/// A block in the concurrent radix tree. /// A block in the concurrent radix tree.
#[derive(Debug)] #[derive(Debug)]
struct Block { struct Block {
/// A map of child blocks, keyed by their local block hash. /// A map of child blocks, keyed by their local block hash.
children: HashMap<LocalBlockHash, SharedBlock>, children: FxHashMap<LocalBlockHash, SharedBlock>,
/// The set of workers that have this block cached. /// The set of workers that have this block cached.
workers: HashSet<WorkerWithDpRank>, workers: FxHashSet<WorkerWithDpRank>,
/// The external sequence block hash for this block (None for root). /// The external sequence block hash for this block (None for root).
block_hash: Option<ExternalSequenceBlockHash>, block_hash: Option<ExternalSequenceBlockHash>,
// NOTE: No recent_uses field. // NOTE: No recent_uses field.
...@@ -55,8 +56,8 @@ impl Block { ...@@ -55,8 +56,8 @@ impl Block {
/// Create a new `Block` (used for root node). /// Create a new `Block` (used for root node).
fn new() -> Self { fn new() -> Self {
Self { Self {
children: HashMap::new(), children: FxHashMap::default(),
workers: HashSet::new(), workers: FxHashSet::default(),
block_hash: None, block_hash: None,
} }
} }
...@@ -64,8 +65,8 @@ impl Block { ...@@ -64,8 +65,8 @@ impl Block {
/// Create a new `Block` with a specific block hash. /// Create a new `Block` with a specific block hash.
fn with_hash(block_hash: ExternalSequenceBlockHash) -> Self { fn with_hash(block_hash: ExternalSequenceBlockHash) -> Self {
Self { Self {
children: HashMap::new(), children: FxHashMap::default(),
workers: HashSet::new(), workers: FxHashSet::default(),
block_hash: Some(block_hash), block_hash: Some(block_hash),
} }
} }
...@@ -75,7 +76,7 @@ impl Block { ...@@ -75,7 +76,7 @@ impl Block {
/// ///
/// Unlike `RadixTree` which uses `Rc<RefCell<>>` and requires single-threaded access, /// Unlike `RadixTree` which uses `Rc<RefCell<>>` and requires single-threaded access,
/// `ConcurrentRadixTree` uses `Arc<RwLock<>>` per node and a /// `ConcurrentRadixTree` uses `Arc<RwLock<>>` per node and a
/// `DashMap<..., RwLock<HashMap<...>>>` for the lookup table, /// `RwLock<FxHashMap<..., RwLock<FxHashMap<...>>>>` for the lookup table,
/// enabling concurrent `find_matches` operations. /// enabling concurrent `find_matches` operations.
/// ///
/// # Limitations vs RadixTree /// # Limitations vs RadixTree
...@@ -88,8 +89,9 @@ impl Block { ...@@ -88,8 +89,9 @@ impl Block {
/// ///
/// - Multiple `find_matches` can run in parallel (read locks only) /// - Multiple `find_matches` can run in parallel (read locks only)
/// - Write operations (`apply_event`, `remove_worker`) acquire write locks /// - Write operations (`apply_event`, `remove_worker`) acquire write locks
/// - The outer `DashMap` distributes contention across shards; inner `RwLock` /// - Outer RwLock is read-locked on the hot path; structural mutations
/// per worker allows per-worker write concurrency. /// (adding/removing workers) are rare and take a write lock.
/// - Inner `RwLock` per worker allows per-worker write concurrency.
/// - Deadlock prevention: always lock parent before child, hand-over-hand locking /// - Deadlock prevention: always lock parent before child, hand-over-hand locking
pub struct ConcurrentRadixTree { pub struct ConcurrentRadixTree {
/// This is the root of the radix/prefix tree. /// This is the root of the radix/prefix tree.
...@@ -97,9 +99,9 @@ pub struct ConcurrentRadixTree { ...@@ -97,9 +99,9 @@ pub struct ConcurrentRadixTree {
root: SharedBlock, root: SharedBlock,
/// Per-worker lookup table for O(1) block access. /// Per-worker lookup table for O(1) block access.
/// Outer `DashMap` distributes lock contention across shards; inner `RwLock` /// Outer RwLock protects the worker map structure (rarely mutated);
/// per worker protects that worker's block-hash map. /// inner RwLock per worker protects that worker's block-hash map.
lookup: DashMap<WorkerWithDpRank, RwLock<HashMap<ExternalSequenceBlockHash, SharedBlock>>>, lookup: RwLock<FxHashMap<WorkerWithDpRank, RwLock<WorkerLookup>>>,
} }
impl Default for ConcurrentRadixTree { impl Default for ConcurrentRadixTree {
...@@ -121,15 +123,11 @@ impl Drop for ConcurrentRadixTree { ...@@ -121,15 +123,11 @@ impl Drop for ConcurrentRadixTree {
} }
// Remove all lookup references (they may include blocks not reachable from root). // Remove all lookup references (they may include blocks not reachable from root).
// We have &mut self so no concurrent access; drain the DashMap by clearing it // We have &mut self so no concurrent access; drain the map.
// after collecting all inner values. let lookup = self.lookup.get_mut();
let entries: Vec<_> = self for (_, inner_lock) in lookup.drain() {
.lookup stack.extend(inner_lock.into_inner().into_values());
.iter() }
.flat_map(|entry| entry.value().read().values().cloned().collect::<Vec<_>>())
.collect();
stack.extend(entries);
self.lookup.clear();
// Iteratively free any uniquely-owned blocks without recursion // Iteratively free any uniquely-owned blocks without recursion
while let Some(block) = stack.pop() { while let Some(block) = stack.pop() {
...@@ -146,7 +144,7 @@ impl ConcurrentRadixTree { ...@@ -146,7 +144,7 @@ impl ConcurrentRadixTree {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
root: Arc::new(RwLock::new(Block::new())), root: Arc::new(RwLock::new(Block::new())),
lookup: DashMap::new(), lookup: RwLock::new(FxHashMap::default()),
} }
} }
...@@ -199,8 +197,9 @@ impl ConcurrentRadixTree { ...@@ -199,8 +197,9 @@ impl ConcurrentRadixTree {
for worker in &active { for worker in &active {
scores.scores.insert(*worker, 1); scores.scores.insert(*worker, 1);
} }
let lk = self.lookup.read();
for worker in scores.scores.keys() { for worker in scores.scores.keys() {
if let Some(inner_lock) = self.lookup.get(worker) { if let Some(inner_lock) = lk.get(worker) {
scores.tree_sizes.insert(*worker, inner_lock.read().len()); scores.tree_sizes.insert(*worker, inner_lock.read().len());
} }
} }
...@@ -234,25 +233,10 @@ impl ConcurrentRadixTree { ...@@ -234,25 +233,10 @@ impl ConcurrentRadixTree {
let guard = block.read(); let guard = block.read();
let child_count = guard.workers.len(); let child_count = guard.workers.len();
if child_count < active_count { if child_count != active_count {
// Workers dropped out. Record scores for those that left. // Workers changed: either dropped out (child < active) or
// Score = matched_depth (number of nodes they were present at). // stale entries exist (child > active). In both cases,
for worker in &active { // retain only workers present in the child, scoring dropouts.
if !guard.workers.contains(worker) {
scores.scores.insert(*worker, matched_depth);
}
}
active.clone_from(&guard.workers);
active_count = child_count;
if active_count == 0 {
break;
}
} else if child_count > active_count {
// child_count > active_count means stale entries exist
// (child retains workers already removed from an ancestor).
// Fall back to full membership check: keep only workers
// present in both active and this child, scoring dropouts.
active.retain(|w| { active.retain(|w| {
if guard.workers.contains(w) { if guard.workers.contains(w) {
true true
...@@ -288,8 +272,9 @@ impl ConcurrentRadixTree { ...@@ -288,8 +272,9 @@ impl ConcurrentRadixTree {
} }
// Get tree sizes from lookup. // Get tree sizes from lookup.
let lk = self.lookup.read();
for worker in scores.scores.keys() { for worker in scores.scores.keys() {
if let Some(inner_lock) = self.lookup.get(worker) { if let Some(inner_lock) = lk.get(worker) {
scores.tree_sizes.insert(*worker, inner_lock.read().len()); scores.tree_sizes.insert(*worker, inner_lock.read().len());
} }
} }
...@@ -330,14 +315,15 @@ impl ConcurrentRadixTree { ...@@ -330,14 +315,15 @@ impl ConcurrentRadixTree {
id: u64, id: u64,
) -> Result<(), KvCacheEventError> { ) -> Result<(), KvCacheEventError> {
// Ensure this worker has an entry in the outer map. // Ensure this worker has an entry in the outer map.
if !self.lookup.contains_key(&worker) { if !self.lookup.read().contains_key(&worker) {
self.lookup self.lookup
.write()
.entry(worker) .entry(worker)
.or_insert_with(|| RwLock::new(HashMap::new())); .or_insert_with(|| RwLock::new(FxHashMap::default()));
} }
let inner_ref = self.lookup.get(&worker).unwrap(); let lk = self.lookup.read();
let mut worker_lookup = inner_ref.write(); let mut worker_lookup = lk.get(&worker).unwrap().write();
// Find parent block // Find parent block
let mut current = match op.parent_hash { let mut current = match op.parent_hash {
...@@ -435,7 +421,8 @@ impl ConcurrentRadixTree { ...@@ -435,7 +421,8 @@ impl ConcurrentRadixTree {
op: KvCacheRemoveData, op: KvCacheRemoveData,
id: u64, id: u64,
) -> Result<(), KvCacheEventError> { ) -> Result<(), KvCacheEventError> {
let Some(inner_ref) = self.lookup.get(&worker) else { let lk = self.lookup.read();
let Some(inner_ref) = lk.get(&worker) else {
return Err(KvCacheEventError::BlockNotFound); return Err(KvCacheEventError::BlockNotFound);
}; };
let mut worker_lookup = inner_ref.write(); let mut worker_lookup = inner_ref.write();
...@@ -467,17 +454,17 @@ impl ConcurrentRadixTree { ...@@ -467,17 +454,17 @@ impl ConcurrentRadixTree {
/// If `keep_worker` is true, the worker remains in lookup with empty blocks. /// If `keep_worker` is true, the worker remains in lookup with empty blocks.
/// If `keep_worker` is false, the worker is completely removed from lookup. /// If `keep_worker` is false, the worker is completely removed from lookup.
fn remove_or_clear_worker_blocks(&self, worker_id: WorkerId, keep_worker: bool) { fn remove_or_clear_worker_blocks(&self, worker_id: WorkerId, keep_worker: bool) {
// Collect all WorkerWithDpRank keys that match this worker_id
let workers: Vec<WorkerWithDpRank> = self let workers: Vec<WorkerWithDpRank> = self
.lookup .lookup
.iter() .read()
.filter(|entry| entry.key().worker_id == worker_id) .keys()
.map(|entry| *entry.key()) .filter(|w| w.worker_id == worker_id)
.copied()
.collect(); .collect();
let mut lk = self.lookup.write();
for worker in workers { for worker in workers {
if let Some((_, inner_lock)) = self.lookup.remove(&worker) { if let Some(inner_lock) = lk.remove(&worker) {
// We now own the inner RwLock; extract the HashMap.
let blocks = inner_lock.into_inner(); let blocks = inner_lock.into_inner();
for (_, block) in blocks { for (_, block) in blocks {
let mut guard = block.write(); let mut guard = block.write();
...@@ -488,7 +475,7 @@ impl ConcurrentRadixTree { ...@@ -488,7 +475,7 @@ impl ConcurrentRadixTree {
} }
if keep_worker { if keep_worker {
self.lookup.insert(worker, RwLock::new(HashMap::new())); lk.insert(worker, RwLock::new(FxHashMap::default()));
} }
} }
} }
...@@ -509,9 +496,10 @@ impl ConcurrentRadixTree { ...@@ -509,9 +496,10 @@ impl ConcurrentRadixTree {
pub fn get_workers(&self) -> Vec<WorkerId> { pub fn get_workers(&self) -> Vec<WorkerId> {
let mut worker_ids: Vec<WorkerId> = self let mut worker_ids: Vec<WorkerId> = self
.lookup .lookup
.iter() .read()
.map(|entry| entry.key().worker_id) .keys()
.collect::<HashSet<_>>() .map(|w| w.worker_id)
.collect::<FxHashSet<_>>()
.into_iter() .into_iter()
.collect(); .collect();
worker_ids.sort_unstable(); worker_ids.sort_unstable();
...@@ -524,7 +512,7 @@ impl ConcurrentRadixTree { ...@@ -524,7 +512,7 @@ impl ConcurrentRadixTree {
pub fn dump_tree_as_events(&self) -> Vec<RouterEvent> { pub fn dump_tree_as_events(&self) -> Vec<RouterEvent> {
tracing::debug!( tracing::debug!(
"Dumping concurrent radix tree as events (contains information about {:?} workers)", "Dumping concurrent radix tree as events (contains information about {:?} workers)",
self.lookup.len() self.lookup.read().len()
); );
let mut events = Vec::new(); let mut events = Vec::new();
...@@ -583,8 +571,9 @@ impl ConcurrentRadixTree { ...@@ -583,8 +571,9 @@ impl ConcurrentRadixTree {
/// Get total number of blocks across all workers. /// Get total number of blocks across all workers.
pub fn current_size(&self) -> usize { pub fn current_size(&self) -> usize {
self.lookup self.lookup
.iter() .read()
.map(|entry| entry.value().read().len()) .values()
.map(|inner| inner.read().len())
.sum() .sum()
} }
} }
...@@ -641,9 +630,10 @@ mod tests { ...@@ -641,9 +630,10 @@ mod tests {
&3 &3
); );
assert_eq!(trie.lookup.len(), 1); assert_eq!(trie.lookup.read().len(), 1);
assert_eq!( assert_eq!(
trie.lookup trie.lookup
.read()
.get(&WorkerWithDpRank::from_worker_id(worker_1)) .get(&WorkerWithDpRank::from_worker_id(worker_1))
.unwrap() .unwrap()
.read() .read()
...@@ -673,7 +663,7 @@ mod tests { ...@@ -673,7 +663,7 @@ mod tests {
&1 &1
); );
assert_eq!(trie.lookup.len(), 2); assert_eq!(trie.lookup.read().len(), 2);
} }
#[test] #[test]
...@@ -693,6 +683,7 @@ mod tests { ...@@ -693,6 +683,7 @@ mod tests {
assert_eq!( assert_eq!(
trie.lookup trie.lookup
.read()
.get(&WorkerWithDpRank::from_worker_id(worker_2)) .get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap() .unwrap()
.read() .read()
...@@ -705,6 +696,7 @@ mod tests { ...@@ -705,6 +696,7 @@ mod tests {
assert_eq!( assert_eq!(
trie.lookup trie.lookup
.read()
.get(&WorkerWithDpRank::from_worker_id(worker_2)) .get(&WorkerWithDpRank::from_worker_id(worker_2))
.unwrap() .unwrap()
.read() .read()
...@@ -751,10 +743,12 @@ mod tests { ...@@ -751,10 +743,12 @@ mod tests {
assert!( assert!(
trie.lookup trie.lookup
.read()
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0)) .contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
); );
assert!( assert!(
trie.lookup trie.lookup
.read()
.get(&WorkerWithDpRank::from_worker_id(worker_0)) .get(&WorkerWithDpRank::from_worker_id(worker_0))
.unwrap() .unwrap()
.read() .read()
...@@ -780,16 +774,17 @@ mod tests { ...@@ -780,16 +774,17 @@ mod tests {
trie.apply_event(create_store_event(worker_1, 0, vec![1, 2, 3], None)) trie.apply_event(create_store_event(worker_1, 0, vec![1, 2, 3], None))
.unwrap(); .unwrap();
assert_eq!(trie.lookup.len(), 2); assert_eq!(trie.lookup.read().len(), 2);
trie.remove_worker(worker_0); trie.remove_worker(worker_0);
assert!( assert!(
!trie !trie
.lookup .lookup
.read()
.contains_key(&WorkerWithDpRank::from_worker_id(worker_0)) .contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
); );
assert_eq!(trie.lookup.len(), 1); assert_eq!(trie.lookup.read().len(), 1);
let result = trie let result = trie
.find_matches_impl( .find_matches_impl(
...@@ -807,7 +802,7 @@ mod tests { ...@@ -807,7 +802,7 @@ mod tests {
let trie: ConcurrentRadixTree = Default::default(); let trie: ConcurrentRadixTree = Default::default();
assert!(trie.root.read().children.is_empty()); assert!(trie.root.read().children.is_empty());
assert!(trie.root.read().workers.is_empty()); assert!(trie.root.read().workers.is_empty());
assert!(trie.lookup.is_empty()); assert!(trie.lookup.read().is_empty());
} }
#[test] #[test]
...@@ -920,14 +915,14 @@ mod tests { ...@@ -920,14 +915,14 @@ mod tests {
.unwrap(); .unwrap();
let worker_key = WorkerWithDpRank::from_worker_id(worker_1); let worker_key = WorkerWithDpRank::from_worker_id(worker_1);
assert_eq!(trie.lookup.get(&worker_key).unwrap().read().len(), 3); assert_eq!(trie.lookup.read().get(&worker_key).unwrap().read().len(), 3);
// Remove ONLY block1 -- descendants should NOT be cascade-removed // Remove ONLY block1 -- descendants should NOT be cascade-removed
trie.apply_event(create_remove_event(worker_1, 2, vec![1])) trie.apply_event(create_remove_event(worker_1, 2, vec![1]))
.unwrap(); .unwrap();
let inner_ref = trie.lookup.get(&worker_key).unwrap(); let lk = trie.lookup.read();
let worker_lookup = inner_ref.read(); let worker_lookup = lk.get(&worker_key).unwrap().read();
assert!( assert!(
!worker_lookup.contains_key(&ExternalSequenceBlockHash(100)), !worker_lookup.contains_key(&ExternalSequenceBlockHash(100)),
"block1 should be removed" "block1 should be removed"
...@@ -959,8 +954,8 @@ mod tests { ...@@ -959,8 +954,8 @@ mod tests {
trie.apply_event(create_remove_event(worker_1, 2, vec![1, 2, 3])) trie.apply_event(create_remove_event(worker_1, 2, vec![1, 2, 3]))
.unwrap(); .unwrap();
let inner_ref = trie.lookup.get(&worker_key).unwrap(); let lk = trie.lookup.read();
let worker_lookup = inner_ref.read(); let worker_lookup = lk.get(&worker_key).unwrap().read();
assert_eq!(worker_lookup.len(), 0, "all blocks should be removed"); assert_eq!(worker_lookup.len(), 0, "all blocks should be removed");
} }
......
...@@ -44,6 +44,7 @@ use dynamo_runtime::{ ...@@ -44,6 +44,7 @@ use dynamo_runtime::{
metrics::{MetricsHierarchy, prometheus_names::kvrouter}, metrics::{MetricsHierarchy, prometheus_names::kvrouter},
}; };
use prometheus::{IntCounterVec, Opts}; use prometheus::{IntCounterVec, Opts};
use rustc_hash::FxBuildHasher;
/// Trait for types that may represent an error response. /// Trait for types that may represent an error response.
/// Used for RPC-style responses that can indicate success or failure. /// Used for RPC-style responses that can indicate success or failure.
...@@ -406,7 +407,7 @@ pub struct ThreadPoolIndexer<T: SyncIndexer> { ...@@ -406,7 +407,7 @@ pub struct ThreadPoolIndexer<T: SyncIndexer> {
backend: Arc<T>, backend: Arc<T>,
/// Maps WorkerId to worker thread index for sticky routing. /// Maps WorkerId to worker thread index for sticky routing.
worker_assignments: DashMap<WorkerId, usize>, worker_assignments: DashMap<WorkerId, usize, FxBuildHasher>,
/// Counter for round-robin assignment of new WorkerIds. /// Counter for round-robin assignment of new WorkerIds.
worker_assignment_count: AtomicUsize, worker_assignment_count: AtomicUsize,
...@@ -463,7 +464,7 @@ impl<T: SyncIndexer> ThreadPoolIndexer<T> { ...@@ -463,7 +464,7 @@ impl<T: SyncIndexer> ThreadPoolIndexer<T> {
Self { Self {
backend, backend,
worker_assignments: DashMap::new(), worker_assignments: DashMap::with_hasher(FxBuildHasher),
worker_assignment_count: AtomicUsize::new(0), worker_assignment_count: AtomicUsize::new(0),
worker_event_channels: worker_event_senders, worker_event_channels: worker_event_senders,
num_workers, num_workers,
...@@ -1388,7 +1389,7 @@ pub struct KvIndexerSharded { ...@@ -1388,7 +1389,7 @@ pub struct KvIndexerSharded {
cancel: CancellationToken, cancel: CancellationToken,
/// The size of the KV block this indexer can handle. /// The size of the KV block this indexer can handle.
kv_block_size: u32, kv_block_size: u32,
worker_assignments: DashMap<WorkerId, usize>, worker_assignments: DashMap<WorkerId, usize, FxBuildHasher>,
worker_counts: Arc<Mutex<Vec<usize>>>, worker_counts: Arc<Mutex<Vec<usize>>>,
event_tx: Vec<mpsc::Sender<RouterEvent>>, event_tx: Vec<mpsc::Sender<RouterEvent>>,
...@@ -1421,7 +1422,7 @@ impl KvIndexerSharded { ...@@ -1421,7 +1422,7 @@ impl KvIndexerSharded {
metrics: Arc<KvIndexerMetrics>, metrics: Arc<KvIndexerMetrics>,
prune_config: Option<PruneConfig>, prune_config: Option<PruneConfig>,
) -> Self { ) -> Self {
let worker_assignments = DashMap::new(); let worker_assignments = DashMap::with_hasher(FxBuildHasher);
let worker_counts = Arc::new(Mutex::new(vec![0; num_shards])); let worker_counts = Arc::new(Mutex::new(vec![0; num_shards]));
let mut event_tx = Vec::new(); let mut event_tx = Vec::new();
......
...@@ -11,6 +11,8 @@ pub mod approx; ...@@ -11,6 +11,8 @@ pub mod approx;
pub mod bench_utils; pub mod bench_utils;
pub mod concurrent_radix_tree; pub mod concurrent_radix_tree;
pub mod indexer; pub mod indexer;
#[cfg(feature = "bench")]
pub mod naive_indexers;
pub mod nested_map; pub mod nested_map;
pub mod protocols; pub mod protocols;
pub mod radix_tree; pub mod radix_tree;
...@@ -21,6 +23,8 @@ pub(crate) mod test_utils; ...@@ -21,6 +23,8 @@ pub(crate) mod test_utils;
// Re-export key types for convenience // Re-export key types for convenience
pub use concurrent_radix_tree::ConcurrentRadixTree; pub use concurrent_radix_tree::ConcurrentRadixTree;
pub use indexer::{MaybeError, SyncIndexer, ThreadPoolIndexer}; pub use indexer::{MaybeError, SyncIndexer, ThreadPoolIndexer};
#[cfg(feature = "bench")]
pub use naive_indexers::{InvertedIndex, NaiveNestedMap};
pub use nested_map::PositionalIndexer; pub use nested_map::PositionalIndexer;
pub use protocols::{ pub use protocols::{
KvCacheEventError, LocalBlockHash, OverlapScores, RouterEvent, WorkerId, KvCacheEventError, LocalBlockHash, OverlapScores, RouterEvent, WorkerId,
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! **DO NOT USE IN PRODUCTION.** These are intentionally simplified indexer
//! implementations for benchmarking and blog illustrations only. They cut
//! corners (no reverse lookup, Remove events are unimplemented) that make
//! them incorrect under real workloads with eviction pressure.
//!
//! They correspond to blog sections 2 and 3 and exist to show the performance
//! progression from naive approaches to the production indexers.
//!
//! - [`NaiveNestedMap`]: `worker -> set<local_hash>`. O(W × D) per
//! `find_matches` call, behind a single-threaded actor. Blog section 2.
//! - [`InvertedIndex`]: `local_hash -> set<worker>`. O(D + W) per
//! `find_matches` call, single-threaded actor. Blog section 3.
use async_trait::async_trait;
use std::collections::{HashMap, HashSet};
use tokio::sync::{mpsc, oneshot};
use crate::indexer::{KvIndexerInterface, KvRouterError};
use crate::protocols::{
KvCacheEventData, LocalBlockHash, OverlapScores, RouterEvent, TokensWithHashes, WorkerId,
WorkerWithDpRank,
};
// ============================================================================
// Section 2 — Naive Nested Map + Actor
// ============================================================================
/// Plain nested `HashMap` index — no locks, owned exclusively by the actor thread.
///
/// Structure: `worker -> set<local_hash>`.
/// No reverse lookup — Remove is unimplemented (relies on large GPU block
/// budget to avoid evictions).
struct NaiveNestedMapInner {
index: HashMap<WorkerWithDpRank, HashSet<LocalBlockHash>>,
}
impl NaiveNestedMapInner {
fn new() -> Self {
Self {
index: HashMap::new(),
}
}
fn find_matches(&self, sequence: &[LocalBlockHash]) -> OverlapScores {
let mut scores = OverlapScores::new();
if sequence.is_empty() {
return scores;
}
for (worker, blocks) in &self.index {
let mut depth = 0u32;
for local_hash in sequence {
if !blocks.contains(local_hash) {
break;
}
depth += 1;
}
if depth > 0 {
scores.scores.insert(*worker, depth);
}
}
scores
}
fn apply_event(&mut self, event: RouterEvent) {
let worker = WorkerWithDpRank::new(event.worker_id, event.event.dp_rank);
match event.event.data {
KvCacheEventData::Stored(store_data) => {
let worker_set = self.index.entry(worker).or_default();
for block in store_data.blocks {
worker_set.insert(block.tokens_hash);
}
}
KvCacheEventData::Removed(_) => {
unimplemented!(
"NaiveNestedMap does not support Remove events; increase --num-gpu-blocks to avoid evictions"
);
}
KvCacheEventData::Cleared => {
self.index.remove(&worker);
}
}
}
fn remove_worker(&mut self, worker_id: WorkerId) {
self.index.retain(|w, _| w.worker_id != worker_id);
}
}
struct MatchRequest {
sequence: Vec<LocalBlockHash>,
reply: oneshot::Sender<OverlapScores>,
}
enum ActorMessage {
Event(RouterEvent),
Match(MatchRequest),
RemoveWorker(WorkerId),
}
/// Single-threaded actor wrapping [`NaiveNestedMapInner`] (blog section 2).
///
/// All reads and writes are serialized through a single OS thread via channels.
/// This is the pure actor pattern described in the blog — no concurrent access
/// to the data structure at all.
pub struct NaiveNestedMap {
tx: mpsc::Sender<ActorMessage>,
}
impl Default for NaiveNestedMap {
fn default() -> Self {
Self::new()
}
}
impl NaiveNestedMap {
pub fn new() -> Self {
let (tx, mut rx) = mpsc::channel::<ActorMessage>(2048);
std::thread::spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async move {
let mut inner = NaiveNestedMapInner::new();
while let Some(msg) = rx.recv().await {
match msg {
ActorMessage::Event(event) => {
inner.apply_event(event);
}
ActorMessage::Match(req) => {
let scores = inner.find_matches(&req.sequence);
let _ = req.reply.send(scores);
}
ActorMessage::RemoveWorker(worker_id) => {
inner.remove_worker(worker_id);
}
}
}
});
});
Self { tx }
}
}
#[async_trait]
impl KvIndexerInterface for NaiveNestedMap {
async fn find_matches(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<OverlapScores, KvRouterError> {
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(ActorMessage::Match(MatchRequest {
sequence,
reply: reply_tx,
}))
.await
.map_err(|_| KvRouterError::IndexerOffline)?;
reply_rx
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)
}
async fn find_matches_for_request(
&self,
_tokens: &[u32],
) -> Result<OverlapScores, KvRouterError> {
unimplemented!("not used in bench")
}
async fn apply_event(&self, event: RouterEvent) {
let _ = self.tx.send(ActorMessage::Event(event)).await;
}
async fn remove_worker(&self, worker: WorkerId) {
let _ = self.tx.send(ActorMessage::RemoveWorker(worker)).await;
}
fn shutdown(&self) {}
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
Ok(Vec::new())
}
async fn process_routing_decision_for_request(
&self,
_tokens_with_hashes: &mut TokensWithHashes,
_worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
unimplemented!("not used in bench")
}
async fn flush(&self) -> usize {
let curr_size = self.tx.max_capacity() - self.tx.capacity();
loop {
if self.tx.capacity() == self.tx.max_capacity() {
break;
}
tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
}
curr_size
}
}
// ============================================================================
// Section 3 — Inverted Index
// ============================================================================
/// Plain inverted index — no locks, owned exclusively by the actor thread.
///
/// Flat forward index: `local_hash -> set<worker>`.
/// No reverse lookup — Remove is a no-op (relies on large GPU block budget
/// to avoid evictions), Clear/remove_worker scan the forward index.
struct InvertedIndexInner {
index: HashMap<LocalBlockHash, HashSet<WorkerWithDpRank>>,
}
impl InvertedIndexInner {
fn new() -> Self {
Self {
index: HashMap::new(),
}
}
fn find_matches(&self, sequence: &[LocalBlockHash]) -> OverlapScores {
let mut scores = OverlapScores::new();
if sequence.is_empty() {
return scores;
}
let Some(workers) = self.index.get(&sequence[0]) else {
return scores;
};
let mut active: HashSet<WorkerWithDpRank> = workers.clone();
if active.is_empty() {
return scores;
}
for (depth, local_hash) in sequence.iter().enumerate() {
let empty = HashSet::new();
let workers_here = self.index.get(local_hash).unwrap_or(&empty);
active.retain(|w| {
if workers_here.contains(w) {
true
} else {
scores.scores.insert(*w, depth as u32);
false
}
});
if active.is_empty() {
break;
}
}
for w in active {
scores.scores.insert(w, sequence.len() as u32);
}
scores
}
fn apply_event(&mut self, event: RouterEvent) {
let worker = WorkerWithDpRank::new(event.worker_id, event.event.dp_rank);
match event.event.data {
KvCacheEventData::Stored(store_data) => {
for block in store_data.blocks {
self.index
.entry(block.tokens_hash)
.or_default()
.insert(worker);
}
}
KvCacheEventData::Removed(_) => {
unimplemented!(
"InvertedIndex does not support Remove events; increase --num-gpu-blocks to avoid evictions"
);
}
KvCacheEventData::Cleared => {
self.clear_worker(worker);
}
}
}
fn remove_worker(&mut self, worker_id: WorkerId) {
for workers in self.index.values_mut() {
workers.retain(|w| w.worker_id != worker_id);
}
}
fn clear_worker(&mut self, worker: WorkerWithDpRank) {
for workers in self.index.values_mut() {
workers.remove(&worker);
}
}
}
enum InvertedIndexMessage {
Event(RouterEvent),
Match(MatchRequest),
RemoveWorker(WorkerId),
}
/// Single-threaded actor wrapping [`InvertedIndexInner`] (blog section 3).
///
/// Same actor pattern as [`NaiveNestedMap`] — all reads and writes are
/// serialized through a single OS thread via channels.
pub struct InvertedIndex {
tx: mpsc::Sender<InvertedIndexMessage>,
}
impl Default for InvertedIndex {
fn default() -> Self {
Self::new()
}
}
impl InvertedIndex {
pub fn new() -> Self {
let (tx, mut rx) = mpsc::channel::<InvertedIndexMessage>(2048);
std::thread::spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async move {
let mut inner = InvertedIndexInner::new();
while let Some(msg) = rx.recv().await {
match msg {
InvertedIndexMessage::Event(event) => {
inner.apply_event(event);
}
InvertedIndexMessage::Match(req) => {
let scores = inner.find_matches(&req.sequence);
let _ = req.reply.send(scores);
}
InvertedIndexMessage::RemoveWorker(worker_id) => {
inner.remove_worker(worker_id);
}
}
}
});
});
Self { tx }
}
}
#[async_trait]
impl KvIndexerInterface for InvertedIndex {
async fn find_matches(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<OverlapScores, KvRouterError> {
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(InvertedIndexMessage::Match(MatchRequest {
sequence,
reply: reply_tx,
}))
.await
.map_err(|_| KvRouterError::IndexerOffline)?;
reply_rx
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)
}
async fn find_matches_for_request(
&self,
_tokens: &[u32],
) -> Result<OverlapScores, KvRouterError> {
unimplemented!("not used in bench")
}
async fn apply_event(&self, event: RouterEvent) {
let _ = self.tx.send(InvertedIndexMessage::Event(event)).await;
}
async fn remove_worker(&self, worker: WorkerId) {
let _ = self
.tx
.send(InvertedIndexMessage::RemoveWorker(worker))
.await;
}
fn shutdown(&self) {}
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
Ok(Vec::new())
}
async fn process_routing_decision_for_request(
&self,
_tokens_with_hashes: &mut TokensWithHashes,
_worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
unimplemented!("not used in bench")
}
async fn flush(&self) -> usize {
let curr_size = self.tx.max_capacity() - self.tx.capacity();
loop {
if self.tx.capacity() == self.tx.max_capacity() {
break;
}
tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
}
curr_size
}
}
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
//! `KvIndexerInterface` with sticky event routing and worker threads, wrap it //! `KvIndexerInterface` with sticky event routing and worker threads, wrap it
//! in a `ThreadPoolIndexer`. //! in a `ThreadPoolIndexer`.
use dashmap::DashMap; use dashmap::DashMap;
use std::collections::{HashMap, HashSet}; use parking_lot::RwLock;
use std::sync::RwLock; use rustc_hash::{FxBuildHasher, FxHashMap, FxHashSet};
use crate::indexer::SyncIndexer; use crate::indexer::SyncIndexer;
use crate::protocols::{ use crate::protocols::{
...@@ -37,15 +37,15 @@ use crate::protocols::{ ...@@ -37,15 +37,15 @@ use crate::protocols::{
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
enum SeqEntry { enum SeqEntry {
/// Single seq_hash -> workers mapping (common case, no HashMap allocation) /// Single seq_hash -> workers mapping (common case, no HashMap allocation)
Single(ExternalSequenceBlockHash, HashSet<WorkerWithDpRank>), Single(ExternalSequenceBlockHash, FxHashSet<WorkerWithDpRank>),
/// Multiple seq_hash -> workers mappings (rare case, different prefixes) /// Multiple seq_hash -> workers mappings (rare case, different prefixes)
Multi(HashMap<ExternalSequenceBlockHash, HashSet<WorkerWithDpRank>>), Multi(FxHashMap<ExternalSequenceBlockHash, FxHashSet<WorkerWithDpRank>>),
} }
impl SeqEntry { impl SeqEntry {
/// Create a new entry with a single worker. /// Create a new entry with a single worker.
fn new(seq_hash: ExternalSequenceBlockHash, worker: WorkerWithDpRank) -> Self { fn new(seq_hash: ExternalSequenceBlockHash, worker: WorkerWithDpRank) -> Self {
let mut workers = HashSet::new(); let mut workers = FxHashSet::default();
workers.insert(worker); workers.insert(worker);
Self::Single(seq_hash, workers) Self::Single(seq_hash, workers)
} }
...@@ -58,7 +58,7 @@ impl SeqEntry { ...@@ -58,7 +58,7 @@ impl SeqEntry {
} }
Self::Single(existing_hash, existing_workers) => { Self::Single(existing_hash, existing_workers) => {
// Upgrade to Multi // Upgrade to Multi
let mut map = HashMap::with_capacity(2); let mut map = FxHashMap::with_capacity_and_hasher(2, FxBuildHasher);
map.insert(*existing_hash, std::mem::take(existing_workers)); map.insert(*existing_hash, std::mem::take(existing_workers));
map.entry(seq_hash).or_default().insert(worker); map.entry(seq_hash).or_default().insert(worker);
*self = Self::Multi(map); *self = Self::Multi(map);
...@@ -91,7 +91,7 @@ impl SeqEntry { ...@@ -91,7 +91,7 @@ impl SeqEntry {
} }
/// Get workers for a specific seq_hash. /// Get workers for a specific seq_hash.
fn get(&self, seq_hash: ExternalSequenceBlockHash) -> Option<&HashSet<WorkerWithDpRank>> { fn get(&self, seq_hash: ExternalSequenceBlockHash) -> Option<&FxHashSet<WorkerWithDpRank>> {
match self { match self {
Self::Single(existing_hash, workers) if *existing_hash == seq_hash => Some(workers), Self::Single(existing_hash, workers) if *existing_hash == seq_hash => Some(workers),
Self::Single(_, _) => None, Self::Single(_, _) => None,
...@@ -100,17 +100,19 @@ impl SeqEntry { ...@@ -100,17 +100,19 @@ impl SeqEntry {
} }
} }
type LevelIndex = RwLock<HashMap<ExternalSequenceBlockHash, (usize, LocalBlockHash)>>; type LevelIndex = RwLock<FxHashMap<ExternalSequenceBlockHash, (usize, LocalBlockHash)>>;
/// Positional HashMap-based KV cache index. /// Positional HashMap-based KV cache index.
/// ///
/// Implements [`SyncIndexer`] for use with [`ThreadPoolIndexer`](crate::indexer::ThreadPoolIndexer). /// Implements [`SyncIndexer`] for use with [`ThreadPoolIndexer`](crate::indexer::ThreadPoolIndexer).
/// All methods are synchronous and thread-safe. /// All methods are synchronous and thread-safe.
pub struct PositionalIndexer { pub struct PositionalIndexer {
index: DashMap<(usize, LocalBlockHash), SeqEntry>, index: DashMap<(usize, LocalBlockHash), SeqEntry, FxBuildHasher>,
/// Per-worker reverse lookup: worker -> seq_hash -> (position, local_hash) /// Per-worker reverse lookup: worker -> seq_hash -> (position, local_hash)
/// Enables efficient remove operations without global flat reverse map. /// Enables efficient remove operations without global flat reverse map.
worker_blocks: DashMap<WorkerWithDpRank, LevelIndex>, /// Uses a single RwLock rather than DashMap because structural mutations
/// (adding/removing workers) are rare; the hot path is read-only.
worker_blocks: RwLock<FxHashMap<WorkerWithDpRank, LevelIndex>>,
jump_size: usize, jump_size: usize,
} }
...@@ -126,8 +128,8 @@ impl PositionalIndexer { ...@@ -126,8 +128,8 @@ impl PositionalIndexer {
assert!(jump_size > 0, "jump_size must be greater than 0"); assert!(jump_size > 0, "jump_size must be greater than 0");
Self { Self {
index: DashMap::new(), index: DashMap::with_hasher(FxBuildHasher),
worker_blocks: DashMap::new(), worker_blocks: RwLock::new(FxHashMap::default()),
jump_size, jump_size,
} }
} }
...@@ -159,9 +161,10 @@ impl SyncIndexer for PositionalIndexer { ...@@ -159,9 +161,10 @@ impl SyncIndexer for PositionalIndexer {
let mut events = Vec::new(); let mut events = Vec::new();
let mut event_id = 0u64; let mut event_id = 0u64;
for entry in self.worker_blocks.iter() { let wb = self.worker_blocks.read();
let worker = *entry.key(); for (worker, level_index) in wb.iter() {
let worker_map = entry.value().read().unwrap(); let worker = *worker;
let worker_map = level_index.read();
// Collect (position, local_hash, seq_hash) and sort by position // Collect (position, local_hash, seq_hash) and sort by position
// so parents are emitted before children during replay. // so parents are emitted before children during replay.
...@@ -172,7 +175,8 @@ impl SyncIndexer for PositionalIndexer { ...@@ -172,7 +175,8 @@ impl SyncIndexer for PositionalIndexer {
blocks.sort_unstable_by_key(|(pos, _, _)| *pos); blocks.sort_unstable_by_key(|(pos, _, _)| *pos);
// Track one valid seq_hash per position for parent_hash synthesis. // Track one valid seq_hash per position for parent_hash synthesis.
let mut last_at_position: HashMap<usize, ExternalSequenceBlockHash> = HashMap::new(); let mut last_at_position: FxHashMap<usize, ExternalSequenceBlockHash> =
FxHashMap::default();
for (pos, local_hash, seq_hash) in blocks { for (pos, local_hash, seq_hash) in blocks {
let parent_hash = if pos == 0 { let parent_hash = if pos == 0 {
...@@ -224,8 +228,8 @@ impl PositionalIndexer { ...@@ -224,8 +228,8 @@ impl PositionalIndexer {
/// Process an event using the provided index and worker_blocks. /// Process an event using the provided index and worker_blocks.
/// This is called from worker threads. /// This is called from worker threads.
fn apply_event_impl( fn apply_event_impl(
index: &DashMap<(usize, LocalBlockHash), SeqEntry>, index: &DashMap<(usize, LocalBlockHash), SeqEntry, FxBuildHasher>,
worker_blocks: &DashMap<WorkerWithDpRank, LevelIndex>, worker_blocks: &RwLock<FxHashMap<WorkerWithDpRank, LevelIndex>>,
event: RouterEvent, event: RouterEvent,
) -> Result<(), KvCacheEventError> { ) -> Result<(), KvCacheEventError> {
let (worker_id, kv_event) = (event.worker_id, event.event); let (worker_id, kv_event) = (event.worker_id, event.event);
...@@ -263,8 +267,8 @@ impl PositionalIndexer { ...@@ -263,8 +267,8 @@ impl PositionalIndexer {
} }
fn store_blocks_impl( fn store_blocks_impl(
index: &DashMap<(usize, LocalBlockHash), SeqEntry>, index: &DashMap<(usize, LocalBlockHash), SeqEntry, FxBuildHasher>,
worker_blocks: &DashMap<WorkerWithDpRank, LevelIndex>, worker_blocks: &RwLock<FxHashMap<WorkerWithDpRank, LevelIndex>>,
worker: WorkerWithDpRank, worker: WorkerWithDpRank,
store_data: KvCacheStoreData, store_data: KvCacheStoreData,
event_id: u64, event_id: u64,
...@@ -272,9 +276,8 @@ impl PositionalIndexer { ...@@ -272,9 +276,8 @@ impl PositionalIndexer {
// Determine starting position based on parent_hash // Determine starting position based on parent_hash
let start_pos = match store_data.parent_hash { let start_pos = match store_data.parent_hash {
Some(parent_hash) => { Some(parent_hash) => {
// Find parent position from worker_blocks let wb = worker_blocks.read();
let Some(level_index) = wb.get(&worker) else {
let Some(worker_map) = worker_blocks.get(&worker) else {
tracing::warn!( tracing::warn!(
worker_id = worker.worker_id.to_string(), worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank, dp_rank = worker.dp_rank,
...@@ -284,7 +287,7 @@ impl PositionalIndexer { ...@@ -284,7 +287,7 @@ impl PositionalIndexer {
return Err(KvCacheEventError::ParentBlockNotFound); return Err(KvCacheEventError::ParentBlockNotFound);
}; };
let worker_map = worker_map.read().unwrap(); let worker_map = level_index.read();
let Some(entry) = worker_map.get(&parent_hash) else { let Some(entry) = worker_map.get(&parent_hash) else {
tracing::warn!( tracing::warn!(
...@@ -301,12 +304,15 @@ impl PositionalIndexer { ...@@ -301,12 +304,15 @@ impl PositionalIndexer {
None => 0, // Start from position 0 None => 0, // Start from position 0
}; };
if !worker_blocks.contains_key(&worker) { if !worker_blocks.read().contains_key(&worker) {
worker_blocks.insert(worker, RwLock::new(HashMap::new())); worker_blocks
.write()
.entry(worker)
.or_insert_with(|| RwLock::new(FxHashMap::default()));
} }
let worker_blocks_entry = worker_blocks.get(&worker).unwrap(); let wb = worker_blocks.read();
let mut worker_map = worker_blocks_entry.write().unwrap(); let mut worker_map = wb.get(&worker).unwrap().write();
for (i, block_data) in store_data.blocks.into_iter().enumerate() { for (i, block_data) in store_data.blocks.into_iter().enumerate() {
let position = start_pos + i; let position = start_pos + i;
...@@ -326,13 +332,14 @@ impl PositionalIndexer { ...@@ -326,13 +332,14 @@ impl PositionalIndexer {
} }
fn remove_blocks_impl( fn remove_blocks_impl(
index: &DashMap<(usize, LocalBlockHash), SeqEntry>, index: &DashMap<(usize, LocalBlockHash), SeqEntry, FxBuildHasher>,
worker_blocks: &DashMap<WorkerWithDpRank, LevelIndex>, worker_blocks: &RwLock<FxHashMap<WorkerWithDpRank, LevelIndex>>,
worker: WorkerWithDpRank, worker: WorkerWithDpRank,
seq_hashes: &Vec<ExternalSequenceBlockHash>, seq_hashes: &Vec<ExternalSequenceBlockHash>,
event_id: u64, event_id: u64,
) -> Result<(), KvCacheEventError> { ) -> Result<(), KvCacheEventError> {
let worker_map = worker_blocks.get(&worker).ok_or_else(|| { let wb = worker_blocks.read();
let level_index = wb.get(&worker).ok_or_else(|| {
tracing::warn!( tracing::warn!(
worker_id = worker.worker_id.to_string(), worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank, dp_rank = worker.dp_rank,
...@@ -343,7 +350,7 @@ impl PositionalIndexer { ...@@ -343,7 +350,7 @@ impl PositionalIndexer {
KvCacheEventError::BlockNotFound KvCacheEventError::BlockNotFound
})?; })?;
let mut worker_map = worker_map.write().unwrap(); let mut worker_map = level_index.write();
for seq_hash in seq_hashes { for seq_hash in seq_hashes {
let Some((position, local_hash)) = worker_map.remove(seq_hash) else { let Some((position, local_hash)) = worker_map.remove(seq_hash) else {
...@@ -369,8 +376,8 @@ impl PositionalIndexer { ...@@ -369,8 +376,8 @@ impl PositionalIndexer {
/// Clear all blocks for a specific worker_id (all dp_ranks), but keep worker tracked. /// Clear all blocks for a specific worker_id (all dp_ranks), but keep worker tracked.
/// Static version for use in worker threads. /// Static version for use in worker threads.
fn clear_worker_blocks_impl( fn clear_worker_blocks_impl(
index: &DashMap<(usize, LocalBlockHash), SeqEntry>, index: &DashMap<(usize, LocalBlockHash), SeqEntry, FxBuildHasher>,
worker_blocks: &DashMap<WorkerWithDpRank, LevelIndex>, worker_blocks: &RwLock<FxHashMap<WorkerWithDpRank, LevelIndex>>,
worker_id: WorkerId, worker_id: WorkerId,
) { ) {
Self::remove_or_clear_worker_blocks_impl(index, worker_blocks, worker_id, true); Self::remove_or_clear_worker_blocks_impl(index, worker_blocks, worker_id, true);
...@@ -379,8 +386,9 @@ impl PositionalIndexer { ...@@ -379,8 +386,9 @@ impl PositionalIndexer {
/// Get total number of blocks across all workers. /// Get total number of blocks across all workers.
pub fn current_size(&self) -> usize { pub fn current_size(&self) -> usize {
self.worker_blocks self.worker_blocks
.iter() .read()
.map(|entry| entry.value().read().unwrap().len()) .values()
.map(|level_index| level_index.read().len())
.sum() .sum()
} }
...@@ -399,34 +407,30 @@ impl PositionalIndexer { ...@@ -399,34 +407,30 @@ impl PositionalIndexer {
/// If `keep_worker` is true, the worker remains tracked with empty blocks. /// If `keep_worker` is true, the worker remains tracked with empty blocks.
/// If `keep_worker` is false, the worker is completely removed. /// If `keep_worker` is false, the worker is completely removed.
fn remove_or_clear_worker_blocks_impl( fn remove_or_clear_worker_blocks_impl(
index: &DashMap<(usize, LocalBlockHash), SeqEntry>, index: &DashMap<(usize, LocalBlockHash), SeqEntry, FxBuildHasher>,
worker_blocks: &DashMap<WorkerWithDpRank, LevelIndex>, worker_blocks: &RwLock<FxHashMap<WorkerWithDpRank, LevelIndex>>,
worker_id: WorkerId, worker_id: WorkerId,
keep_worker: bool, keep_worker: bool,
) { ) {
// Collect all WorkerWithDpRank keys that match this worker_id
let workers: Vec<WorkerWithDpRank> = worker_blocks let workers: Vec<WorkerWithDpRank> = worker_blocks
.iter() .read()
.filter(|entry| entry.key().worker_id == worker_id) .keys()
.map(|entry| *entry.key()) .filter(|w| w.worker_id == worker_id)
.copied()
.collect(); .collect();
let mut wb = worker_blocks.write();
for worker in workers { for worker in workers {
if let Some((_, worker_map)) = worker_blocks.remove(&worker) { if let Some(worker_map) = wb.remove(&worker) {
// Remove each block from the index for (seq_hash, (position, local_hash)) in worker_map.read().iter() {
for entry in worker_map.read().unwrap().iter() { if let Some(mut entry) = index.get_mut(&(*position, *local_hash)) {
let seq_hash = *entry.0; let _ = entry.remove(*seq_hash, worker);
let (position, local_hash) = *entry.1;
if let Some(mut entry) = index.get_mut(&(position, local_hash)) {
let _ = entry.remove(seq_hash, worker);
} }
} }
} }
if keep_worker { if keep_worker {
// Re-insert worker with empty map to keep it tracked wb.insert(worker, RwLock::new(FxHashMap::default()));
worker_blocks.insert(worker, RwLock::new(HashMap::new()));
} }
} }
} }
...@@ -481,7 +485,7 @@ impl PositionalIndexer { ...@@ -481,7 +485,7 @@ impl PositionalIndexer {
local_hash: LocalBlockHash, local_hash: LocalBlockHash,
seq_hashes: &mut Vec<ExternalSequenceBlockHash>, seq_hashes: &mut Vec<ExternalSequenceBlockHash>,
sequence: &[LocalBlockHash], sequence: &[LocalBlockHash],
) -> Option<HashSet<WorkerWithDpRank>> { ) -> Option<FxHashSet<WorkerWithDpRank>> {
let entry = self.index.get(&(position, local_hash))?; let entry = self.index.get(&(position, local_hash))?;
// Always compute and verify seq_hash to handle divergent queries correctly. // Always compute and verify seq_hash to handle divergent queries correctly.
...@@ -517,13 +521,13 @@ impl PositionalIndexer { ...@@ -517,13 +521,13 @@ impl PositionalIndexer {
/// Scan positions sequentially, updating active set and recording drain scores. /// Scan positions sequentially, updating active set and recording drain scores.
/// ///
/// Inlines the DashMap lookup so the guard lives for each iteration, /// Inlines the DashMap lookup so the guard lives for each iteration,
/// avoiding a per-position `HashSet` clone. /// avoiding a per-position `FxHashSet` clone.
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn linear_scan_drain( fn linear_scan_drain(
&self, &self,
sequence: &[LocalBlockHash], sequence: &[LocalBlockHash],
seq_hashes: &mut Vec<ExternalSequenceBlockHash>, seq_hashes: &mut Vec<ExternalSequenceBlockHash>,
active: &mut HashSet<WorkerWithDpRank>, active: &mut FxHashSet<WorkerWithDpRank>,
scores: &mut OverlapScores, scores: &mut OverlapScores,
lo: usize, lo: usize,
hi: usize, hi: usize,
...@@ -622,10 +626,10 @@ impl PositionalIndexer { ...@@ -622,10 +626,10 @@ impl PositionalIndexer {
scores.scores.insert(*worker, 1); scores.scores.insert(*worker, 1);
} }
// Populate tree_sizes // Populate tree_sizes
let wb = self.worker_blocks.read();
for worker in scores.scores.keys() { for worker in scores.scores.keys() {
if let Some(worker_map) = self.worker_blocks.get(worker) { if let Some(level_index) = wb.get(worker) {
let worker_map = worker_map.read().unwrap(); scores.tree_sizes.insert(*worker, level_index.read().len());
scores.tree_sizes.insert(*worker, worker_map.len());
} }
} }
return scores; return scores;
...@@ -674,10 +678,10 @@ impl PositionalIndexer { ...@@ -674,10 +678,10 @@ impl PositionalIndexer {
} }
// Populate tree_sizes from worker_blocks // Populate tree_sizes from worker_blocks
let wb = self.worker_blocks.read();
for worker in scores.scores.keys() { for worker in scores.scores.keys() {
if let Some(worker_map) = self.worker_blocks.get(worker) { if let Some(level_index) = wb.get(worker) {
let worker_map = worker_map.read().unwrap(); scores.tree_sizes.insert(*worker, level_index.read().len());
scores.tree_sizes.insert(*worker, worker_map.len());
} }
} }
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use dynamo_tokens::{SequenceHash, Token}; use dynamo_tokens::{SequenceHash, Token};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use xxhash_rust::xxh3; use xxhash_rust::xxh3;
...@@ -506,11 +507,11 @@ impl RouterEvent { ...@@ -506,11 +507,11 @@ impl RouterEvent {
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OverlapScores { pub struct OverlapScores {
/// Map of worker (with dp_rank) to score. /// Map of worker (with dp_rank) to score.
pub scores: std::collections::HashMap<WorkerWithDpRank, u32>, pub scores: FxHashMap<WorkerWithDpRank, u32>,
/// List of frequencies that the blocks have been accessed. Entries with value 0 are omitted. /// List of frequencies that the blocks have been accessed. Entries with value 0 are omitted.
pub frequencies: Vec<usize>, pub frequencies: Vec<usize>,
/// Map of worker to their tree size (number of blocks in the tree for that worker). /// Map of worker to their tree size (number of blocks in the tree for that worker).
pub tree_sizes: std::collections::HashMap<WorkerWithDpRank, usize>, pub tree_sizes: FxHashMap<WorkerWithDpRank, usize>,
} }
impl Default for OverlapScores { impl Default for OverlapScores {
...@@ -527,9 +528,9 @@ impl OverlapScores { ...@@ -527,9 +528,9 @@ impl OverlapScores {
/// A new `OverlapScores`. /// A new `OverlapScores`.
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
scores: std::collections::HashMap::new(), scores: FxHashMap::default(),
frequencies: Vec::with_capacity(32), frequencies: Vec::with_capacity(32),
tree_sizes: std::collections::HashMap::new(), tree_sizes: FxHashMap::default(),
} }
} }
......
...@@ -16,11 +16,13 @@ ...@@ -16,11 +16,13 @@
use std::{ use std::{
cell::RefCell, cell::RefCell,
collections::{HashMap, HashSet, VecDeque}, collections::VecDeque,
rc::Rc, rc::Rc,
time::{Duration, Instant}, time::{Duration, Instant},
}; };
use rustc_hash::{FxHashMap, FxHashSet};
use crate::protocols::*; use crate::protocols::*;
/// A shared reference to a [`RadixBlock`]. /// A shared reference to a [`RadixBlock`].
...@@ -30,9 +32,9 @@ pub(crate) type SharedRadixBlock = Rc<RefCell<RadixBlock>>; ...@@ -30,9 +32,9 @@ pub(crate) type SharedRadixBlock = Rc<RefCell<RadixBlock>>;
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct RadixBlock { pub(crate) struct RadixBlock {
/// A map of child blocks, keyed by their local block hash. /// A map of child blocks, keyed by their local block hash.
pub(crate) children: HashMap<LocalBlockHash, SharedRadixBlock>, pub(crate) children: FxHashMap<LocalBlockHash, SharedRadixBlock>,
/// The set of workers that have this block cached. /// The set of workers that have this block cached.
pub(crate) workers: HashSet<WorkerWithDpRank>, pub(crate) workers: FxHashSet<WorkerWithDpRank>,
/// The external sequence block hash for this block (None for root). /// The external sequence block hash for this block (None for root).
/// This is the same for all workers under the simplifying assumption. /// This is the same for all workers under the simplifying assumption.
pub(crate) block_hash: Option<ExternalSequenceBlockHash>, pub(crate) block_hash: Option<ExternalSequenceBlockHash>,
...@@ -48,8 +50,8 @@ impl RadixBlock { ...@@ -48,8 +50,8 @@ impl RadixBlock {
/// A new `RadixBlock` with no block_hash. /// A new `RadixBlock` with no block_hash.
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
children: HashMap::new(), children: FxHashMap::default(),
workers: HashSet::new(), workers: FxHashSet::default(),
block_hash: None, block_hash: None,
recent_uses: VecDeque::new(), recent_uses: VecDeque::new(),
} }
...@@ -62,8 +64,8 @@ impl RadixBlock { ...@@ -62,8 +64,8 @@ impl RadixBlock {
/// A new `RadixBlock` with the given block_hash. /// A new `RadixBlock` with the given block_hash.
pub fn with_hash(block_hash: ExternalSequenceBlockHash) -> Self { pub fn with_hash(block_hash: ExternalSequenceBlockHash) -> Self {
Self { Self {
children: HashMap::new(), children: FxHashMap::default(),
workers: HashSet::new(), workers: FxHashSet::default(),
block_hash: Some(block_hash), block_hash: Some(block_hash),
recent_uses: VecDeque::new(), recent_uses: VecDeque::new(),
} }
...@@ -78,7 +80,7 @@ pub struct RadixTree { ...@@ -78,7 +80,7 @@ pub struct RadixTree {
/// Per-worker lookup table for O(1) block access. /// Per-worker lookup table for O(1) block access.
/// Maps worker -> (block_hash -> block). /// Maps worker -> (block_hash -> block).
pub(crate) lookup: pub(crate) lookup:
HashMap<WorkerWithDpRank, HashMap<ExternalSequenceBlockHash, SharedRadixBlock>>, FxHashMap<WorkerWithDpRank, FxHashMap<ExternalSequenceBlockHash, SharedRadixBlock>>,
/// The time buffer the radix tree should check when considering frequence of block accesses /// The time buffer the radix tree should check when considering frequence of block accesses
pub(crate) expiration_duration: Option<Duration>, pub(crate) expiration_duration: Option<Duration>,
...@@ -132,7 +134,7 @@ impl RadixTree { ...@@ -132,7 +134,7 @@ impl RadixTree {
pub fn new_with_frequency(expiration_duration: Option<Duration>) -> Self { pub fn new_with_frequency(expiration_duration: Option<Duration>) -> Self {
Self { Self {
root: Rc::new(RefCell::new(RadixBlock::new())), root: Rc::new(RefCell::new(RadixBlock::new())),
lookup: HashMap::new(), lookup: FxHashMap::default(),
expiration_duration, expiration_duration,
} }
} }
...@@ -485,7 +487,7 @@ impl RadixTree { ...@@ -485,7 +487,7 @@ impl RadixTree {
if keep_worker { if keep_worker {
// Re-insert worker with empty blocks map to keep it tracked // Re-insert worker with empty blocks map to keep it tracked
self.lookup.insert(worker_key, HashMap::new()); self.lookup.insert(worker_key, FxHashMap::default());
} }
} }
} }
......
...@@ -184,11 +184,7 @@ impl Indexer { ...@@ -184,11 +184,7 @@ impl Indexer {
match self { match self {
Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await, Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
Indexer::Concurrent(tpi) => tpi.find_matches(sequence).await, Indexer::Concurrent(tpi) => tpi.find_matches(sequence).await,
Indexer::None => Ok(OverlapScores { Indexer::None => Ok(OverlapScores::new()),
scores: HashMap::new(),
frequencies: Vec::new(),
tree_sizes: HashMap::new(),
}),
} }
} }
......
...@@ -47,9 +47,9 @@ checksum = "5f0e0fee31ef5ed1ba1316088939cea399010ed7731dba877ed44aeb407a75ea" ...@@ -47,9 +47,9 @@ checksum = "5f0e0fee31ef5ed1ba1316088939cea399010ed7731dba877ed44aeb407a75ea"
[[package]] [[package]]
name = "arc-swap" name = "arc-swap"
version = "1.8.1" version = "1.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ded5f9a03ac8f24d1b8a25101ee812cd32cdc8c50a4c50237de2c4915850e73" checksum = "f9f3647c145568cec02c42054e07bdf9a5a698e15b466fb2341bfc393cd24aa5"
dependencies = [ dependencies = [
"rustversion", "rustversion",
] ]
...@@ -1100,9 +1100,9 @@ checksum = "3dcaa9ae7725d12cdb85b3ad99a434db70b468c09ded17e012d86b5c1010f7a7" ...@@ -1100,9 +1100,9 @@ checksum = "3dcaa9ae7725d12cdb85b3ad99a434db70b468c09ded17e012d86b5c1010f7a7"
[[package]] [[package]]
name = "futures" name = "futures"
version = "0.3.31" version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d"
dependencies = [ dependencies = [
"futures-channel", "futures-channel",
"futures-core", "futures-core",
...@@ -1115,9 +1115,9 @@ dependencies = [ ...@@ -1115,9 +1115,9 @@ dependencies = [
[[package]] [[package]]
name = "futures-channel" name = "futures-channel"
version = "0.3.31" version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"futures-sink", "futures-sink",
...@@ -1125,15 +1125,15 @@ dependencies = [ ...@@ -1125,15 +1125,15 @@ dependencies = [
[[package]] [[package]]
name = "futures-core" name = "futures-core"
version = "0.3.31" version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d"
[[package]] [[package]]
name = "futures-executor" name = "futures-executor"
version = "0.3.31" version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"futures-task", "futures-task",
...@@ -1142,15 +1142,15 @@ dependencies = [ ...@@ -1142,15 +1142,15 @@ dependencies = [
[[package]] [[package]]
name = "futures-io" name = "futures-io"
version = "0.3.31" version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718"
[[package]] [[package]]
name = "futures-macro" name = "futures-macro"
version = "0.3.31" version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
...@@ -1159,21 +1159,21 @@ dependencies = [ ...@@ -1159,21 +1159,21 @@ dependencies = [
[[package]] [[package]]
name = "futures-sink" name = "futures-sink"
version = "0.3.31" version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893"
[[package]] [[package]]
name = "futures-task" name = "futures-task"
version = "0.3.31" version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393"
[[package]] [[package]]
name = "futures-util" name = "futures-util"
version = "0.3.31" version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6"
dependencies = [ dependencies = [
"futures-channel", "futures-channel",
"futures-core", "futures-core",
...@@ -1183,7 +1183,6 @@ dependencies = [ ...@@ -1183,7 +1183,6 @@ dependencies = [
"futures-task", "futures-task",
"memchr", "memchr",
"pin-project-lite", "pin-project-lite",
"pin-utils",
"slab", "slab",
] ]
...@@ -3466,9 +3465,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" ...@@ -3466,9 +3465,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292"
[[package]] [[package]]
name = "syn" name = "syn"
version = "2.0.115" version = "2.0.116"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e614ed320ac28113fa64972c4262d5dbc89deacdfd00c34a3e4cea073243c12" checksum = "3df424c70518695237746f84cede799c9c58fcb37450d7b23716568cc8bc69cb"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
...@@ -4068,9 +4067,9 @@ checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142" ...@@ -4068,9 +4067,9 @@ checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142"
[[package]] [[package]]
name = "unicode-ident" name = "unicode-ident"
version = "1.0.23" version = "1.0.24"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "537dd038a89878be9b64dd4bd1b260315c1bb94f4d784956b81e27a088d9a09e" checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75"
[[package]] [[package]]
name = "unicode-xid" name = "unicode-xid"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment