Unverified Commit 73a9a53f authored by Janelle Cai's avatar Janelle Cai Committed by GitHub
Browse files

feat(router): branch sharded kv indexer (#7859)


Signed-off-by: default avatarHannah Zhang <hannahz@nvidia.com>
Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarHannah Zhang <hannahz@nvidia.com>
parent af32579e
......@@ -50,7 +50,7 @@ dynamo-mocker = { workspace = true }
[dev-dependencies]
async-trait = { workspace = true }
dynamo-kv-router = { workspace = true, features = ["bench"] }
dynamo-kv-router = { workspace = true, features = ["bench", "shard-metrics"] }
dynamo-tokens = { workspace = true }
minstant = "0.1.7"
plotters = { version = "0.3", default-features = false, features = ["svg_backend", "line_series", "point_series", "full_palette"] }
......
......@@ -42,11 +42,11 @@ pub struct CommonArgs {
pub test: bool,
/// Number of GPU blocks available in the mock engine's KV cache.
#[clap(long, default_value = "1048576")]
#[clap(long, default_value = "16384")]
pub num_gpu_blocks: usize,
/// Number of tokens per KV cache block.
#[clap(long, default_value = "512")]
#[clap(long, default_value = "128")]
pub block_size: u32,
/// Wall-clock duration (ms) over which the trace is replayed during event generation.
......@@ -58,7 +58,7 @@ pub struct CommonArgs {
pub benchmark_duration_ms: u64,
/// Number of unique simulated inference workers.
#[clap(short, long, default_value = "256")]
#[clap(short, long, default_value = "1000")]
pub num_unique_inference_workers: usize,
/// How many times to duplicate unique workers during the benchmark phase.
......@@ -124,10 +124,28 @@ pub struct MooncakeRequest {
#[serde(default)]
pub input_length: usize,
pub hash_ids: Vec<u64>,
#[serde(alias = "output_length", alias = "osl")]
pub output_length: u64,
}
#[derive(Deserialize)]
struct RawMooncakeRecord {
#[serde(default)]
timestamp: Option<f64>,
#[serde(default)]
delay: Option<f64>,
hash_ids: Vec<u64>,
#[serde(alias = "output_length", alias = "osl")]
output_length: u64,
}
/// Load the mooncake trace from disk into a flat list of requests.
///
/// Supports two JSONL formats:
/// - Legacy: every record has an integer `timestamp` field (absolute ms).
/// - aiperf: first record has `timestamp` (float), subsequent records have
/// `delay` (float ms since previous). Absolute timestamps are reconstructed
/// by accumulating delays.
pub fn load_mooncake_trace(path: &str) -> anyhow::Result<Vec<MooncakeRequest>> {
let file = File::open(path)?;
let reader = BufReader::new(file);
......@@ -136,8 +154,24 @@ pub fn load_mooncake_trace(path: &str) -> anyhow::Result<Vec<MooncakeRequest>> {
let progress = make_progress_bar(None);
let mut requests = Vec::new();
let mut cursor_ms: f64 = 0.0;
for line in reader.lines() {
requests.push(serde_json::from_str::<MooncakeRequest>(&line?)?);
let raw: RawMooncakeRecord = serde_json::from_str(&line?)?;
if let Some(ts) = raw.timestamp {
cursor_ms = ts;
} else if let Some(d) = raw.delay {
cursor_ms += d;
}
requests.push(MooncakeRequest {
uuid: Uuid::new_v4(),
timestamp: cursor_ms as u64,
input_length: 0,
hash_ids: raw.hash_ids,
output_length: raw.output_length,
});
progress.inc(1);
}
......@@ -155,6 +189,14 @@ pub fn partition_trace(
for request in requests {
traces[rng.random_range(0..num_workers)].push(request);
}
// Sort each worker's trace by timestamp so that scale_mooncake_trace and
// generate_kv_events see monotonically increasing timestamps. Without this,
// mixing requests from multiple sessions (each starting at timestamp=0) into
// one worker produces non-monotonic sequences; u64 underflow in the delta
// computation then creates sleep durations measured in centuries.
for trace in &mut traces {
trace.sort_by_key(|r| r.timestamp);
}
traces
}
......
......@@ -7,14 +7,20 @@ use common::*;
use clap::{Parser, Subcommand};
use dynamo_kv_router::LocalBlockHash;
use dynamo_kv_router::indexer::{KvIndexer, KvIndexerInterface, KvIndexerMetrics};
use dynamo_kv_router::indexer::{
KvIndexer, KvIndexerInterface, KvIndexerMetrics, ShardSizeSnapshot,
};
use dynamo_kv_router::protocols::{KvCacheEvent, KvCacheEventData, RouterEvent};
use dynamo_kv_router::{
ConcurrentRadixTree, ConcurrentRadixTreeCompressed, PositionalIndexer, ThreadPoolIndexer,
BranchShardedIndexer, ConcurrentRadixTree, ConcurrentRadixTreeCompressed, PositionalIndexer,
ThreadPoolIndexer,
};
use dynamo_mocker::loadgen::Trace;
use serde::Serialize;
use std::sync::Arc;
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering},
};
use tokio::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
......@@ -48,6 +54,31 @@ enum IndexerArgs {
#[clap(long, default_value = "16")]
num_event_workers: usize,
},
/// Branch-sharded CRTC: N independent CRTC shards assigned via an explicit routing
/// table keyed on the first K block hashes. New branches are assigned to the
/// least-loaded shard. find_matches touches exactly ONE shard (no scatter-gather).
/// Unknown branch keys return empty scores immediately without any dispatch.
BranchShardedCrtc {
/// Number of independent CRTC shards.
#[clap(long, default_value = "2")]
num_shards: usize,
/// Number of OS event-worker threads per shard.
#[clap(long, default_value = "4")]
num_event_workers_per_shard: usize,
/// Number of prefix blocks hashed to identify a branch. K=2 is the
/// recommended default: depth=1 often produces too few distinct branch
/// keys, while depth=2 gives a much larger set of distinguishable branches.
#[clap(long, default_value = "2")]
prefix_depth: usize,
/// Number of OS threads per shard dedicated to find_matches (read isolation).
/// 0 (default): reads run inline on the calling tokio thread.
#[clap(long, default_value = "0")]
num_read_threads_per_shard: usize,
},
}
impl IndexerArgs {
......@@ -77,6 +108,27 @@ impl IndexerArgs {
block_size,
))
}
IndexerArgs::BranchShardedCrtc {
num_shards,
num_event_workers_per_shard,
prefix_depth,
num_read_threads_per_shard: _,
} => {
let shards = (0..num_shards)
.map(|_| {
ThreadPoolIndexer::new(
ConcurrentRadixTreeCompressed::new(),
num_event_workers_per_shard,
block_size,
)
})
.collect();
Arc::new(BranchShardedIndexer::new_with_options(
shards,
prefix_depth,
block_size,
))
}
}
}
......@@ -87,7 +139,10 @@ impl IndexerArgs {
fn is_multi_threaded(name: &str) -> bool {
matches!(
name,
"nested-map" | "concurrent-radix-tree" | "concurrent-radix-tree-compressed"
"nested-map"
| "concurrent-radix-tree"
| "concurrent-radix-tree-compressed"
| "branch-sharded-crtc"
)
}
......@@ -110,9 +165,16 @@ impl IndexerArgs {
"concurrent-radix-tree-compressed" => IndexerArgs::ConcurrentRadixTreeCompressed {
num_event_workers: nw,
},
"branch-sharded-crtc" => IndexerArgs::BranchShardedCrtc {
num_shards: 2,
num_event_workers_per_shard: nw,
prefix_depth: 2,
num_read_threads_per_shard: 0,
},
_ => anyhow::bail!(
"Unknown indexer '{}'. Valid names: radix-tree, \
nested-map, concurrent-radix-tree, concurrent-radix-tree-compressed",
"Unknown indexer '{}'. Valid names: radix-tree, radix-tree-sharded, \
nested-map, concurrent-radix-tree, concurrent-radix-tree-compressed, \
branch-sharded-crtc",
name
),
};
......@@ -144,6 +206,23 @@ struct Args {
#[clap(long, default_value = "16")]
num_event_workers: usize,
/// Number of additional concurrent tokio tasks that issue find_matches in a
/// tight loop to stress the read path. These tasks run alongside the normal
/// trace-replay workers. Set to 0 (default) to disable.
#[clap(long, default_value = "0")]
find_matches_concurrency: usize,
/// Output path for the shard-size CSV produced when `shard-metrics` feature
/// is enabled. Rows: `elapsed_ms,shard_idx,worker_count,block_count,node_count`.
/// An SVG plot is written alongside it (<path>.svg).
/// Omit or leave empty to disable shard-size sampling.
#[clap(long, default_value = "")]
shard_metrics_csv: String,
/// How often (ms) to sample shard sizes when `--shard-metrics-csv` is set.
#[clap(long, default_value = "200")]
shard_metrics_interval_ms: u64,
/// Indexer backend to benchmark (defaults to radix-tree if not specified).
#[clap(subcommand)]
indexer: Option<IndexerArgs>,
......@@ -219,6 +298,193 @@ struct SweepStepResult {
results: BenchmarkResults,
}
// ---------------------------------------------------------------------------
// Shard-size sampling (always compiled; only called when a CSV path is given)
// ---------------------------------------------------------------------------
/// A single row in the shard-size time-series CSV.
#[derive(Clone)]
struct ShardSampleRow {
elapsed_ms: u64,
snapshot: ShardSizeSnapshot,
}
/// Spawn a background tokio task that samples `indexer.shard_sizes()` every
/// `interval_ms` milliseconds until `cancel` is triggered.
///
/// Returns a `JoinHandle` that resolves to all collected samples.
fn start_shard_sampler(
indexer: Arc<dyn KvIndexerInterface + Send + Sync>,
interval_ms: u64,
cancel: tokio_util::sync::CancellationToken,
) -> tokio::task::JoinHandle<Vec<ShardSampleRow>> {
tokio::spawn(async move {
let mut rows = Vec::new();
let start = Instant::now();
let mut interval = tokio::time::interval(Duration::from_millis(interval_ms));
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
loop {
tokio::select! {
_ = interval.tick() => {
let elapsed_ms = start.elapsed().as_millis() as u64;
for snap in indexer.shard_sizes() {
rows.push(ShardSampleRow { elapsed_ms, snapshot: snap });
}
}
_ = cancel.cancelled() => break,
}
}
rows
})
}
/// Write the collected shard-size samples to a CSV file.
fn write_shard_metrics_csv(rows: &[ShardSampleRow], path: &str) -> anyhow::Result<()> {
use std::io::Write;
let mut f = std::fs::File::create(path)?;
writeln!(
f,
"elapsed_ms,shard_idx,worker_count,block_count,node_count"
)?;
for r in rows {
writeln!(
f,
"{},{},{},{},{}",
r.elapsed_ms,
r.snapshot.shard_idx,
r.snapshot.worker_count,
r.snapshot.block_count,
r.snapshot.node_count,
)?;
}
println!("Shard metrics CSV written to {path}");
Ok(())
}
/// Plot per-shard `worker_count` and `block_count` over time and write an SVG.
///
/// Draws two panels stacked vertically:
/// - Top: workers per shard over time
/// - Bottom: blocks per shard over time
///
/// Each shard gets a distinct colour; shards are identified by their `shard_idx`.
fn plot_shard_metrics(rows: &[ShardSampleRow], svg_path: &str) -> anyhow::Result<()> {
use plotters::prelude::*;
if rows.is_empty() {
return Ok(());
}
// Collect the set of shard indices present.
let mut shard_indices: Vec<usize> = rows.iter().map(|r| r.snapshot.shard_idx).collect();
shard_indices.sort_unstable();
shard_indices.dedup();
let max_elapsed = rows.iter().map(|r| r.elapsed_ms).max().unwrap_or(1);
let max_workers = rows
.iter()
.map(|r| r.snapshot.worker_count)
.max()
.unwrap_or(1);
let max_blocks = rows
.iter()
.map(|r| r.snapshot.block_count)
.max()
.unwrap_or(1);
let colors: Vec<RGBColor> = vec![
RGBColor(31, 119, 180),
RGBColor(255, 127, 14),
RGBColor(44, 160, 44),
RGBColor(214, 39, 40),
RGBColor(148, 103, 189),
RGBColor(140, 86, 75),
];
let root = SVGBackend::new(svg_path, (900, 700)).into_drawing_area();
root.fill(&WHITE)?;
let (upper, lower) = root.split_vertically(350);
// --- Top panel: workers per shard ---
let mut chart = ChartBuilder::on(&upper)
.caption("Workers per shard over time", ("sans-serif", 18))
.margin(15)
.x_label_area_size(30)
.y_label_area_size(60)
.build_cartesian_2d(0u64..max_elapsed, 0usize..max_workers + 1)?;
chart
.configure_mesh()
.x_desc("Elapsed (ms)")
.y_desc("Workers")
.draw()?;
for (i, &shard_idx) in shard_indices.iter().enumerate() {
let color = colors[i % colors.len()];
let points: Vec<(u64, usize)> = rows
.iter()
.filter(|r| r.snapshot.shard_idx == shard_idx)
.map(|r| (r.elapsed_ms, r.snapshot.worker_count))
.collect();
let label = format!("shard {shard_idx}");
chart
.draw_series(LineSeries::new(points, &color))?
.label(label)
.legend(move |(x, y)| {
plotters::element::PathElement::new(
vec![(x, y), (x + 20, y)],
color.stroke_width(2),
)
});
}
chart
.configure_series_labels()
.background_style(WHITE.mix(0.8))
.border_style(BLACK)
.draw()?;
// --- Bottom panel: blocks per shard ---
let mut chart2 = ChartBuilder::on(&lower)
.caption("Blocks per shard over time", ("sans-serif", 18))
.margin(15)
.x_label_area_size(30)
.y_label_area_size(60)
.build_cartesian_2d(0u64..max_elapsed, 0usize..max_blocks + 1)?;
chart2
.configure_mesh()
.x_desc("Elapsed (ms)")
.y_desc("Cached blocks")
.draw()?;
for (i, &shard_idx) in shard_indices.iter().enumerate() {
let color = colors[i % colors.len()];
let points: Vec<(u64, usize)> = rows
.iter()
.filter(|r| r.snapshot.shard_idx == shard_idx)
.map(|r| (r.elapsed_ms, r.snapshot.block_count))
.collect();
let label = format!("shard {shard_idx}");
chart2
.draw_series(LineSeries::new(points, &color))?
.label(label)
.legend(move |(x, y)| {
plotters::element::PathElement::new(
vec![(x, y), (x + 20, y)],
color.stroke_width(2),
)
});
}
chart2
.configure_series_labels()
.background_style(WHITE.mix(0.8))
.border_style(BLACK)
.draw()?;
root.present()?;
println!("Shard metrics plot written to {svg_path}");
Ok(())
}
/// Run the benchmark: replay each worker's merged trace against the indexer,
/// measuring find_matches latency and event processing throughput.
///
......@@ -231,6 +497,7 @@ async fn run_benchmark(
args: &Args,
benchmark_duration_ms: u64,
count_events: bool,
find_matches_concurrency: usize,
) -> anyhow::Result<BenchmarkResults> {
let worker_traces = prepare_worker_traces(artifacts, benchmark_duration_ms);
let worker_traces = worker_traces.into_iter().map(Arc::new).collect::<Vec<_>>();
......@@ -319,12 +586,58 @@ async fn run_benchmark(
}
}
// Spawn additional concurrent find_matches callers if requested.
// These tasks run alongside the trace-replay workers to stress the read path.
let fm_stop = Arc::new(AtomicBool::new(false));
let mut fm_tasks = Vec::new();
if find_matches_concurrency > 0 {
// Collect all request sequences as a pool for random selection.
let seq_pool: Arc<Vec<Vec<LocalBlockHash>>> = Arc::new(
worker_traces
.iter()
.flat_map(|t| t.iter())
.filter_map(|entry| match &entry.entry {
WorkerTraceEntry::Request(hashes) => Some(hashes.clone()),
_ => None,
})
.collect(),
);
if !seq_pool.is_empty() {
for task_id in 0..find_matches_concurrency {
let indexer = indexer.clone();
let pool = Arc::clone(&seq_pool);
let stop = Arc::clone(&fm_stop);
fm_tasks.push(tokio::spawn(async move {
let mut latencies = Vec::new();
let mut idx = task_id % pool.len();
while !stop.load(Ordering::Relaxed) {
let seq = pool[idx].clone();
let start = minstant::Instant::now();
let _ = indexer.find_matches(seq).await;
latencies.push(start.elapsed().as_nanos() as u64);
idx = (idx + 1) % pool.len();
}
latencies
}));
}
}
}
let mut latencies = Vec::new();
for task in tasks {
latencies.extend(task.await??);
}
// Signal concurrent find_matches callers to stop and collect their latencies.
fm_stop.store(true, Ordering::Relaxed);
for task in fm_tasks {
if let Ok(fm_latencies) = task.await {
latencies.extend(fm_latencies);
}
}
if progress.elapsed() > Duration::from_millis(benchmark_duration_ms * 11 / 10) {
eprintln!(
"WARNING: The benchmarker is unable to keep up with the request/event generation rate. Rerun with a larger --benchmark-duration-ms."
......@@ -389,10 +702,13 @@ async fn run_benchmark(
};
println!(
"Ops Throughput: {} ops/s (requests + events)",
ops_throughput
"Offered Ops Throughput: {} ops/s | Achieved: {} ops/s (requests + events)",
offered_ops_throughput as u64, ops_throughput as u64,
);
println!(
"Offered Block Throughput: {} block ops/s | Achieved: {} block ops/s",
offered_block_throughput as u64, block_throughput as u64,
);
println!("Block Throughput: {} block ops/s", block_throughput);
println!("Latency p99: {}us", latency_p99_us);
Ok(BenchmarkResults {
......@@ -543,6 +859,7 @@ async fn main() -> anyhow::Result<()> {
IndexerArgs::NestedMap { .. } => "nested-map",
IndexerArgs::ConcurrentRadixTree { .. } => "concurrent-radix-tree",
IndexerArgs::ConcurrentRadixTreeCompressed { .. } => "concurrent-radix-tree-compressed",
IndexerArgs::BranchShardedCrtc { .. } => "branch-sharded-crtc",
};
vec![name.to_string()]
} else {
......@@ -582,8 +899,15 @@ async fn main() -> anyhow::Result<()> {
IndexerArgs::from_name(name, args.common.block_size, args.num_event_workers)?
};
let count_events = IndexerArgs::supports_remove(name);
let result =
run_benchmark(indexer, artifacts.clone(), &args, dur_ms, count_events).await?;
let result = run_benchmark(
indexer,
artifacts.clone(),
&args,
dur_ms,
count_events,
args.find_matches_concurrency,
)
.await?;
if multi_threaded {
if result.block_throughput >= result.offered_block_throughput * 0.95 {
......@@ -640,6 +964,8 @@ async fn main() -> anyhow::Result<()> {
std::fs::write(&json_path, serde_json::to_string_pretty(&json_map)?)?;
println!("Sweep results saved to {}", json_path);
} else {
drop(traces);
for name in &indexer_names {
println!("\nBenchmarking indexer: {}", name);
let indexer = if args.compare.is_empty() {
......@@ -648,14 +974,82 @@ async fn main() -> anyhow::Result<()> {
IndexerArgs::from_name(name, args.common.block_size, args.num_event_workers)?
};
let count_events = IndexerArgs::supports_remove(name);
// Start shard-size sampler if a CSV path was provided.
let shard_cancel = CancellationToken::new();
let shard_sampler = if !args.shard_metrics_csv.is_empty() {
Some(start_shard_sampler(
indexer.clone(),
args.shard_metrics_interval_ms,
shard_cancel.clone(),
))
} else {
None
};
run_benchmark(
indexer,
indexer.clone(),
artifacts.clone(),
&args,
args.common.benchmark_duration_ms,
count_events,
args.find_matches_concurrency,
)
.await?;
// Stop sampler and write CSV + plot.
shard_cancel.cancel();
if let Some(handle) = shard_sampler {
let rows = handle.await?;
// In compare mode, prefix the indexer name to distinguish outputs.
let csv_path = if args.compare.len() > 1 {
let stem = args.shard_metrics_csv.trim_end_matches(".csv");
format!("{stem}_{name}.csv")
} else {
args.shard_metrics_csv.clone()
};
write_shard_metrics_csv(&rows, &csv_path)?;
let svg = format!("{}.svg", csv_path.trim_end_matches(".csv"));
plot_shard_metrics(&rows, &svg)?;
}
let report = indexer.timing_report();
if !report.is_empty() {
println!("{}", report);
}
let sizes = indexer.shard_sizes();
if sizes.len() > 1 {
let total_blocks: usize = sizes.iter().map(|s| s.block_count).sum();
let total_nodes: usize = sizes.iter().map(|s| s.node_count).sum();
println!("Shard block distribution:");
for s in &sizes {
let pct = if total_blocks > 0 {
100.0 * s.block_count as f64 / total_blocks as f64
} else {
0.0
};
println!(
" shard {}: {} blocks ({:.1}%), {} workers, {} nodes",
s.shard_idx, s.block_count, pct, s.worker_count, s.node_count
);
}
if total_nodes > 0 {
println!(" total nodes across shards: {}", total_nodes);
}
}
let mut edge_lengths = indexer.node_edge_lengths();
if !edge_lengths.is_empty() {
let avg = edge_lengths.iter().sum::<usize>() as f64 / edge_lengths.len() as f64;
edge_lengths.sort_unstable();
let p99 = edge_lengths[edge_lengths.len() * 99 / 100];
println!(
"Node edge lengths ({} nodes): avg={:.1} hashes/node, p99={} hashes/node",
edge_lengths.len(),
avg,
p99,
);
}
}
}
......
......@@ -17,7 +17,9 @@ default = []
metrics = ["dep:prometheus"]
runtime-protocols = ["dep:dynamo-runtime"]
bench = []
shard-metrics = []
standalone-indexer = ["dep:axum", "dep:serde_json", "dep:reqwest", "dep:zmq"]
indexer-runtime = ["metrics", "runtime-protocols", "standalone-indexer"]
[dependencies]
# repo
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Branch-based prefix sharding over `ThreadPoolIndexer<T>`.
//!
//! [`BranchShardedIndexer`] partitions the prefix space by building an explicit
//! routing table that maps branch keys (FNV-1a hash of first `prefix_depth`
//! block hashes) to shard indices. Unlike [`PrefixShardedIndexer`] which uses
//! `hash % N`, new branches are assigned to the **least-loaded shard** at first
//! insertion time, so load is balanced regardless of hash distribution.
//!
//! ## Key properties
//!
//! - **Single-shard `find_matches`**: a query routes to exactly one shard — no
//! scatter-gather. Read throughput scales linearly with shard count.
//! - **Least-loaded branch assignment**: each new branch key is assigned to the
//! shard with the fewest branches, ensuring balanced distribution even when
//! the underlying hash values cluster.
//! - **Stable shard assignment**: once a branch is assigned, it never migrates.
//! CRTC-internal splits stay within the owning shard — no migration protocol
//! needed. The shard assignment is keyed on the *sequence prefix* (first K
//! blocks), not on tree nodes, so splits are transparent to this layer.
//! - **Unknown-branch fast path**: if a query's branch key is not in the routing
//! table, no worker has ever stored that prefix. `find_matches` returns empty
//! scores immediately without dispatching to any shard.
//!
//! ## Remove routing
//!
//! Two strategies are used in combination:
//!
//! 1. **Mapping (primary)**: each `block_hash` is looked up in a
//! `block_to_shard` index (populated at Stored time) and routed to its
//! owning shard only.
//! 2. **Broadcast fallback**: if a block hash is absent from the index (evicted,
//! out-of-order event, or index overflow), the Remove is broadcast to all
//! shards. Each shard's CRTC handles a missing block as a no-op.
//! `remove_broadcast_count` tracks how often this occurs.
use std::sync::{
Arc, Mutex,
atomic::{AtomicU64, AtomicUsize, Ordering},
};
use async_trait::async_trait;
use dashmap::DashMap;
use rustc_hash::FxBuildHasher;
use super::{KvIndexerInterface, KvRouterError, ShardSizeSnapshot, SyncIndexer, ThreadPoolIndexer};
use crate::protocols::*;
// ---------------------------------------------------------------------------
// Per-shard read thread pool (kept for potential future use)
// ---------------------------------------------------------------------------
/// A bounded pool of OS threads dedicated to `find_matches` requests for one
/// shard. Mirrors the equivalent struct in `prefix_sharded.rs`.
///
/// Not currently used by [`BranchShardedIndexer`] — reads run inline on the
/// caller's thread. Retained here as a building block if dedicated read
/// isolation is needed in the future.
#[allow(dead_code)]
struct ShardReadPool {
sender: flume::Sender<(
Vec<LocalBlockHash>,
tokio::sync::oneshot::Sender<OverlapScores>,
)>,
_threads: Vec<std::thread::JoinHandle<()>>,
}
#[allow(dead_code)]
impl ShardReadPool {
fn new<T: SyncIndexer>(backend: Arc<T>, num_threads: usize) -> Self {
let (tx, rx) = flume::unbounded();
let mut threads = Vec::with_capacity(num_threads);
for _ in 0..num_threads {
let backend = Arc::clone(&backend);
let rx: flume::Receiver<(
Vec<LocalBlockHash>,
tokio::sync::oneshot::Sender<OverlapScores>,
)> = rx.clone();
threads.push(std::thread::spawn(move || {
while let Ok((seq, resp_tx)) = rx.recv() {
let result = backend.find_matches(&seq, false);
let _ = resp_tx.send(result);
}
}));
}
Self {
sender: tx,
_threads: threads,
}
}
}
// ---------------------------------------------------------------------------
// FNV-1a constants
// ---------------------------------------------------------------------------
const FNV_OFFSET_BASIS: u64 = 14695981039346656037;
const FNV_PRIME: u64 = 1099511628211;
/// Fold one `u64` value into an FNV-1a accumulator.
#[inline(always)]
fn fnv_fold(state: u64, value: u64) -> u64 {
let mut h = state;
for b in value.to_le_bytes() {
h ^= b as u64;
h = h.wrapping_mul(FNV_PRIME);
}
h
}
// ---------------------------------------------------------------------------
// BranchShardedIndexer
// ---------------------------------------------------------------------------
/// Branch-sharded wrapper over N [`ThreadPoolIndexer<T>`] instances.
///
/// Construct with [`BranchShardedIndexer::new`].
pub struct BranchShardedIndexer<T: SyncIndexer> {
shards: Vec<Arc<ThreadPoolIndexer<T>>>,
num_shards: usize,
/// Number of leading blocks used to identify a branch. Default: 2.
prefix_depth: usize,
/// Routing table: FNV-1a(first `prefix_depth` `LocalBlockHash`) → shard index.
///
/// Populated lazily at first `Stored` event for each distinct branch.
branch_to_shard: DashMap<u64, usize, FxBuildHasher>,
/// Number of branches assigned to each shard (for observability).
branch_counts: Mutex<Vec<usize>>,
/// Eagerly-updated block count per shard.
///
/// Incremented synchronously in `apply_event` (before the event is dispatched
/// to the async worker thread) so that `assign_shard` always sees an up-to-date
/// load estimate even when the CRTC backend has not yet processed the event.
/// This prevents every branch from being assigned to the same shard during
/// burst startup, when all CRTC node counts are still zero.
shard_block_counts: Vec<AtomicUsize>,
/// Remove index: `ExternalSequenceBlockHash.0` → `(shard_index, ref_count)`.
///
/// Written on `Stored` (ref_count incremented), decremented on `Removed`.
/// The entry is deleted only when ref_count reaches zero — i.e. every worker
/// that stored the block has since evicted it.
///
/// Note: `block_to_shard` entries are content-addressed — the same
/// `ExternalSequenceBlockHash` can be shared by multiple workers (identical
/// token sequences). Without ref-counting, the first worker to evict a
/// shared block would delete the entry, causing all subsequent workers'
/// Removed events for that block to fall through to broadcast. Ref-counting
/// keeps the entry alive until the last holder evicts it.
///
/// A `Cleared` event does NOT touch this map because doing so would break
/// routing for other workers whose continuations reference the same parent
/// hashes. Only `Removed` events (which carry explicit block hashes)
/// decrement the ref-count.
///
/// Note: parent-hash inheritance via this map is only used once a chain tail
/// has reached `prefix_depth` blocks (depth ≥ prefix_depth). Shallower
/// tails are tracked in `block_to_fnv_state` and route by FNV accumulation.
block_to_shard: DashMap<u64, (usize, usize), FxBuildHasher>,
/// FNV accumulator for chain tails that have not yet reached `prefix_depth` blocks.
///
/// Maps the `ExternalSequenceBlockHash.0` of the **last stored block** in a
/// shallow chain to `(accumulated_fnv, depth)`, where `depth < prefix_depth`.
///
/// # Why this exists
///
/// For workloads with a shared prefix shorter than `prefix_depth` (e.g. a
/// 15-block system prompt with `prefix_depth = 17`), all root events produce
/// the **same** partial FNV hash, collapsing every conversation onto a single
/// shard. By carrying the accumulated FNV forward into continuation events,
/// each conversation extends the hash with its own unique blocks (positions
/// 15 and 16) and thereby receives a distinct, balanced shard assignment.
///
/// # CRTC chain / lookup notes
///
/// When a continuation's finalized FNV routes it to a different shard than its
/// parent, the CRTC on the new shard will not find the parent and will drop the
/// event. Fixing this fully requires replaying the shallow prefix to the new
/// shard ("shallow chain replay"), which is left as a future improvement. For
/// now the routing table is correct — `find_matches` routes to the right shard —
/// but the underlying CRTC may have no data there until replay is implemented.
///
/// Separately, `find_matches` hashes only the available prefix
/// (`min(prefix_depth, len)`). A query shorter than `prefix_depth` therefore
/// probes with a shorter key than a root `Stored` event that first established
/// the branch with `>= prefix_depth` blocks. With `prefix_depth > 1`, that can
/// cause false early-miss returns for short queries unless shorter-prefix keys
/// are also recorded or reads fall back to a broader lookup.
///
/// Like `block_to_shard`, entries are content-addressed and are NOT removed by
/// `Cleared` events; only `Removed` events prune them.
block_to_fnv_state: DashMap<u64, (u64, usize), FxBuildHasher>,
kv_block_size: u32,
// --- timing / observability ---
/// Number of `find_matches` calls that dispatched to a shard.
timing_calls: AtomicU64,
/// Cumulative routing (table-lookup) time for dispatched calls (ns).
timing_sum_routing_ns: AtomicU64,
/// Cumulative delegated shard `find_matches` time (ns).
timing_sum_shard_ns: AtomicU64,
/// `find_matches` calls that returned early (unknown branch key).
find_matches_miss_count: AtomicU64,
/// Individual `Removed` block hashes that fell back to broadcast.
remove_broadcast_count: AtomicU64,
}
impl<T: SyncIndexer> BranchShardedIndexer<T> {
/// Create a branch-sharded indexer from pre-built [`ThreadPoolIndexer`] shards.
///
/// # Arguments
///
/// * `shards` - One `ThreadPoolIndexer` per shard.
/// * `prefix_depth` - Number of prefix blocks to hash for routing. Clamped
/// to ≥ 1. K=2 is the recommended default (depth=1 gives too few distinct
/// branch keys on many workloads).
/// * `kv_block_size` - Block size for KV cache.
///
/// # Panics
///
/// Panics if `shards` is empty.
pub fn new(shards: Vec<ThreadPoolIndexer<T>>, prefix_depth: usize, kv_block_size: u32) -> Self {
assert!(!shards.is_empty(), "Must provide at least one shard");
let num_shards = shards.len();
let shards: Vec<Arc<ThreadPoolIndexer<T>>> = shards.into_iter().map(Arc::new).collect();
Self {
shards,
num_shards,
prefix_depth: prefix_depth.max(1),
branch_to_shard: DashMap::with_hasher(FxBuildHasher),
branch_counts: Mutex::new(vec![0usize; num_shards]),
shard_block_counts: (0..num_shards).map(|_| AtomicUsize::new(0)).collect(),
block_to_shard: DashMap::with_hasher(FxBuildHasher),
block_to_fnv_state: DashMap::with_hasher(FxBuildHasher),
kv_block_size,
timing_calls: AtomicU64::new(0),
timing_sum_routing_ns: AtomicU64::new(0),
timing_sum_shard_ns: AtomicU64::new(0),
find_matches_miss_count: AtomicU64::new(0),
remove_broadcast_count: AtomicU64::new(0),
}
}
/// Alias for [`BranchShardedIndexer::new`], kept for call-site compatibility.
pub fn new_with_options(
shards: Vec<ThreadPoolIndexer<T>>,
prefix_depth: usize,
kv_block_size: u32,
) -> Self {
Self::new(shards, prefix_depth, kv_block_size)
}
// --- branch key computation ---
/// FNV-1a hash of the first `min(prefix_depth, len)` `LocalBlockHash` values.
///
/// Used by `find_matches` to compute the branch key for an incoming query.
fn branch_key_for_local_hashes(&self, hashes: &[LocalBlockHash]) -> u64 {
let k = self.prefix_depth.min(hashes.len());
hashes[..k]
.iter()
.fold(FNV_OFFSET_BASIS, |h, block| fnv_fold(h, block.0))
}
/// FNV-1a hash of the first `min(prefix_depth, len)` `tokens_hash` values
/// from a `Stored` event's block list.
fn branch_key_for_stored_blocks(&self, blocks: &[KvCacheStoredBlockData]) -> u64 {
let k = self.prefix_depth.min(blocks.len());
blocks[..k].iter().fold(FNV_OFFSET_BASIS, |h, block| {
fnv_fold(h, block.tokens_hash.0)
})
}
// --- routing table operations ---
fn lookup_shard(&self, branch_key: u64) -> Option<usize> {
self.branch_to_shard.get(&branch_key).map(|v| *v)
}
/// Get or create a shard assignment for a branch key.
///
/// Fast path if already assigned; otherwise acquires the lock, picks the
/// least-loaded shard, and inserts atomically.
///
/// Load is measured by **live block count** in each shard (an O(1) atomic
/// read). Block count is a better proxy than branch count when conversation
/// lengths vary widely — long conversations contribute many more blocks than
/// short ones even though both count as one branch. Branch count is used as
/// a tiebreaker when block counts are equal (e.g. at startup before any
/// events have been processed).
fn assign_shard(&self, branch_key: u64) -> usize {
if let Some(shard_idx) = self.branch_to_shard.get(&branch_key).map(|v| *v) {
return shard_idx;
}
let mut counts = self.branch_counts.lock().unwrap();
if let Some(shard_idx) = self.branch_to_shard.get(&branch_key).map(|v| *v) {
return shard_idx;
}
let selected = self
.shard_block_counts
.iter()
.enumerate()
.min_by(|(i, a), (j, b)| {
a.load(Ordering::Relaxed)
.cmp(&b.load(Ordering::Relaxed))
.then(counts[*i].cmp(&counts[*j]))
})
.unwrap()
.0;
counts[selected] += 1;
drop(counts);
self.branch_to_shard.insert(branch_key, selected);
selected
}
// -----------------------------------------------------------------------
// Private event handlers (called from apply_event)
// -----------------------------------------------------------------------
/// Compute the target shard and (if still shallow) the updated FNV
/// accumulator state for a `Stored` event.
///
/// Shard assignment uses accumulated FNV until the chain reaches
/// `prefix_depth` blocks, then switches to parent-hash inheritance.
///
/// Three cases:
///
/// A. Parent tail found in `block_to_fnv_state` (depth < prefix_depth):
/// Extend the FNV accumulator with leading blocks from this batch.
/// Once the accumulated depth reaches `prefix_depth`, call
/// `assign_shard` with the finalized key so that distinct
/// continuations receive distinct shard assignments.
/// Record the updated state on the last block of this batch if the
/// chain is still shallow after processing.
///
/// B. Parent tail found in `block_to_shard` (depth >= prefix_depth):
/// Inherit the shard — the branch was already decided.
///
/// C. No parent (root) or OOO (parent not in either map):
/// Compute FNV from this batch's own blocks. For root events
/// shorter than `prefix_depth` this is a partial key; a future
/// continuation in case A will extend it to the full depth.
///
/// Returns `(shard_idx, Option<(fnv, depth)>)`. A `Some` state means
/// the chain has not yet reached `prefix_depth` blocks; the caller should
/// record it on the last block of the batch so the next continuation can
/// extend it.
fn compute_stored_routing(
&self,
store_data: &KvCacheStoreData,
) -> (usize, Option<(u64, usize)>) {
if let Some(parent_hash) = &store_data.parent_hash {
if let Some(entry) = self.block_to_fnv_state.get(&parent_hash.0) {
// Case A: parent is shallow — extend FNV accumulator.
let (parent_fnv, parent_depth) = *entry;
drop(entry);
let remaining = self.prefix_depth - parent_depth;
let to_process = remaining.min(store_data.blocks.len());
let fnv = store_data.blocks[..to_process]
.iter()
.fold(parent_fnv, |h, block| fnv_fold(h, block.tokens_hash.0));
let new_depth = parent_depth + to_process;
let shard = self.assign_shard(fnv);
let state = (new_depth < self.prefix_depth).then_some((fnv, new_depth));
(shard, state)
} else if let Some(shard) = self.block_to_shard.get(&parent_hash.0).map(|v| v.0) {
// Case B: deep chain — inherit shard.
(shard, None)
} else {
// Case C (OOO): parent not in either map; best-effort key from this batch.
let key = self.branch_key_for_stored_blocks(&store_data.blocks);
(self.assign_shard(key), None)
}
} else {
// Case C (root): start FNV accumulation from scratch.
let to_process = self.prefix_depth.min(store_data.blocks.len());
let fnv = store_data.blocks[..to_process]
.iter()
.fold(FNV_OFFSET_BASIS, |h, block| {
fnv_fold(h, block.tokens_hash.0)
});
let depth = to_process;
let shard = self.assign_shard(fnv);
let state = (depth < self.prefix_depth).then_some((fnv, depth));
(shard, state)
}
}
async fn apply_stored(&self, event: RouterEvent) {
let KvCacheEventData::Stored(store_data) = &event.event.data else {
return;
};
let (shard_idx, new_fnv_state) = self.compute_stored_routing(store_data);
// Update eager block count before dispatching.
self.shard_block_counts[shard_idx].fetch_add(store_data.blocks.len(), Ordering::Relaxed);
// Record block → shard before dispatching so a fast continuation
// can find entries immediately.
for block in &store_data.blocks {
self.block_to_shard
.entry(block.block_hash.0)
.and_modify(|e| e.1 += 1)
.or_insert((shard_idx, 1));
}
// Propagate partial FNV state on the last block of this batch.
if let Some(fnv_state) = new_fnv_state
&& let Some(last_block) = store_data.blocks.last()
{
self.block_to_fnv_state
.insert(last_block.block_hash.0, fnv_state);
}
self.shards[shard_idx].apply_event(event).await;
}
async fn apply_removed(&self, event: RouterEvent) {
// Copy metadata before borrowing event.event.data.
let worker_id = event.worker_id;
let storage_tier = event.storage_tier;
let event_id = event.event.event_id;
let dp_rank = event.event.dp_rank;
let KvCacheEventData::Removed(remove_data) = &event.event.data else {
return;
};
// --- Plan: classify each block as mapped-to-shard or broadcast ---
let mut shard_blocks: Vec<Vec<ExternalSequenceBlockHash>> =
vec![Vec::new(); self.num_shards];
let mut broadcast_blocks: Vec<ExternalSequenceBlockHash> = Vec::new();
for &block_hash in &remove_data.block_hashes {
self.block_to_fnv_state.remove(&block_hash.0);
let found_shard = self.block_to_shard.get_mut(&block_hash.0).map(|mut e| {
let shard_idx = e.0;
e.1 = e.1.saturating_sub(1);
shard_idx
});
match found_shard {
Some(shard_idx) => {
self.block_to_shard
.remove_if(&block_hash.0, |_, v| v.1 == 0);
shard_blocks[shard_idx].push(block_hash);
}
None => {
self.remove_broadcast_count.fetch_add(1, Ordering::Relaxed);
broadcast_blocks.push(block_hash);
}
}
}
// --- Dispatch: route mapped removes to their owning shards ---
for (shard_idx, blocks) in shard_blocks.into_iter().enumerate() {
if blocks.is_empty() {
continue;
}
self.shard_block_counts[shard_idx]
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |count| {
Some(count.saturating_sub(blocks.len()))
})
.ok();
let shard_event = RouterEvent {
worker_id,
storage_tier,
event: KvCacheEvent {
event_id,
dp_rank,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: blocks,
}),
},
};
self.shards[shard_idx].apply_event(shard_event).await;
}
// Broadcast unknown blocks to all shards; each CRTC treats a missing
// block as a no-op so correctness is maintained.
if !broadcast_blocks.is_empty() {
for shard in &self.shards {
let broadcast_event = RouterEvent {
worker_id,
storage_tier,
event: KvCacheEvent {
event_id,
dp_rank,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: broadcast_blocks.clone(),
}),
},
};
shard.apply_event(broadcast_event).await;
}
}
}
}
#[async_trait]
impl<T: SyncIndexer> KvIndexerInterface for BranchShardedIndexer<T> {
/// Route to a single shard determined by the first `prefix_depth` block hashes.
///
/// If the branch key is not in the routing table, no worker has ever stored
/// that prefix, so the result would be empty regardless of which shard is
/// queried. We return `OverlapScores::new()` immediately without dispatching.
async fn find_matches(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<OverlapScores, KvRouterError> {
let t_routing = std::time::Instant::now();
let branch_key = self.branch_key_for_local_hashes(&sequence);
let shard_idx = match self.lookup_shard(branch_key) {
Some(idx) => idx,
None => {
self.find_matches_miss_count.fetch_add(1, Ordering::Relaxed);
return Ok(OverlapScores::new());
}
};
let routing_ns = t_routing.elapsed().as_nanos() as u64;
let t_shard = std::time::Instant::now();
let result = self.shards[shard_idx].find_matches(sequence).await;
let shard_ns = t_shard.elapsed().as_nanos() as u64;
self.timing_calls.fetch_add(1, Ordering::Relaxed);
self.timing_sum_routing_ns
.fetch_add(routing_ns, Ordering::Relaxed);
self.timing_sum_shard_ns
.fetch_add(shard_ns, Ordering::Relaxed);
result
}
async fn find_matches_for_request(
&self,
tokens: &[u32],
lora_name: Option<&str>,
is_eagle: Option<bool>,
) -> Result<OverlapScores, KvRouterError> {
let sequence = compute_block_hash_for_seq(
tokens,
self.kv_block_size,
BlockHashOptions {
lora_name,
is_eagle,
block_mm_infos: None,
},
);
let branch_key = self.branch_key_for_local_hashes(&sequence);
match self.lookup_shard(branch_key) {
Some(idx) => self.shards[idx].find_matches(sequence).await,
None => Ok(OverlapScores::new()),
}
}
async fn apply_event(&self, event: RouterEvent) {
match &event.event.data {
KvCacheEventData::Stored(_) => self.apply_stored(event).await,
KvCacheEventData::Removed(_) => self.apply_removed(event).await,
KvCacheEventData::Cleared => {
// A worker may have blocks across multiple shards (different
// branches stored over its lifetime) — broadcast to all.
for shard in &self.shards {
shard.apply_event(event.clone()).await;
}
}
}
}
async fn remove_worker(&self, worker_id: WorkerId) {
// A worker may have blocks on any shard — broadcast.
for shard in &self.shards {
shard.remove_worker(worker_id).await;
}
}
async fn remove_worker_dp_rank(&self, worker_id: WorkerId, dp_rank: DpRank) {
for shard in &self.shards {
shard.remove_worker_dp_rank(worker_id, dp_rank).await;
}
}
fn shutdown(&self) {
for shard in &self.shards {
shard.shutdown();
}
}
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
let mut all_events = Vec::new();
for shard in &self.shards {
all_events.extend(shard.dump_events().await?);
}
Ok(all_events)
}
async fn process_routing_decision_for_request(
&self,
_tokens_with_hashes: &mut TokensWithHashes,
_worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
Ok(())
}
async fn flush(&self) -> usize {
let mut total = 0;
for shard in &self.shards {
total += <ThreadPoolIndexer<T> as KvIndexerInterface>::flush(shard).await;
}
total
}
fn shard_sizes(&self) -> Vec<ShardSizeSnapshot> {
self.shards
.iter()
.enumerate()
.flat_map(|(idx, shard)| {
// ThreadPoolIndexer::shard_sizes() already populates node_count
// via backend.node_count() (O(1)). No need to call
// node_edge_lengths().len() which allocates an O(N) Vec.
shard.shard_sizes().into_iter().map(move |mut s| {
s.shard_idx = idx;
s
})
})
.collect()
}
fn node_edge_lengths(&self) -> Vec<usize> {
self.shards
.iter()
.flat_map(|shard| shard.node_edge_lengths())
.collect()
}
fn timing_report(&self) -> String {
let dispatched = self.timing_calls.load(Ordering::Relaxed);
let misses = self.find_matches_miss_count.load(Ordering::Relaxed);
let total_calls = dispatched + misses;
let broadcasts = self.remove_broadcast_count.load(Ordering::Relaxed);
if total_calls == 0 {
return String::new();
}
let miss_pct = 100.0 * misses as f64 / total_calls as f64;
let avg_routing_ns = if dispatched > 0 {
self.timing_sum_routing_ns.load(Ordering::Relaxed) / dispatched
} else {
0
};
let avg_shard_us = if dispatched > 0 {
self.timing_sum_shard_ns.load(Ordering::Relaxed) / dispatched / 1000
} else {
0
};
let branch_counts = self.branch_counts.lock().unwrap();
let total_branches: usize = branch_counts.iter().sum();
let branch_dist: Vec<String> = branch_counts
.iter()
.enumerate()
.map(|(i, c)| format!("shard[{i}]={c}"))
.collect();
drop(branch_counts);
format!(
"BranchShardedIndexer find_matches ({total_calls} total: {dispatched} dispatched, \
{misses} early-exit / {miss_pct:.1}% miss):\n \
avg routing = {avg_routing_ns}ns (routing table lookup)\n \
avg shard = {avg_shard_us}µs (CRTC traversal, inline on caller thread)\n \
branches known = {total_branches} ({})\n \
remove broadcasts = {broadcasts} (fallback for blocks absent from index)",
branch_dist.join(", ")
)
}
}
......@@ -31,6 +31,8 @@
//!
//! This module provides a scalable and efficient way to manage and retrieve data blocks for LLM inference, leveraging a global KV cache to optimize performance.
mod branch_sharded;
fn warn_on_unit_block_size(indexer_type: &'static str, kv_block_size: u32) {
if kv_block_size == 1 {
tracing::warn!(
......@@ -40,7 +42,6 @@ fn warn_on_unit_block_size(indexer_type: &'static str, kv_block_size: u32) {
);
}
}
mod kv_indexer;
mod local;
mod metrics;
......@@ -58,6 +59,7 @@ pub mod radix_tree;
mod tests;
// Re-export everything that was public in the old single-file module.
pub use branch_sharded::*;
pub use kv_indexer::*;
pub use local::*;
pub use metrics::*;
......
......@@ -12,7 +12,9 @@ use dashmap::DashMap;
use rustc_hash::FxBuildHasher;
use tokio::sync::oneshot;
use super::{KvIndexerInterface, KvIndexerMetrics, KvRouterError, SyncIndexer, WorkerTask};
use super::{
KvIndexerInterface, KvIndexerMetrics, KvRouterError, ShardSizeSnapshot, SyncIndexer, WorkerTask,
};
use crate::protocols::*;
/// Generic wrapper that provides [`KvIndexerInterface`] for any [`SyncIndexer`] backend.
......@@ -133,6 +135,15 @@ impl<T: SyncIndexer> ThreadPoolIndexer<T> {
&self.backend
}
/// Get a cloned `Arc` to the underlying backend.
///
/// Useful when a caller needs to hand off an owned `Arc<T>` to a blocking
/// task (e.g. `tokio::task::spawn_blocking`) without cloning the backend
/// itself.
pub fn backend_arc(&self) -> Arc<T> {
Arc::clone(&self.backend)
}
/// Wait for all worker channels to drain.
///
/// Used primarily for testing and benchmarking to ensure all queued events
......@@ -365,4 +376,17 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
}
curr_size
}
fn shard_sizes(&self) -> Vec<ShardSizeSnapshot> {
vec![ShardSizeSnapshot {
shard_idx: 0,
worker_count: self.backend.worker_count(),
block_count: self.backend.block_count(),
node_count: self.backend.node_count(),
}]
}
fn node_edge_lengths(&self) -> Vec<usize> {
self.backend.node_edge_lengths()
}
}
......@@ -8,6 +8,23 @@ use std::sync::Arc;
use super::{KvIndexerMetrics, KvRouterError, WorkerTask};
use crate::protocols::*;
/// Per-shard size snapshot returned by [`KvIndexerInterface::shard_sizes`].
///
/// `worker_count` and `block_count` are always populated.
/// `node_count` is populated only when the `shard-metrics` feature is enabled
/// on the `dynamo-kv-router` crate; otherwise it is `0`.
#[derive(Debug, Clone)]
pub struct ShardSizeSnapshot {
/// Zero-based shard index.
pub shard_idx: usize,
/// Distinct `(worker_id, dp_rank)` pairs stored in this shard.
pub worker_count: usize,
/// Total cached blocks across all workers in this shard.
pub block_count: usize,
/// Radix-tree node count (only non-zero with `shard-metrics` feature).
pub node_count: usize,
}
#[async_trait]
pub trait KvIndexerInterface {
/// Find matches for a given sequence of `LocalBlockHash`es.
......@@ -93,6 +110,32 @@ pub trait KvIndexerInterface {
/// Returns the amount of events still in the queue at the time of the flush.
/// Used primarily for debugging.
async fn flush(&self) -> usize;
/// Return a human-readable timing breakdown of `find_matches` overhead.
///
/// Implementations that track per-phase timing (e.g. scatter/gather overhead
/// vs. actual shard work) override this to return a multi-line report string.
/// The default returns an empty string so callers can skip printing it.
fn timing_report(&self) -> String {
String::new()
}
/// Return a size snapshot for each shard.
///
/// Single-shard indexers return one entry (shard 0). Multi-shard indexers
/// return one entry per shard. Non-sharded indexers (and implementations
/// that don't override this) return an empty `Vec`.
///
/// See [`ShardSizeSnapshot`] for the fields exposed per shard.
fn shard_sizes(&self) -> Vec<ShardSizeSnapshot> {
vec![]
}
/// Edge lengths (hashes per node) for every non-root node.
/// Returns an empty vec for backends that don't support this.
fn node_edge_lengths(&self) -> Vec<usize> {
vec![]
}
}
// ============================================================================
......@@ -136,4 +179,26 @@ pub trait SyncIndexer: Send + Sync + 'static {
fn dump_events(&self) -> Option<Vec<RouterEvent>> {
None
}
/// Number of distinct workers registered in this backend.
fn worker_count(&self) -> usize {
0
}
/// Total cached blocks across all workers.
fn block_count(&self) -> usize {
0
}
/// Number of radix-tree nodes created since construction.
/// Only meaningful when the `shard-metrics` feature is enabled; returns 0 otherwise.
fn node_count(&self) -> usize {
0
}
/// Edge lengths (hashes per node) for every non-root node in the tree.
/// Returns an empty vec for backends that don't support this.
fn node_edge_lengths(&self) -> Vec<usize> {
vec![]
}
}
......@@ -44,7 +44,7 @@ pub use self::sequence::{ActiveSequences, RequestId};
pub use concurrent_radix_tree::ConcurrentRadixTree;
pub use concurrent_radix_tree_compressed::ConcurrentRadixTreeCompressed;
pub use config::{KvRouterConfig, RouterConfigOverride, RouterPrefillLoadModel, RouterQueuePolicy};
pub use indexer::{MaybeError, SyncIndexer, ThreadPoolIndexer};
pub use indexer::{BranchShardedIndexer, MaybeError, SyncIndexer, ThreadPoolIndexer};
pub use nested_map::PositionalIndexer;
pub use protocols::{
KvCacheEventError, LocalBlockHash, OverlapScores, RouterEvent, RouterEventSink,
......
......@@ -162,9 +162,20 @@ impl Trace {
let hash_ids = raw
.hash_ids
.ok_or_else(|| anyhow!("trace line {} is missing hash_ids", line_idx + 1))?;
// Clamp input_length to the synthesizable capacity: in the mooncake
// trace format, input_length is the full prompt token count which may
// exceed hash_ids.len() * block_size (cached portion only).
let synthesizable_capacity =
hash_ids
.len()
.checked_mul(trace_block_size)
.ok_or_else(|| {
anyhow!("trace line {} synthesized capacity overflow", line_idx + 1)
})?;
let input_length = raw
.input_length
.unwrap_or(hash_ids.len() * trace_block_size);
.unwrap_or(synthesizable_capacity)
.min(synthesizable_capacity);
let output_length = raw
.output_length
.ok_or_else(|| anyhow!("trace line {} is missing output_length", line_idx + 1))?;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment