Unverified Commit 02666f04 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore(kv-router): remove sharded indexer path (#8041)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent ba274a03
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//! Combined benchmark for KvIndexer, KvIndexerSharded, and PositionalIndexer (nested). //! Combined benchmark for KvIndexer, PositionalIndexer (nested), and ConcurrentRadixTree.
//! //!
//! Provides two modes: //! Provides two modes:
//! - `microbench`: Per-operation latency benchmarks comparing indexer implementations //! - `microbench`: Per-operation latency benchmarks comparing indexer implementations
//! - `stress`: Queue saturation stress test under load //! - `stress`: Queue saturation stress test under load
//! //!
//! Supported indexer types: single, sharded, nested, all //! Supported indexer types: single, nested, concurrent, all
//! //!
//! Run with: //! Run with:
//! cargo bench --package dynamo-bench --bench kv_indexer_bench -- microbench --help //! cargo bench --package dynamo-bench --bench kv_indexer_bench -- microbench --help
...@@ -21,9 +21,7 @@ use clap::{Args, Parser, Subcommand, ValueEnum}; ...@@ -21,9 +21,7 @@ use clap::{Args, Parser, Subcommand, ValueEnum};
use dynamo_bench::common::LatencyStats; use dynamo_bench::common::LatencyStats;
use dynamo_kv_router::{ use dynamo_kv_router::{
ConcurrentRadixTree, ConcurrentRadixTree,
indexer::{ indexer::{KvIndexer, KvIndexerInterface, KvIndexerMetrics, ThreadPoolIndexer},
KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvIndexerSharded, ThreadPoolIndexer,
},
nested_map::PositionalIndexer, nested_map::PositionalIndexer,
protocols::{LocalBlockHash, RouterEvent}, protocols::{LocalBlockHash, RouterEvent},
}; };
...@@ -40,7 +38,7 @@ use tokio_util::sync::CancellationToken; ...@@ -40,7 +38,7 @@ use tokio_util::sync::CancellationToken;
#[derive(Parser)] #[derive(Parser)]
#[command(name = "kv_indexer_bench")] #[command(name = "kv_indexer_bench")]
#[command(about = "Combined benchmark for KvIndexer, KvIndexerSharded, and PositionalIndexer")] #[command(about = "Combined benchmark for KvIndexer, PositionalIndexer, and ConcurrentRadixTree")]
struct Cli { struct Cli {
#[command(subcommand)] #[command(subcommand)]
command: Command, command: Command,
...@@ -63,8 +61,6 @@ enum Command { ...@@ -63,8 +61,6 @@ enum Command {
enum IndexerType { enum IndexerType {
/// Non-sharded KvIndexer (single background thread) /// Non-sharded KvIndexer (single background thread)
Single, Single,
/// Sharded KvIndexer (multiple shards with separate trees)
Sharded,
/// Nested PositionalIndexer (position-based HashMap with jump search) /// Nested PositionalIndexer (position-based HashMap with jump search)
Nested, Nested,
/// Concurrent radix tree (lock-per-node with DashMap lookup) /// Concurrent radix tree (lock-per-node with DashMap lookup)
...@@ -122,9 +118,9 @@ struct MicrobenchArgs { ...@@ -122,9 +118,9 @@ struct MicrobenchArgs {
#[arg(long, value_enum, default_value = "all")] #[arg(long, value_enum, default_value = "all")]
indexer_type: IndexerType, indexer_type: IndexerType,
/// Number of shards for sharded indexer /// Number of event worker threads for nested/concurrent indexers
#[arg(long, default_value = "4")] #[arg(long, default_value = "4")]
num_shards: usize, num_event_workers: usize,
/// Jump size for nested/positional indexer /// Jump size for nested/positional indexer
#[arg(long, default_value = "32")] #[arg(long, default_value = "32")]
...@@ -164,9 +160,9 @@ struct StressArgs { ...@@ -164,9 +160,9 @@ struct StressArgs {
#[arg(long, value_enum, default_value = "single")] #[arg(long, value_enum, default_value = "single")]
indexer_type: IndexerType, indexer_type: IndexerType,
/// Number of shards for sharded indexer /// Number of event worker threads for nested/concurrent indexers
#[arg(long, default_value = "4")] #[arg(long, default_value = "4")]
num_shards: usize, num_event_workers: usize,
/// Jump size for nested/positional indexer /// Jump size for nested/positional indexer
#[arg(long, default_value = "32")] #[arg(long, default_value = "32")]
...@@ -177,7 +173,7 @@ struct StressArgs { ...@@ -177,7 +173,7 @@ struct StressArgs {
// Benchable Indexer Trait // Benchable Indexer Trait
// ============================================================================ // ============================================================================
/// Trait for abstracting over KvIndexer and KvIndexerSharded /// Trait for abstracting over benchmarked indexers
#[async_trait::async_trait] #[async_trait::async_trait]
trait BenchableIndexer: Send + Sync { trait BenchableIndexer: Send + Sync {
async fn apply_event(&mut self, event: RouterEvent); async fn apply_event(&mut self, event: RouterEvent);
...@@ -207,25 +203,6 @@ impl BenchableIndexer for KvIndexer { ...@@ -207,25 +203,6 @@ impl BenchableIndexer for KvIndexer {
} }
} }
#[async_trait::async_trait]
impl BenchableIndexer for KvIndexerSharded {
async fn apply_event(&mut self, event: RouterEvent) {
KvIndexerInterface::apply_event(self, event).await;
}
async fn find_matches(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<(), dynamo_kv_router::indexer::KvRouterError> {
KvIndexerInterface::find_matches(self, sequence).await?;
Ok(())
}
fn name(&self) -> &str {
"KvIndexerSharded"
}
}
#[async_trait::async_trait] #[async_trait::async_trait]
impl BenchableIndexer for ThreadPoolIndexer<PositionalIndexer> { impl BenchableIndexer for ThreadPoolIndexer<PositionalIndexer> {
async fn apply_event(&mut self, event: RouterEvent) { async fn apply_event(&mut self, event: RouterEvent) {
...@@ -697,6 +674,16 @@ async fn run_microbench_mode(args: MicrobenchArgs) { ...@@ -697,6 +674,16 @@ async fn run_microbench_mode(args: MicrobenchArgs) {
eprintln!("Error: size must be >= depth"); eprintln!("Error: size must be >= depth");
std::process::exit(1); std::process::exit(1);
} }
if matches!(
args.indexer_type,
IndexerType::Nested | IndexerType::Concurrent | IndexerType::All
) && args.num_event_workers == 0
{
eprintln!(
"Error: num_event_workers must be > 0 when using Nested, Concurrent, or All indexer type"
);
std::process::exit(1);
}
println!("KvIndexer Microbenchmark"); println!("KvIndexer Microbenchmark");
println!("========================\n"); println!("========================\n");
...@@ -716,7 +703,7 @@ async fn run_microbench_mode(args: MicrobenchArgs) { ...@@ -716,7 +703,7 @@ async fn run_microbench_mode(args: MicrobenchArgs) {
args.prefix_prompt_ratio * 100.0 args.prefix_prompt_ratio * 100.0
); );
println!(" Prefix prompt groups: {}", args.num_prefix_prompts); println!(" Prefix prompt groups: {}", args.num_prefix_prompts);
println!(" Num shards (for sharded): {}", args.num_shards); println!(" Event worker threads: {}", args.num_event_workers);
println!(" Indexer type: {:?}", args.indexer_type); println!(" Indexer type: {:?}", args.indexer_type);
println!(" Benchmark type: {}", args.benchmark_type); println!(" Benchmark type: {}", args.benchmark_type);
println!( println!(
...@@ -751,26 +738,11 @@ async fn run_microbench_mode(args: MicrobenchArgs) { ...@@ -751,26 +738,11 @@ async fn run_microbench_mode(args: MicrobenchArgs) {
tokio::time::sleep(Duration::from_millis(50)).await; tokio::time::sleep(Duration::from_millis(50)).await;
} }
// Benchmark sharded indexer
if matches!(args.indexer_type, IndexerType::Sharded | IndexerType::All) {
let token = CancellationToken::new();
let mut indexer = KvIndexerSharded::new(
token.clone(),
args.num_shards,
args.common.block_size,
metrics.clone(),
);
let result = run_microbenchmarks(&mut indexer, sequences, extra_sequences, &args).await;
results.push(result);
token.cancel();
tokio::time::sleep(Duration::from_millis(50)).await;
}
// Benchmark nested indexer // Benchmark nested indexer
if matches!(args.indexer_type, IndexerType::Nested | IndexerType::All) { if matches!(args.indexer_type, IndexerType::Nested | IndexerType::All) {
let mut indexer = ThreadPoolIndexer::new( let mut indexer = ThreadPoolIndexer::new(
PositionalIndexer::new(args.jump_size), PositionalIndexer::new(args.jump_size),
args.num_shards, args.num_event_workers,
args.common.block_size, args.common.block_size,
); );
let result = run_microbenchmarks(&mut indexer, sequences, extra_sequences, &args).await; let result = run_microbenchmarks(&mut indexer, sequences, extra_sequences, &args).await;
...@@ -786,7 +758,7 @@ async fn run_microbench_mode(args: MicrobenchArgs) { ...@@ -786,7 +758,7 @@ async fn run_microbench_mode(args: MicrobenchArgs) {
) { ) {
let mut indexer = ThreadPoolIndexer::new( let mut indexer = ThreadPoolIndexer::new(
ConcurrentRadixTree::new(), ConcurrentRadixTree::new(),
args.num_shards, args.num_event_workers,
args.common.block_size, args.common.block_size,
); );
let result = run_microbenchmarks(&mut indexer, sequences, extra_sequences, &args).await; let result = run_microbenchmarks(&mut indexer, sequences, extra_sequences, &args).await;
...@@ -1226,10 +1198,12 @@ async fn run_stress_mode(args: StressArgs) { ...@@ -1226,10 +1198,12 @@ async fn run_stress_mode(args: StressArgs) {
} }
if matches!( if matches!(
args.indexer_type, args.indexer_type,
IndexerType::Sharded | IndexerType::Nested | IndexerType::All IndexerType::Nested | IndexerType::Concurrent | IndexerType::All
) && args.num_shards == 0 ) && args.num_event_workers == 0
{ {
eprintln!("Error: num_shards must be > 0 when using Sharded, Nested, or All indexer type"); eprintln!(
"Error: num_event_workers must be > 0 when using Nested, Concurrent, or All indexer type"
);
std::process::exit(1); std::process::exit(1);
} }
...@@ -1254,11 +1228,13 @@ async fn run_stress_mode(args: StressArgs) { ...@@ -1254,11 +1228,13 @@ async fn run_stress_mode(args: StressArgs) {
println!(" Duration: {}s", args.duration); println!(" Duration: {}s", args.duration);
println!(" In-flight timeout: {}s", args.in_flight_timeout); println!(" In-flight timeout: {}s", args.in_flight_timeout);
println!(" Indexer type: {:?}", args.indexer_type); println!(" Indexer type: {:?}", args.indexer_type);
if matches!(args.indexer_type, IndexerType::Sharded | IndexerType::All) { if matches!(
println!(" Num shards (sharded): {}", args.num_shards); args.indexer_type,
IndexerType::Nested | IndexerType::Concurrent | IndexerType::All
) {
println!(" Event worker threads: {}", args.num_event_workers);
} }
if matches!(args.indexer_type, IndexerType::Nested | IndexerType::All) { if matches!(args.indexer_type, IndexerType::Nested | IndexerType::All) {
println!(" Num workers (nested): {}", args.num_shards);
println!(" Jump size (nested): {}", args.jump_size); println!(" Jump size (nested): {}", args.jump_size);
} }
...@@ -1322,58 +1298,11 @@ async fn run_stress_mode(args: StressArgs) { ...@@ -1322,58 +1298,11 @@ async fn run_stress_mode(args: StressArgs) {
tokio::time::sleep(Duration::from_millis(50)).await; tokio::time::sleep(Duration::from_millis(50)).await;
} }
// Test sharded indexer
if matches!(args.indexer_type, IndexerType::Sharded | IndexerType::All) {
let token = CancellationToken::new();
let indexer = KvIndexerSharded::new(
token.clone(),
args.num_shards,
args.common.block_size,
metrics.clone(),
);
println!(
"\n Applying {} store events to KvIndexerSharded...",
sequences.len()
);
let construction_start = Instant::now();
for (event_id, seq) in sequences.iter().enumerate() {
let event = seq.to_store_event(event_id as u64);
KvIndexerInterface::apply_event(&indexer, event).await;
if args.common.verbose && (event_id + 1) % 100 == 0 {
println!(" Applied {}/{} events...", event_id + 1, sequences.len());
}
}
let construction_time = construction_start.elapsed();
let construction_events = sequences.len() as u64;
println!(" Tree construction completed in {:?}", construction_time);
println!(
" Throughput: {:.0} events/sec",
construction_events as f64 / construction_time.as_secs_f64()
);
tokio::time::sleep(Duration::from_millis(100)).await;
let mut results = run_stress_test(Arc::new(indexer), &sequences, &args).await;
results.construction_time = construction_time;
results.construction_events = construction_events;
print_stress_results(&args, &results);
all_results.push(results);
token.cancel();
tokio::time::sleep(Duration::from_millis(50)).await;
}
// Test nested indexer // Test nested indexer
if matches!(args.indexer_type, IndexerType::Nested | IndexerType::All) { if matches!(args.indexer_type, IndexerType::Nested | IndexerType::All) {
let indexer = ThreadPoolIndexer::new( let indexer = ThreadPoolIndexer::new(
PositionalIndexer::new(args.jump_size), PositionalIndexer::new(args.jump_size),
args.num_shards, args.num_event_workers,
args.common.block_size, args.common.block_size,
); );
...@@ -1425,7 +1354,7 @@ async fn run_stress_mode(args: StressArgs) { ...@@ -1425,7 +1354,7 @@ async fn run_stress_mode(args: StressArgs) {
) { ) {
let indexer = ThreadPoolIndexer::new( let indexer = ThreadPoolIndexer::new(
ConcurrentRadixTree::new(), ConcurrentRadixTree::new(),
args.num_shards, args.num_event_workers,
args.common.block_size, args.common.block_size,
); );
......
...@@ -7,9 +7,7 @@ use common::*; ...@@ -7,9 +7,7 @@ use common::*;
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use dynamo_kv_router::LocalBlockHash; use dynamo_kv_router::LocalBlockHash;
use dynamo_kv_router::indexer::{ use dynamo_kv_router::indexer::{KvIndexer, KvIndexerInterface, KvIndexerMetrics};
KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvIndexerSharded,
};
use dynamo_kv_router::protocols::{KvCacheEvent, KvCacheEventData, RouterEvent}; use dynamo_kv_router::protocols::{KvCacheEvent, KvCacheEventData, RouterEvent};
use dynamo_kv_router::{ use dynamo_kv_router::{
ConcurrentRadixTree, ConcurrentRadixTreeCompressed, PositionalIndexer, ThreadPoolIndexer, ConcurrentRadixTree, ConcurrentRadixTreeCompressed, PositionalIndexer, ThreadPoolIndexer,
...@@ -26,13 +24,6 @@ enum IndexerArgs { ...@@ -26,13 +24,6 @@ enum IndexerArgs {
/// Single-threaded radix tree indexer. /// Single-threaded radix tree indexer.
RadixTree {}, RadixTree {},
/// Sharded radix tree indexer that partitions workers across independent shards.
RadixTreeSharded {
/// Number of independent shards to split workers across.
#[clap(long, default_value = "4")]
num_shards: usize,
},
/// Position-based nested map indexer with jump search. /// Position-based nested map indexer with jump search.
NestedMap { NestedMap {
/// Number of positions to skip during jump search before scanning back. /// Number of positions to skip during jump search before scanning back.
...@@ -68,12 +59,6 @@ impl IndexerArgs { ...@@ -68,12 +59,6 @@ impl IndexerArgs {
IndexerArgs::RadixTree {} => { IndexerArgs::RadixTree {} => {
Arc::new(KvIndexer::new(cancel_token, block_size, metrics)) Arc::new(KvIndexer::new(cancel_token, block_size, metrics))
} }
IndexerArgs::RadixTreeSharded { num_shards } => Arc::new(KvIndexerSharded::new(
cancel_token,
num_shards,
block_size,
metrics,
)),
IndexerArgs::NestedMap { IndexerArgs::NestedMap {
jump_size, jump_size,
num_event_workers, num_event_workers,
...@@ -115,7 +100,6 @@ impl IndexerArgs { ...@@ -115,7 +100,6 @@ impl IndexerArgs {
let nw = num_event_workers; let nw = num_event_workers;
let indexer_args = match name { let indexer_args = match name {
"radix-tree" => IndexerArgs::RadixTree {}, "radix-tree" => IndexerArgs::RadixTree {},
"radix-tree-sharded" => IndexerArgs::RadixTreeSharded { num_shards: 4 },
"nested-map" => IndexerArgs::NestedMap { "nested-map" => IndexerArgs::NestedMap {
jump_size: 8, jump_size: 8,
num_event_workers: nw, num_event_workers: nw,
...@@ -127,7 +111,7 @@ impl IndexerArgs { ...@@ -127,7 +111,7 @@ impl IndexerArgs {
num_event_workers: nw, num_event_workers: nw,
}, },
_ => anyhow::bail!( _ => anyhow::bail!(
"Unknown indexer '{}'. Valid names: radix-tree, radix-tree-sharded, \ "Unknown indexer '{}'. Valid names: radix-tree, \
nested-map, concurrent-radix-tree, concurrent-radix-tree-compressed", nested-map, concurrent-radix-tree, concurrent-radix-tree-compressed",
name name
), ),
...@@ -148,14 +132,15 @@ struct Args { ...@@ -148,14 +132,15 @@ struct Args {
/// Comma-separated list of indexer names to benchmark and compare on the /// Comma-separated list of indexer names to benchmark and compare on the
/// same plot. Overrides the subcommand indexer when present. Valid names: /// same plot. Overrides the subcommand indexer when present. Valid names:
/// radix-tree, radix-tree-sharded, nested-map, concurrent-radix-tree, /// radix-tree, nested-map, concurrent-radix-tree,
/// concurrent-radix-tree-compressed. /// concurrent-radix-tree-compressed.
#[clap(long, value_delimiter = ',')] #[clap(long, value_delimiter = ',')]
compare: Vec<String>, compare: Vec<String>,
/// Number of OS threads for event processing in compare mode. Applies to /// Number of OS threads for event processing in compare mode. Applies to
/// indexers that use a thread pool (nested-map, concurrent-radix-tree). /// indexers that use a thread pool (nested-map, concurrent-radix-tree,
/// Ignored by radix-tree and radix-tree-sharded. /// concurrent-radix-tree-compressed).
/// Ignored by radix-tree.
#[clap(long, default_value = "16")] #[clap(long, default_value = "16")]
num_event_workers: usize, num_event_workers: usize,
...@@ -555,7 +540,6 @@ async fn main() -> anyhow::Result<()> { ...@@ -555,7 +540,6 @@ async fn main() -> anyhow::Result<()> {
let indexer_names: Vec<String> = if args.compare.is_empty() { let indexer_names: Vec<String> = if args.compare.is_empty() {
let name = match args.get_indexer() { let name = match args.get_indexer() {
IndexerArgs::RadixTree {} => "radix-tree", IndexerArgs::RadixTree {} => "radix-tree",
IndexerArgs::RadixTreeSharded { .. } => "radix-tree-sharded",
IndexerArgs::NestedMap { .. } => "nested-map", IndexerArgs::NestedMap { .. } => "nested-map",
IndexerArgs::ConcurrentRadixTree { .. } => "concurrent-radix-tree", IndexerArgs::ConcurrentRadixTree { .. } => "concurrent-radix-tree",
IndexerArgs::ConcurrentRadixTreeCompressed { .. } => "concurrent-radix-tree-compressed", IndexerArgs::ConcurrentRadixTreeCompressed { .. } => "concurrent-radix-tree-compressed",
......
...@@ -17,7 +17,6 @@ The concurrent indexers achieve a combined throughput of over **10 million event ...@@ -17,7 +17,6 @@ The concurrent indexers achieve a combined throughput of over **10 million event
| `concurrent_radix_tree.rs` | `ConcurrentRadixTree` — thread-safe variant with `Arc<RwLock<Block>>` nodes and `DashMap` lookup | | `concurrent_radix_tree.rs` | `ConcurrentRadixTree` — thread-safe variant with `Arc<RwLock<Block>>` nodes and `DashMap` lookup |
| `positional.rs` | `PositionalIndexer` — flat `DashMap<(pos, hash), SeqEntry>` with jump optimization | | `positional.rs` | `PositionalIndexer` — flat `DashMap<(pos, hash), SeqEntry>` with jump optimization |
| `thread_pool.rs` | `ThreadPoolIndexer<T: SyncIndexer>` — N OS threads for sticky-routed writes, inline reads; wraps `ConcurrentRadixTree` or `PositionalIndexer` | | `thread_pool.rs` | `ThreadPoolIndexer<T: SyncIndexer>` — N OS threads for sticky-routed writes, inline reads; wraps `ConcurrentRadixTree` or `PositionalIndexer` |
| `sharded.rs` | `KvIndexerSharded` — N independent `RadixTree` shards each in its own OS thread, scatter-gather for matches |
| `local.rs` | `LocalKvIndexer` — thin wrapper around `KvIndexer` with a circular event buffer for worker-side decentralized routing | | `local.rs` | `LocalKvIndexer` — thin wrapper around `KvIndexer` with a circular event buffer for worker-side decentralized routing |
| `pruning.rs` | `PruneManager` — TTL-based expiration and size-based pruning via `BinaryHeap<BlockEntry>` | | `pruning.rs` | `PruneManager` — TTL-based expiration and size-based pruning via `BinaryHeap<BlockEntry>` |
| `naive.rs` | Brute-force baseline indexers (bench-only, behind `bench` feature flag) | | `naive.rs` | Brute-force baseline indexers (bench-only, behind `bench` feature flag) |
......
...@@ -34,7 +34,6 @@ ...@@ -34,7 +34,6 @@
mod kv_indexer; mod kv_indexer;
mod local; mod local;
mod metrics; mod metrics;
mod sharded;
mod thread_pool; mod thread_pool;
mod traits; mod traits;
mod types; mod types;
...@@ -52,7 +51,6 @@ mod tests; ...@@ -52,7 +51,6 @@ mod tests;
pub use kv_indexer::*; pub use kv_indexer::*;
pub use local::*; pub use local::*;
pub use metrics::*; pub use metrics::*;
pub use sharded::*;
pub use thread_pool::*; pub use thread_pool::*;
pub use traits::*; pub use traits::*;
pub use types::*; pub use types::*;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
#[cfg(feature = "bench")]
use std::time::Instant;
use std::{
iter,
sync::{Arc, Mutex},
thread::JoinHandle,
time::Duration,
};
use async_trait::async_trait;
use dashmap::DashMap;
use rustc_hash::FxBuildHasher;
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio_util::sync::CancellationToken;
use super::{
DumpRequest, GetWorkersRequest, KvIndexerInterface, KvIndexerMetrics, KvRouterError, RadixTree,
RoutingDecisionRequest, ShardedMatchRequest,
};
use crate::indexer::pruning::{BlockEntry, PruneConfig, PruneManager};
use crate::protocols::*;
use dynamo_tokens::SequenceHash;
/// A sharded KV Indexer that partitions the RadixTree across multiple independent shards.
///
/// ## Sharding Strategy
/// - Each worker is **permanently assigned** to a single shard on first event
/// - All KV blocks from a worker exist only in that worker's assigned shard
/// - New workers are assigned to the shard with the fewest workers (load balancing)
///
/// ## Operation
/// - **Events**: Routed directly to the worker's assigned shard
/// - **Match requests**: Broadcast to all shards (scatter-gather pattern)
/// - **Threading**: Each shard runs in its own thread with a single-threaded runtime
///
/// This design ensures no cross-shard synchronization for writes while enabling
/// parallel processing and better scalability.
pub struct KvIndexerSharded {
/// A `CancellationToken` for managing shutdown.
cancel: CancellationToken,
/// The size of the KV block this indexer can handle.
kv_block_size: u32,
worker_assignments: DashMap<WorkerId, usize, FxBuildHasher>,
worker_counts: Arc<Mutex<Vec<usize>>>,
event_tx: Vec<mpsc::Sender<RouterEvent>>,
request_broadcast_tx: broadcast::Sender<ShardedMatchRequest>,
remove_worker_tx: Vec<mpsc::Sender<WorkerId>>,
remove_worker_dp_rank_tx: Vec<mpsc::Sender<(WorkerId, DpRank)>>,
dump_tx: Vec<mpsc::Sender<DumpRequest>>,
routing_tx: Vec<mpsc::Sender<RoutingDecisionRequest>>,
tasks: Arc<Mutex<Vec<JoinHandle<()>>>>,
}
impl KvIndexerSharded {
/// Create a new `KvIndexerSharded`.
///
/// ### Arguments
///
/// * `token` - A `CancellationToken` for managing shutdown.
/// * `shards` - A list of kvindexer shards.
/// * `expiration_duration` - The amount of time that block usage should be buffered.
/// * `ttl` - The time-to-live for blocks before they expire.
/// * `prune_config` - Configuration for tree-size based pruning.
///
/// ### Returns
///
/// A new `KvIndexer`.
pub fn new_with_frequency(
token: CancellationToken,
num_shards: usize,
expiration_duration: Option<Duration>,
kv_block_size: u32,
metrics: Arc<KvIndexerMetrics>,
prune_config: Option<PruneConfig>,
) -> Self {
let worker_assignments = DashMap::with_hasher(FxBuildHasher);
let worker_counts = Arc::new(Mutex::new(vec![0; num_shards]));
let mut event_tx = Vec::new();
let mut remove_worker_tx = Vec::new();
let mut remove_worker_dp_rank_tx = Vec::new();
let mut get_workers_tx = Vec::new();
let mut dump_tx = Vec::new();
let mut routing_tx = Vec::new();
let tasks = Arc::new(Mutex::new(Vec::new()));
let (request_broadcast_tx, _) = broadcast::channel::<ShardedMatchRequest>(1048576);
for _ in 0..num_shards {
let (shard_event_tx, mut shard_event_rx) = mpsc::channel::<RouterEvent>(2048);
let (shard_remove_worker_tx, mut shard_remove_worker_rx) =
mpsc::channel::<WorkerId>(16);
let (shard_remove_worker_dp_rank_tx, mut shard_remove_worker_dp_rank_rx) =
mpsc::channel::<(WorkerId, DpRank)>(16);
let (shard_get_workers_tx, mut shard_get_workers_rx) =
mpsc::channel::<GetWorkersRequest>(16);
let (shard_dump_tx, mut shard_dump_rx) = mpsc::channel::<DumpRequest>(16);
let (shard_routing_tx, mut shard_routing_rx) =
mpsc::channel::<RoutingDecisionRequest>(2048);
let (shard_prune_tx, mut shard_prune_rx) = mpsc::channel::<()>(1);
let mut shard_broadcast_rx = request_broadcast_tx.subscribe();
let cancel = token.clone();
let metrics = metrics.clone();
let prune_config_clone = prune_config.clone();
event_tx.push(shard_event_tx);
remove_worker_tx.push(shard_remove_worker_tx);
remove_worker_dp_rank_tx.push(shard_remove_worker_dp_rank_tx);
get_workers_tx.push(shard_get_workers_tx);
dump_tx.push(shard_dump_tx);
routing_tx.push(shard_routing_tx);
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
tasks.lock().unwrap().push(std::thread::spawn(move || {
runtime.block_on(async move {
let mut trie = RadixTree::new_with_frequency(expiration_duration);
// Create PruneManager if prune_config is specified
let mut prune_manager = prune_config_clone.map(|config| {
PruneManager::<BlockEntry>::new(50, config)
});
let mut event_id_counter = 0u64;
loop {
// Create a future that sleeps until the next expiration time
let expiry_fut = if let Some(ref pm) = prune_manager
&& let Some(next_expiry) = pm.peek_next_expiry() {
tokio::time::sleep_until(next_expiry)
} else {
tokio::time::sleep(Duration::MAX)
};
tokio::select! {
biased;
_ = cancel.cancelled() => {
tracing::trace!("KvCacheIndexer progress loop shutting down");
return;
}
Some(worker) = shard_remove_worker_rx.recv() => {
trie.remove_worker(worker);
}
Some((worker_id, dp_rank)) = shard_remove_worker_dp_rank_rx.recv() => {
trie.remove_worker_dp_rank(worker_id, dp_rank);
}
Some(get_workers_req) = shard_get_workers_rx.recv() => {
let workers = trie.get_workers();
let _ = get_workers_req.resp.send(workers);
}
Some(_) = shard_prune_rx.recv() => {
// Tree size-based pruning triggered
let Some(ref mut pm) = prune_manager else { continue };
let Ok(pruned) = pm.prune(trie.current_size()) else { continue };
for p in pruned {
event_id_counter += 1;
let event = RouterEvent::new(
p.worker.worker_id,
KvCacheEvent {
event_id: event_id_counter,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![p.key],
}),
dp_rank: p.worker.dp_rank,
}
);
let _ = trie.apply_event(event);
}
}
Some(event) = shard_event_rx.recv() => {
let event_type = KvIndexerMetrics::get_event_type(&event.event.data);
// Only clone if we need the event for prune_manager afterward
let event_for_prune = prune_manager.is_some().then(|| event.clone());
let result = trie.apply_event(event);
let result_is_ok = result.is_ok();
metrics.increment_event_applied(event_type, result);
// Track blocks in PruneManager if TTL is enabled and event was stored successfully
let Some(ref mut pm) = prune_manager else { continue };
if !result_is_ok { continue };
let Some(ref event) = event_for_prune else { continue };
let KvCacheEventData::Stored(ref store_data) = event.event.data else { continue };
let worker = WorkerWithDpRank::new(event.worker_id, event.event.dp_rank);
let block_entries: Vec<BlockEntry> = store_data.blocks.iter().enumerate().map(|(idx, block)| {
BlockEntry {
key: block.block_hash,
worker,
seq_position: idx,
}
}).collect();
pm.insert(block_entries);
// Check if we need to prune due to tree size
let Some(ref pc) = pm.prune_config else { continue };
let current_size = trie.current_size();
if current_size > pc.max_tree_size {
tracing::info!(
"Pruning: tree size ({}) exceeded max tree size ({}), scheduling pruning",
current_size,
pc.max_tree_size
);
let _ = shard_prune_tx.try_send(());
}
}
Some(routing_req) = shard_routing_rx.recv() => {
// Process routing decisions when TTL/pruning is enabled
let Some(ref mut pm) = prune_manager else { continue };
event_id_counter += 1;
let hashes = routing_req.local_hashes.iter().zip(routing_req.sequence_hashes.iter());
let stored_event = KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: hashes.map(|(local_hash, sequence_hash)| KvCacheStoredBlockData {
tokens_hash: *local_hash,
block_hash: ExternalSequenceBlockHash(*sequence_hash),
mm_extra_info: None,
}).collect(),
});
let event = RouterEvent::new(
routing_req.worker.worker_id,
KvCacheEvent {
event_id: event_id_counter,
data: stored_event,
dp_rank: routing_req.worker.dp_rank,
}
);
if trie.apply_event(event).is_err() {
continue;
}
let block_entries: Vec<BlockEntry> = routing_req.sequence_hashes.iter().enumerate().map(|(idx, h)| {
BlockEntry {
key: ExternalSequenceBlockHash(*h),
worker: routing_req.worker,
seq_position: idx,
}
}).collect();
pm.insert(block_entries);
// Check if we need to prune due to tree size
let Some(ref pc) = pm.prune_config else { continue };
let current_size = trie.current_size();
if current_size > pc.max_tree_size {
tracing::info!(
"Pruning: tree size ({}) exceeded max tree size ({}), scheduling pruning",
current_size,
pc.max_tree_size
);
let _ = shard_prune_tx.try_send(());
}
}
Some(dump_req) = shard_dump_rx.recv() => {
let events = trie.dump_tree_as_events();
let _ = dump_req.resp.send(events);
}
Ok(req) = shard_broadcast_rx.recv() => {
#[cfg(feature = "bench")]
let queue_wait = req.created_at.elapsed();
#[cfg(feature = "bench")]
let seq_len = req.sequence.len();
#[cfg(feature = "bench")]
let process_start = Instant::now();
let matches = trie.find_matches(req.sequence, req.early_exit);
#[cfg(feature = "bench")]
let process_time = process_start.elapsed();
#[cfg(feature = "bench")]
tracing::info!(
seq_len,
queue_wait_us = queue_wait.as_micros() as u64,
process_us = process_time.as_micros() as u64,
"sharded indexer: processed find_matches"
);
if let Err(e) = req.resp.send(matches).await {
tracing::trace!("Failed to send match response: {:?}", e);
}
}
_ = expiry_fut => {
// TTL-based expiry triggered
let Some(ref mut pm) = prune_manager else { continue };
let expired = pm.pop_expired();
for e in expired {
event_id_counter += 1;
let event = RouterEvent::new(
e.worker.worker_id,
KvCacheEvent {
event_id: event_id_counter,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![e.key],
}),
dp_rank: e.worker.dp_rank,
}
);
let _ = trie.apply_event(event);
}
}
}
}
});
tracing::debug!("KvCacheIndexer task completed");
}));
}
Self {
cancel: token,
kv_block_size,
worker_assignments,
worker_counts,
event_tx,
request_broadcast_tx,
remove_worker_tx,
remove_worker_dp_rank_tx,
dump_tx,
routing_tx,
tasks,
}
}
pub fn block_size(&self) -> u32 {
self.kv_block_size
}
pub fn new(
token: CancellationToken,
num_shards: usize,
kv_block_size: u32,
metrics: Arc<KvIndexerMetrics>,
) -> Self {
Self::new_with_frequency(token, num_shards, None, kv_block_size, metrics, None)
}
fn shard_for_worker(&self, worker_id: WorkerId) -> usize {
*self.worker_assignments.entry(worker_id).or_insert_with(|| {
let worker_counts = self.worker_counts.lock().unwrap();
let selected_shard = worker_counts
.iter()
.enumerate()
.min_by_key(|&(_, value)| value)
.unwrap()
.0;
drop(worker_counts);
self.worker_counts.lock().unwrap()[selected_shard] += 1;
selected_shard
})
}
}
#[async_trait]
impl KvIndexerInterface for KvIndexerSharded {
async fn find_matches(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<OverlapScores, KvRouterError> {
#[cfg(feature = "bench")]
let start = Instant::now();
#[cfg(feature = "bench")]
let seq_len = sequence.len();
#[cfg(feature = "bench")]
let num_shards = self.event_tx.len();
'match_loop: loop {
let (match_tx, mut match_rx) = mpsc::channel(self.event_tx.len());
let sharded_req = ShardedMatchRequest::new(sequence.clone(), false, match_tx);
self.request_broadcast_tx
.send(sharded_req)
.map_err(|_| KvRouterError::IndexerOffline)?;
let mut scores = OverlapScores::new();
for response_num in 0..self.event_tx.len() {
match match_rx.recv().await {
Some(response) => {
scores.scores.extend(response.scores);
scores.tree_sizes.extend(response.tree_sizes);
if response_num == 0 {
scores.frequencies = response.frequencies;
} else {
let diff = (response.frequencies.len() as i64)
- (scores.frequencies.len() as i64);
if diff > 0 {
scores.frequencies.extend(iter::repeat_n(0, diff as usize));
}
for i in 0..response.frequencies.len() {
scores.frequencies[i] += response.frequencies[i];
}
}
}
None => {
// This can only happen if the broadcast channel overflows.
// In this case, we don't want to recursively call find_matches again. Otherwise, we could overflow the stack.
continue 'match_loop;
}
}
}
#[cfg(feature = "bench")]
{
let elapsed = start.elapsed();
tracing::info!(
seq_len,
num_shards,
elapsed_us = elapsed.as_micros() as u64,
"find_matches (sharded) completed"
);
}
return Ok(scores);
}
}
async fn find_matches_for_request(
&self,
tokens: &[u32],
lora_name: Option<&str>,
is_eagle: Option<bool>,
) -> Result<OverlapScores, KvRouterError> {
let sequence = compute_block_hash_for_seq(
tokens,
self.kv_block_size,
BlockHashOptions {
lora_name,
is_eagle,
..Default::default()
},
);
self.find_matches(sequence).await
}
async fn apply_event(&self, event: RouterEvent) {
let shard = self.shard_for_worker(event.worker_id);
self.event_tx[shard].send(event).await.unwrap();
}
async fn remove_worker(&self, worker: WorkerId) {
if let Some((_, shard)) = self.worker_assignments.remove(&worker) {
self.worker_counts.lock().unwrap()[shard] -= 1;
self.remove_worker_tx[shard].send(worker).await.unwrap();
}
}
async fn remove_worker_dp_rank(&self, worker: WorkerId, dp_rank: DpRank) {
// Worker is assigned to a single shard, so route there directly.
// Don't remove from worker_assignments since other dp_ranks may still exist.
if let Some(shard) = self.worker_assignments.get(&worker) {
self.remove_worker_dp_rank_tx[*shard]
.send((worker, dp_rank))
.await
.unwrap();
}
}
/// Shutdown the KV Indexer.
fn shutdown(&self) {
self.cancel.cancel();
let mut tasks = self.tasks.lock().unwrap();
while !tasks.is_empty() {
tasks.pop().unwrap().join().unwrap();
}
}
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
let mut all_events = Vec::new();
// Create channels for each shard
let mut receivers = Vec::new();
for shard_dump_tx in &self.dump_tx {
let (resp_tx, resp_rx) = oneshot::channel();
let dump_req = DumpRequest { resp: resp_tx };
if let Err(e) = shard_dump_tx.send(dump_req).await {
tracing::error!("Failed to send dump request to shard: {:?}", e);
return Err(KvRouterError::IndexerOffline);
}
receivers.push(resp_rx);
}
// Collect results from all shards
for resp_rx in receivers {
match resp_rx.await {
Ok(events) => all_events.extend(events),
Err(_) => return Err(KvRouterError::IndexerDroppedRequest),
}
}
Ok(all_events)
}
async fn process_routing_decision_for_request(
&self,
tokens_with_hashes: &mut TokensWithHashes,
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
let local_hashes = tokens_with_hashes.get_or_compute_block_hashes().to_vec();
let sequence_hashes = tokens_with_hashes.get_or_compute_seq_hashes().to_vec();
self.process_routing_decision_with_hashes(worker, local_hashes, sequence_hashes)
.await
}
async fn flush(&self) -> usize {
let curr_size = self
.event_tx
.iter()
.map(|tx| tx.max_capacity() - tx.capacity())
.sum();
loop {
if self
.event_tx
.iter()
.all(|tx| tx.capacity() == tx.max_capacity())
{
break;
}
tokio::time::sleep(Duration::from_millis(5)).await;
}
curr_size
}
}
impl KvIndexerSharded {
/// Process a routing decision with pre-computed hashes.
pub async fn process_routing_decision_with_hashes(
&self,
worker: WorkerWithDpRank,
local_hashes: Vec<LocalBlockHash>,
sequence_hashes: Vec<SequenceHash>,
) -> Result<(), KvRouterError> {
let shard_idx = self.shard_for_worker(worker.worker_id);
self.routing_tx[shard_idx]
.send(RoutingDecisionRequest {
worker,
local_hashes,
sequence_hashes,
})
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)?;
Ok(())
}
}
impl Drop for KvIndexerSharded {
fn drop(&mut self) {
self.shutdown();
}
}
...@@ -207,14 +207,14 @@ fn make_clear_event_with_dp_rank(worker_id: u64, dp_rank: u32) -> RouterEvent { ...@@ -207,14 +207,14 @@ fn make_clear_event_with_dp_rank(worker_id: u64, dp_rank: u32) -> RouterEvent {
#[template] #[template]
#[rstest] #[rstest]
fn indexer_template( fn indexer_template(
#[values("single", "sharded", "flat", "concurrent", "concurrent_compressed")] variant: &str, #[values("single", "flat", "concurrent", "concurrent_compressed")] variant: &str,
) { ) {
} }
#[template] #[template]
#[rstest] #[rstest]
fn tree_size_indexer_template( fn tree_size_indexer_template(
#[values("single", "sharded", "concurrent", "concurrent_compressed")] variant: &str, #[values("single", "concurrent", "concurrent_compressed")] variant: &str,
) { ) {
} }
...@@ -225,7 +225,6 @@ fn make_indexer(variant: &str) -> Box<dyn KvIndexerInterface> { ...@@ -225,7 +225,6 @@ fn make_indexer(variant: &str) -> Box<dyn KvIndexerInterface> {
match variant { match variant {
"single" => Box::new(KvIndexer::new(token, kv_block_size, metrics)), "single" => Box::new(KvIndexer::new(token, kv_block_size, metrics)),
"sharded" => Box::new(KvIndexerSharded::new(token, 4, kv_block_size, metrics)),
"flat" => Box::new(ThreadPoolIndexer::new( "flat" => Box::new(ThreadPoolIndexer::new(
PositionalIndexer::new(32), PositionalIndexer::new(32),
4, 4,
...@@ -330,7 +329,7 @@ mod interface_tests { ...@@ -330,7 +329,7 @@ mod interface_tests {
// tree-size accounting gap after mid-chain removes because descendant // tree-size accounting gap after mid-chain removes because descendant
// lookup entries are cleaned up lazily. That means "store -> partial // lookup entries are cleaned up lazily. That means "store -> partial
// remove -> restore continuation" can still miscount restored coverage // remove -> restore continuation" can still miscount restored coverage
// in single, sharded, and concurrent. This test is intentionally scoped // in single and concurrent. This test is intentionally scoped
// to duplicate store/remove replay so all tree-size variants share the // to duplicate store/remove replay so all tree-size variants share the
// same stable baseline. // same stable baseline.
...@@ -1854,13 +1853,13 @@ mod long_sequence_tests { ...@@ -1854,13 +1853,13 @@ mod long_sequence_tests {
} }
// ============================================================================ // ============================================================================
// Tests specific to tree-based implementations (KvIndexer, KvIndexerSharded) // Tests specific to tree-based implementations with frequency/pruning support.
// These use features not available in PositionalIndexer // These use features not available in PositionalIndexer
// ============================================================================ // ============================================================================
#[template] #[template]
#[rstest] #[rstest]
fn tree_indexer_template(#[values("single", "sharded")] variant: &str) {} fn tree_indexer_template(#[values("single")] variant: &str) {}
fn make_tree_indexer_with_frequency( fn make_tree_indexer_with_frequency(
variant: &str, variant: &str,
...@@ -1878,25 +1877,16 @@ fn make_tree_indexer_with_frequency( ...@@ -1878,25 +1877,16 @@ fn make_tree_indexer_with_frequency(
metrics, metrics,
None, None,
)), )),
"sharded" => Box::new(KvIndexerSharded::new_with_frequency(
token,
4,
Some(expiration),
kv_block_size,
metrics,
None,
)),
_ => panic!("Unknown variant: {}", variant), _ => panic!("Unknown variant: {}", variant),
} }
} }
#[tokio::test] #[tokio::test]
async fn test_sharded_routing_decision_assigns_first_seen_worker() { async fn test_routing_decision_assigns_first_seen_worker() {
let token = CancellationToken::new(); let token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered()); let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let index = KvIndexerSharded::new_with_frequency( let index = KvIndexer::new_with_frequency(
token, token,
4,
Some(Duration::from_secs(60)), Some(Duration::from_secs(60)),
32, 32,
metrics, metrics,
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
use std::time::Instant; use std::time::Instant;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::sync::{mpsc, oneshot}; use tokio::sync::oneshot;
use crate::protocols::*; use crate::protocols::*;
use dynamo_tokens::SequenceHash; use dynamo_tokens::SequenceHash;
...@@ -307,28 +307,3 @@ pub(super) struct RoutingDecisionRequest { ...@@ -307,28 +307,3 @@ pub(super) struct RoutingDecisionRequest {
pub(super) local_hashes: Vec<LocalBlockHash>, pub(super) local_hashes: Vec<LocalBlockHash>,
pub(super) sequence_hashes: Vec<SequenceHash>, pub(super) sequence_hashes: Vec<SequenceHash>,
} }
#[derive(Debug, Clone)]
pub struct ShardedMatchRequest {
pub(super) sequence: Vec<LocalBlockHash>,
pub(super) early_exit: bool,
pub(super) resp: mpsc::Sender<OverlapScores>,
#[cfg(feature = "bench")]
pub(super) created_at: Instant,
}
impl ShardedMatchRequest {
pub(super) fn new(
sequence: Vec<LocalBlockHash>,
early_exit: bool,
resp: mpsc::Sender<OverlapScores>,
) -> Self {
Self {
sequence,
early_exit,
resp,
#[cfg(feature = "bench")]
created_at: Instant::now(),
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment