Unverified Commit 039d35ff authored by Janelle Cai's avatar Janelle Cai Committed by GitHub
Browse files

test: indexer and full router benchmarks (#5784)

parent 051f18a4
......@@ -46,3 +46,8 @@ tokio = { workspace = true, features = ["rt", "macros", "time"] }
name = "radix_tree_microbench"
harness = false
required-features = ["bench"]
[[bench]]
name = "kv_indexer_bench"
harness = false
required-features = ["bench"]
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Combined benchmark for KvIndexer and KvIndexerSharded.
//!
//! Provides two modes:
//! - `microbench`: Per-operation latency benchmarks comparing single vs sharded indexer
//! - `stress`: Queue saturation stress test under load
//!
//! Run with:
//! cargo bench --package dynamo-kv-router --bench kv_indexer_bench --features bench -- microbench --help
//! cargo bench --package dynamo-kv-router --bench kv_indexer_bench --features bench -- stress --help
use clap::{Args, Parser, Subcommand, ValueEnum};
use dynamo_kv_router::{
bench_utils::{LatencyStats, SequenceData, generate_sequences, median},
indexer::{KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvIndexerSharded},
protocols::{LocalBlockHash, RouterEvent},
};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
// ============================================================================
// CLI Definitions
// ============================================================================
#[derive(Parser)]
#[command(name = "kv_indexer_bench")]
#[command(about = "Combined benchmark for KvIndexer and KvIndexerSharded")]
struct Cli {
#[command(subcommand)]
command: Command,
/// Ignored - passed by cargo bench harness
#[arg(long, hide = true, global = true)]
bench: bool,
}
#[derive(Subcommand)]
enum Command {
/// Per-operation latency benchmarks comparing single vs sharded indexer
Microbench(MicrobenchArgs),
/// Queue saturation stress test under load
Stress(StressArgs),
}
/// Indexer type to benchmark
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
enum IndexerType {
/// Non-sharded KvIndexer (single background thread)
Single,
/// Sharded KvIndexer (multiple shards with separate trees)
Sharded,
/// Run both and compare
Both,
}
/// Common arguments shared between subcommands
#[derive(Args, Debug, Clone)]
struct CommonArgs {
/// Target tree size in total (worker, block) pairs
#[arg(long, default_value = "100000")]
size: usize,
/// Sequence depth in blocks (blocks per sequence)
#[arg(long, default_value = "64")]
depth: usize,
/// Number of workers to distribute blocks across
#[arg(long, default_value = "4")]
num_workers: usize,
/// KV block size in tokens
#[arg(long, default_value = "16")]
block_size: u32,
/// Random seed for reproducibility
#[arg(long, default_value = "42")]
seed: u64,
/// Verbose output
#[arg(short, long)]
verbose: bool,
}
#[derive(Args, Debug)]
struct MicrobenchArgs {
#[command(flatten)]
common: CommonArgs,
/// Number of iterations per operation for timing
#[arg(long, default_value = "1000")]
iterations: usize,
/// Prefix prompt ratio (0.0 to 1.0)
#[arg(long, default_value = "0.25")]
prefix_prompt_ratio: f64,
/// Number of unique prefix prompt groups
#[arg(long, default_value = "4")]
num_prefix_prompts: usize,
/// Indexer type to benchmark
#[arg(long, value_enum, default_value = "both")]
indexer_type: IndexerType,
/// Number of shards for sharded indexer
#[arg(long, default_value = "4")]
num_shards: usize,
/// Run only specific benchmark (store, find_matches, remove, or all)
#[arg(long, default_value = "all")]
benchmark_type: String,
/// Output format: "table" or "csv"
#[arg(long, default_value = "table")]
format: String,
}
#[derive(Args, Debug)]
struct StressArgs {
#[command(flatten)]
common: CommonArgs,
/// Prefix sharing ratio (0.0 to 1.0) - fraction of sequences sharing a common prefix
#[arg(long, default_value = "0.5")]
prefix_share_ratio: f64,
/// Requests per second to submit
#[arg(long, default_value = "20.0")]
arrival_rate: f64,
/// Test duration in seconds
#[arg(long, default_value = "10")]
duration: u64,
/// Seconds to wait for in-flight requests after test
#[arg(long, default_value = "5")]
in_flight_timeout: u64,
/// Indexer type to stress test
#[arg(long, value_enum, default_value = "single")]
indexer_type: IndexerType,
/// Number of shards for sharded indexer
#[arg(long, default_value = "4")]
num_shards: usize,
}
// ============================================================================
// Benchable Indexer Trait
// ============================================================================
/// Trait for abstracting over KvIndexer and KvIndexerSharded
#[async_trait::async_trait]
trait BenchableIndexer: Send + Sync {
async fn apply_event(&mut self, event: RouterEvent);
async fn find_matches(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<(), dynamo_kv_router::indexer::KvRouterError>;
fn name(&self) -> &str;
}
#[async_trait::async_trait]
impl BenchableIndexer for KvIndexer {
async fn apply_event(&mut self, event: RouterEvent) {
KvIndexerInterface::apply_event(self, event).await;
}
async fn find_matches(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<(), dynamo_kv_router::indexer::KvRouterError> {
KvIndexerInterface::find_matches(self, sequence).await?;
Ok(())
}
fn name(&self) -> &str {
"KvIndexer (single)"
}
}
#[async_trait::async_trait]
impl BenchableIndexer for KvIndexerSharded {
async fn apply_event(&mut self, event: RouterEvent) {
KvIndexerInterface::apply_event(self, event).await;
}
async fn find_matches(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<(), dynamo_kv_router::indexer::KvRouterError> {
KvIndexerInterface::find_matches(self, sequence).await?;
Ok(())
}
fn name(&self) -> &str {
"KvIndexerSharded"
}
}
// ============================================================================
// Microbench Mode
// ============================================================================
/// Results for a single indexer benchmark
#[derive(Debug)]
struct MicrobenchResults {
indexer_name: String,
construction_time: Duration,
construction_events: usize,
store_stats: Option<LatencyStats>,
find_matches_hit_stats: Option<LatencyStats>,
find_matches_miss_stats: Option<LatencyStats>,
remove_stats: Option<LatencyStats>,
}
impl MicrobenchResults {
fn print(&self, depth: usize) {
println!("\n========================================");
println!("Results for: {}", self.indexer_name);
println!("========================================");
println!("\nConstruction:");
println!(" Time: {:?}", self.construction_time);
println!(" Events: {}", self.construction_events);
println!(
" Throughput: {:.0} events/sec",
self.construction_events as f64 / self.construction_time.as_secs_f64()
);
if let Some(ref stats) = self.store_stats {
stats.print("APPLY_EVENT (store)", depth);
}
if let Some(ref stats) = self.find_matches_hit_stats {
stats.print("FIND_MATCHES (hit)", depth);
}
if let Some(ref stats) = self.find_matches_miss_stats {
stats.print("FIND_MATCHES (miss)", depth);
}
if let Some(ref stats) = self.remove_stats {
stats.print("APPLY_EVENT (remove)", depth);
}
}
fn print_csv_header() {
println!(
"indexer,construction_ms,construction_events,construction_throughput,\
store_avg_us,store_p50_us,store_p99_us,store_throughput,\
find_hit_avg_us,find_hit_p50_us,find_hit_p99_us,find_hit_throughput,\
find_miss_avg_us,find_miss_p50_us,find_miss_p99_us,find_miss_throughput,\
remove_avg_us,remove_p50_us,remove_p99_us,remove_throughput"
);
}
fn print_csv_row(&self) {
let construction_throughput =
self.construction_events as f64 / self.construction_time.as_secs_f64();
let store = self.store_stats.as_ref();
let find_hit = self.find_matches_hit_stats.as_ref();
let find_miss = self.find_matches_miss_stats.as_ref();
let remove = self.remove_stats.as_ref();
println!(
"{},{:.3},{},{:.0},{},{},{},{:.0},{},{},{},{:.0},{},{},{},{:.0},{},{},{},{:.0}",
self.indexer_name,
self.construction_time.as_secs_f64() * 1000.0,
self.construction_events,
construction_throughput,
store.map(|s| s.avg.as_micros()).unwrap_or(0),
store.map(|s| s.p50.as_micros()).unwrap_or(0),
store.map(|s| s.p99.as_micros()).unwrap_or(0),
store.map(|s| s.throughput_ops_sec).unwrap_or(0.0),
find_hit.map(|s| s.avg.as_micros()).unwrap_or(0),
find_hit.map(|s| s.p50.as_micros()).unwrap_or(0),
find_hit.map(|s| s.p99.as_micros()).unwrap_or(0),
find_hit.map(|s| s.throughput_ops_sec).unwrap_or(0.0),
find_miss.map(|s| s.avg.as_micros()).unwrap_or(0),
find_miss.map(|s| s.p50.as_micros()).unwrap_or(0),
find_miss.map(|s| s.p99.as_micros()).unwrap_or(0),
find_miss.map(|s| s.throughput_ops_sec).unwrap_or(0.0),
remove.map(|s| s.avg.as_micros()).unwrap_or(0),
remove.map(|s| s.p50.as_micros()).unwrap_or(0),
remove.map(|s| s.p99.as_micros()).unwrap_or(0),
remove.map(|s| s.throughput_ops_sec).unwrap_or(0.0),
);
}
}
/// Build a pre-populated indexer
async fn build_indexer<I: BenchableIndexer>(
indexer: &mut I,
sequences: &[SequenceData],
verbose: bool,
) -> Duration {
let num_blocks: usize = sequences.iter().map(|s| s.local_hashes.len()).sum();
print!(
" Building {} with {} sequences ({} blocks)... ",
indexer.name(),
sequences.len(),
num_blocks
);
std::io::Write::flush(&mut std::io::stdout()).unwrap();
let start = Instant::now();
for (event_id, seq) in sequences.iter().enumerate() {
let event = seq.to_store_event(event_id as u64);
indexer.apply_event(event).await;
if verbose && (event_id + 1) % 1000 == 0 {
print!(".");
std::io::Write::flush(&mut std::io::stdout()).unwrap();
}
}
let elapsed = start.elapsed();
// Allow background processing to complete
tokio::time::sleep(Duration::from_millis(50)).await;
println!(
"done in {:.2?} ({:.2} events/sec)",
elapsed,
sequences.len() as f64 / elapsed.as_secs_f64()
);
elapsed
}
/// Benchmark apply_event (store) operation
async fn bench_store<I: BenchableIndexer>(
indexer: &mut I,
extra_sequences: &[SequenceData],
iterations: usize,
verbose: bool,
) -> LatencyStats {
println!("\n Benchmarking APPLY_EVENT (store)...");
let mut durations = Vec::with_capacity(iterations);
for (i, seq) in extra_sequences.iter().enumerate().take(iterations) {
let event = seq.to_store_event((1_000_000 + i) as u64);
let start = Instant::now();
indexer.apply_event(event).await;
durations.push(start.elapsed());
// Remove to restore state (untimed)
let remove_event = seq.to_remove_event((2_000_000 + i) as u64);
indexer.apply_event(remove_event).await;
if verbose && (i + 1) % 100 == 0 {
println!(" Completed {}/{} iterations", i + 1, iterations);
}
}
LatencyStats::from_durations(durations).unwrap()
}
/// Benchmark find_matches operation (hit case)
async fn bench_find_matches_hit<I: BenchableIndexer>(
indexer: &I,
sequences: &[SequenceData],
iterations: usize,
verbose: bool,
) -> LatencyStats {
println!("\n Benchmarking FIND_MATCHES (hit)...");
let mut durations = Vec::with_capacity(iterations);
for i in 0..iterations {
let seq = &sequences[i % sequences.len()];
let hashes = seq.local_hashes.clone();
let start = Instant::now();
let _ = indexer.find_matches(hashes).await;
durations.push(start.elapsed());
if verbose && (i + 1) % 100 == 0 {
println!(" Completed {}/{} iterations", i + 1, iterations);
}
}
LatencyStats::from_durations(durations).unwrap()
}
/// Benchmark find_matches operation (miss case)
async fn bench_find_matches_miss<I: BenchableIndexer>(
indexer: &I,
depth: usize,
iterations: usize,
verbose: bool,
) -> LatencyStats {
println!("\n Benchmarking FIND_MATCHES (miss)...");
let mut durations = Vec::with_capacity(iterations);
for i in 0..iterations {
let miss_hashes: Vec<LocalBlockHash> = (0..depth)
.map(|j| LocalBlockHash(0xBAD_C0DE_0000_0000 | ((i as u64) << 16) | (j as u64)))
.collect();
let start = Instant::now();
let _ = indexer.find_matches(miss_hashes).await;
durations.push(start.elapsed());
if verbose && (i + 1) % 100 == 0 {
println!(" Completed {}/{} iterations", i + 1, iterations);
}
}
LatencyStats::from_durations(durations).unwrap()
}
/// Benchmark apply_event (remove) operation
async fn bench_remove<I: BenchableIndexer>(
indexer: &mut I,
sequences: &[SequenceData],
iterations: usize,
verbose: bool,
) -> LatencyStats {
println!("\n Benchmarking APPLY_EVENT (remove)...");
let mut durations = Vec::with_capacity(iterations);
for i in 0..iterations {
let seq = &sequences[i % sequences.len()];
let remove_event = seq.to_remove_event((3_000_000 + i) as u64);
let start = Instant::now();
indexer.apply_event(remove_event).await;
durations.push(start.elapsed());
// Re-add to restore state (untimed)
let store_event = seq.to_store_event((4_000_000 + i) as u64);
indexer.apply_event(store_event).await;
if verbose && (i + 1) % 100 == 0 {
println!(" Completed {}/{} iterations", i + 1, iterations);
}
}
LatencyStats::from_durations(durations).unwrap()
}
/// Run all microbenchmarks for an indexer
async fn run_microbenchmarks<I: BenchableIndexer>(
indexer: &mut I,
sequences: &[SequenceData],
extra_sequences: &[SequenceData],
args: &MicrobenchArgs,
) -> MicrobenchResults {
let indexer_name = indexer.name().to_string();
println!("\n--- Benchmarking {} ---", indexer_name);
// Build the indexer
let construction_time = build_indexer(indexer, sequences, args.common.verbose).await;
let construction_events = sequences.len();
let run_all = args.benchmark_type == "all";
let store_stats = if run_all || args.benchmark_type == "store" {
Some(
bench_store(
indexer,
extra_sequences,
args.iterations,
args.common.verbose,
)
.await,
)
} else {
None
};
let find_matches_hit_stats = if run_all || args.benchmark_type == "find_matches" {
Some(bench_find_matches_hit(indexer, sequences, args.iterations, args.common.verbose).await)
} else {
None
};
let find_matches_miss_stats = if run_all || args.benchmark_type == "find_matches" {
Some(
bench_find_matches_miss(
indexer,
args.common.depth,
args.iterations,
args.common.verbose,
)
.await,
)
} else {
None
};
let remove_stats = if run_all || args.benchmark_type == "remove" {
Some(bench_remove(indexer, sequences, args.iterations, args.common.verbose).await)
} else {
None
};
MicrobenchResults {
indexer_name,
construction_time,
construction_events,
store_stats,
find_matches_hit_stats,
find_matches_miss_stats,
remove_stats,
}
}
fn print_microbench_comparison(results: &[MicrobenchResults], _depth: usize) {
if results.len() < 2 {
return;
}
println!("\n========================================");
println!("COMPARISON SUMMARY");
println!("========================================\n");
let single = &results[0];
let sharded = &results[1];
println!(
"{:<30} {:>15} {:>15} {:>10}",
"Metric", "Single", "Sharded", "Ratio"
);
println!("{}", "-".repeat(72));
// Construction
let single_constr = single.construction_time.as_secs_f64() * 1000.0;
let sharded_constr = sharded.construction_time.as_secs_f64() * 1000.0;
println!(
"{:<30} {:>12.2}ms {:>12.2}ms {:>9.2}x",
"Construction time",
single_constr,
sharded_constr,
single_constr / sharded_constr
);
// Store p50
if let (Some(s1), Some(s2)) = (&single.store_stats, &sharded.store_stats) {
let s1_us = s1.p50.as_nanos() as f64 / 1000.0;
let s2_us = s2.p50.as_nanos() as f64 / 1000.0;
println!(
"{:<30} {:>12.2}us {:>12.2}us {:>9.2}x",
"Store p50",
s1_us,
s2_us,
s1_us / s2_us
);
}
// Find matches hit p50
if let (Some(s1), Some(s2)) = (
&single.find_matches_hit_stats,
&sharded.find_matches_hit_stats,
) {
let s1_us = s1.p50.as_nanos() as f64 / 1000.0;
let s2_us = s2.p50.as_nanos() as f64 / 1000.0;
println!(
"{:<30} {:>12.2}us {:>12.2}us {:>9.2}x",
"Find matches (hit) p50",
s1_us,
s2_us,
s1_us / s2_us
);
}
// Find matches hit p99
if let (Some(s1), Some(s2)) = (
&single.find_matches_hit_stats,
&sharded.find_matches_hit_stats,
) {
let s1_us = s1.p99.as_nanos() as f64 / 1000.0;
let s2_us = s2.p99.as_nanos() as f64 / 1000.0;
println!(
"{:<30} {:>12.2}us {:>12.2}us {:>9.2}x",
"Find matches (hit) p99",
s1_us,
s2_us,
s1_us / s2_us
);
}
// Find matches miss p50
if let (Some(s1), Some(s2)) = (
&single.find_matches_miss_stats,
&sharded.find_matches_miss_stats,
) {
let s1_us = s1.p50.as_nanos() as f64 / 1000.0;
let s2_us = s2.p50.as_nanos() as f64 / 1000.0;
println!(
"{:<30} {:>12.2}us {:>12.2}us {:>9.2}x",
"Find matches (miss) p50",
s1_us,
s2_us,
s1_us / s2_us
);
}
// Remove p50
if let (Some(s1), Some(s2)) = (&single.remove_stats, &sharded.remove_stats) {
let s1_us = s1.p50.as_nanos() as f64 / 1000.0;
let s2_us = s2.p50.as_nanos() as f64 / 1000.0;
println!(
"{:<30} {:>12.2}us {:>12.2}us {:>9.2}x",
"Remove p50",
s1_us,
s2_us,
s1_us / s2_us
);
}
// Throughput comparison
println!();
println!(
"{:<30} {:>15} {:>15} {:>10}",
"Throughput (ops/sec)", "Single", "Sharded", "Ratio"
);
println!("{}", "-".repeat(72));
if let (Some(s1), Some(s2)) = (
&single.find_matches_hit_stats,
&sharded.find_matches_hit_stats,
) {
println!(
"{:<30} {:>12.0}/s {:>12.0}/s {:>9.2}x",
"Find matches (hit)",
s1.throughput_ops_sec,
s2.throughput_ops_sec,
s2.throughput_ops_sec / s1.throughput_ops_sec
);
}
println!("\nNote: Ratio > 1.0 means sharded is faster for that metric.");
}
async fn run_microbench_mode(args: MicrobenchArgs) {
let num_sequences = args.common.size / args.common.depth;
if num_sequences == 0 {
eprintln!("Error: size must be >= depth");
std::process::exit(1);
}
println!("KvIndexer Microbenchmark");
println!("========================\n");
println!("Configuration:");
println!(" Target size: {} (worker, block) pairs", args.common.size);
println!(
" Depth: {} blocks/sequence (= {} tokens with block_size={})",
args.common.depth,
args.common.depth * args.common.block_size as usize,
args.common.block_size
);
println!(" Block size: {} tokens", args.common.block_size);
println!(" Workers: {}", args.common.num_workers);
println!(" Iterations: {}", args.iterations);
println!(
" Prefix prompt ratio: {:.1}%",
args.prefix_prompt_ratio * 100.0
);
println!(" Prefix prompt groups: {}", args.num_prefix_prompts);
println!(" Num shards (for sharded): {}", args.num_shards);
println!(" Indexer type: {:?}", args.indexer_type);
println!(" Benchmark type: {}", args.benchmark_type);
println!(
"\n Derived: {} sequences to reach target size",
num_sequences
);
// Generate sequences
let extra_count = args.iterations;
let all_sequences = generate_sequences(
num_sequences + extra_count,
args.common.depth,
args.common.num_workers,
args.prefix_prompt_ratio,
args.num_prefix_prompts,
args.common.seed,
false,
);
let sequences = &all_sequences[..num_sequences];
let extra_sequences = &all_sequences[num_sequences..];
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let mut results = Vec::new();
// Benchmark single indexer
if matches!(args.indexer_type, IndexerType::Single | IndexerType::Both) {
let token = CancellationToken::new();
let mut indexer = KvIndexer::new(token.clone(), args.common.block_size, metrics.clone());
let result = run_microbenchmarks(&mut indexer, sequences, extra_sequences, &args).await;
results.push(result);
token.cancel();
tokio::time::sleep(Duration::from_millis(50)).await;
}
// Benchmark sharded indexer
if matches!(args.indexer_type, IndexerType::Sharded | IndexerType::Both) {
let token = CancellationToken::new();
let mut indexer = KvIndexerSharded::new(
token.clone(),
args.num_shards,
args.common.block_size,
metrics.clone(),
);
let result = run_microbenchmarks(&mut indexer, sequences, extra_sequences, &args).await;
results.push(result);
token.cancel();
tokio::time::sleep(Duration::from_millis(50)).await;
}
// Print results
if args.format == "csv" {
MicrobenchResults::print_csv_header();
for result in &results {
result.print_csv_row();
}
} else {
for result in &results {
result.print(args.common.depth);
}
if results.len() == 2 {
print_microbench_comparison(&results, args.common.depth);
}
}
println!("\nMicrobenchmark complete.");
}
// ============================================================================
// Stress Test Mode
// ============================================================================
/// Result of a single request during stress test
#[allow(dead_code)]
struct RequestResult {
request_id: u64,
submit_time: Instant,
complete_time: Instant,
success: bool,
}
/// Aggregated results from stress test
struct StressResults {
indexer_name: String,
submitted: u64,
completed: u64,
timed_out: u64,
latencies: Vec<Duration>,
max_in_flight: u64,
baseline_service_time: Duration,
construction_time: Duration,
construction_events: u64,
}
/// Run the stress test with a generic indexer
async fn run_stress_test<I: BenchableIndexer + 'static>(
indexer: Arc<I>,
sequences: &[SequenceData],
args: &StressArgs,
) -> StressResults {
let indexer_name = indexer.name().to_string();
// Phase 2: Baseline Measurement
println!("\nPhase 2: Baseline Measurement");
println!(" Running 10 sequential find_matches calls...");
let mut baseline_durations = Vec::new();
for seq in sequences.iter().take(10) {
let start = Instant::now();
let _ = indexer.find_matches(seq.local_hashes.clone()).await;
baseline_durations.push(start.elapsed());
}
let stats = LatencyStats::from_durations(baseline_durations.clone()).unwrap();
let baseline_service_time = stats.p50;
let theoretical_max = stats.throughput_ops_sec;
println!(
" Baseline find_matches latency: {:?} (p50 of 10)",
baseline_service_time
);
println!(
" Theoretical max throughput: {:.1} req/sec",
theoretical_max
);
// Phase 3: Pre-generate Lookup Sequences
println!("\nPhase 3: Pre-generating Lookup Sequences");
let expected_requests = (args.arrival_rate * args.duration as f64).ceil() as usize + 100;
let lookup_sequences: Vec<Vec<LocalBlockHash>> = (0..expected_requests)
.map(|i| {
let seq = &sequences[i % sequences.len()];
seq.local_hashes.clone()
})
.collect();
println!(
" Pre-generated {} lookup sequences",
lookup_sequences.len()
);
// Phase 4: Stress Test
println!("\nPhase 4: Stress Test");
println!(" Arrival rate: {:.1} req/sec", args.arrival_rate);
println!(" Duration: {}s", args.duration);
let in_flight = Arc::new(AtomicU64::new(0));
let max_in_flight = Arc::new(AtomicU64::new(0));
let (result_tx, mut result_rx) = mpsc::channel::<RequestResult>(expected_requests);
let start = Instant::now();
let mut request_id = 0u64;
let interval = Duration::from_secs_f64(1.0 / args.arrival_rate);
while start.elapsed() < Duration::from_secs(args.duration) {
let submit_time = Instant::now();
let seq = lookup_sequences[request_id as usize].clone();
// Track in-flight
let current = in_flight.fetch_add(1, Ordering::Relaxed) + 1;
max_in_flight.fetch_max(current, Ordering::Relaxed);
let indexer = Arc::clone(&indexer);
let result_tx = result_tx.clone();
let in_flight_clone = in_flight.clone();
let req_id = request_id;
let verbose = args.common.verbose;
tokio::spawn(async move {
let result = indexer.find_matches(seq).await;
let complete_time = Instant::now();
in_flight_clone.fetch_sub(1, Ordering::Relaxed);
if verbose {
let latency = complete_time.duration_since(submit_time);
println!(" Request {} completed in {:?}", req_id, latency);
}
let _ = result_tx
.send(RequestResult {
request_id: req_id,
submit_time,
complete_time,
success: result.is_ok(),
})
.await;
});
request_id += 1;
tokio::time::sleep(interval).await;
}
let submitted = request_id;
println!(" Submitted {} requests", submitted);
// Wait for in-flight requests with timeout
println!("\nPhase 5: Draining In-flight Requests");
let drain_start = Instant::now();
let mut last_in_flight = in_flight.load(Ordering::Relaxed);
println!(
" Waiting for {} in-flight requests (timeout: {}s)...",
last_in_flight, args.in_flight_timeout
);
while in_flight.load(Ordering::Relaxed) > 0
&& drain_start.elapsed() < Duration::from_secs(args.in_flight_timeout)
{
tokio::time::sleep(Duration::from_millis(100)).await;
let current = in_flight.load(Ordering::Relaxed);
if current != last_in_flight && args.common.verbose {
println!(" In-flight: {}", current);
last_in_flight = current;
}
}
let timed_out = in_flight.load(Ordering::Relaxed);
if timed_out > 0 {
println!(" {} requests timed out", timed_out);
} else {
println!(" All requests completed");
}
// Collect results
drop(result_tx);
if timed_out > 0 {
result_rx.close();
}
let mut results = Vec::new();
while let Some(r) = result_rx.recv().await {
results.push(r);
}
// Compute latencies
let latencies: Vec<Duration> = results
.iter()
.map(|r| r.complete_time.duration_since(r.submit_time))
.collect();
StressResults {
indexer_name,
submitted,
completed: results.len() as u64,
timed_out,
latencies,
max_in_flight: max_in_flight.load(Ordering::Relaxed),
baseline_service_time,
construction_time: Duration::ZERO, // Set by caller
construction_events: 0, // Set by caller
}
}
/// Print the final stress test results report
fn print_stress_results(args: &StressArgs, results: &StressResults) {
let num_sequences = args.common.size / args.common.depth;
println!("\n=====================");
println!("Queue Saturation Test Results: {}", results.indexer_name);
println!("=====================\n");
println!("Configuration:");
println!(
" Tree size: {} blocks ({} sequences x {} depth)",
args.common.size, num_sequences, args.common.depth
);
println!(" Workers: {}", args.common.num_workers);
println!(
" Prefix share ratio: {:.1}%",
args.prefix_share_ratio * 100.0
);
println!(" Arrival rate: {:.1} req/sec", args.arrival_rate);
println!(" Duration: {}s", args.duration);
println!();
println!("Tree Construction:");
println!(" Time: {:.2?}", results.construction_time);
println!(" Events: {}", results.construction_events);
let throughput = results.construction_events as f64 / results.construction_time.as_secs_f64();
println!(" Throughput: {:.0} events/sec", throughput);
println!();
println!("Baseline:");
println!(
" find_matches latency: {:?} (median of 10)",
results.baseline_service_time
);
let theoretical_max = 1.0 / results.baseline_service_time.as_secs_f64();
println!(
" Theoretical max throughput: {:.1} req/sec",
theoretical_max
);
println!();
println!("Saturation Test Results:");
println!(" Submitted: {} requests", results.submitted);
println!(" Completed: {} requests", results.completed);
println!(
" Timed out: {} requests (in-flight at end)",
results.timed_out
);
println!();
if !results.latencies.is_empty() {
let test_duration = args.duration as f64 + args.in_flight_timeout as f64;
let achieved_throughput = results.completed as f64 / test_duration;
println!(" Throughput:");
println!(" Requested: {:.1} req/sec", args.arrival_rate);
println!(" Achieved: {:.1} req/sec", achieved_throughput);
println!();
if let Some(stats) = LatencyStats::from_durations(results.latencies.clone()) {
println!(" Latency (end-to-end, includes queue wait):");
println!(" min: {:>12?}", stats.min);
println!(" p50: {:>12?}", stats.p50);
println!(" p95: {:>12?}", stats.p95);
println!(" p99: {:>12?}", stats.p99);
println!(" max: {:>12?}", stats.max);
println!();
let estimated_queue_wait = if stats.p50 > results.baseline_service_time {
stats.p50 - results.baseline_service_time
} else {
Duration::ZERO
};
println!(" Queue Analysis:");
println!(
" Baseline service time: {:?}",
results.baseline_service_time
);
println!(" Estimated queue wait (p50): {:?}", estimated_queue_wait);
println!(" Max in-flight observed: {}", results.max_in_flight);
println!();
// Determine saturation status
let is_saturated = achieved_throughput < args.arrival_rate * 0.95
|| results.timed_out > 0
|| stats.p50 > results.baseline_service_time * 2;
if is_saturated {
println!(" STATUS: SATURATED");
if achieved_throughput < args.arrival_rate * 0.95 {
println!(
" - Throughput ({:.1}) < Arrival rate ({:.1})",
achieved_throughput, args.arrival_rate
);
}
if results.timed_out > 0 {
println!(" - Requests timed out: {}", results.timed_out);
}
if stats.p50 > results.baseline_service_time * 2 {
println!(
" - P50 latency ({:?}) > 2x baseline ({:?})",
stats.p50, results.baseline_service_time
);
}
} else {
println!(" STATUS: NOT SATURATED");
println!(" - Throughput matches arrival rate");
println!(" - No requests timed out");
println!(" - Latency within acceptable bounds");
}
}
}
}
fn print_stress_comparison(results: &[StressResults], args: &StressArgs) {
if results.len() < 2 {
return;
}
println!("\n========================================");
println!("STRESS TEST COMPARISON SUMMARY");
println!("========================================\n");
let single = &results[0];
let sharded = &results[1];
println!(
"{:<35} {:>18} {:>18} {:>10}",
"Metric", "Single", "Sharded", "Ratio"
);
println!("{}", "-".repeat(85));
// Construction time
let single_constr = single.construction_time.as_secs_f64() * 1000.0;
let sharded_constr = sharded.construction_time.as_secs_f64() * 1000.0;
println!(
"{:<35} {:>15.2}ms {:>15.2}ms {:>9.2}x",
"Construction time",
single_constr,
sharded_constr,
single_constr / sharded_constr
);
// Baseline service time
let single_baseline = single.baseline_service_time.as_nanos() as f64 / 1000.0;
let sharded_baseline = sharded.baseline_service_time.as_nanos() as f64 / 1000.0;
println!(
"{:<35} {:>15.2}us {:>15.2}us {:>9.2}x",
"Baseline service time",
single_baseline,
sharded_baseline,
single_baseline / sharded_baseline
);
// Completed requests
println!(
"{:<35} {:>18} {:>18} {:>9.2}x",
"Completed requests",
single.completed,
sharded.completed,
sharded.completed as f64 / single.completed as f64
);
// Max in-flight
println!(
"{:<35} {:>18} {:>18}",
"Max in-flight", single.max_in_flight, sharded.max_in_flight
);
// Timed out
println!(
"{:<35} {:>18} {:>18}",
"Timed out", single.timed_out, sharded.timed_out
);
// Latency comparison
if let (Some(s1), Some(s2)) = (
LatencyStats::from_durations(single.latencies.clone()),
LatencyStats::from_durations(sharded.latencies.clone()),
) {
let s1_p50 = s1.p50.as_nanos() as f64 / 1000.0;
let s2_p50 = s2.p50.as_nanos() as f64 / 1000.0;
println!(
"{:<35} {:>15.2}us {:>15.2}us {:>9.2}x",
"Latency p50",
s1_p50,
s2_p50,
s1_p50 / s2_p50
);
let s1_p99 = s1.p99.as_nanos() as f64 / 1000.0;
let s2_p99 = s2.p99.as_nanos() as f64 / 1000.0;
println!(
"{:<35} {:>15.2}us {:>15.2}us {:>9.2}x",
"Latency p99",
s1_p99,
s2_p99,
s1_p99 / s2_p99
);
let test_duration = args.duration as f64 + args.in_flight_timeout as f64;
let s1_throughput = single.completed as f64 / test_duration;
let s2_throughput = sharded.completed as f64 / test_duration;
println!(
"{:<35} {:>14.1}/s {:>14.1}/s {:>9.2}x",
"Achieved throughput",
s1_throughput,
s2_throughput,
s2_throughput / s1_throughput
);
}
println!("\nNote: Ratio > 1.0 means sharded is better for that metric.");
}
async fn run_stress_mode(args: StressArgs) {
// Validate inputs before proceeding
if args.common.depth == 0 {
eprintln!("Error: depth must be > 0");
std::process::exit(1);
}
if args.common.num_workers == 0 {
eprintln!("Error: num_workers must be > 0");
std::process::exit(1);
}
if args.common.size < args.common.depth {
eprintln!(
"Error: size ({}) must be >= depth ({})",
args.common.size, args.common.depth
);
std::process::exit(1);
}
if !(0.0..=1.0).contains(&args.prefix_share_ratio) {
eprintln!(
"Error: prefix_share_ratio ({}) must be in range 0.0..=1.0",
args.prefix_share_ratio
);
std::process::exit(1);
}
if args.arrival_rate <= 0.0 {
eprintln!("Error: arrival_rate must be > 0.0");
std::process::exit(1);
}
if matches!(args.indexer_type, IndexerType::Sharded | IndexerType::Both) && args.num_shards == 0
{
eprintln!("Error: num_shards must be > 0 when using Sharded or Both indexer type");
std::process::exit(1);
}
let num_sequences = args.common.size / args.common.depth;
println!("Queue Saturation Stress Test");
println!("============================\n");
println!("Configuration:");
println!(
" Tree size: {} blocks ({} sequences x {} depth)",
args.common.size, num_sequences, args.common.depth
);
println!(" Workers: {}", args.common.num_workers);
println!(" Block size: {} tokens", args.common.block_size);
println!(
" Prefix share ratio: {:.1}%",
args.prefix_share_ratio * 100.0
);
println!(" Seed: {}", args.common.seed);
println!(" Arrival rate: {:.1} req/sec", args.arrival_rate);
println!(" Duration: {}s", args.duration);
println!(" In-flight timeout: {}s", args.in_flight_timeout);
println!(" Indexer type: {:?}", args.indexer_type);
if matches!(args.indexer_type, IndexerType::Sharded | IndexerType::Both) {
println!(" Num shards: {}", args.num_shards);
}
// Generate sequences
println!("\nPhase 1: Tree Construction");
println!(" Generating {} sequences...", num_sequences);
// Use prefix_share_ratio as prefix_ratio and 1 group for stress test
let sequences = generate_sequences(
num_sequences,
args.common.depth,
args.common.num_workers,
args.prefix_share_ratio,
1, // Single prefix group for stress test
args.common.seed,
false, // use_cumulative_hash
);
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let mut all_results = Vec::new();
// Test single indexer
if matches!(args.indexer_type, IndexerType::Single | IndexerType::Both) {
let token = CancellationToken::new();
let mut indexer = KvIndexer::new(token.clone(), args.common.block_size, metrics.clone());
println!(
"\n Applying {} store events to KvIndexer...",
sequences.len()
);
let construction_start = Instant::now();
for (event_id, seq) in sequences.iter().enumerate() {
let event = seq.to_store_event(event_id as u64);
KvIndexerInterface::apply_event(&mut indexer, event).await;
if args.common.verbose && (event_id + 1) % 100 == 0 {
println!(" Applied {}/{} events...", event_id + 1, sequences.len());
}
}
let construction_time = construction_start.elapsed();
let construction_events = sequences.len() as u64;
println!(" Tree construction completed in {:?}", construction_time);
println!(
" Throughput: {:.0} events/sec",
construction_events as f64 / construction_time.as_secs_f64()
);
tokio::time::sleep(Duration::from_millis(100)).await;
let mut results = run_stress_test(Arc::new(indexer), &sequences, &args).await;
results.construction_time = construction_time;
results.construction_events = construction_events;
print_stress_results(&args, &results);
all_results.push(results);
token.cancel();
tokio::time::sleep(Duration::from_millis(50)).await;
}
// Test sharded indexer
if matches!(args.indexer_type, IndexerType::Sharded | IndexerType::Both) {
let token = CancellationToken::new();
let mut indexer = KvIndexerSharded::new(
token.clone(),
args.num_shards,
args.common.block_size,
metrics.clone(),
);
println!(
"\n Applying {} store events to KvIndexerSharded...",
sequences.len()
);
let construction_start = Instant::now();
for (event_id, seq) in sequences.iter().enumerate() {
let event = seq.to_store_event(event_id as u64);
KvIndexerInterface::apply_event(&mut indexer, event).await;
if args.common.verbose && (event_id + 1) % 100 == 0 {
println!(" Applied {}/{} events...", event_id + 1, sequences.len());
}
}
let construction_time = construction_start.elapsed();
let construction_events = sequences.len() as u64;
println!(" Tree construction completed in {:?}", construction_time);
println!(
" Throughput: {:.0} events/sec",
construction_events as f64 / construction_time.as_secs_f64()
);
tokio::time::sleep(Duration::from_millis(100)).await;
let mut results = run_stress_test(Arc::new(indexer), &sequences, &args).await;
results.construction_time = construction_time;
results.construction_events = construction_events;
print_stress_results(&args, &results);
all_results.push(results);
token.cancel();
tokio::time::sleep(Duration::from_millis(50)).await;
}
// Print comparison if both were run
if all_results.len() == 2 {
print_stress_comparison(&all_results, &args);
}
println!("\nStress test complete.");
}
// ============================================================================
// Main Entry Point
// ============================================================================
#[tokio::main]
async fn main() {
let cli = Cli::parse();
match cli.command {
Command::Microbench(args) => run_microbench_mode(args).await,
Command::Stress(args) => run_stress_mode(args).await,
}
}
......@@ -15,16 +15,12 @@
use clap::{Parser, ValueEnum};
use dynamo_kv_router::{
OverlapScores, RadixTree, RouterEvent, compute_block_hash_for_seq,
RadixTree, RouterEvent,
bench_utils::{LatencyStats, SequenceData, generate_sequences},
compute_block_hash_for_seq,
flat_hashmap::FlatHashMap,
protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData,
KvCacheStoreData, KvCacheStoredBlockData, LocalBlockHash, WorkerId,
compute_seq_hash_for_block,
},
protocols::LocalBlockHash,
};
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::time::{Duration, Instant};
/// Unified interface for RadixTree and FlatHashMap benchmarking.
......@@ -206,114 +202,6 @@ struct Args {
flat_hashmap: bool,
}
/// Pre-generated sequence data for benchmarking
#[derive(Clone)]
struct SequenceData {
worker_id: WorkerId,
local_hashes: Vec<LocalBlockHash>,
external_hashes: Vec<ExternalSequenceBlockHash>,
}
impl SequenceData {
/// Create a new SequenceData from local_hashes.
/// Automatically computes external_hashes using compute_seq_hash_for_block (cumulative hashes).
/// This ensures FlatHashMap can correctly identify block positions.
fn from_local_hashes(worker_id: WorkerId, local_hashes: Vec<LocalBlockHash>) -> Self {
let seq_hashes = compute_seq_hash_for_block(&local_hashes);
let external_hashes = seq_hashes
.into_iter()
.map(ExternalSequenceBlockHash)
.collect();
Self {
worker_id,
local_hashes,
external_hashes,
}
}
fn to_store_event(&self, event_id: u64) -> RouterEvent {
RouterEvent {
worker_id: self.worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: self
.local_hashes
.iter()
.zip(self.external_hashes.iter())
.map(|(local, ext)| KvCacheStoredBlockData {
tokens_hash: *local,
block_hash: *ext,
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
}
}
fn to_remove_event(&self, event_id: u64) -> RouterEvent {
RouterEvent {
worker_id: self.worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: self.external_hashes.clone(),
}),
dp_rank: 0,
},
}
}
}
/// Generate sequences with shared prefix prompts
fn generate_sequences(
num_sequences: usize,
depth: usize,
num_workers: usize,
prefix_prompt_ratio: f64,
num_prefix_prompts: usize,
seed: u64,
) -> Vec<SequenceData> {
let mut sequences = Vec::with_capacity(num_sequences);
let prefix_length = (depth as f64 * prefix_prompt_ratio).round() as usize;
let mut rng: StdRng = StdRng::seed_from_u64(seed);
for seq_id in 0..num_sequences {
let seq_id_u64 = seq_id as u64;
let worker_id = (seq_id % num_workers) as WorkerId;
// Determine prefix group for this sequence
let group_id = if num_prefix_prompts > 0 && prefix_length > 0 {
Some(rng.random_range(0..num_prefix_prompts) as u64)
} else {
None
};
// Build local_hashes: shared prefix (if applicable) + unique suffix
let local_hashes: Vec<LocalBlockHash> = (0..depth)
.map(|block_idx| {
let block_idx_u64 = block_idx as u64;
if let Some(gid) = group_id {
if block_idx < prefix_length {
// Shared prefix based on group_id
return LocalBlockHash(0xDEAD_BEEF_0000_0000 | (gid << 32) | block_idx_u64);
}
}
// Unique suffix (or no shared prefix)
LocalBlockHash((seq_id_u64 << 32) | block_idx_u64)
})
.collect();
sequences.push(SequenceData::from_local_hashes(worker_id, local_hashes));
}
sequences
}
/// Build a pre-populated RadixTree (for sweep/dump benchmarks that specifically need RadixTree)
fn build_tree(sequences: &[SequenceData]) -> RadixTree {
let num_blocks: usize = sequences.iter().map(|s| s.local_hashes.len()).sum();
......@@ -381,52 +269,6 @@ fn build_index(sequences: &[SequenceData], use_flat_hashmap: bool) -> KvIndex {
index
}
/// Statistics for a set of timing measurements
#[derive(Debug)]
struct LatencyStats {
min: Duration,
max: Duration,
avg: Duration,
p50: Duration,
p95: Duration,
p99: Duration,
throughput_ops_sec: f64,
}
impl LatencyStats {
fn from_durations(mut durations: Vec<Duration>) -> Self {
durations.sort();
let n = durations.len();
let total: Duration = durations.iter().sum();
let avg = total / n as u32;
Self {
min: durations[0],
max: durations[n - 1],
avg,
p50: durations[n / 2],
p95: durations[n * 95 / 100],
p99: durations[n * 99 / 100],
throughput_ops_sec: n as f64 / total.as_secs_f64(),
}
}
fn print(&self, operation: &str, blocks_per_op: usize) {
println!("\n{} Latency Statistics:", operation);
println!(" min: {:>12?}", self.min);
println!(" avg: {:>12?}", self.avg);
println!(" p50: {:>12?}", self.p50);
println!(" p95: {:>12?}", self.p95);
println!(" p99: {:>12?}", self.p99);
println!(" max: {:>12?}", self.max);
println!(" throughput: {:.2} ops/sec", self.throughput_ops_sec);
println!(
" throughput: {:.2} blocks/sec",
self.throughput_ops_sec * blocks_per_op as f64
);
}
}
/// Benchmark compute_block_hash_for_seq operation
fn bench_hash(args: &Args) {
println!("\n=== Benchmarking COMPUTE_BLOCK_HASH (per-request hot path) ===");
......@@ -464,7 +306,7 @@ fn bench_hash(args: &Args) {
}
}
let stats = LatencyStats::from_durations(durations);
let stats = LatencyStats::from_durations(durations).unwrap();
stats.print("COMPUTE_BLOCK_HASH", args.depth);
}
......@@ -487,6 +329,7 @@ fn bench_store_remove_cycle(args: &Args, time_store: bool) {
args.prefix_prompt_ratio,
args.num_prefix_prompts,
args.seed,
true, // use_cumulative_hash
);
let mut index = build_index(&sequences, args.flat_hashmap);
......@@ -524,7 +367,7 @@ fn bench_store_remove_cycle(args: &Args, time_store: bool) {
}
}
let stats = LatencyStats::from_durations(durations);
let stats = LatencyStats::from_durations(durations).unwrap();
stats.print(op_name, args.depth);
}
......@@ -548,6 +391,7 @@ fn bench_find_matches(args: &Args) {
args.prefix_prompt_ratio,
args.num_prefix_prompts,
args.seed,
true, // use_cumulative_hash
);
let index = build_index(&sequences, args.flat_hashmap);
......@@ -575,7 +419,9 @@ fn bench_find_matches(args: &Args) {
println!(" Completed {}/{} iterations", i + 1, args.iterations);
}
}
LatencyStats::from_durations(hit_durations).print("FIND_MATCHES (HIT)", args.depth);
LatencyStats::from_durations(hit_durations)
.unwrap()
.print("FIND_MATCHES (HIT)", args.depth);
// MISS case
println!("\n --- MISS case (non-existing sequences) ---");
......@@ -589,7 +435,9 @@ fn bench_find_matches(args: &Args) {
println!(" Completed {}/{} iterations", i + 1, args.iterations);
}
}
LatencyStats::from_durations(miss_durations).print("FIND_MATCHES (MISS)", args.depth);
LatencyStats::from_durations(miss_durations)
.unwrap()
.print("FIND_MATCHES (MISS)", args.depth);
// PARTIAL case
println!("\n --- PARTIAL case (prefix match only) ---");
......@@ -604,7 +452,9 @@ fn bench_find_matches(args: &Args) {
println!(" Completed {}/{} iterations", i + 1, args.iterations);
}
}
LatencyStats::from_durations(partial_durations).print("FIND_MATCHES (PARTIAL)", args.depth);
LatencyStats::from_durations(partial_durations)
.unwrap()
.print("FIND_MATCHES (PARTIAL)", args.depth);
// EARLY_EXIT case
println!("\n --- EARLY_EXIT case ---");
......@@ -617,6 +467,7 @@ fn bench_find_matches(args: &Args) {
}
}
LatencyStats::from_durations(early_exit_durations)
.unwrap()
.print("FIND_MATCHES (EARLY_EXIT)", args.depth);
}
......@@ -845,6 +696,7 @@ fn bench_sweep(args: &Args) {
args.prefix_prompt_ratio,
num_prefix_prompts,
args.seed,
true, // use_cumulative_hash
);
let tree_sequences = &all_sequences[..num_sequences];
let extra_sequences = &all_sequences[num_sequences..];
......@@ -956,6 +808,7 @@ fn bench_dump(args: &Args) {
args.prefix_prompt_ratio,
args.num_prefix_prompts,
args.seed,
true, // use_cumulative_hash
);
let tree = build_tree(&sequences);
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Benchmark utilities for kv-router benchmarks.
//!
//! This module provides shared data structures for benchmarking:
//! - `LatencyStats`: Statistics for latency measurements
//! - `SequenceData`: Pre-generated sequence data for benchmarking
use crate::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash, RouterEvent, WorkerId, compute_seq_hash_for_block,
};
use rand::{Rng, SeedableRng, rngs::StdRng};
use std::time::Duration;
/// Statistics for latency measurements.
#[derive(Debug, Clone)]
pub struct LatencyStats {
pub min: Duration,
pub max: Duration,
pub avg: Duration,
pub p50: Duration,
pub p95: Duration,
pub p99: Duration,
pub throughput_ops_sec: f64,
}
impl LatencyStats {
/// Compute statistics from a vector of durations.
///
/// Returns `None` if the input is empty.
pub fn from_durations(mut durations: Vec<Duration>) -> Option<Self> {
if durations.is_empty() {
return None;
}
durations.sort();
let n = durations.len();
let total: Duration = durations.iter().sum();
let avg = total / n as u32;
Some(Self {
min: durations[0],
max: durations[n - 1],
avg,
p50: durations[n / 2],
p95: durations[n * 95 / 100],
p99: durations[n * 99 / 100],
throughput_ops_sec: n as f64 / total.as_secs_f64(),
})
}
/// Print formatted latency statistics to stdout.
pub fn print(&self, operation: &str, blocks_per_op: usize) {
println!("\n{} Latency Statistics:", operation);
println!(" min: {:>12?}", self.min);
println!(" avg: {:>12?}", self.avg);
println!(" p50: {:>12?}", self.p50);
println!(" p95: {:>12?}", self.p95);
println!(" p99: {:>12?}", self.p99);
println!(" max: {:>12?}", self.max);
println!(" throughput: {:.2} ops/sec", self.throughput_ops_sec);
println!(
" throughput: {:.2} blocks/sec",
self.throughput_ops_sec * blocks_per_op as f64
);
}
}
/// Pre-generated sequence data for benchmarking.
#[derive(Clone)]
pub struct SequenceData {
pub worker_id: WorkerId,
pub local_hashes: Vec<LocalBlockHash>,
pub external_hashes: Vec<ExternalSequenceBlockHash>,
}
impl SequenceData {
/// Create a new sequence with synthetic hashes based on sequence ID.
pub fn new(seq_id: u64, worker_id: WorkerId, depth: usize) -> Self {
let local_hashes: Vec<LocalBlockHash> = (0..depth)
.map(|block_idx| LocalBlockHash((seq_id << 32) | (block_idx as u64)))
.collect();
let external_hashes: Vec<ExternalSequenceBlockHash> = (0..depth)
.map(|block_idx| ExternalSequenceBlockHash((seq_id << 32) | (block_idx as u64)))
.collect();
Self {
worker_id,
local_hashes,
external_hashes,
}
}
/// Create a sequence from local hashes, computing external hashes using cumulative hash.
///
/// This ensures FlatHashMap can correctly identify block positions.
pub fn from_local_hashes(worker_id: WorkerId, local_hashes: Vec<LocalBlockHash>) -> Self {
let seq_hashes = compute_seq_hash_for_block(&local_hashes);
let external_hashes = seq_hashes
.into_iter()
.map(ExternalSequenceBlockHash)
.collect();
Self {
worker_id,
local_hashes,
external_hashes,
}
}
/// Convert to a store event.
pub fn to_store_event(&self, event_id: u64) -> RouterEvent {
RouterEvent {
worker_id: self.worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: self
.local_hashes
.iter()
.zip(self.external_hashes.iter())
.map(|(local, ext)| KvCacheStoredBlockData {
tokens_hash: *local,
block_hash: *ext,
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
}
}
/// Convert to a remove event.
pub fn to_remove_event(&self, event_id: u64) -> RouterEvent {
RouterEvent {
worker_id: self.worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: self.external_hashes.clone(),
}),
dp_rank: 0,
},
}
}
}
/// Generate sequences with shared prefix prompts.
///
/// # Arguments
/// * `num_sequences` - Number of sequences to generate
/// * `depth` - Number of blocks per sequence
/// * `num_workers` - Number of workers to distribute sequences across
/// * `prefix_ratio` - Ratio of blocks that share a prefix (0.0 to 1.0)
/// * `num_prefix_groups` - Number of distinct prefix groups
/// * `seed` - Random seed for reproducibility
/// * `use_cumulative_hash` - If true, use `from_local_hashes` for proper cumulative hashes
pub fn generate_sequences(
num_sequences: usize,
depth: usize,
num_workers: usize,
prefix_ratio: f64,
num_prefix_groups: usize,
seed: u64,
use_cumulative_hash: bool,
) -> Vec<SequenceData> {
let mut sequences = Vec::with_capacity(num_sequences);
let prefix_length = (depth as f64 * prefix_ratio).round() as usize;
let mut rng: StdRng = StdRng::seed_from_u64(seed);
for seq_id in 0..num_sequences {
let seq_id_u64 = seq_id as u64;
let worker_id = (seq_id % num_workers) as WorkerId;
// Determine prefix group for this sequence
let group_id = if num_prefix_groups > 0 && prefix_length > 0 {
Some(rng.random_range(0..num_prefix_groups) as u64)
} else {
None
};
// Build local_hashes: shared prefix (if applicable) + unique suffix
let local_hashes: Vec<LocalBlockHash> = (0..depth)
.map(|block_idx| {
let block_idx_u64 = block_idx as u64;
if let Some(gid) = group_id {
if block_idx < prefix_length {
// Shared prefix based on group_id
return LocalBlockHash(0xDEAD_BEEF_0000_0000 | (gid << 32) | block_idx_u64);
}
}
// Unique suffix (or no shared prefix)
LocalBlockHash((seq_id_u64 << 32) | block_idx_u64)
})
.collect();
if use_cumulative_hash {
sequences.push(SequenceData::from_local_hashes(worker_id, local_hashes));
} else {
let external_hashes: Vec<ExternalSequenceBlockHash> = (0..depth)
.map(|block_idx| {
let block_idx_u64 = block_idx as u64;
if let Some(gid) = group_id {
if block_idx < prefix_length {
return ExternalSequenceBlockHash(
0xDEAD_BEEF_0000_0000 | (gid << 32) | block_idx_u64,
);
}
}
ExternalSequenceBlockHash((seq_id_u64 << 32) | block_idx_u64)
})
.collect();
sequences.push(SequenceData {
worker_id,
local_hashes,
external_hashes,
});
}
}
sequences
}
/// Compute median of durations.
pub fn median(durations: &[Duration]) -> Duration {
if durations.is_empty() {
return Duration::ZERO;
}
let mut sorted = durations.to_vec();
sorted.sort();
sorted[sorted.len() / 2]
}
......@@ -31,6 +31,9 @@
//!
//! This module provides a scalable and efficient way to manage and retrieve data blocks for LLM inference, leveraging a global KV cache to optimize performance.
#[cfg(feature = "bench")]
use std::time::Instant;
use async_trait::async_trait;
#[cfg(feature = "metrics")]
pub use dynamo_runtime::protocols::maybe_error::MaybeError;
......@@ -335,6 +338,25 @@ pub struct MatchRequest {
early_exit: bool,
/// A channel sender to send the `OverlapScores` response.
resp: oneshot::Sender<OverlapScores>,
/// Timestamp when the request was created (for queue wait time measurement)
#[cfg(feature = "bench")]
created_at: Instant,
}
impl MatchRequest {
fn new(
sequence: Vec<LocalBlockHash>,
early_exit: bool,
resp: oneshot::Sender<OverlapScores>,
) -> Self {
Self {
sequence,
early_exit,
resp,
#[cfg(feature = "bench")]
created_at: Instant::now(),
}
}
}
/// A request to dump the tree as events
......@@ -551,10 +573,16 @@ impl KvIndexer {
Some(event) = event_rx.recv() => {
let event_type = KvIndexerMetrics::get_event_type(&event.event.data);
let event_id = event.event.event_id;
let worker_id = event.worker_id;
// Only clone if we need the event for prune_manager afterward
let event_for_prune = prune_manager.is_some().then(|| event.clone());
let result = trie.apply_event(event);
let result_is_ok = result.is_ok();
let tree_size = trie.current_size();
tracing::trace!(
"Applied KV event to global radix tree: event_type={event_type}, event_id={event_id}, worker_id={worker_id}, success={result_is_ok}, global_radix_tree_size={tree_size}"
);
metrics.increment_event_applied(event_type, result);
// Track blocks in PruneManager if TTL is enabled and event was stored successfully
......@@ -643,7 +671,24 @@ impl KvIndexer {
}
Some(req) = match_rx.recv() => {
#[cfg(feature = "bench")]
let queue_wait = req.created_at.elapsed();
#[cfg(feature = "bench")]
let seq_len = req.sequence.len();
#[cfg(feature = "bench")]
let process_start = Instant::now();
let matches = trie.find_matches(req.sequence, req.early_exit);
#[cfg(feature = "bench")]
let process_time = process_start.elapsed();
#[cfg(feature = "bench")]
tracing::info!(
seq_len,
queue_wait_us = queue_wait.as_micros() as u64,
process_us = process_time.as_micros() as u64,
"indexer: processed find_matches"
);
let _ = req.resp.send(matches);
}
......@@ -742,12 +787,11 @@ impl KvIndexerInterface for KvIndexer {
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<OverlapScores, KvRouterError> {
#[cfg(feature = "bench")]
let start = Instant::now();
let seq_len = sequence.len();
let (resp_tx, resp_rx) = oneshot::channel();
let req = MatchRequest {
sequence,
early_exit: false,
resp: resp_tx,
};
let req = MatchRequest::new(sequence, false, resp_tx);
if let Err(e) = self.match_tx.send(req).await {
tracing::error!(
......@@ -757,9 +801,23 @@ impl KvIndexerInterface for KvIndexer {
return Err(KvRouterError::IndexerOffline);
}
resp_rx
let result = resp_rx
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)
.map_err(|_| KvRouterError::IndexerDroppedRequest);
#[cfg(feature = "bench")]
{
let elapsed = start.elapsed();
tracing::info!(
seq_len,
elapsed_us = elapsed.as_micros() as u64,
"find_matches completed"
);
}
#[cfg(not(feature = "bench"))]
let _ = seq_len;
result
}
async fn find_matches_for_request(
......@@ -1131,6 +1189,24 @@ pub struct ShardedMatchRequest {
sequence: Vec<LocalBlockHash>,
early_exit: bool,
resp: mpsc::Sender<OverlapScores>,
#[cfg(feature = "bench")]
created_at: Instant,
}
impl ShardedMatchRequest {
fn new(
sequence: Vec<LocalBlockHash>,
early_exit: bool,
resp: mpsc::Sender<OverlapScores>,
) -> Self {
Self {
sequence,
early_exit,
resp,
#[cfg(feature = "bench")]
created_at: Instant::now(),
}
}
}
/// A sharded KV Indexer that partitions the RadixTree across multiple independent shards.
......@@ -1374,7 +1450,24 @@ impl KvIndexerSharded {
}
Ok(req) = shard_broadcast_rx.recv() => {
#[cfg(feature = "bench")]
let queue_wait = req.created_at.elapsed();
#[cfg(feature = "bench")]
let seq_len = req.sequence.len();
#[cfg(feature = "bench")]
let process_start = Instant::now();
let matches = trie.find_matches(req.sequence, req.early_exit);
#[cfg(feature = "bench")]
let process_time = process_start.elapsed();
#[cfg(feature = "bench")]
tracing::info!(
seq_len,
queue_wait_us = queue_wait.as_micros() as u64,
process_us = process_time.as_micros() as u64,
"sharded indexer: processed find_matches"
);
if let Err(e) = req.resp.send(matches).await {
tracing::trace!("Failed to send match response: {:?}", e);
}
......@@ -1442,14 +1535,18 @@ impl KvIndexerInterface for KvIndexerSharded {
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<OverlapScores, KvRouterError> {
#[cfg(feature = "bench")]
let start = Instant::now();
#[cfg(feature = "bench")]
let seq_len = sequence.len();
#[cfg(feature = "bench")]
let num_shards = self.event_tx.len();
'match_loop: loop {
let (match_tx, mut match_rx) = mpsc::channel(self.event_tx.len());
let sharded_req = ShardedMatchRequest::new(sequence.clone(), false, match_tx);
self.request_broadcast_tx
.send(ShardedMatchRequest {
sequence: sequence.clone(),
early_exit: false,
resp: match_tx,
})
.send(sharded_req)
.map_err(|_| KvRouterError::IndexerOffline)?;
let mut scores = OverlapScores::new();
......@@ -1482,6 +1579,17 @@ impl KvIndexerInterface for KvIndexerSharded {
}
}
}
#[cfg(feature = "bench")]
{
let elapsed = start.elapsed();
tracing::info!(
seq_len,
num_shards,
elapsed_us = elapsed.as_micros() as u64,
"find_matches (sharded) completed"
);
}
return Ok(scores);
}
}
......
......@@ -7,6 +7,8 @@
//! efficient KV cache lookup and routing in distributed LLM inference systems.
pub mod approx;
#[cfg(feature = "bench")]
pub mod bench_utils;
pub mod flat_hashmap;
pub mod indexer;
pub mod protocols;
......
......@@ -27,7 +27,8 @@ cuda = ["dep:cudarc"]
integration = ["dynamo-runtime/integration"]
media-nixl = ["dep:nixl-sys", "dep:flate2"]
media-ffmpeg = ["dep:video-rs", "dep:ffmpeg-next", "dep:memfile", "media-nixl"]
kv-router-stress = ["dep:clap", "dep:indicatif"]
bench = ["dynamo-kv-router/bench"]
kv-router-stress = ["dep:clap", "dep:indicatif", "bench"]
[[bench]]
name = "tokenizer"
......@@ -38,6 +39,11 @@ name = "transfer_context_v2"
harness = false
required-features = ["block-manager", "testing-cuda"]
[[bench]]
name = "kv_router_bench"
harness = false
required-features = ["kv-router-stress"]
[dependencies]
# repo
dynamo-runtime = { workspace = true }
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Full stress test for the KV Router.
//!
//! Stress tests the full `KvRouter` frontend without worker backends:
//! - Phase 1: Build tree by publishing KV events to NATS (with computed hashes matching tokenized requests)
//! - Phase 2: Send HTTP requests and measure routing decision latency
//!
//! The key feature is that tree construction uses the same hash computation as the frontend,
//! ensuring that HTTP requests will match the pre-populated tree entries.
//!
//! Run with: cargo bench --package dynamo-llm --bench kv_router_bench --features kv-router-stress -- --help
use anyhow::{Context, Result};
use bytes::Bytes;
use clap::Parser;
use dynamo_runtime::transports::event_plane::EventEnvelope;
use hf_hub;
use indicatif::{ProgressBar, ProgressStyle};
use minijinja::{Environment, context, value::Value};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use tokenizers::Tokenizer;
use tokio::sync::{Mutex, Semaphore};
use dynamo_llm::kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash, RouterEvent, WorkerId, compute_hash,
compute_seq_hash_for_block,
};
use dynamo_llm::model_card::ModelDeploymentCard;
use dynamo_llm::preprocessor::prompt::{
ChatTemplate, ContextMixins, OAIChatLikeRequest, PromptFormatter,
};
/// KV Router event subject suffix (appended to Component.subject())
/// Full subject format: namespace.{namespace}.component.{component}.kv-events
const KV_EVENT_SUBJECT: &str = "kv-events";
/// Unique publisher ID for this benchmark instance
static PUBLISHER_ID: std::sync::LazyLock<u64> = std::sync::LazyLock::new(|| {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0)
});
/// Sequence counter for envelope ordering
static ENVELOPE_SEQUENCE: AtomicU64 = AtomicU64::new(0);
/// Encode an event into Msgpack format with EventEnvelope wrapper.
/// This matches the format expected by the event plane subscriber.
fn encode_event_with_envelope<T: Serialize>(event: &T, topic: &str) -> Result<Vec<u8>> {
// Encode the payload with msgpack
let payload = rmp_serde::to_vec_named(event).context("Failed to encode event payload")?;
// Create the envelope
let envelope = EventEnvelope {
publisher_id: *PUBLISHER_ID,
sequence: ENVELOPE_SEQUENCE.fetch_add(1, Ordering::SeqCst),
published_at: SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0),
topic: topic.to_string(),
payload: Bytes::from(payload),
};
// Encode the envelope with msgpack
rmp_serde::to_vec_named(&envelope).context("Failed to encode envelope")
}
#[derive(Parser, Debug)]
#[command(name = "kv_router_bench")]
#[command(about = "Full stress test for the KV Router via NATS events and HTTP requests")]
struct Args {
// Tree construction parameters
/// Target tree size in total (worker, block) pairs
#[arg(long, default_value = "500000")]
tree_size: usize,
/// Sequence depth in blocks (blocks per sequence)
#[arg(long, default_value = "512")]
depth: usize,
/// Number of workers to distribute blocks across
#[arg(long, default_value = "4")]
num_workers: usize,
/// Portion of sequence that is shared prefix (0.0 to 1.0)
#[arg(long, default_value = "0.25")]
prefix_prompt_ratio: f64,
/// Number of unique prefix groups
#[arg(long, default_value = "20")]
num_prefix_prompts: usize,
/// Random seed for reproducibility
#[arg(long, default_value = "42")]
seed: u64,
// Stress test parameters
/// KV events per second during stress test (0 = no ongoing events)
#[arg(long, default_value = "0")]
event_rate: f64,
/// HTTP requests per second
#[arg(long, default_value = "100")]
request_rate: f64,
/// Test duration in seconds
#[arg(long, default_value = "30")]
duration: u64,
/// Warmup duration before measurement in seconds
#[arg(long, default_value = "5")]
warmup: u64,
/// Maximum concurrent HTTP requests
#[arg(long, default_value = "50")]
concurrency: usize,
// Infrastructure
/// NATS server URL
#[arg(long, default_value = "nats://localhost:4222")]
nats_url: String,
/// Frontend HTTP URL
#[arg(long, default_value = "http://localhost:8000")]
frontend_url: String,
/// NATS namespace (used to construct subject)
#[arg(long, default_value = "dynamo")]
namespace: String,
/// Component name (used to construct subject)
#[arg(long, default_value = "backend")]
component: String,
// Output
/// Write results to JSON file
#[arg(long)]
output: Option<String>,
/// Print per-request timings
#[arg(short, long)]
verbose: bool,
/// Skip tree construction via NATS (use when kv_stress_worker handles it)
#[arg(long)]
skip_tree_construction: bool,
/// Time bucket size in seconds for latency-over-time tracking (0 to disable)
#[arg(long, default_value = "5")]
bucket_size: u64,
/// Include raw latency samples in JSON output (for graphing)
#[arg(long)]
include_raw_samples: bool,
/// Model name to use in requests (should match the registered model).
/// If not specified, auto-detects from /v1/models when exactly one model is available.
#[arg(long)]
model: Option<String>,
/// KV block size in tokens (must match frontend configuration)
#[arg(long, default_value = "16")]
kv_block_size: u32,
/// Path to tokenizer (HuggingFace model ID or local path). Defaults to --model value.
#[arg(long)]
tokenizer_path: Option<String>,
/// Ignored - passed by cargo bench
#[arg(long, hide = true)]
bench: bool,
}
/// Compute LocalBlockHash (tokens_hash) from a slice of token IDs.
/// Uses the same algorithm as lib/llm/src/kv_router/protocols.rs::compute_block_hash_for_seq
fn compute_block_hashes(tokens: &[u32], kv_block_size: u32) -> Vec<LocalBlockHash> {
tokens
.chunks_exact(kv_block_size as usize)
.map(|chunk| {
let bytes: Vec<u8> = chunk.iter().flat_map(|&num| num.to_le_bytes()).collect();
LocalBlockHash(compute_hash(&bytes))
})
.collect()
}
/// Compute ExternalSequenceBlockHash (block_hash) from LocalBlockHash values.
/// Uses the router's compute_seq_hash_for_block to ensure identical computation.
fn compute_sequence_hashes(block_hashes: &[LocalBlockHash]) -> Vec<ExternalSequenceBlockHash> {
// Use the router's sequence hash computation to ensure consistency
let seq_hashes = compute_seq_hash_for_block(block_hashes);
seq_hashes
.into_iter()
.map(ExternalSequenceBlockHash::from)
.collect()
}
fn compute_hashes_for_content(
content: &str,
tokenizer: &Tokenizer,
kv_block_size: u32,
prompt_renderer: Option<&PromptRenderer>,
) -> Result<(Vec<LocalBlockHash>, Vec<ExternalSequenceBlockHash>)> {
let formatted_text = if let Some(renderer) = prompt_renderer {
renderer.render_user_message(content)?
} else {
content.to_string()
};
let encoding = tokenizer
.encode(formatted_text.as_str(), false)
.map_err(|e| anyhow::anyhow!("Failed to tokenize request content: {}", e))?;
let token_ids: Vec<u32> = encoding.get_ids().to_vec();
let local_hashes = compute_block_hashes(&token_ids, kv_block_size);
let external_hashes = compute_sequence_hashes(&local_hashes);
Ok((local_hashes, external_hashes))
}
/// Tokenizer config from tokenizer_config.json
#[derive(Debug, Deserialize)]
struct TokenizerConfig {
chat_template: Option<String>,
bos_token: Option<serde_json::Value>,
eos_token: Option<serde_json::Value>,
}
impl TokenizerConfig {
/// Extract bos_token as a string (handles both string and object formats)
fn bos_token_str(&self) -> Option<String> {
self.bos_token.as_ref().and_then(|v| {
if let Some(s) = v.as_str() {
Some(s.to_string())
} else if let Some(obj) = v.as_object() {
obj.get("content")
.and_then(|c| c.as_str())
.map(|s| s.to_string())
} else {
None
}
})
}
/// Extract eos_token as a string (handles both string and object formats)
fn eos_token_str(&self) -> Option<String> {
self.eos_token.as_ref().and_then(|v| {
if let Some(s) = v.as_str() {
Some(s.to_string())
} else if let Some(obj) = v.as_object() {
obj.get("content")
.and_then(|c| c.as_str())
.map(|s| s.to_string())
} else {
None
}
})
}
}
/// Load tokenizer_config.json to get the chat template and special tokens.
fn load_tokenizer_config(model_or_path: &str) -> Result<Option<TokenizerConfig>> {
use std::path::Path;
let path = Path::new(model_or_path);
// If it's a directory, look for tokenizer_config.json inside
if path.is_dir() {
let config_path = path.join("tokenizer_config.json");
if config_path.exists() {
let content = std::fs::read_to_string(&config_path)
.context("Failed to read tokenizer_config.json")?;
let config: TokenizerConfig =
serde_json::from_str(&content).context("Failed to parse tokenizer_config.json")?;
return Ok(Some(config));
}
return Ok(None);
}
// Try to download from HuggingFace
let cache = hf_hub::Cache::default();
let api = hf_hub::api::sync::ApiBuilder::from_cache(cache)
.with_progress(false)
.build()
.map_err(|e| anyhow::anyhow!("Failed to create HuggingFace API client: {}", e))?;
let repo = api.model(model_or_path.to_string());
match repo.get("tokenizer_config.json") {
Ok(config_path) => {
let content = std::fs::read_to_string(&config_path)
.context("Failed to read tokenizer_config.json")?;
let config: TokenizerConfig =
serde_json::from_str(&content).context("Failed to parse tokenizer_config.json")?;
Ok(Some(config))
}
Err(_) => Ok(None),
}
}
fn try_load_prompt_renderer(model_or_path: &str) -> Option<PromptRenderer> {
use std::path::Path;
let path = Path::new(model_or_path);
if !path.is_dir() {
return None;
}
let card = ModelDeploymentCard::load_from_disk(path, None).ok()?;
let formatter = PromptFormatter::from_mdc(&card).ok()?;
Some(PromptRenderer::Formatter(formatter))
}
/// Chat template renderer using minijinja.
#[derive(Clone)]
struct ChatTemplateRenderer {
template: String,
bos_token: String,
eos_token: String,
}
impl ChatTemplateRenderer {
fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
Self {
template,
bos_token: bos_token.unwrap_or_else(|| "<s>".to_string()),
eos_token: eos_token.unwrap_or_else(|| "</s>".to_string()),
}
}
/// Apply the chat template to a list of messages.
/// Returns the formatted prompt string.
fn apply(&self, messages: &[ChatTemplateMessage]) -> Result<String> {
let mut env = Environment::new();
env.add_template("chat", &self.template)
.context("Failed to compile chat template")?;
let tmpl = env.get_template("chat").unwrap();
// Render with add_generation_prompt=true to match frontend behavior
let result = tmpl
.render(context! {
messages => messages,
add_generation_prompt => true,
bos_token => &self.bos_token,
eos_token => &self.eos_token,
})
.context("Failed to render chat template")?;
Ok(result)
}
}
/// Minimal chat request to reuse the frontend prompt formatter.
struct SimpleChatRequest {
messages: Vec<ChatTemplateMessage>,
}
impl OAIChatLikeRequest for SimpleChatRequest {
fn model(&self) -> String {
"kv_router_bench".to_string()
}
fn messages(&self) -> Value {
Value::from_serialize(&self.messages)
}
fn should_add_generation_prompt(&self) -> bool {
true
}
}
/// Prompt renderer that mirrors the frontend prompt formatting pipeline.
#[derive(Clone)]
enum PromptRenderer {
Formatter(PromptFormatter),
Simple(ChatTemplateRenderer),
}
impl PromptRenderer {
fn render_user_message(&self, content: &str) -> Result<String> {
match self {
PromptRenderer::Formatter(formatter) => {
let req = SimpleChatRequest {
messages: vec![ChatTemplateMessage {
role: "user".to_string(),
content: content.to_string(),
}],
};
match formatter {
PromptFormatter::OAI(inner) => inner.render(&req),
}
}
PromptRenderer::Simple(renderer) => renderer.apply(&[ChatTemplateMessage {
role: "user".to_string(),
content: content.to_string(),
}]),
}
}
}
/// Message format for chat template rendering
#[derive(Debug, Clone, Serialize)]
struct ChatTemplateMessage {
role: String,
content: String,
}
/// Load a tokenizer from a local path or HuggingFace model ID.
///
/// Tries in order:
/// 1. Direct file path (tokenizer.json)
/// 2. Directory containing tokenizer.json
/// 3. HuggingFace model ID (downloads tokenizer.json)
fn load_tokenizer(model_or_path: &str) -> Result<Tokenizer> {
use std::path::Path;
let path = Path::new(model_or_path);
// If it's a file, load directly
if path.is_file() {
return Tokenizer::from_file(path).map_err(|e| {
anyhow::anyhow!(
"Failed to load tokenizer from file '{}': {}",
model_or_path,
e
)
});
}
// If it's a directory, look for tokenizer.json inside
if path.is_dir() {
let tokenizer_path = path.join("tokenizer.json");
if tokenizer_path.exists() {
return Tokenizer::from_file(&tokenizer_path).map_err(|e| {
anyhow::anyhow!(
"Failed to load tokenizer from '{}': {}",
tokenizer_path.display(),
e
)
});
}
return Err(anyhow::anyhow!(
"Directory '{}' does not contain tokenizer.json",
model_or_path
));
}
// Try to download from HuggingFace
println!(
" Downloading tokenizer from HuggingFace: {}...",
model_or_path
);
let cache = hf_hub::Cache::default();
let api = hf_hub::api::sync::ApiBuilder::from_cache(cache)
.with_progress(true)
.build()
.map_err(|e| anyhow::anyhow!("Failed to create HuggingFace API client: {}", e))?;
let repo = api.model(model_or_path.to_string());
let tokenizer_path = repo.get("tokenizer.json").map_err(|e| {
anyhow::anyhow!(
"Failed to download tokenizer.json from '{}': {}",
model_or_path,
e
)
})?;
Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("Failed to load downloaded tokenizer: {}", e))
}
/// Pre-computed prefix data with text and corresponding hashes.
#[derive(Clone, Debug)]
struct PrefixData {
/// The raw text content of this prefix (before chat template)
text: String,
/// The formatted text after applying chat template
formatted_text: String,
/// Token IDs from tokenizing the formatted text
token_ids: Vec<u32>,
/// LocalBlockHash values (tokens_hash) for each complete block
local_hashes: Vec<LocalBlockHash>,
}
impl PrefixData {
/// Create a new PrefixData by applying prompt formatting (if provided), tokenizing, and computing hashes.
fn from_text(
text: String,
tokenizer: &Tokenizer,
kv_block_size: u32,
prompt_renderer: Option<&PromptRenderer>,
) -> Result<Self> {
// Apply prompt formatting if provided
let formatted_text = if let Some(renderer) = prompt_renderer {
renderer.render_user_message(&text)?
} else {
text.clone()
};
let encoding = tokenizer
.encode(formatted_text.as_str(), false)
.map_err(|e| anyhow::anyhow!("Failed to tokenize prefix: {}", e))?;
let token_ids: Vec<u32> = encoding.get_ids().to_vec();
let local_hashes = compute_block_hashes(&token_ids, kv_block_size);
Ok(Self {
text,
formatted_text,
token_ids,
local_hashes,
})
}
/// Number of complete blocks in this prefix
fn num_blocks(&self) -> usize {
self.local_hashes.len()
}
}
/// Pre-generated sequence data for benchmarking
#[derive(Clone)]
struct SequenceData {
worker_id: WorkerId,
local_hashes: Vec<LocalBlockHash>,
external_hashes: Vec<ExternalSequenceBlockHash>,
}
impl SequenceData {
/// Create a sequence from the exact request content.
fn from_request_content(
content: &str,
worker_id: WorkerId,
kv_block_size: u32,
tokenizer: &Tokenizer,
prompt_renderer: Option<&PromptRenderer>,
) -> Result<Self> {
let (local_hashes, external_hashes) =
compute_hashes_for_content(content, tokenizer, kv_block_size, prompt_renderer)?;
Ok(Self {
worker_id,
local_hashes,
external_hashes,
})
}
fn to_router_event(&self, event_id: u64) -> RouterEvent {
let kv_event = KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: self
.local_hashes
.iter()
.zip(self.external_hashes.iter())
.map(|(local, ext)| KvCacheStoredBlockData {
block_hash: *ext,
tokens_hash: *local,
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
};
RouterEvent::new(self.worker_id, kv_event)
}
}
/// Response from the frontend's /health endpoint
#[derive(Debug, Deserialize)]
struct HealthResponse {
#[allow(dead_code)]
status: String,
instances: Vec<HealthInstance>,
}
/// Instance info from health endpoint
#[derive(Debug, Deserialize)]
struct HealthInstance {
instance_id: u64,
#[allow(dead_code)]
endpoint: String,
}
/// Response from the frontend's /v1/models endpoint
#[derive(Debug, Deserialize)]
struct ModelsResponse {
data: Vec<ModelInfo>,
}
/// Model info from /v1/models endpoint
#[derive(Debug, Deserialize)]
struct ModelInfo {
id: String,
}
/// Fetch the model name from the frontend's /v1/models endpoint.
///
/// Returns the model ID if exactly one model is available.
/// Returns an error if zero or multiple models are found (requiring explicit --model).
async fn fetch_model_name(frontend_url: &str) -> Result<String> {
let client = reqwest::Client::new();
let url = format!("{}/v1/models", frontend_url);
println!(" Auto-detecting model from {}...", url);
let response = client
.get(&url)
.send()
.await
.context("Failed to connect to frontend /v1/models endpoint")?;
if !response.status().is_success() {
anyhow::bail!("Models endpoint returned status: {}", response.status());
}
let models: ModelsResponse = response
.json()
.await
.context("Failed to parse models response")?;
match models.data.len() {
0 => anyhow::bail!("No models found at endpoint. Is a backend running?"),
1 => {
let model_id = models.data[0].id.clone();
println!(" Auto-detected model: {}", model_id);
Ok(model_id)
}
n => {
println!(" Multiple models available ({}):", n);
for m in &models.data {
println!(" - {}", m.id);
}
anyhow::bail!("Multiple models available. Please specify --model explicitly.")
}
}
}
/// Discover worker IDs from the frontend's /health endpoint.
///
/// Returns a list of instance_ids (worker_ids) that are currently registered.
async fn discover_worker_ids(frontend_url: &str) -> Result<Vec<WorkerId>> {
let client = reqwest::Client::new();
let url = format!("{}/health", frontend_url);
println!(" Discovering workers from {}...", url);
let response = client
.get(&url)
.send()
.await
.context("Failed to connect to frontend /health endpoint")?;
if !response.status().is_success() {
anyhow::bail!("Health endpoint returned status: {}", response.status());
}
let health: HealthResponse = response
.json()
.await
.context("Failed to parse health response")?;
let worker_ids: Vec<WorkerId> = health.instances.iter().map(|i| i.instance_id).collect();
// Deduplicate (in case of multiple endpoints per worker)
let mut unique_ids: Vec<WorkerId> = worker_ids.clone();
unique_ids.sort_unstable();
unique_ids.dedup();
println!(" Discovered {} workers", unique_ids.len());
if unique_ids.is_empty() {
anyhow::bail!("No workers discovered from frontend. Are kv_stress_workers running?");
}
Ok(unique_ids)
}
/// Generate sequences with shared prefix prompts using computed hashes.
///
/// This function:
/// 1. Takes pre-computed PrefixData with real token hashes
/// 2. Creates sequences that share these prefix hashes
/// 3. Adds unique suffix blocks for each sequence
///
/// The prefix hashes are computed from actual tokenized text, so HTTP requests
/// with the same prefix text will produce matching hashes in the frontend.
///
/// Worker IDs are taken from the provided list (discovered from frontend).
/// Uses parallel processing for tokenization to speed up generation.
fn generate_sequences_for_requests(
num_sequences: usize,
worker_ids: &[WorkerId],
prefix_prompts: &[String],
num_prefix_prompts: usize,
kv_block_size: u32,
tokenizer: &Tokenizer,
prompt_renderer: Option<&PromptRenderer>,
seed: u64,
show_progress: bool,
) -> Result<Vec<SequenceData>> {
if prefix_prompts.is_empty() || num_prefix_prompts == 0 {
anyhow::bail!("No prefix prompts available for request-aligned sequence generation");
}
let progress = if show_progress {
let pb = ProgressBar::new(num_sequences as u64);
pb.set_style(
ProgressStyle::default_bar()
.template(" [{bar:40.cyan/blue}] {pos}/{len} sequences ({eta})")
.unwrap()
.progress_chars("=> "),
);
Some(pb)
} else {
None
};
// Clone tokenizer and prompt_renderer for parallel access
let tokenizer = tokenizer.clone();
let prompt_renderer_clone = prompt_renderer.cloned();
let progress_clone = progress.clone();
// Generate sequences in parallel
let results: Result<Vec<SequenceData>> = (0..num_sequences as u64)
.into_par_iter()
.map(|request_id| {
let worker_id = worker_ids[request_id as usize % worker_ids.len()];
let (_prefix_idx, content) = build_request_content_with_prefix(
request_id,
prefix_prompts,
num_prefix_prompts,
seed,
);
let seq = SequenceData::from_request_content(
&content,
worker_id,
kv_block_size,
&tokenizer,
prompt_renderer_clone.as_ref(),
)?;
if let Some(ref pb) = progress_clone {
pb.inc(1);
}
Ok(seq)
})
.collect();
if let Some(pb) = progress {
pb.finish_and_clear();
}
results
}
/// Build tree by publishing events to NATS
async fn build_tree_via_nats(
nats_client: &async_nats::Client,
namespace: &str,
component: &str,
sequences: &[SequenceData],
verbose: bool,
) -> Result<Duration> {
// Subject format must match Component.subject() from lib/runtime/src/component/component.rs
// which returns: namespace.{namespace_name}.component.{component_name}
let subject = format!(
"namespace.{}.component.{}.{}",
namespace, component, KV_EVENT_SUBJECT
);
println!(
"Building tree: {} sequences to subject {}...",
sequences.len(),
subject
);
let start = Instant::now();
let progress = if !verbose {
let pb = ProgressBar::new(sequences.len() as u64);
pb.set_style(
ProgressStyle::default_bar()
.template(" [{bar:40.cyan/blue}] {pos}/{len} ({eta})")
.unwrap()
.progress_chars("=> "),
);
Some(pb)
} else {
None
};
for (event_id, seq) in sequences.iter().enumerate() {
let event = seq.to_router_event(event_id as u64);
let data = encode_event_with_envelope(&event, KV_EVENT_SUBJECT)?;
nats_client
.publish(subject.clone(), data.into())
.await
.context("Failed to publish to NATS")?;
if let Some(ref pb) = progress {
pb.set_position((event_id + 1) as u64);
} else if verbose && (event_id + 1) % 100 == 0 {
println!(" Published {}/{} events", event_id + 1, sequences.len());
}
}
if let Some(pb) = progress {
pb.finish_and_clear();
}
nats_client.flush().await.context("Failed to flush NATS")?;
// Wait for events to be processed by the frontend
println!(" Waiting for event processing...");
tokio::time::sleep(Duration::from_secs(2)).await;
let elapsed = start.elapsed();
println!("Tree construction: {:.2?}", elapsed);
Ok(elapsed)
}
/// Result of a single HTTP request
#[derive(Debug, Clone)]
struct RequestResult {
latency: Duration,
/// Time when request completed, relative to measurement start
completion_time: Duration,
success: bool,
}
/// Individual latency sample for raw data export
#[derive(Debug, Clone, Serialize)]
struct LatencySample {
/// Latency in microseconds
latency_us: u64,
/// Completion time in milliseconds from measurement start
completion_time_ms: u64,
/// Whether the request succeeded
success: bool,
}
/// OpenAI-style chat completion request
#[derive(Debug, Serialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
}
#[derive(Debug, Serialize)]
struct ChatMessage {
role: String,
content: String,
}
/// Generate prefix text content.
/// These are long enough to span multiple KV blocks when tokenized.
/// Each prefix is designed to be distinct and consistent across requests.
///
fn generate_prefix_text(prefix_id: usize, target_tokens: usize) -> String {
// Each word is roughly 1-2 tokens. We generate enough words to hit target_tokens.
// Using deterministic content so the same prefix_id always produces the same text.
let words_per_prefix = target_tokens * 2; // Conservative estimate
// Generate a deterministic "document" for each prefix
// This simulates a system prompt or context that would be cached
let mut content = format!(
"System configuration document version {} revision {}. ",
prefix_id,
prefix_id * 17 + 3
);
// Add filler content to reach target token count
for i in 0..words_per_prefix {
let word_idx = (prefix_id * 1000 + i) % 100;
let words = [
"the", "and", "for", "are", "but", "not", "you", "all", "can", "had", "her", "was",
"one", "our", "out", "day", "get", "has", "him", "his", "how", "its", "may", "new",
"now", "old", "see", "two", "way", "who", "boy", "did", "oil", "sit", "set", "run",
"top", "got", "let", "put", "say", "she", "too", "use", "dad", "mom", "end", "big",
"ask", "own", "why", "men", "read", "need", "land", "same", "here", "must", "home",
"hand", "high", "year", "come", "made", "find", "long", "down", "look", "write", "go",
"word", "call", "first", "water", "been", "number", "people", "over", "such", "make",
"time", "very", "when", "would", "more", "some", "into", "them", "than", "only",
"have", "from", "this", "that", "with", "they", "will", "each", "about", "which",
];
content.push_str(words[word_idx]);
content.push(' ');
}
// Trim trailing space and end with punctuation + newline for clean token boundary.
// This ensures the prefix tokenization is stable regardless of what follows.
content = content.trim_end().to_string();
content.push_str(".\n");
content
}
/// Generate prefix prompts with pre-computed hashes.
///
/// This tokenizes each prefix and computes the block hashes that the frontend
/// will produce when it tokenizes the same text. This ensures that NATS events
/// and HTTP requests will have matching hashes.
///
/// If a chat template is provided, it will be applied to the messages before tokenizing,
/// matching the frontend's behavior for /v1/chat/completions requests.
///
/// Uses parallel processing for tokenization to speed up generation.
fn generate_prefix_data(
num_prefixes: usize,
target_tokens: usize,
tokenizer: &Tokenizer,
kv_block_size: u32,
prompt_renderer: Option<&PromptRenderer>,
show_progress: bool,
) -> Result<Vec<PrefixData>> {
let progress = if show_progress {
let pb = ProgressBar::new(num_prefixes as u64);
pb.set_style(
ProgressStyle::default_bar()
.template(" [{bar:40.cyan/blue}] {pos}/{len} prefixes ({eta})")
.unwrap()
.progress_chars("=> "),
);
Some(pb)
} else {
None
};
// Generate prefix texts in parallel
let texts: Vec<String> = (0..num_prefixes)
.into_par_iter()
.map(|prefix_id| generate_prefix_text(prefix_id, target_tokens))
.collect();
// Tokenize and compute hashes in parallel
// The tokenizer is thread-safe (Send + Sync), and prompt rendering creates
// a new Environment each time, so this is safe to parallelize.
let tokenizer = tokenizer.clone();
let prompt_renderer_clone = prompt_renderer.cloned();
let progress_clone = progress.clone();
let results: Result<Vec<PrefixData>> = texts
.into_par_iter()
.map(|text| {
let result = PrefixData::from_text(
text,
&tokenizer,
kv_block_size,
prompt_renderer_clone.as_ref(),
);
if let Some(ref pb) = progress_clone {
pb.inc(1);
}
result
})
.collect();
if let Some(pb) = progress {
pb.finish_and_clear();
}
// Results are already collected, just return them
results
}
/// Build an HTTP request body that will exercise routing with cache-friendly prefixes.
///
/// Uses a shared prefix prompt (based on group_id) plus a unique suffix.
/// This allows the warmup phase to populate the cache, and measurement phase
/// requests with the same prefix will get cache hits.
///
/// IMPORTANT: The suffix is appended with a newline separator to ensure clean token
/// boundaries. This prevents the suffix from affecting how the prefix tokens are
/// split by BPE tokenizers, ensuring that pre-computed prefix hashes match what
/// the frontend computes for the full request.
fn build_request_content_with_prefix(
request_id: u64,
prefix_prompts: &[String],
num_prefix_prompts: usize,
seed: u64,
) -> (usize, String) {
// Deterministically select a prefix based on request_id and seed
let prefix_idx = ((request_id ^ seed) as usize) % num_prefix_prompts.min(prefix_prompts.len());
let prefix = &prefix_prompts[prefix_idx];
// Add a unique suffix so each request is distinct but shares the prefix.
// Use a newline separator to create a clean token boundary between prefix and suffix.
// This ensures the prefix tokens remain identical whether tokenized alone or with suffix,
// which is critical for hash matching between pre-computed NATS events and HTTP requests.
let suffix = format!(
"\n\nRequest {} query: What is the answer to question number {}?",
request_id,
request_id % 1000
);
let content = format!("{}{}", prefix, suffix);
(prefix_idx, content)
}
fn build_routing_request_with_prefix(
request_id: u64,
prefix_prompts: &[String],
num_prefix_prompts: usize,
model: &str,
seed: u64,
) -> ChatCompletionRequest {
let (_prefix_idx, content) =
build_request_content_with_prefix(request_id, prefix_prompts, num_prefix_prompts, seed);
ChatCompletionRequest {
model: model.to_string(),
messages: vec![ChatMessage {
role: "user".to_string(),
content,
}],
max_tokens: Some(1),
}
}
/// Send HTTP requests at a specified rate.
/// Returns the Unix timestamp (seconds since epoch) when warmup ended.
async fn send_requests_at_rate(
client: reqwest::Client,
frontend_url: String,
prefix_prompts: Arc<Vec<String>>,
num_prefix_prompts: usize,
model: String,
seed: u64,
rate: f64,
duration_secs: u64,
warmup_secs: u64,
max_concurrency: usize,
results: Arc<Mutex<Vec<RequestResult>>>,
in_flight: Arc<AtomicU64>,
max_in_flight: Arc<AtomicU64>,
verbose: bool,
) -> f64 {
let semaphore = Arc::new(Semaphore::new(max_concurrency));
let interval = Duration::from_secs_f64(1.0 / rate);
let start = Instant::now();
let warmup_duration = Duration::from_secs(warmup_secs);
let total_duration = Duration::from_secs(duration_secs + warmup_secs);
let measurement_start = Arc::new(Mutex::new(None::<Instant>));
let mut request_id = 0u64;
let mut warmup_end_timestamp: f64 = 0.0;
let mut warmup_ended = false;
println!(
" Running for {}s ({}s warmup + {}s measurement) at {} req/sec...",
warmup_secs + duration_secs,
warmup_secs,
duration_secs,
rate
);
println!(
" Using {} prefix prompts for cache sharing (warmup populates cache)...",
num_prefix_prompts
);
// Counters for completed requests (updated by spawned tasks)
let success_count = Arc::new(AtomicU64::new(0));
let failure_count = Arc::new(AtomicU64::new(0));
// Monitor in-flight count and throughput every second
let in_flight_monitor = in_flight.clone();
let success_monitor = success_count.clone();
let failure_monitor = failure_count.clone();
let monitor_start = start;
let monitor_handle = tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(1));
interval.tick().await; // Skip first immediate tick
let mut prev_success = 0u64;
let mut prev_failure = 0u64;
loop {
interval.tick().await;
let in_flight_now = in_flight_monitor.load(Ordering::Relaxed);
let success_now = success_monitor.load(Ordering::Relaxed);
let failure_now = failure_monitor.load(Ordering::Relaxed);
let success_delta = success_now - prev_success;
let failure_delta = failure_now - prev_failure;
prev_success = success_now;
prev_failure = failure_now;
eprintln!(
" [t={:>3}s] in-flight: {:>4}, completed: {:>4} ok / {:>3} err",
monitor_start.elapsed().as_secs(),
in_flight_now,
success_delta,
failure_delta
);
}
});
while start.elapsed() < total_duration {
let is_warmup = start.elapsed() < warmup_duration;
// Detect transition from warmup to measurement phase
if !is_warmup && !warmup_ended {
warmup_ended = true;
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap();
warmup_end_timestamp = now.as_secs_f64();
println!();
println!(" *** WARMUP COMPLETE ***");
println!(" WARMUP_END_TIMESTAMP={:.6}", warmup_end_timestamp);
println!();
}
let permit = semaphore.clone().acquire_owned().await.unwrap();
let body = build_routing_request_with_prefix(
request_id,
&prefix_prompts,
num_prefix_prompts,
&model,
seed,
);
let client = client.clone();
let url = format!("{}/v1/chat/completions", frontend_url);
let results = results.clone();
let in_flight_clone = in_flight.clone();
let max_in_flight_clone = max_in_flight.clone();
let success_clone = success_count.clone();
let failure_clone = failure_count.clone();
let measurement_start_clone = measurement_start.clone();
let req_id = request_id;
// Track in-flight
let current = in_flight_clone.fetch_add(1, Ordering::Relaxed) + 1;
max_in_flight_clone.fetch_max(current, Ordering::Relaxed);
tokio::spawn(async move {
let submit_time = Instant::now();
let response = client.post(&url).json(&body).send().await;
let complete_time = Instant::now();
in_flight_clone.fetch_sub(1, Ordering::Relaxed);
drop(permit);
// Determine success/failure and update counters
let success = match &response {
Ok(resp) => resp.status().is_success(),
Err(_) => false,
};
if success {
success_clone.fetch_add(1, Ordering::Relaxed);
} else {
failure_clone.fetch_add(1, Ordering::Relaxed);
}
// Only record results after warmup
if !is_warmup {
// Initialize measurement start on first non-warmup completion
let mut ms_guard = measurement_start_clone.lock().await;
if ms_guard.is_none() {
*ms_guard = Some(complete_time);
}
let measurement_base = ms_guard.unwrap();
drop(ms_guard);
let result = RequestResult {
latency: complete_time.duration_since(submit_time),
completion_time: complete_time.duration_since(measurement_base),
success,
};
if verbose {
println!(
" Request {} completed in {:?} (success: {})",
req_id, result.latency, result.success
);
}
results.lock().await.push(result);
}
});
request_id += 1;
tokio::time::sleep(interval).await;
}
println!(" Submitted {} requests", request_id);
// Wait for in-flight requests
println!(" Waiting for in-flight requests...");
let drain_start = Instant::now();
while in_flight.load(Ordering::Relaxed) > 0 && drain_start.elapsed() < Duration::from_secs(30) {
tokio::time::sleep(Duration::from_millis(100)).await;
}
let remaining = in_flight.load(Ordering::Relaxed);
if remaining > 0 {
println!(" {} requests still in-flight after timeout", remaining);
}
// Stop the in-flight monitor
monitor_handle.abort();
warmup_end_timestamp
}
/// Publish events at a specified rate during stress test
async fn publish_events_at_rate(
nats_client: async_nats::Client,
namespace: String,
component: String,
sequences: Vec<SequenceData>,
rate: f64,
duration_secs: u64,
) {
// Subject format must match Component.subject() from lib/runtime/src/component/component.rs
let subject = format!(
"namespace.{}.component.{}.{}",
namespace, component, KV_EVENT_SUBJECT
);
let interval = Duration::from_secs_f64(1.0 / rate);
let start = Instant::now();
let duration = Duration::from_secs(duration_secs);
let start_id = 1000000u64; // Start high to avoid collision with tree construction
let mut event_id = start_id;
// Failure tracking
let mut publish_failures: u64 = 0;
let mut encode_failures: u64 = 0;
let mut last_publish_error: Option<String> = None;
let mut last_encode_error: Option<String> = None;
// Periodic reporting interval (every 10 seconds)
let report_interval = Duration::from_secs(10);
let mut last_report = Instant::now();
while start.elapsed() < duration {
let seq = &sequences[(event_id as usize) % sequences.len()];
let event = seq.to_router_event(event_id);
match encode_event_with_envelope(&event, KV_EVENT_SUBJECT) {
Ok(data) => {
if let Err(e) = nats_client.publish(subject.clone(), data.into()).await {
publish_failures += 1;
last_publish_error = Some(format!("{:?}", e));
}
}
Err(e) => {
encode_failures += 1;
last_encode_error = Some(format!("{:?}", e));
}
}
event_id += 1;
// Periodic failure report
if last_report.elapsed() >= report_interval {
let total_attempts = event_id - start_id;
let total_failures = publish_failures + encode_failures;
if total_failures > 0 {
eprintln!(
" [publish_events] Periodic report: {} failures / {} attempts ({} publish, {} encode)",
total_failures, total_attempts, publish_failures, encode_failures
);
}
last_report = Instant::now();
}
tokio::time::sleep(interval).await;
}
// Final failure report
let total_attempts = event_id - start_id;
let total_failures = publish_failures + encode_failures;
if total_failures > 0 {
eprintln!(
" [publish_events] Final report: {} failures / {} attempts ({:.2}% failure rate)",
total_failures,
total_attempts,
(total_failures as f64 / total_attempts as f64) * 100.0
);
eprintln!(
" Publish failures: {}, Encode failures: {}",
publish_failures, encode_failures
);
if let Some(ref err) = last_publish_error {
eprintln!(" Last publish error: {}", err);
}
if let Some(ref err) = last_encode_error {
eprintln!(" Last encode error: {}", err);
}
} else {
println!(
" [publish_events] Completed: {} events published with no failures",
total_attempts
);
}
}
/// Latency statistics
struct LatencyStats {
min: Duration,
max: Duration,
p50: Duration,
p95: Duration,
p99: Duration,
}
impl LatencyStats {
fn from_durations(durations: &[Duration]) -> Option<Self> {
if durations.is_empty() {
return None;
}
let mut sorted = durations.to_vec();
sorted.sort();
let n = sorted.len();
Some(Self {
min: sorted[0],
max: sorted[n - 1],
p50: sorted[n / 2],
p95: sorted[n * 95 / 100],
p99: sorted[n * 99 / 100],
})
}
}
/// Time-bucketed latency statistics for tracking latency over time
#[derive(Debug, Clone, Serialize)]
struct TimeBucketStats {
/// Bucket start time in seconds from measurement start
bucket_start_sec: u64,
/// Bucket end time in seconds
bucket_end_sec: u64,
/// Number of requests completed in this bucket
count: usize,
/// Latency stats for this bucket (in microseconds)
latency_min_us: u64,
latency_p50_us: u64,
latency_p95_us: u64,
latency_max_us: u64,
}
/// Compute per-bucket latency statistics
fn compute_time_bucket_stats(
results: &[RequestResult],
bucket_size_secs: u64,
) -> Vec<TimeBucketStats> {
if results.is_empty() {
return Vec::new();
}
// Find the max completion time to determine bucket count
let max_completion = results
.iter()
.map(|r| r.completion_time)
.max()
.unwrap_or(Duration::ZERO);
let num_buckets = (max_completion.as_secs() / bucket_size_secs) + 1;
let mut bucket_latencies: Vec<Vec<Duration>> = vec![Vec::new(); num_buckets as usize];
// Group latencies by completion time bucket
for result in results {
let bucket_idx = (result.completion_time.as_secs() / bucket_size_secs) as usize;
if bucket_idx < bucket_latencies.len() {
bucket_latencies[bucket_idx].push(result.latency);
}
}
// Compute stats for each bucket
bucket_latencies
.iter()
.enumerate()
.filter_map(|(idx, latencies)| {
if latencies.is_empty() {
return None;
}
let stats = LatencyStats::from_durations(latencies)?;
Some(TimeBucketStats {
bucket_start_sec: idx as u64 * bucket_size_secs,
bucket_end_sec: (idx as u64 + 1) * bucket_size_secs,
count: latencies.len(),
latency_min_us: stats.min.as_micros() as u64,
latency_p50_us: stats.p50.as_micros() as u64,
latency_p95_us: stats.p95.as_micros() as u64,
latency_max_us: stats.max.as_micros() as u64,
})
})
.collect()
}
/// Print time-bucket latency report
fn print_time_bucket_report(buckets: &[TimeBucketStats]) {
if buckets.is_empty() {
println!(" No time bucket data available");
return;
}
println!(
" {:>8} {:>8} {:>12} {:>12} {:>12} {:>12}",
"Time(s)", "Count", "Min(ms)", "P50(ms)", "P95(ms)", "Max(ms)"
);
println!(" {}", "-".repeat(68));
for bucket in buckets {
println!(
" {:>3}-{:<4} {:>8} {:>12.1} {:>12.1} {:>12.1} {:>12.1}",
bucket.bucket_start_sec,
bucket.bucket_end_sec,
bucket.count,
bucket.latency_min_us as f64 / 1000.0,
bucket.latency_p50_us as f64 / 1000.0,
bucket.latency_p95_us as f64 / 1000.0,
bucket.latency_max_us as f64 / 1000.0,
);
}
}
/// Stress test results
#[derive(Debug, Serialize)]
struct StressResults {
// Configuration
tree_size: usize,
num_sequences: usize,
depth: usize,
num_workers: usize,
// Tree construction
tree_construction_time_ms: u64,
// Request metrics
requests_submitted: u64,
requests_completed: u64,
requests_failed: u64,
// Latency stats (in microseconds)
latency_min_us: u64,
latency_p50_us: u64,
latency_p95_us: u64,
latency_p99_us: u64,
latency_max_us: u64,
// Throughput
achieved_request_rate: f64,
// Saturation
max_in_flight: u64,
// Time-bucketed latency stats for tracking latency over time
#[serde(skip_serializing_if = "Vec::is_empty")]
time_buckets: Vec<TimeBucketStats>,
// Raw latency samples for detailed graphing
#[serde(skip_serializing_if = "Vec::is_empty")]
raw_samples: Vec<LatencySample>,
}
impl StressResults {
fn print_report(&self) {
println!("\n========================================");
println!("KV Router Full Stress Test Results");
println!("========================================\n");
println!("Tree Construction:");
println!(" Sequences: {}", self.num_sequences);
println!(" Blocks: {}", self.num_sequences * self.depth);
println!(" Time: {}ms", self.tree_construction_time_ms);
println!();
println!("Request Statistics:");
println!(" Submitted: {}", self.requests_submitted);
println!(" Completed: {}", self.requests_completed);
println!(" Failed: {}", self.requests_failed);
println!(" Throughput: {:.1} req/sec", self.achieved_request_rate);
println!();
println!("End-to-End Latency (includes HTTP overhead):");
println!(" min: {:>12}us", self.latency_min_us);
println!(" p50: {:>12}us", self.latency_p50_us);
println!(" p95: {:>12}us", self.latency_p95_us);
println!(" p99: {:>12}us", self.latency_p99_us);
println!(" max: {:>12}us", self.latency_max_us);
println!();
if !self.time_buckets.is_empty() {
println!("Latency Over Time:");
print_time_bucket_report(&self.time_buckets);
println!();
}
println!("Saturation:");
println!(" Max in-flight: {}", self.max_in_flight);
}
}
#[tokio::main]
async fn main() -> Result<()> {
let args = Args::parse();
let num_sequences = args.tree_size / args.depth;
println!("KV Router Full Stress Test");
println!("==========================\n");
// Resolve model name: use provided value or auto-detect from /v1/models
let model = match args.model {
Some(m) => m,
None => {
println!("Model Detection:");
fetch_model_name(&args.frontend_url).await?
}
};
// Tokenizer path defaults to model if not specified
let tokenizer_path = args.tokenizer_path.as_ref().unwrap_or(&model);
println!("Configuration:");
println!(
" Tree size: {} blocks ({} sequences x {} depth)",
args.tree_size, num_sequences, args.depth
);
println!(" Workers: {}", args.num_workers);
println!(
" Prefix prompt ratio: {:.1}%",
args.prefix_prompt_ratio * 100.0
);
println!(" Prefix prompts: {}", args.num_prefix_prompts);
println!(" Seed: {}", args.seed);
println!(" Request rate: {:.1} req/sec", args.request_rate);
println!(" Event rate: {:.1} events/sec", args.event_rate);
println!(" Duration: {}s", args.duration);
println!(" Warmup: {}s", args.warmup);
println!(" Concurrency: {}", args.concurrency);
println!(" Model: {}", model);
println!(" KV block size: {}", args.kv_block_size);
println!(" Tokenizer: {}", tokenizer_path);
println!(" Namespace: {}", args.namespace);
println!(" Component: {}", args.component);
println!(
" NATS subject: namespace.{}.component.{}.kv-events",
args.namespace, args.component
);
if args.skip_tree_construction {
println!(" Tree construction: SKIPPED (using external kv_stress_worker)");
}
println!();
// Create HTTP client
let http_client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()
.context("Failed to create HTTP client")?;
// Phase 0: Load tokenizer and prompt formatter
println!("Phase 0: Loading Tokenizer and Prompt Formatter");
println!(" Loading tokenizer from {}...", tokenizer_path);
let tokenizer = load_tokenizer(tokenizer_path)?;
println!(" Tokenizer loaded successfully");
let mut prompt_renderer =
try_load_prompt_renderer(&model).or_else(|| try_load_prompt_renderer(tokenizer_path));
if prompt_renderer.is_some() {
println!(" Prompt formatter loaded from ModelDeploymentCard");
} else {
// Fallback to tokenizer_config.json - use the same HfTokenizerConfigJsonFormatter as the frontend
// Try local path first, then download from HuggingFace
let contents: Option<String> = {
let config_path = std::path::Path::new(tokenizer_path).join("tokenizer_config.json");
if config_path.exists() {
std::fs::read_to_string(&config_path).ok()
} else if !std::path::Path::new(tokenizer_path).exists() {
// Might be a HuggingFace model ID - try to download tokenizer_config.json
println!(
" Downloading tokenizer_config.json from HuggingFace: {}...",
tokenizer_path
);
if let Ok(api) = hf_hub::api::sync::Api::new() {
let repo = api.model(tokenizer_path.to_string());
repo.get("tokenizer_config.json")
.ok()
.and_then(|path| std::fs::read_to_string(&path).ok())
} else {
None
}
} else {
None
}
};
let try_simple_fallback = |path: &str| -> Option<PromptRenderer> {
let config = load_tokenizer_config(path).ok()??;
let template = config.chat_template.clone()?;
Some(PromptRenderer::Simple(ChatTemplateRenderer::new(
template,
config.bos_token_str(),
config.eos_token_str(),
)))
};
if let Some(contents) = contents {
match serde_json::from_str::<ChatTemplate>(&contents) {
Ok(chat_template) => {
match PromptFormatter::from_parts(chat_template, ContextMixins::new(&[])) {
Ok(formatter) => {
println!(
" Prompt formatter loaded from tokenizer_config.json (using frontend-compatible renderer)"
);
prompt_renderer = Some(PromptRenderer::Formatter(formatter));
}
Err(e) => {
println!(" WARNING: Failed to create prompt formatter: {}", e);
println!(
" Using fallback Simple renderer (may not match frontend)"
);
prompt_renderer = try_simple_fallback(tokenizer_path);
}
}
}
Err(e) => {
println!(
" WARNING: Failed to parse tokenizer_config.json as ChatTemplate: {}",
e
);
println!(" Using fallback Simple renderer (may not match frontend)");
prompt_renderer = try_simple_fallback(tokenizer_path);
}
}
} else {
println!(" WARNING: No tokenizer_config.json found. Hashes may not match frontend!");
println!(
" The frontend applies chat templates to /v1/chat/completions requests."
);
}
}
// Target tokens for prefix (block_size tokens per block)
let target_prefix_tokens = (args.depth as f64 * args.prefix_prompt_ratio).round() as usize
* args.kv_block_size as usize;
// Phase 1: Generate prefix data with computed hashes
println!("\nPhase 1: Generating Prefix Data");
println!(
" Generating {} prefixes (~{} tokens each)...",
args.num_prefix_prompts, target_prefix_tokens
);
let prefix_data = generate_prefix_data(
args.num_prefix_prompts,
target_prefix_tokens,
&tokenizer,
args.kv_block_size,
prompt_renderer.as_ref(),
true, // show_progress
)?;
// Print prefix stats
for (i, prefix) in prefix_data.iter().enumerate() {
if i < 3 || args.verbose {
// Show a preview of the formatted text (first 80 chars, escape newlines)
let preview: String = prefix
.formatted_text
.chars()
.take(80)
.map(|c| if c == '\n' { ' ' } else { c })
.collect();
println!(
" Prefix {}: {} tokens, {} blocks, first hash: {:016x}",
i,
prefix.token_ids.len(),
prefix.num_blocks(),
prefix.local_hashes.first().map(|h| h.0).unwrap_or(0)
);
if args.verbose {
println!(" Formatted: {}...", preview);
}
}
}
if prefix_data.len() > 3 && !args.verbose {
println!(" ... ({} more prefixes)", prefix_data.len() - 3);
}
// Show first prefix's formatted text sample if chat template was applied
if prompt_renderer.is_some() && !prefix_data.is_empty() {
let sample: String = prefix_data[0].formatted_text.chars().take(200).collect();
println!(" Sample formatted prefix (first 200 chars):");
for line in sample.lines().take(5) {
println!(" | {}", line);
}
if prefix_data[0].formatted_text.len() > 200 {
println!(" | ...");
}
}
// Extract prefix texts for HTTP requests and request-aligned hashing
let prefix_prompts: Vec<String> = prefix_data.iter().map(|p| p.text.clone()).collect();
// Phase 2: Discover workers and generate sequences
println!("\nPhase 2: Discover Workers & Generate Sequences");
// Discover actual worker IDs from the frontend
let discovered_worker_ids = discover_worker_ids(&args.frontend_url).await?;
if discovered_worker_ids.len() != args.num_workers {
println!(
" NOTE: Discovered {} workers but --num-workers was set to {}. Using discovered workers.",
discovered_worker_ids.len(),
args.num_workers
);
}
println!(
" Generating {} sequences with shared prefixes...",
num_sequences
);
let sequences = generate_sequences_for_requests(
num_sequences,
&discovered_worker_ids,
&prefix_prompts,
args.num_prefix_prompts,
args.kv_block_size,
&tokenizer,
prompt_renderer.as_ref(),
args.seed,
true, // show_progress
)?;
println!(
" Generated {} sequences distributed across {} workers",
sequences.len(),
discovered_worker_ids.len()
);
// Phase 3: Build tree via NATS (unless skipped)
let tree_construction_time = if args.skip_tree_construction {
println!("\nPhase 3: Tree Construction via NATS - SKIPPED");
println!(" Using external kv_stress_worker for tree construction");
Duration::ZERO
} else {
println!("\nPhase 3: Tree Construction via NATS");
// Connect to NATS
println!(" Connecting to NATS at {}...", args.nats_url);
let nats_client = async_nats::connect(&args.nats_url)
.await
.context("Failed to connect to NATS")?;
println!(" Connected to NATS");
build_tree_via_nats(
&nats_client,
&args.namespace,
&args.component,
&sequences,
args.verbose,
)
.await?
};
// Phase 4: Stress Test
println!("\nPhase 4: Stress Test");
println!(
" Using {} prefix prompts for HTTP requests (hashes pre-computed)...",
prefix_prompts.len()
);
let prefix_prompts = Arc::new(prefix_prompts);
let results = Arc::new(Mutex::new(Vec::new()));
let in_flight = Arc::new(AtomicU64::new(0));
let max_in_flight = Arc::new(AtomicU64::new(0));
// Spawn event publisher if rate > 0 and tree construction wasn't skipped
let event_handle = if args.event_rate > 0.0 && !args.skip_tree_construction {
// Need to connect to NATS for ongoing events
let nats = async_nats::connect(&args.nats_url)
.await
.context("Failed to connect to NATS for event publishing")?;
let ns = args.namespace.clone();
let comp = args.component.clone();
let seqs = sequences.clone();
let rate = args.event_rate;
let dur = args.duration + args.warmup;
Some(tokio::spawn(async move {
publish_events_at_rate(nats, ns, comp, seqs, rate, dur).await;
}))
} else {
None
};
// Run request generator
// During warmup, requests populate the cache via mocker engine.
// During measurement, requests with the same prefixes get cache hits.
let warmup_end_ts = send_requests_at_rate(
http_client,
args.frontend_url.clone(),
prefix_prompts,
args.num_prefix_prompts,
model.clone(),
args.seed,
args.request_rate,
args.duration,
args.warmup,
args.concurrency,
results.clone(),
in_flight.clone(),
max_in_flight.clone(),
args.verbose,
)
.await;
// Print the timestamp again at the end for easy copy-paste
println!();
println!("To filter FE logs for post-warmup only:");
println!(
" python analyze_frontend_log.py frontend.log --after-warmup {:.6}",
warmup_end_ts
);
// Wait for event publisher
if let Some(h) = event_handle {
let _ = h.await;
}
// Collect results
let results = results.lock().await;
let latencies: Vec<Duration> = results.iter().map(|r| r.latency).collect();
let successful_results: Vec<&RequestResult> = results.iter().filter(|r| r.success).collect();
let failed_count = results.len() - successful_results.len();
// Compute actual measurement duration from completion times of successful requests.
// This accounts for the drain phase where in-flight requests complete after submission stops.
let actual_duration_secs = successful_results
.iter()
.map(|r| r.completion_time.as_secs_f64())
.fold(0.0_f64, |a, b| a.max(b))
.max(1.0); // Avoid division by zero
let stats = LatencyStats::from_durations(&latencies);
// Compute time-bucketed stats for latency-over-time tracking
let time_buckets = if args.bucket_size > 0 {
compute_time_bucket_stats(&results, args.bucket_size)
} else {
Vec::new()
};
// Collect raw latency samples if requested
let raw_samples: Vec<LatencySample> = if args.include_raw_samples {
results
.iter()
.map(|r| LatencySample {
latency_us: r.latency.as_micros() as u64,
completion_time_ms: r.completion_time.as_millis() as u64,
success: r.success,
})
.collect()
} else {
Vec::new()
};
let stress_results = StressResults {
tree_size: args.tree_size,
num_sequences,
depth: args.depth,
num_workers: discovered_worker_ids.len(),
tree_construction_time_ms: tree_construction_time.as_millis() as u64,
requests_submitted: results.len() as u64,
requests_completed: successful_results.len() as u64,
requests_failed: failed_count as u64,
latency_min_us: stats
.as_ref()
.map(|s| s.min.as_micros() as u64)
.unwrap_or(0),
latency_p50_us: stats
.as_ref()
.map(|s| s.p50.as_micros() as u64)
.unwrap_or(0),
latency_p95_us: stats
.as_ref()
.map(|s| s.p95.as_micros() as u64)
.unwrap_or(0),
latency_p99_us: stats
.as_ref()
.map(|s| s.p99.as_micros() as u64)
.unwrap_or(0),
latency_max_us: stats
.as_ref()
.map(|s| s.max.as_micros() as u64)
.unwrap_or(0),
achieved_request_rate: successful_results.len() as f64 / actual_duration_secs,
max_in_flight: max_in_flight.load(Ordering::Relaxed),
time_buckets,
raw_samples,
};
stress_results.print_report();
// Write JSON output if requested
if let Some(output_path) = args.output {
let json = serde_json::to_string_pretty(&stress_results)?;
std::fs::write(&output_path, json)?;
println!("\nResults written to: {}", output_path);
}
Ok(())
}
......@@ -4,6 +4,8 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
#[cfg(feature = "bench")]
use std::time::Instant;
use anyhow::Result;
use derive_builder::Builder;
......@@ -504,6 +506,9 @@ impl KvRouter {
update_states: bool,
lora_name: Option<String>,
) -> anyhow::Result<(WorkerWithDpRank, u32)> {
#[cfg(feature = "bench")]
let start = Instant::now();
// Validate that context_id is provided when update_states is true
if update_states && context_id.is_none() {
panic!("context_id must be provided if update_states is true");
......@@ -512,7 +517,11 @@ impl KvRouter {
let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
#[cfg(feature = "bench")]
let hash_elapsed = start.elapsed();
let overlap_scores = self.indexer.find_matches(block_hashes).await?;
#[cfg(feature = "bench")]
let find_matches_elapsed = start.elapsed();
// Compute seq_hashes only if scheduler needs it for active blocks tracking
let maybe_seq_hashes = self
......@@ -532,6 +541,19 @@ impl KvRouter {
)
.await?;
#[cfg(feature = "bench")]
{
let total_elapsed = start.elapsed();
tracing::info!(
isl_tokens,
hash_us = hash_elapsed.as_micros() as u64,
find_matches_us = (find_matches_elapsed - hash_elapsed).as_micros() as u64,
schedule_us = (total_elapsed - find_matches_elapsed).as_micros() as u64,
total_us = total_elapsed.as_micros() as u64,
"find_best_match completed"
);
}
// Note: Routing decision recording (for approximate mode) is now handled
// by KvPushRouter::generate after select_worker returns.
......
......@@ -12,6 +12,8 @@ use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
#[cfg(feature = "bench")]
use std::time::Instant;
use super::KV_HIT_RATE_SUBJECT;
use super::KvRouterConfig;
......@@ -288,6 +290,9 @@ impl KvScheduler {
update_states: bool,
lora_name: Option<String>,
) -> Result<WorkerWithDpRank, KvSchedulerError> {
#[cfg(feature = "bench")]
let start = Instant::now();
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest {
maybe_request_id,
......@@ -306,10 +311,24 @@ impl KvScheduler {
.send(request)
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?;
#[cfg(feature = "bench")]
let send_elapsed = start.elapsed();
let response = resp_rx
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?;
#[cfg(feature = "bench")]
let total_elapsed = start.elapsed();
#[cfg(feature = "bench")]
tracing::info!(
isl_tokens,
send_us = send_elapsed.as_micros() as u64,
total_us = total_elapsed.as_micros() as u64,
"scheduler.schedule completed"
);
Ok(response.best_worker)
}
......
......@@ -1175,6 +1175,11 @@ impl ActiveSequencesMultiWorker {
HashMap<WorkerWithDpRank, usize>,
HashMap<WorkerWithDpRank, usize>,
) {
#[cfg(feature = "bench")]
let start = Instant::now();
#[cfg(feature = "bench")]
let num_workers = self.senders.len();
let mut potential_blocks = HashMap::new();
let mut potential_tokens = HashMap::new();
let token_sequence_shared = token_sequence.map(Arc::new);
......@@ -1206,6 +1211,9 @@ impl ActiveSequencesMultiWorker {
}
}
#[cfg(feature = "bench")]
let send_elapsed = start.elapsed();
// Collect results from all workers
for (worker, receiver) in receivers {
match tokio::time::timeout(tokio::time::Duration::from_secs(1), receiver).await {
......@@ -1222,6 +1230,17 @@ impl ActiveSequencesMultiWorker {
}
}
#[cfg(feature = "bench")]
{
let total_elapsed = start.elapsed();
tracing::info!(
num_workers,
send_us = send_elapsed.as_micros() as u64,
total_us = total_elapsed.as_micros() as u64,
"potential_blocks_and_tokens completed"
);
}
(potential_blocks, potential_tokens)
}
......
......@@ -28,7 +28,7 @@ use crate::preprocessor::media::MediaDecoder;
pub mod deepseek_v32;
mod template;
pub use template::ContextMixins;
pub use template::{ChatTemplate, ContextMixins};
#[derive(Debug)]
pub enum TokenInput {
......@@ -95,6 +95,7 @@ pub trait OAIPromptFormatter: Send + Sync + 'static {
fn render(&self, req: &dyn OAIChatLikeRequest) -> Result<String>;
}
#[derive(Clone)]
pub enum PromptFormatter {
OAI(Arc<dyn OAIPromptFormatter>),
}
......
......@@ -14,7 +14,8 @@ mod oai;
mod tokcfg;
use super::{OAIChatLikeRequest, OAIPromptFormatter, PromptFormatter};
use tokcfg::{ChatTemplate, ChatTemplateValue};
pub use tokcfg::ChatTemplate;
use tokcfg::ChatTemplateValue;
impl PromptFormatter {
pub fn from_mdc(mdc: &ModelDeploymentCard) -> Result<PromptFormatter> {
......
......@@ -28,7 +28,7 @@ pub use etcd::EtcdStore;
mod file;
pub use file::FileStore;
const WATCH_SEND_TIMEOUT: Duration = Duration::from_millis(100);
const WATCH_SEND_TIMEOUT: Duration = Duration::from_millis(1000);
/// String we use as the Key in a key-value storage operation. Simple String wrapper
/// that can encode / decode a string.
......@@ -324,7 +324,7 @@ impl Manager {
tokio::sync::mpsc::Receiver<WatchEvent>,
) {
let bucket_name = bucket_name.to_string();
let (tx, rx) = tokio::sync::mpsc::channel(128);
let (tx, rx) = tokio::sync::mpsc::channel(1024);
let watch_task = tokio::spawn(async move {
// Start listening for changes but don't poll this yet
let bucket = self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment