Unverified Commit 039d35ff authored by Janelle Cai's avatar Janelle Cai Committed by GitHub
Browse files

test: indexer and full router benchmarks (#5784)

parent 051f18a4
...@@ -46,3 +46,8 @@ tokio = { workspace = true, features = ["rt", "macros", "time"] } ...@@ -46,3 +46,8 @@ tokio = { workspace = true, features = ["rt", "macros", "time"] }
name = "radix_tree_microbench" name = "radix_tree_microbench"
harness = false harness = false
required-features = ["bench"] required-features = ["bench"]
[[bench]]
name = "kv_indexer_bench"
harness = false
required-features = ["bench"]
This diff is collapsed.
...@@ -15,16 +15,12 @@ ...@@ -15,16 +15,12 @@
use clap::{Parser, ValueEnum}; use clap::{Parser, ValueEnum};
use dynamo_kv_router::{ use dynamo_kv_router::{
OverlapScores, RadixTree, RouterEvent, compute_block_hash_for_seq, RadixTree, RouterEvent,
bench_utils::{LatencyStats, SequenceData, generate_sequences},
compute_block_hash_for_seq,
flat_hashmap::FlatHashMap, flat_hashmap::FlatHashMap,
protocols::{ protocols::LocalBlockHash,
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData,
KvCacheStoreData, KvCacheStoredBlockData, LocalBlockHash, WorkerId,
compute_seq_hash_for_block,
},
}; };
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
/// Unified interface for RadixTree and FlatHashMap benchmarking. /// Unified interface for RadixTree and FlatHashMap benchmarking.
...@@ -206,114 +202,6 @@ struct Args { ...@@ -206,114 +202,6 @@ struct Args {
flat_hashmap: bool, flat_hashmap: bool,
} }
/// Pre-generated sequence data for benchmarking
#[derive(Clone)]
struct SequenceData {
worker_id: WorkerId,
local_hashes: Vec<LocalBlockHash>,
external_hashes: Vec<ExternalSequenceBlockHash>,
}
impl SequenceData {
/// Create a new SequenceData from local_hashes.
/// Automatically computes external_hashes using compute_seq_hash_for_block (cumulative hashes).
/// This ensures FlatHashMap can correctly identify block positions.
fn from_local_hashes(worker_id: WorkerId, local_hashes: Vec<LocalBlockHash>) -> Self {
let seq_hashes = compute_seq_hash_for_block(&local_hashes);
let external_hashes = seq_hashes
.into_iter()
.map(ExternalSequenceBlockHash)
.collect();
Self {
worker_id,
local_hashes,
external_hashes,
}
}
fn to_store_event(&self, event_id: u64) -> RouterEvent {
RouterEvent {
worker_id: self.worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: self
.local_hashes
.iter()
.zip(self.external_hashes.iter())
.map(|(local, ext)| KvCacheStoredBlockData {
tokens_hash: *local,
block_hash: *ext,
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
}
}
fn to_remove_event(&self, event_id: u64) -> RouterEvent {
RouterEvent {
worker_id: self.worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: self.external_hashes.clone(),
}),
dp_rank: 0,
},
}
}
}
/// Generate sequences with shared prefix prompts
fn generate_sequences(
num_sequences: usize,
depth: usize,
num_workers: usize,
prefix_prompt_ratio: f64,
num_prefix_prompts: usize,
seed: u64,
) -> Vec<SequenceData> {
let mut sequences = Vec::with_capacity(num_sequences);
let prefix_length = (depth as f64 * prefix_prompt_ratio).round() as usize;
let mut rng: StdRng = StdRng::seed_from_u64(seed);
for seq_id in 0..num_sequences {
let seq_id_u64 = seq_id as u64;
let worker_id = (seq_id % num_workers) as WorkerId;
// Determine prefix group for this sequence
let group_id = if num_prefix_prompts > 0 && prefix_length > 0 {
Some(rng.random_range(0..num_prefix_prompts) as u64)
} else {
None
};
// Build local_hashes: shared prefix (if applicable) + unique suffix
let local_hashes: Vec<LocalBlockHash> = (0..depth)
.map(|block_idx| {
let block_idx_u64 = block_idx as u64;
if let Some(gid) = group_id {
if block_idx < prefix_length {
// Shared prefix based on group_id
return LocalBlockHash(0xDEAD_BEEF_0000_0000 | (gid << 32) | block_idx_u64);
}
}
// Unique suffix (or no shared prefix)
LocalBlockHash((seq_id_u64 << 32) | block_idx_u64)
})
.collect();
sequences.push(SequenceData::from_local_hashes(worker_id, local_hashes));
}
sequences
}
/// Build a pre-populated RadixTree (for sweep/dump benchmarks that specifically need RadixTree) /// Build a pre-populated RadixTree (for sweep/dump benchmarks that specifically need RadixTree)
fn build_tree(sequences: &[SequenceData]) -> RadixTree { fn build_tree(sequences: &[SequenceData]) -> RadixTree {
let num_blocks: usize = sequences.iter().map(|s| s.local_hashes.len()).sum(); let num_blocks: usize = sequences.iter().map(|s| s.local_hashes.len()).sum();
...@@ -381,52 +269,6 @@ fn build_index(sequences: &[SequenceData], use_flat_hashmap: bool) -> KvIndex { ...@@ -381,52 +269,6 @@ fn build_index(sequences: &[SequenceData], use_flat_hashmap: bool) -> KvIndex {
index index
} }
/// Statistics for a set of timing measurements
#[derive(Debug)]
struct LatencyStats {
min: Duration,
max: Duration,
avg: Duration,
p50: Duration,
p95: Duration,
p99: Duration,
throughput_ops_sec: f64,
}
impl LatencyStats {
fn from_durations(mut durations: Vec<Duration>) -> Self {
durations.sort();
let n = durations.len();
let total: Duration = durations.iter().sum();
let avg = total / n as u32;
Self {
min: durations[0],
max: durations[n - 1],
avg,
p50: durations[n / 2],
p95: durations[n * 95 / 100],
p99: durations[n * 99 / 100],
throughput_ops_sec: n as f64 / total.as_secs_f64(),
}
}
fn print(&self, operation: &str, blocks_per_op: usize) {
println!("\n{} Latency Statistics:", operation);
println!(" min: {:>12?}", self.min);
println!(" avg: {:>12?}", self.avg);
println!(" p50: {:>12?}", self.p50);
println!(" p95: {:>12?}", self.p95);
println!(" p99: {:>12?}", self.p99);
println!(" max: {:>12?}", self.max);
println!(" throughput: {:.2} ops/sec", self.throughput_ops_sec);
println!(
" throughput: {:.2} blocks/sec",
self.throughput_ops_sec * blocks_per_op as f64
);
}
}
/// Benchmark compute_block_hash_for_seq operation /// Benchmark compute_block_hash_for_seq operation
fn bench_hash(args: &Args) { fn bench_hash(args: &Args) {
println!("\n=== Benchmarking COMPUTE_BLOCK_HASH (per-request hot path) ==="); println!("\n=== Benchmarking COMPUTE_BLOCK_HASH (per-request hot path) ===");
...@@ -464,7 +306,7 @@ fn bench_hash(args: &Args) { ...@@ -464,7 +306,7 @@ fn bench_hash(args: &Args) {
} }
} }
let stats = LatencyStats::from_durations(durations); let stats = LatencyStats::from_durations(durations).unwrap();
stats.print("COMPUTE_BLOCK_HASH", args.depth); stats.print("COMPUTE_BLOCK_HASH", args.depth);
} }
...@@ -487,6 +329,7 @@ fn bench_store_remove_cycle(args: &Args, time_store: bool) { ...@@ -487,6 +329,7 @@ fn bench_store_remove_cycle(args: &Args, time_store: bool) {
args.prefix_prompt_ratio, args.prefix_prompt_ratio,
args.num_prefix_prompts, args.num_prefix_prompts,
args.seed, args.seed,
true, // use_cumulative_hash
); );
let mut index = build_index(&sequences, args.flat_hashmap); let mut index = build_index(&sequences, args.flat_hashmap);
...@@ -524,7 +367,7 @@ fn bench_store_remove_cycle(args: &Args, time_store: bool) { ...@@ -524,7 +367,7 @@ fn bench_store_remove_cycle(args: &Args, time_store: bool) {
} }
} }
let stats = LatencyStats::from_durations(durations); let stats = LatencyStats::from_durations(durations).unwrap();
stats.print(op_name, args.depth); stats.print(op_name, args.depth);
} }
...@@ -548,6 +391,7 @@ fn bench_find_matches(args: &Args) { ...@@ -548,6 +391,7 @@ fn bench_find_matches(args: &Args) {
args.prefix_prompt_ratio, args.prefix_prompt_ratio,
args.num_prefix_prompts, args.num_prefix_prompts,
args.seed, args.seed,
true, // use_cumulative_hash
); );
let index = build_index(&sequences, args.flat_hashmap); let index = build_index(&sequences, args.flat_hashmap);
...@@ -575,7 +419,9 @@ fn bench_find_matches(args: &Args) { ...@@ -575,7 +419,9 @@ fn bench_find_matches(args: &Args) {
println!(" Completed {}/{} iterations", i + 1, args.iterations); println!(" Completed {}/{} iterations", i + 1, args.iterations);
} }
} }
LatencyStats::from_durations(hit_durations).print("FIND_MATCHES (HIT)", args.depth); LatencyStats::from_durations(hit_durations)
.unwrap()
.print("FIND_MATCHES (HIT)", args.depth);
// MISS case // MISS case
println!("\n --- MISS case (non-existing sequences) ---"); println!("\n --- MISS case (non-existing sequences) ---");
...@@ -589,7 +435,9 @@ fn bench_find_matches(args: &Args) { ...@@ -589,7 +435,9 @@ fn bench_find_matches(args: &Args) {
println!(" Completed {}/{} iterations", i + 1, args.iterations); println!(" Completed {}/{} iterations", i + 1, args.iterations);
} }
} }
LatencyStats::from_durations(miss_durations).print("FIND_MATCHES (MISS)", args.depth); LatencyStats::from_durations(miss_durations)
.unwrap()
.print("FIND_MATCHES (MISS)", args.depth);
// PARTIAL case // PARTIAL case
println!("\n --- PARTIAL case (prefix match only) ---"); println!("\n --- PARTIAL case (prefix match only) ---");
...@@ -604,7 +452,9 @@ fn bench_find_matches(args: &Args) { ...@@ -604,7 +452,9 @@ fn bench_find_matches(args: &Args) {
println!(" Completed {}/{} iterations", i + 1, args.iterations); println!(" Completed {}/{} iterations", i + 1, args.iterations);
} }
} }
LatencyStats::from_durations(partial_durations).print("FIND_MATCHES (PARTIAL)", args.depth); LatencyStats::from_durations(partial_durations)
.unwrap()
.print("FIND_MATCHES (PARTIAL)", args.depth);
// EARLY_EXIT case // EARLY_EXIT case
println!("\n --- EARLY_EXIT case ---"); println!("\n --- EARLY_EXIT case ---");
...@@ -617,6 +467,7 @@ fn bench_find_matches(args: &Args) { ...@@ -617,6 +467,7 @@ fn bench_find_matches(args: &Args) {
} }
} }
LatencyStats::from_durations(early_exit_durations) LatencyStats::from_durations(early_exit_durations)
.unwrap()
.print("FIND_MATCHES (EARLY_EXIT)", args.depth); .print("FIND_MATCHES (EARLY_EXIT)", args.depth);
} }
...@@ -845,6 +696,7 @@ fn bench_sweep(args: &Args) { ...@@ -845,6 +696,7 @@ fn bench_sweep(args: &Args) {
args.prefix_prompt_ratio, args.prefix_prompt_ratio,
num_prefix_prompts, num_prefix_prompts,
args.seed, args.seed,
true, // use_cumulative_hash
); );
let tree_sequences = &all_sequences[..num_sequences]; let tree_sequences = &all_sequences[..num_sequences];
let extra_sequences = &all_sequences[num_sequences..]; let extra_sequences = &all_sequences[num_sequences..];
...@@ -956,6 +808,7 @@ fn bench_dump(args: &Args) { ...@@ -956,6 +808,7 @@ fn bench_dump(args: &Args) {
args.prefix_prompt_ratio, args.prefix_prompt_ratio,
args.num_prefix_prompts, args.num_prefix_prompts,
args.seed, args.seed,
true, // use_cumulative_hash
); );
let tree = build_tree(&sequences); let tree = build_tree(&sequences);
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Benchmark utilities for kv-router benchmarks.
//!
//! This module provides shared data structures for benchmarking:
//! - `LatencyStats`: Statistics for latency measurements
//! - `SequenceData`: Pre-generated sequence data for benchmarking
use crate::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash, RouterEvent, WorkerId, compute_seq_hash_for_block,
};
use rand::{Rng, SeedableRng, rngs::StdRng};
use std::time::Duration;
/// Statistics for latency measurements.
#[derive(Debug, Clone)]
pub struct LatencyStats {
pub min: Duration,
pub max: Duration,
pub avg: Duration,
pub p50: Duration,
pub p95: Duration,
pub p99: Duration,
pub throughput_ops_sec: f64,
}
impl LatencyStats {
/// Compute statistics from a vector of durations.
///
/// Returns `None` if the input is empty.
pub fn from_durations(mut durations: Vec<Duration>) -> Option<Self> {
if durations.is_empty() {
return None;
}
durations.sort();
let n = durations.len();
let total: Duration = durations.iter().sum();
let avg = total / n as u32;
Some(Self {
min: durations[0],
max: durations[n - 1],
avg,
p50: durations[n / 2],
p95: durations[n * 95 / 100],
p99: durations[n * 99 / 100],
throughput_ops_sec: n as f64 / total.as_secs_f64(),
})
}
/// Print formatted latency statistics to stdout.
pub fn print(&self, operation: &str, blocks_per_op: usize) {
println!("\n{} Latency Statistics:", operation);
println!(" min: {:>12?}", self.min);
println!(" avg: {:>12?}", self.avg);
println!(" p50: {:>12?}", self.p50);
println!(" p95: {:>12?}", self.p95);
println!(" p99: {:>12?}", self.p99);
println!(" max: {:>12?}", self.max);
println!(" throughput: {:.2} ops/sec", self.throughput_ops_sec);
println!(
" throughput: {:.2} blocks/sec",
self.throughput_ops_sec * blocks_per_op as f64
);
}
}
/// Pre-generated sequence data for benchmarking.
#[derive(Clone)]
pub struct SequenceData {
pub worker_id: WorkerId,
pub local_hashes: Vec<LocalBlockHash>,
pub external_hashes: Vec<ExternalSequenceBlockHash>,
}
impl SequenceData {
/// Create a new sequence with synthetic hashes based on sequence ID.
pub fn new(seq_id: u64, worker_id: WorkerId, depth: usize) -> Self {
let local_hashes: Vec<LocalBlockHash> = (0..depth)
.map(|block_idx| LocalBlockHash((seq_id << 32) | (block_idx as u64)))
.collect();
let external_hashes: Vec<ExternalSequenceBlockHash> = (0..depth)
.map(|block_idx| ExternalSequenceBlockHash((seq_id << 32) | (block_idx as u64)))
.collect();
Self {
worker_id,
local_hashes,
external_hashes,
}
}
/// Create a sequence from local hashes, computing external hashes using cumulative hash.
///
/// This ensures FlatHashMap can correctly identify block positions.
pub fn from_local_hashes(worker_id: WorkerId, local_hashes: Vec<LocalBlockHash>) -> Self {
let seq_hashes = compute_seq_hash_for_block(&local_hashes);
let external_hashes = seq_hashes
.into_iter()
.map(ExternalSequenceBlockHash)
.collect();
Self {
worker_id,
local_hashes,
external_hashes,
}
}
/// Convert to a store event.
pub fn to_store_event(&self, event_id: u64) -> RouterEvent {
RouterEvent {
worker_id: self.worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: self
.local_hashes
.iter()
.zip(self.external_hashes.iter())
.map(|(local, ext)| KvCacheStoredBlockData {
tokens_hash: *local,
block_hash: *ext,
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
}
}
/// Convert to a remove event.
pub fn to_remove_event(&self, event_id: u64) -> RouterEvent {
RouterEvent {
worker_id: self.worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: self.external_hashes.clone(),
}),
dp_rank: 0,
},
}
}
}
/// Generate sequences with shared prefix prompts.
///
/// # Arguments
/// * `num_sequences` - Number of sequences to generate
/// * `depth` - Number of blocks per sequence
/// * `num_workers` - Number of workers to distribute sequences across
/// * `prefix_ratio` - Ratio of blocks that share a prefix (0.0 to 1.0)
/// * `num_prefix_groups` - Number of distinct prefix groups
/// * `seed` - Random seed for reproducibility
/// * `use_cumulative_hash` - If true, use `from_local_hashes` for proper cumulative hashes
pub fn generate_sequences(
num_sequences: usize,
depth: usize,
num_workers: usize,
prefix_ratio: f64,
num_prefix_groups: usize,
seed: u64,
use_cumulative_hash: bool,
) -> Vec<SequenceData> {
let mut sequences = Vec::with_capacity(num_sequences);
let prefix_length = (depth as f64 * prefix_ratio).round() as usize;
let mut rng: StdRng = StdRng::seed_from_u64(seed);
for seq_id in 0..num_sequences {
let seq_id_u64 = seq_id as u64;
let worker_id = (seq_id % num_workers) as WorkerId;
// Determine prefix group for this sequence
let group_id = if num_prefix_groups > 0 && prefix_length > 0 {
Some(rng.random_range(0..num_prefix_groups) as u64)
} else {
None
};
// Build local_hashes: shared prefix (if applicable) + unique suffix
let local_hashes: Vec<LocalBlockHash> = (0..depth)
.map(|block_idx| {
let block_idx_u64 = block_idx as u64;
if let Some(gid) = group_id {
if block_idx < prefix_length {
// Shared prefix based on group_id
return LocalBlockHash(0xDEAD_BEEF_0000_0000 | (gid << 32) | block_idx_u64);
}
}
// Unique suffix (or no shared prefix)
LocalBlockHash((seq_id_u64 << 32) | block_idx_u64)
})
.collect();
if use_cumulative_hash {
sequences.push(SequenceData::from_local_hashes(worker_id, local_hashes));
} else {
let external_hashes: Vec<ExternalSequenceBlockHash> = (0..depth)
.map(|block_idx| {
let block_idx_u64 = block_idx as u64;
if let Some(gid) = group_id {
if block_idx < prefix_length {
return ExternalSequenceBlockHash(
0xDEAD_BEEF_0000_0000 | (gid << 32) | block_idx_u64,
);
}
}
ExternalSequenceBlockHash((seq_id_u64 << 32) | block_idx_u64)
})
.collect();
sequences.push(SequenceData {
worker_id,
local_hashes,
external_hashes,
});
}
}
sequences
}
/// Compute median of durations.
pub fn median(durations: &[Duration]) -> Duration {
if durations.is_empty() {
return Duration::ZERO;
}
let mut sorted = durations.to_vec();
sorted.sort();
sorted[sorted.len() / 2]
}
...@@ -31,6 +31,9 @@ ...@@ -31,6 +31,9 @@
//! //!
//! This module provides a scalable and efficient way to manage and retrieve data blocks for LLM inference, leveraging a global KV cache to optimize performance. //! This module provides a scalable and efficient way to manage and retrieve data blocks for LLM inference, leveraging a global KV cache to optimize performance.
#[cfg(feature = "bench")]
use std::time::Instant;
use async_trait::async_trait; use async_trait::async_trait;
#[cfg(feature = "metrics")] #[cfg(feature = "metrics")]
pub use dynamo_runtime::protocols::maybe_error::MaybeError; pub use dynamo_runtime::protocols::maybe_error::MaybeError;
...@@ -335,6 +338,25 @@ pub struct MatchRequest { ...@@ -335,6 +338,25 @@ pub struct MatchRequest {
early_exit: bool, early_exit: bool,
/// A channel sender to send the `OverlapScores` response. /// A channel sender to send the `OverlapScores` response.
resp: oneshot::Sender<OverlapScores>, resp: oneshot::Sender<OverlapScores>,
/// Timestamp when the request was created (for queue wait time measurement)
#[cfg(feature = "bench")]
created_at: Instant,
}
impl MatchRequest {
fn new(
sequence: Vec<LocalBlockHash>,
early_exit: bool,
resp: oneshot::Sender<OverlapScores>,
) -> Self {
Self {
sequence,
early_exit,
resp,
#[cfg(feature = "bench")]
created_at: Instant::now(),
}
}
} }
/// A request to dump the tree as events /// A request to dump the tree as events
...@@ -551,10 +573,16 @@ impl KvIndexer { ...@@ -551,10 +573,16 @@ impl KvIndexer {
Some(event) = event_rx.recv() => { Some(event) = event_rx.recv() => {
let event_type = KvIndexerMetrics::get_event_type(&event.event.data); let event_type = KvIndexerMetrics::get_event_type(&event.event.data);
let event_id = event.event.event_id;
let worker_id = event.worker_id;
// Only clone if we need the event for prune_manager afterward // Only clone if we need the event for prune_manager afterward
let event_for_prune = prune_manager.is_some().then(|| event.clone()); let event_for_prune = prune_manager.is_some().then(|| event.clone());
let result = trie.apply_event(event); let result = trie.apply_event(event);
let result_is_ok = result.is_ok(); let result_is_ok = result.is_ok();
let tree_size = trie.current_size();
tracing::trace!(
"Applied KV event to global radix tree: event_type={event_type}, event_id={event_id}, worker_id={worker_id}, success={result_is_ok}, global_radix_tree_size={tree_size}"
);
metrics.increment_event_applied(event_type, result); metrics.increment_event_applied(event_type, result);
// Track blocks in PruneManager if TTL is enabled and event was stored successfully // Track blocks in PruneManager if TTL is enabled and event was stored successfully
...@@ -643,7 +671,24 @@ impl KvIndexer { ...@@ -643,7 +671,24 @@ impl KvIndexer {
} }
Some(req) = match_rx.recv() => { Some(req) = match_rx.recv() => {
#[cfg(feature = "bench")]
let queue_wait = req.created_at.elapsed();
#[cfg(feature = "bench")]
let seq_len = req.sequence.len();
#[cfg(feature = "bench")]
let process_start = Instant::now();
let matches = trie.find_matches(req.sequence, req.early_exit); let matches = trie.find_matches(req.sequence, req.early_exit);
#[cfg(feature = "bench")]
let process_time = process_start.elapsed();
#[cfg(feature = "bench")]
tracing::info!(
seq_len,
queue_wait_us = queue_wait.as_micros() as u64,
process_us = process_time.as_micros() as u64,
"indexer: processed find_matches"
);
let _ = req.resp.send(matches); let _ = req.resp.send(matches);
} }
...@@ -742,12 +787,11 @@ impl KvIndexerInterface for KvIndexer { ...@@ -742,12 +787,11 @@ impl KvIndexerInterface for KvIndexer {
&self, &self,
sequence: Vec<LocalBlockHash>, sequence: Vec<LocalBlockHash>,
) -> Result<OverlapScores, KvRouterError> { ) -> Result<OverlapScores, KvRouterError> {
#[cfg(feature = "bench")]
let start = Instant::now();
let seq_len = sequence.len();
let (resp_tx, resp_rx) = oneshot::channel(); let (resp_tx, resp_rx) = oneshot::channel();
let req = MatchRequest { let req = MatchRequest::new(sequence, false, resp_tx);
sequence,
early_exit: false,
resp: resp_tx,
};
if let Err(e) = self.match_tx.send(req).await { if let Err(e) = self.match_tx.send(req).await {
tracing::error!( tracing::error!(
...@@ -757,9 +801,23 @@ impl KvIndexerInterface for KvIndexer { ...@@ -757,9 +801,23 @@ impl KvIndexerInterface for KvIndexer {
return Err(KvRouterError::IndexerOffline); return Err(KvRouterError::IndexerOffline);
} }
resp_rx let result = resp_rx
.await .await
.map_err(|_| KvRouterError::IndexerDroppedRequest) .map_err(|_| KvRouterError::IndexerDroppedRequest);
#[cfg(feature = "bench")]
{
let elapsed = start.elapsed();
tracing::info!(
seq_len,
elapsed_us = elapsed.as_micros() as u64,
"find_matches completed"
);
}
#[cfg(not(feature = "bench"))]
let _ = seq_len;
result
} }
async fn find_matches_for_request( async fn find_matches_for_request(
...@@ -1131,6 +1189,24 @@ pub struct ShardedMatchRequest { ...@@ -1131,6 +1189,24 @@ pub struct ShardedMatchRequest {
sequence: Vec<LocalBlockHash>, sequence: Vec<LocalBlockHash>,
early_exit: bool, early_exit: bool,
resp: mpsc::Sender<OverlapScores>, resp: mpsc::Sender<OverlapScores>,
#[cfg(feature = "bench")]
created_at: Instant,
}
impl ShardedMatchRequest {
fn new(
sequence: Vec<LocalBlockHash>,
early_exit: bool,
resp: mpsc::Sender<OverlapScores>,
) -> Self {
Self {
sequence,
early_exit,
resp,
#[cfg(feature = "bench")]
created_at: Instant::now(),
}
}
} }
/// A sharded KV Indexer that partitions the RadixTree across multiple independent shards. /// A sharded KV Indexer that partitions the RadixTree across multiple independent shards.
...@@ -1374,7 +1450,24 @@ impl KvIndexerSharded { ...@@ -1374,7 +1450,24 @@ impl KvIndexerSharded {
} }
Ok(req) = shard_broadcast_rx.recv() => { Ok(req) = shard_broadcast_rx.recv() => {
#[cfg(feature = "bench")]
let queue_wait = req.created_at.elapsed();
#[cfg(feature = "bench")]
let seq_len = req.sequence.len();
#[cfg(feature = "bench")]
let process_start = Instant::now();
let matches = trie.find_matches(req.sequence, req.early_exit); let matches = trie.find_matches(req.sequence, req.early_exit);
#[cfg(feature = "bench")]
let process_time = process_start.elapsed();
#[cfg(feature = "bench")]
tracing::info!(
seq_len,
queue_wait_us = queue_wait.as_micros() as u64,
process_us = process_time.as_micros() as u64,
"sharded indexer: processed find_matches"
);
if let Err(e) = req.resp.send(matches).await { if let Err(e) = req.resp.send(matches).await {
tracing::trace!("Failed to send match response: {:?}", e); tracing::trace!("Failed to send match response: {:?}", e);
} }
...@@ -1442,14 +1535,18 @@ impl KvIndexerInterface for KvIndexerSharded { ...@@ -1442,14 +1535,18 @@ impl KvIndexerInterface for KvIndexerSharded {
&self, &self,
sequence: Vec<LocalBlockHash>, sequence: Vec<LocalBlockHash>,
) -> Result<OverlapScores, KvRouterError> { ) -> Result<OverlapScores, KvRouterError> {
#[cfg(feature = "bench")]
let start = Instant::now();
#[cfg(feature = "bench")]
let seq_len = sequence.len();
#[cfg(feature = "bench")]
let num_shards = self.event_tx.len();
'match_loop: loop { 'match_loop: loop {
let (match_tx, mut match_rx) = mpsc::channel(self.event_tx.len()); let (match_tx, mut match_rx) = mpsc::channel(self.event_tx.len());
let sharded_req = ShardedMatchRequest::new(sequence.clone(), false, match_tx);
self.request_broadcast_tx self.request_broadcast_tx
.send(ShardedMatchRequest { .send(sharded_req)
sequence: sequence.clone(),
early_exit: false,
resp: match_tx,
})
.map_err(|_| KvRouterError::IndexerOffline)?; .map_err(|_| KvRouterError::IndexerOffline)?;
let mut scores = OverlapScores::new(); let mut scores = OverlapScores::new();
...@@ -1482,6 +1579,17 @@ impl KvIndexerInterface for KvIndexerSharded { ...@@ -1482,6 +1579,17 @@ impl KvIndexerInterface for KvIndexerSharded {
} }
} }
} }
#[cfg(feature = "bench")]
{
let elapsed = start.elapsed();
tracing::info!(
seq_len,
num_shards,
elapsed_us = elapsed.as_micros() as u64,
"find_matches (sharded) completed"
);
}
return Ok(scores); return Ok(scores);
} }
} }
......
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
//! efficient KV cache lookup and routing in distributed LLM inference systems. //! efficient KV cache lookup and routing in distributed LLM inference systems.
pub mod approx; pub mod approx;
#[cfg(feature = "bench")]
pub mod bench_utils;
pub mod flat_hashmap; pub mod flat_hashmap;
pub mod indexer; pub mod indexer;
pub mod protocols; pub mod protocols;
......
...@@ -27,7 +27,8 @@ cuda = ["dep:cudarc"] ...@@ -27,7 +27,8 @@ cuda = ["dep:cudarc"]
integration = ["dynamo-runtime/integration"] integration = ["dynamo-runtime/integration"]
media-nixl = ["dep:nixl-sys", "dep:flate2"] media-nixl = ["dep:nixl-sys", "dep:flate2"]
media-ffmpeg = ["dep:video-rs", "dep:ffmpeg-next", "dep:memfile", "media-nixl"] media-ffmpeg = ["dep:video-rs", "dep:ffmpeg-next", "dep:memfile", "media-nixl"]
kv-router-stress = ["dep:clap", "dep:indicatif"] bench = ["dynamo-kv-router/bench"]
kv-router-stress = ["dep:clap", "dep:indicatif", "bench"]
[[bench]] [[bench]]
name = "tokenizer" name = "tokenizer"
...@@ -38,6 +39,11 @@ name = "transfer_context_v2" ...@@ -38,6 +39,11 @@ name = "transfer_context_v2"
harness = false harness = false
required-features = ["block-manager", "testing-cuda"] required-features = ["block-manager", "testing-cuda"]
[[bench]]
name = "kv_router_bench"
harness = false
required-features = ["kv-router-stress"]
[dependencies] [dependencies]
# repo # repo
dynamo-runtime = { workspace = true } dynamo-runtime = { workspace = true }
......
This diff is collapsed.
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
#[cfg(feature = "bench")]
use std::time::Instant;
use anyhow::Result; use anyhow::Result;
use derive_builder::Builder; use derive_builder::Builder;
...@@ -504,6 +506,9 @@ impl KvRouter { ...@@ -504,6 +506,9 @@ impl KvRouter {
update_states: bool, update_states: bool,
lora_name: Option<String>, lora_name: Option<String>,
) -> anyhow::Result<(WorkerWithDpRank, u32)> { ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
#[cfg(feature = "bench")]
let start = Instant::now();
// Validate that context_id is provided when update_states is true // Validate that context_id is provided when update_states is true
if update_states && context_id.is_none() { if update_states && context_id.is_none() {
panic!("context_id must be provided if update_states is true"); panic!("context_id must be provided if update_states is true");
...@@ -512,7 +517,11 @@ impl KvRouter { ...@@ -512,7 +517,11 @@ impl KvRouter {
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None); let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
#[cfg(feature = "bench")]
let hash_elapsed = start.elapsed();
let overlap_scores = self.indexer.find_matches(block_hashes).await?; let overlap_scores = self.indexer.find_matches(block_hashes).await?;
#[cfg(feature = "bench")]
let find_matches_elapsed = start.elapsed();
// Compute seq_hashes only if scheduler needs it for active blocks tracking // Compute seq_hashes only if scheduler needs it for active blocks tracking
let maybe_seq_hashes = self let maybe_seq_hashes = self
...@@ -532,6 +541,19 @@ impl KvRouter { ...@@ -532,6 +541,19 @@ impl KvRouter {
) )
.await?; .await?;
#[cfg(feature = "bench")]
{
let total_elapsed = start.elapsed();
tracing::info!(
isl_tokens,
hash_us = hash_elapsed.as_micros() as u64,
find_matches_us = (find_matches_elapsed - hash_elapsed).as_micros() as u64,
schedule_us = (total_elapsed - find_matches_elapsed).as_micros() as u64,
total_us = total_elapsed.as_micros() as u64,
"find_best_match completed"
);
}
// Note: Routing decision recording (for approximate mode) is now handled // Note: Routing decision recording (for approximate mode) is now handled
// by KvPushRouter::generate after select_worker returns. // by KvPushRouter::generate after select_worker returns.
......
...@@ -12,6 +12,8 @@ use serde::{Deserialize, Serialize}; ...@@ -12,6 +12,8 @@ use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
#[cfg(feature = "bench")]
use std::time::Instant;
use super::KV_HIT_RATE_SUBJECT; use super::KV_HIT_RATE_SUBJECT;
use super::KvRouterConfig; use super::KvRouterConfig;
...@@ -288,6 +290,9 @@ impl KvScheduler { ...@@ -288,6 +290,9 @@ impl KvScheduler {
update_states: bool, update_states: bool,
lora_name: Option<String>, lora_name: Option<String>,
) -> Result<WorkerWithDpRank, KvSchedulerError> { ) -> Result<WorkerWithDpRank, KvSchedulerError> {
#[cfg(feature = "bench")]
let start = Instant::now();
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest { let request = SchedulingRequest {
maybe_request_id, maybe_request_id,
...@@ -306,10 +311,24 @@ impl KvScheduler { ...@@ -306,10 +311,24 @@ impl KvScheduler {
.send(request) .send(request)
.await .await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?; .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
#[cfg(feature = "bench")]
let send_elapsed = start.elapsed();
let response = resp_rx let response = resp_rx
.await .await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?; .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
#[cfg(feature = "bench")]
let total_elapsed = start.elapsed();
#[cfg(feature = "bench")]
tracing::info!(
isl_tokens,
send_us = send_elapsed.as_micros() as u64,
total_us = total_elapsed.as_micros() as u64,
"scheduler.schedule completed"
);
Ok(response.best_worker) Ok(response.best_worker)
} }
......
...@@ -1175,6 +1175,11 @@ impl ActiveSequencesMultiWorker { ...@@ -1175,6 +1175,11 @@ impl ActiveSequencesMultiWorker {
HashMap<WorkerWithDpRank, usize>, HashMap<WorkerWithDpRank, usize>,
HashMap<WorkerWithDpRank, usize>, HashMap<WorkerWithDpRank, usize>,
) { ) {
#[cfg(feature = "bench")]
let start = Instant::now();
#[cfg(feature = "bench")]
let num_workers = self.senders.len();
let mut potential_blocks = HashMap::new(); let mut potential_blocks = HashMap::new();
let mut potential_tokens = HashMap::new(); let mut potential_tokens = HashMap::new();
let token_sequence_shared = token_sequence.map(Arc::new); let token_sequence_shared = token_sequence.map(Arc::new);
...@@ -1206,6 +1211,9 @@ impl ActiveSequencesMultiWorker { ...@@ -1206,6 +1211,9 @@ impl ActiveSequencesMultiWorker {
} }
} }
#[cfg(feature = "bench")]
let send_elapsed = start.elapsed();
// Collect results from all workers // Collect results from all workers
for (worker, receiver) in receivers { for (worker, receiver) in receivers {
match tokio::time::timeout(tokio::time::Duration::from_secs(1), receiver).await { match tokio::time::timeout(tokio::time::Duration::from_secs(1), receiver).await {
...@@ -1222,6 +1230,17 @@ impl ActiveSequencesMultiWorker { ...@@ -1222,6 +1230,17 @@ impl ActiveSequencesMultiWorker {
} }
} }
#[cfg(feature = "bench")]
{
let total_elapsed = start.elapsed();
tracing::info!(
num_workers,
send_us = send_elapsed.as_micros() as u64,
total_us = total_elapsed.as_micros() as u64,
"potential_blocks_and_tokens completed"
);
}
(potential_blocks, potential_tokens) (potential_blocks, potential_tokens)
} }
......
...@@ -28,7 +28,7 @@ use crate::preprocessor::media::MediaDecoder; ...@@ -28,7 +28,7 @@ use crate::preprocessor::media::MediaDecoder;
pub mod deepseek_v32; pub mod deepseek_v32;
mod template; mod template;
pub use template::ContextMixins; pub use template::{ChatTemplate, ContextMixins};
#[derive(Debug)] #[derive(Debug)]
pub enum TokenInput { pub enum TokenInput {
...@@ -95,6 +95,7 @@ pub trait OAIPromptFormatter: Send + Sync + 'static { ...@@ -95,6 +95,7 @@ pub trait OAIPromptFormatter: Send + Sync + 'static {
fn render(&self, req: &dyn OAIChatLikeRequest) -> Result<String>; fn render(&self, req: &dyn OAIChatLikeRequest) -> Result<String>;
} }
#[derive(Clone)]
pub enum PromptFormatter { pub enum PromptFormatter {
OAI(Arc<dyn OAIPromptFormatter>), OAI(Arc<dyn OAIPromptFormatter>),
} }
......
...@@ -14,7 +14,8 @@ mod oai; ...@@ -14,7 +14,8 @@ mod oai;
mod tokcfg; mod tokcfg;
use super::{OAIChatLikeRequest, OAIPromptFormatter, PromptFormatter}; use super::{OAIChatLikeRequest, OAIPromptFormatter, PromptFormatter};
use tokcfg::{ChatTemplate, ChatTemplateValue}; pub use tokcfg::ChatTemplate;
use tokcfg::ChatTemplateValue;
impl PromptFormatter { impl PromptFormatter {
pub fn from_mdc(mdc: &ModelDeploymentCard) -> Result<PromptFormatter> { pub fn from_mdc(mdc: &ModelDeploymentCard) -> Result<PromptFormatter> {
......
...@@ -28,7 +28,7 @@ pub use etcd::EtcdStore; ...@@ -28,7 +28,7 @@ pub use etcd::EtcdStore;
mod file; mod file;
pub use file::FileStore; pub use file::FileStore;
const WATCH_SEND_TIMEOUT: Duration = Duration::from_millis(100); const WATCH_SEND_TIMEOUT: Duration = Duration::from_millis(1000);
/// String we use as the Key in a key-value storage operation. Simple String wrapper /// String we use as the Key in a key-value storage operation. Simple String wrapper
/// that can encode / decode a string. /// that can encode / decode a string.
...@@ -324,7 +324,7 @@ impl Manager { ...@@ -324,7 +324,7 @@ impl Manager {
tokio::sync::mpsc::Receiver<WatchEvent>, tokio::sync::mpsc::Receiver<WatchEvent>,
) { ) {
let bucket_name = bucket_name.to_string(); let bucket_name = bucket_name.to_string();
let (tx, rx) = tokio::sync::mpsc::channel(128); let (tx, rx) = tokio::sync::mpsc::channel(1024);
let watch_task = tokio::spawn(async move { let watch_task = tokio::spawn(async move {
// Start listening for changes but don't poll this yet // Start listening for changes but don't poll this yet
let bucket = self let bucket = self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment