Unverified Commit 039d35ff authored by Janelle Cai's avatar Janelle Cai Committed by GitHub
Browse files

test: indexer and full router benchmarks (#5784)

parent 051f18a4
......@@ -46,3 +46,8 @@ tokio = { workspace = true, features = ["rt", "macros", "time"] }
name = "radix_tree_microbench"
harness = false
required-features = ["bench"]
[[bench]]
name = "kv_indexer_bench"
harness = false
required-features = ["bench"]
This diff is collapsed.
......@@ -15,16 +15,12 @@
use clap::{Parser, ValueEnum};
use dynamo_kv_router::{
OverlapScores, RadixTree, RouterEvent, compute_block_hash_for_seq,
RadixTree, RouterEvent,
bench_utils::{LatencyStats, SequenceData, generate_sequences},
compute_block_hash_for_seq,
flat_hashmap::FlatHashMap,
protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData,
KvCacheStoreData, KvCacheStoredBlockData, LocalBlockHash, WorkerId,
compute_seq_hash_for_block,
},
protocols::LocalBlockHash,
};
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::time::{Duration, Instant};
/// Unified interface for RadixTree and FlatHashMap benchmarking.
......@@ -206,114 +202,6 @@ struct Args {
flat_hashmap: bool,
}
/// Pre-generated sequence data for benchmarking
#[derive(Clone)]
struct SequenceData {
worker_id: WorkerId,
local_hashes: Vec<LocalBlockHash>,
external_hashes: Vec<ExternalSequenceBlockHash>,
}
impl SequenceData {
/// Create a new SequenceData from local_hashes.
/// Automatically computes external_hashes using compute_seq_hash_for_block (cumulative hashes).
/// This ensures FlatHashMap can correctly identify block positions.
fn from_local_hashes(worker_id: WorkerId, local_hashes: Vec<LocalBlockHash>) -> Self {
let seq_hashes = compute_seq_hash_for_block(&local_hashes);
let external_hashes = seq_hashes
.into_iter()
.map(ExternalSequenceBlockHash)
.collect();
Self {
worker_id,
local_hashes,
external_hashes,
}
}
fn to_store_event(&self, event_id: u64) -> RouterEvent {
RouterEvent {
worker_id: self.worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: self
.local_hashes
.iter()
.zip(self.external_hashes.iter())
.map(|(local, ext)| KvCacheStoredBlockData {
tokens_hash: *local,
block_hash: *ext,
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
}
}
fn to_remove_event(&self, event_id: u64) -> RouterEvent {
RouterEvent {
worker_id: self.worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: self.external_hashes.clone(),
}),
dp_rank: 0,
},
}
}
}
/// Generate sequences with shared prefix prompts
fn generate_sequences(
num_sequences: usize,
depth: usize,
num_workers: usize,
prefix_prompt_ratio: f64,
num_prefix_prompts: usize,
seed: u64,
) -> Vec<SequenceData> {
let mut sequences = Vec::with_capacity(num_sequences);
let prefix_length = (depth as f64 * prefix_prompt_ratio).round() as usize;
let mut rng: StdRng = StdRng::seed_from_u64(seed);
for seq_id in 0..num_sequences {
let seq_id_u64 = seq_id as u64;
let worker_id = (seq_id % num_workers) as WorkerId;
// Determine prefix group for this sequence
let group_id = if num_prefix_prompts > 0 && prefix_length > 0 {
Some(rng.random_range(0..num_prefix_prompts) as u64)
} else {
None
};
// Build local_hashes: shared prefix (if applicable) + unique suffix
let local_hashes: Vec<LocalBlockHash> = (0..depth)
.map(|block_idx| {
let block_idx_u64 = block_idx as u64;
if let Some(gid) = group_id {
if block_idx < prefix_length {
// Shared prefix based on group_id
return LocalBlockHash(0xDEAD_BEEF_0000_0000 | (gid << 32) | block_idx_u64);
}
}
// Unique suffix (or no shared prefix)
LocalBlockHash((seq_id_u64 << 32) | block_idx_u64)
})
.collect();
sequences.push(SequenceData::from_local_hashes(worker_id, local_hashes));
}
sequences
}
/// Build a pre-populated RadixTree (for sweep/dump benchmarks that specifically need RadixTree)
fn build_tree(sequences: &[SequenceData]) -> RadixTree {
let num_blocks: usize = sequences.iter().map(|s| s.local_hashes.len()).sum();
......@@ -381,52 +269,6 @@ fn build_index(sequences: &[SequenceData], use_flat_hashmap: bool) -> KvIndex {
index
}
/// Statistics for a set of timing measurements
#[derive(Debug)]
struct LatencyStats {
min: Duration,
max: Duration,
avg: Duration,
p50: Duration,
p95: Duration,
p99: Duration,
throughput_ops_sec: f64,
}
impl LatencyStats {
fn from_durations(mut durations: Vec<Duration>) -> Self {
durations.sort();
let n = durations.len();
let total: Duration = durations.iter().sum();
let avg = total / n as u32;
Self {
min: durations[0],
max: durations[n - 1],
avg,
p50: durations[n / 2],
p95: durations[n * 95 / 100],
p99: durations[n * 99 / 100],
throughput_ops_sec: n as f64 / total.as_secs_f64(),
}
}
fn print(&self, operation: &str, blocks_per_op: usize) {
println!("\n{} Latency Statistics:", operation);
println!(" min: {:>12?}", self.min);
println!(" avg: {:>12?}", self.avg);
println!(" p50: {:>12?}", self.p50);
println!(" p95: {:>12?}", self.p95);
println!(" p99: {:>12?}", self.p99);
println!(" max: {:>12?}", self.max);
println!(" throughput: {:.2} ops/sec", self.throughput_ops_sec);
println!(
" throughput: {:.2} blocks/sec",
self.throughput_ops_sec * blocks_per_op as f64
);
}
}
/// Benchmark compute_block_hash_for_seq operation
fn bench_hash(args: &Args) {
println!("\n=== Benchmarking COMPUTE_BLOCK_HASH (per-request hot path) ===");
......@@ -464,7 +306,7 @@ fn bench_hash(args: &Args) {
}
}
let stats = LatencyStats::from_durations(durations);
let stats = LatencyStats::from_durations(durations).unwrap();
stats.print("COMPUTE_BLOCK_HASH", args.depth);
}
......@@ -487,6 +329,7 @@ fn bench_store_remove_cycle(args: &Args, time_store: bool) {
args.prefix_prompt_ratio,
args.num_prefix_prompts,
args.seed,
true, // use_cumulative_hash
);
let mut index = build_index(&sequences, args.flat_hashmap);
......@@ -524,7 +367,7 @@ fn bench_store_remove_cycle(args: &Args, time_store: bool) {
}
}
let stats = LatencyStats::from_durations(durations);
let stats = LatencyStats::from_durations(durations).unwrap();
stats.print(op_name, args.depth);
}
......@@ -548,6 +391,7 @@ fn bench_find_matches(args: &Args) {
args.prefix_prompt_ratio,
args.num_prefix_prompts,
args.seed,
true, // use_cumulative_hash
);
let index = build_index(&sequences, args.flat_hashmap);
......@@ -575,7 +419,9 @@ fn bench_find_matches(args: &Args) {
println!(" Completed {}/{} iterations", i + 1, args.iterations);
}
}
LatencyStats::from_durations(hit_durations).print("FIND_MATCHES (HIT)", args.depth);
LatencyStats::from_durations(hit_durations)
.unwrap()
.print("FIND_MATCHES (HIT)", args.depth);
// MISS case
println!("\n --- MISS case (non-existing sequences) ---");
......@@ -589,7 +435,9 @@ fn bench_find_matches(args: &Args) {
println!(" Completed {}/{} iterations", i + 1, args.iterations);
}
}
LatencyStats::from_durations(miss_durations).print("FIND_MATCHES (MISS)", args.depth);
LatencyStats::from_durations(miss_durations)
.unwrap()
.print("FIND_MATCHES (MISS)", args.depth);
// PARTIAL case
println!("\n --- PARTIAL case (prefix match only) ---");
......@@ -604,7 +452,9 @@ fn bench_find_matches(args: &Args) {
println!(" Completed {}/{} iterations", i + 1, args.iterations);
}
}
LatencyStats::from_durations(partial_durations).print("FIND_MATCHES (PARTIAL)", args.depth);
LatencyStats::from_durations(partial_durations)
.unwrap()
.print("FIND_MATCHES (PARTIAL)", args.depth);
// EARLY_EXIT case
println!("\n --- EARLY_EXIT case ---");
......@@ -617,6 +467,7 @@ fn bench_find_matches(args: &Args) {
}
}
LatencyStats::from_durations(early_exit_durations)
.unwrap()
.print("FIND_MATCHES (EARLY_EXIT)", args.depth);
}
......@@ -845,6 +696,7 @@ fn bench_sweep(args: &Args) {
args.prefix_prompt_ratio,
num_prefix_prompts,
args.seed,
true, // use_cumulative_hash
);
let tree_sequences = &all_sequences[..num_sequences];
let extra_sequences = &all_sequences[num_sequences..];
......@@ -956,6 +808,7 @@ fn bench_dump(args: &Args) {
args.prefix_prompt_ratio,
args.num_prefix_prompts,
args.seed,
true, // use_cumulative_hash
);
let tree = build_tree(&sequences);
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Benchmark utilities for kv-router benchmarks.
//!
//! This module provides shared data structures for benchmarking:
//! - `LatencyStats`: Statistics for latency measurements
//! - `SequenceData`: Pre-generated sequence data for benchmarking
use crate::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash, RouterEvent, WorkerId, compute_seq_hash_for_block,
};
use rand::{Rng, SeedableRng, rngs::StdRng};
use std::time::Duration;
/// Statistics for latency measurements.
#[derive(Debug, Clone)]
pub struct LatencyStats {
pub min: Duration,
pub max: Duration,
pub avg: Duration,
pub p50: Duration,
pub p95: Duration,
pub p99: Duration,
pub throughput_ops_sec: f64,
}
impl LatencyStats {
/// Compute statistics from a vector of durations.
///
/// Returns `None` if the input is empty.
pub fn from_durations(mut durations: Vec<Duration>) -> Option<Self> {
if durations.is_empty() {
return None;
}
durations.sort();
let n = durations.len();
let total: Duration = durations.iter().sum();
let avg = total / n as u32;
Some(Self {
min: durations[0],
max: durations[n - 1],
avg,
p50: durations[n / 2],
p95: durations[n * 95 / 100],
p99: durations[n * 99 / 100],
throughput_ops_sec: n as f64 / total.as_secs_f64(),
})
}
/// Print formatted latency statistics to stdout.
pub fn print(&self, operation: &str, blocks_per_op: usize) {
println!("\n{} Latency Statistics:", operation);
println!(" min: {:>12?}", self.min);
println!(" avg: {:>12?}", self.avg);
println!(" p50: {:>12?}", self.p50);
println!(" p95: {:>12?}", self.p95);
println!(" p99: {:>12?}", self.p99);
println!(" max: {:>12?}", self.max);
println!(" throughput: {:.2} ops/sec", self.throughput_ops_sec);
println!(
" throughput: {:.2} blocks/sec",
self.throughput_ops_sec * blocks_per_op as f64
);
}
}
/// Pre-generated sequence data for benchmarking.
#[derive(Clone)]
pub struct SequenceData {
pub worker_id: WorkerId,
pub local_hashes: Vec<LocalBlockHash>,
pub external_hashes: Vec<ExternalSequenceBlockHash>,
}
impl SequenceData {
/// Create a new sequence with synthetic hashes based on sequence ID.
pub fn new(seq_id: u64, worker_id: WorkerId, depth: usize) -> Self {
let local_hashes: Vec<LocalBlockHash> = (0..depth)
.map(|block_idx| LocalBlockHash((seq_id << 32) | (block_idx as u64)))
.collect();
let external_hashes: Vec<ExternalSequenceBlockHash> = (0..depth)
.map(|block_idx| ExternalSequenceBlockHash((seq_id << 32) | (block_idx as u64)))
.collect();
Self {
worker_id,
local_hashes,
external_hashes,
}
}
/// Create a sequence from local hashes, computing external hashes using cumulative hash.
///
/// This ensures FlatHashMap can correctly identify block positions.
pub fn from_local_hashes(worker_id: WorkerId, local_hashes: Vec<LocalBlockHash>) -> Self {
let seq_hashes = compute_seq_hash_for_block(&local_hashes);
let external_hashes = seq_hashes
.into_iter()
.map(ExternalSequenceBlockHash)
.collect();
Self {
worker_id,
local_hashes,
external_hashes,
}
}
/// Convert to a store event.
pub fn to_store_event(&self, event_id: u64) -> RouterEvent {
RouterEvent {
worker_id: self.worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: self
.local_hashes
.iter()
.zip(self.external_hashes.iter())
.map(|(local, ext)| KvCacheStoredBlockData {
tokens_hash: *local,
block_hash: *ext,
mm_extra_info: None,
})
.collect(),
}),
dp_rank: 0,
},
}
}
/// Convert to a remove event.
pub fn to_remove_event(&self, event_id: u64) -> RouterEvent {
RouterEvent {
worker_id: self.worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: self.external_hashes.clone(),
}),
dp_rank: 0,
},
}
}
}
/// Generate sequences with shared prefix prompts.
///
/// # Arguments
/// * `num_sequences` - Number of sequences to generate
/// * `depth` - Number of blocks per sequence
/// * `num_workers` - Number of workers to distribute sequences across
/// * `prefix_ratio` - Ratio of blocks that share a prefix (0.0 to 1.0)
/// * `num_prefix_groups` - Number of distinct prefix groups
/// * `seed` - Random seed for reproducibility
/// * `use_cumulative_hash` - If true, use `from_local_hashes` for proper cumulative hashes
pub fn generate_sequences(
num_sequences: usize,
depth: usize,
num_workers: usize,
prefix_ratio: f64,
num_prefix_groups: usize,
seed: u64,
use_cumulative_hash: bool,
) -> Vec<SequenceData> {
let mut sequences = Vec::with_capacity(num_sequences);
let prefix_length = (depth as f64 * prefix_ratio).round() as usize;
let mut rng: StdRng = StdRng::seed_from_u64(seed);
for seq_id in 0..num_sequences {
let seq_id_u64 = seq_id as u64;
let worker_id = (seq_id % num_workers) as WorkerId;
// Determine prefix group for this sequence
let group_id = if num_prefix_groups > 0 && prefix_length > 0 {
Some(rng.random_range(0..num_prefix_groups) as u64)
} else {
None
};
// Build local_hashes: shared prefix (if applicable) + unique suffix
let local_hashes: Vec<LocalBlockHash> = (0..depth)
.map(|block_idx| {
let block_idx_u64 = block_idx as u64;
if let Some(gid) = group_id {
if block_idx < prefix_length {
// Shared prefix based on group_id
return LocalBlockHash(0xDEAD_BEEF_0000_0000 | (gid << 32) | block_idx_u64);
}
}
// Unique suffix (or no shared prefix)
LocalBlockHash((seq_id_u64 << 32) | block_idx_u64)
})
.collect();
if use_cumulative_hash {
sequences.push(SequenceData::from_local_hashes(worker_id, local_hashes));
} else {
let external_hashes: Vec<ExternalSequenceBlockHash> = (0..depth)
.map(|block_idx| {
let block_idx_u64 = block_idx as u64;
if let Some(gid) = group_id {
if block_idx < prefix_length {
return ExternalSequenceBlockHash(
0xDEAD_BEEF_0000_0000 | (gid << 32) | block_idx_u64,
);
}
}
ExternalSequenceBlockHash((seq_id_u64 << 32) | block_idx_u64)
})
.collect();
sequences.push(SequenceData {
worker_id,
local_hashes,
external_hashes,
});
}
}
sequences
}
/// Compute median of durations.
pub fn median(durations: &[Duration]) -> Duration {
if durations.is_empty() {
return Duration::ZERO;
}
let mut sorted = durations.to_vec();
sorted.sort();
sorted[sorted.len() / 2]
}
......@@ -31,6 +31,9 @@
//!
//! This module provides a scalable and efficient way to manage and retrieve data blocks for LLM inference, leveraging a global KV cache to optimize performance.
#[cfg(feature = "bench")]
use std::time::Instant;
use async_trait::async_trait;
#[cfg(feature = "metrics")]
pub use dynamo_runtime::protocols::maybe_error::MaybeError;
......@@ -335,6 +338,25 @@ pub struct MatchRequest {
early_exit: bool,
/// A channel sender to send the `OverlapScores` response.
resp: oneshot::Sender<OverlapScores>,
/// Timestamp when the request was created (for queue wait time measurement)
#[cfg(feature = "bench")]
created_at: Instant,
}
impl MatchRequest {
fn new(
sequence: Vec<LocalBlockHash>,
early_exit: bool,
resp: oneshot::Sender<OverlapScores>,
) -> Self {
Self {
sequence,
early_exit,
resp,
#[cfg(feature = "bench")]
created_at: Instant::now(),
}
}
}
/// A request to dump the tree as events
......@@ -551,10 +573,16 @@ impl KvIndexer {
Some(event) = event_rx.recv() => {
let event_type = KvIndexerMetrics::get_event_type(&event.event.data);
let event_id = event.event.event_id;
let worker_id = event.worker_id;
// Only clone if we need the event for prune_manager afterward
let event_for_prune = prune_manager.is_some().then(|| event.clone());
let result = trie.apply_event(event);
let result_is_ok = result.is_ok();
let tree_size = trie.current_size();
tracing::trace!(
"Applied KV event to global radix tree: event_type={event_type}, event_id={event_id}, worker_id={worker_id}, success={result_is_ok}, global_radix_tree_size={tree_size}"
);
metrics.increment_event_applied(event_type, result);
// Track blocks in PruneManager if TTL is enabled and event was stored successfully
......@@ -643,7 +671,24 @@ impl KvIndexer {
}
Some(req) = match_rx.recv() => {
#[cfg(feature = "bench")]
let queue_wait = req.created_at.elapsed();
#[cfg(feature = "bench")]
let seq_len = req.sequence.len();
#[cfg(feature = "bench")]
let process_start = Instant::now();
let matches = trie.find_matches(req.sequence, req.early_exit);
#[cfg(feature = "bench")]
let process_time = process_start.elapsed();
#[cfg(feature = "bench")]
tracing::info!(
seq_len,
queue_wait_us = queue_wait.as_micros() as u64,
process_us = process_time.as_micros() as u64,
"indexer: processed find_matches"
);
let _ = req.resp.send(matches);
}
......@@ -742,12 +787,11 @@ impl KvIndexerInterface for KvIndexer {
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<OverlapScores, KvRouterError> {
#[cfg(feature = "bench")]
let start = Instant::now();
let seq_len = sequence.len();
let (resp_tx, resp_rx) = oneshot::channel();
let req = MatchRequest {
sequence,
early_exit: false,
resp: resp_tx,
};
let req = MatchRequest::new(sequence, false, resp_tx);
if let Err(e) = self.match_tx.send(req).await {
tracing::error!(
......@@ -757,9 +801,23 @@ impl KvIndexerInterface for KvIndexer {
return Err(KvRouterError::IndexerOffline);
}
resp_rx
let result = resp_rx
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)
.map_err(|_| KvRouterError::IndexerDroppedRequest);
#[cfg(feature = "bench")]
{
let elapsed = start.elapsed();
tracing::info!(
seq_len,
elapsed_us = elapsed.as_micros() as u64,
"find_matches completed"
);
}
#[cfg(not(feature = "bench"))]
let _ = seq_len;
result
}
async fn find_matches_for_request(
......@@ -1131,6 +1189,24 @@ pub struct ShardedMatchRequest {
sequence: Vec<LocalBlockHash>,
early_exit: bool,
resp: mpsc::Sender<OverlapScores>,
#[cfg(feature = "bench")]
created_at: Instant,
}
impl ShardedMatchRequest {
fn new(
sequence: Vec<LocalBlockHash>,
early_exit: bool,
resp: mpsc::Sender<OverlapScores>,
) -> Self {
Self {
sequence,
early_exit,
resp,
#[cfg(feature = "bench")]
created_at: Instant::now(),
}
}
}
/// A sharded KV Indexer that partitions the RadixTree across multiple independent shards.
......@@ -1374,7 +1450,24 @@ impl KvIndexerSharded {
}
Ok(req) = shard_broadcast_rx.recv() => {
#[cfg(feature = "bench")]
let queue_wait = req.created_at.elapsed();
#[cfg(feature = "bench")]
let seq_len = req.sequence.len();
#[cfg(feature = "bench")]
let process_start = Instant::now();
let matches = trie.find_matches(req.sequence, req.early_exit);
#[cfg(feature = "bench")]
let process_time = process_start.elapsed();
#[cfg(feature = "bench")]
tracing::info!(
seq_len,
queue_wait_us = queue_wait.as_micros() as u64,
process_us = process_time.as_micros() as u64,
"sharded indexer: processed find_matches"
);
if let Err(e) = req.resp.send(matches).await {
tracing::trace!("Failed to send match response: {:?}", e);
}
......@@ -1442,14 +1535,18 @@ impl KvIndexerInterface for KvIndexerSharded {
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<OverlapScores, KvRouterError> {
#[cfg(feature = "bench")]
let start = Instant::now();
#[cfg(feature = "bench")]
let seq_len = sequence.len();
#[cfg(feature = "bench")]
let num_shards = self.event_tx.len();
'match_loop: loop {
let (match_tx, mut match_rx) = mpsc::channel(self.event_tx.len());
let sharded_req = ShardedMatchRequest::new(sequence.clone(), false, match_tx);
self.request_broadcast_tx
.send(ShardedMatchRequest {
sequence: sequence.clone(),
early_exit: false,
resp: match_tx,
})
.send(sharded_req)
.map_err(|_| KvRouterError::IndexerOffline)?;
let mut scores = OverlapScores::new();
......@@ -1482,6 +1579,17 @@ impl KvIndexerInterface for KvIndexerSharded {
}
}
}
#[cfg(feature = "bench")]
{
let elapsed = start.elapsed();
tracing::info!(
seq_len,
num_shards,
elapsed_us = elapsed.as_micros() as u64,
"find_matches (sharded) completed"
);
}
return Ok(scores);
}
}
......
......@@ -7,6 +7,8 @@
//! efficient KV cache lookup and routing in distributed LLM inference systems.
pub mod approx;
#[cfg(feature = "bench")]
pub mod bench_utils;
pub mod flat_hashmap;
pub mod indexer;
pub mod protocols;
......
......@@ -27,7 +27,8 @@ cuda = ["dep:cudarc"]
integration = ["dynamo-runtime/integration"]
media-nixl = ["dep:nixl-sys", "dep:flate2"]
media-ffmpeg = ["dep:video-rs", "dep:ffmpeg-next", "dep:memfile", "media-nixl"]
kv-router-stress = ["dep:clap", "dep:indicatif"]
bench = ["dynamo-kv-router/bench"]
kv-router-stress = ["dep:clap", "dep:indicatif", "bench"]
[[bench]]
name = "tokenizer"
......@@ -38,6 +39,11 @@ name = "transfer_context_v2"
harness = false
required-features = ["block-manager", "testing-cuda"]
[[bench]]
name = "kv_router_bench"
harness = false
required-features = ["kv-router-stress"]
[dependencies]
# repo
dynamo-runtime = { workspace = true }
......
This diff is collapsed.
......@@ -4,6 +4,8 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
#[cfg(feature = "bench")]
use std::time::Instant;
use anyhow::Result;
use derive_builder::Builder;
......@@ -504,6 +506,9 @@ impl KvRouter {
update_states: bool,
lora_name: Option<String>,
) -> anyhow::Result<(WorkerWithDpRank, u32)> {
#[cfg(feature = "bench")]
let start = Instant::now();
// Validate that context_id is provided when update_states is true
if update_states && context_id.is_none() {
panic!("context_id must be provided if update_states is true");
......@@ -512,7 +517,11 @@ impl KvRouter {
let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
#[cfg(feature = "bench")]
let hash_elapsed = start.elapsed();
let overlap_scores = self.indexer.find_matches(block_hashes).await?;
#[cfg(feature = "bench")]
let find_matches_elapsed = start.elapsed();
// Compute seq_hashes only if scheduler needs it for active blocks tracking
let maybe_seq_hashes = self
......@@ -532,6 +541,19 @@ impl KvRouter {
)
.await?;
#[cfg(feature = "bench")]
{
let total_elapsed = start.elapsed();
tracing::info!(
isl_tokens,
hash_us = hash_elapsed.as_micros() as u64,
find_matches_us = (find_matches_elapsed - hash_elapsed).as_micros() as u64,
schedule_us = (total_elapsed - find_matches_elapsed).as_micros() as u64,
total_us = total_elapsed.as_micros() as u64,
"find_best_match completed"
);
}
// Note: Routing decision recording (for approximate mode) is now handled
// by KvPushRouter::generate after select_worker returns.
......
......@@ -12,6 +12,8 @@ use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
#[cfg(feature = "bench")]
use std::time::Instant;
use super::KV_HIT_RATE_SUBJECT;
use super::KvRouterConfig;
......@@ -288,6 +290,9 @@ impl KvScheduler {
update_states: bool,
lora_name: Option<String>,
) -> Result<WorkerWithDpRank, KvSchedulerError> {
#[cfg(feature = "bench")]
let start = Instant::now();
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest {
maybe_request_id,
......@@ -306,10 +311,24 @@ impl KvScheduler {
.send(request)
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?;
#[cfg(feature = "bench")]
let send_elapsed = start.elapsed();
let response = resp_rx
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?;
#[cfg(feature = "bench")]
let total_elapsed = start.elapsed();
#[cfg(feature = "bench")]
tracing::info!(
isl_tokens,
send_us = send_elapsed.as_micros() as u64,
total_us = total_elapsed.as_micros() as u64,
"scheduler.schedule completed"
);
Ok(response.best_worker)
}
......
......@@ -1175,6 +1175,11 @@ impl ActiveSequencesMultiWorker {
HashMap<WorkerWithDpRank, usize>,
HashMap<WorkerWithDpRank, usize>,
) {
#[cfg(feature = "bench")]
let start = Instant::now();
#[cfg(feature = "bench")]
let num_workers = self.senders.len();
let mut potential_blocks = HashMap::new();
let mut potential_tokens = HashMap::new();
let token_sequence_shared = token_sequence.map(Arc::new);
......@@ -1206,6 +1211,9 @@ impl ActiveSequencesMultiWorker {
}
}
#[cfg(feature = "bench")]
let send_elapsed = start.elapsed();
// Collect results from all workers
for (worker, receiver) in receivers {
match tokio::time::timeout(tokio::time::Duration::from_secs(1), receiver).await {
......@@ -1222,6 +1230,17 @@ impl ActiveSequencesMultiWorker {
}
}
#[cfg(feature = "bench")]
{
let total_elapsed = start.elapsed();
tracing::info!(
num_workers,
send_us = send_elapsed.as_micros() as u64,
total_us = total_elapsed.as_micros() as u64,
"potential_blocks_and_tokens completed"
);
}
(potential_blocks, potential_tokens)
}
......
......@@ -28,7 +28,7 @@ use crate::preprocessor::media::MediaDecoder;
pub mod deepseek_v32;
mod template;
pub use template::ContextMixins;
pub use template::{ChatTemplate, ContextMixins};
#[derive(Debug)]
pub enum TokenInput {
......@@ -95,6 +95,7 @@ pub trait OAIPromptFormatter: Send + Sync + 'static {
fn render(&self, req: &dyn OAIChatLikeRequest) -> Result<String>;
}
#[derive(Clone)]
pub enum PromptFormatter {
OAI(Arc<dyn OAIPromptFormatter>),
}
......
......@@ -14,7 +14,8 @@ mod oai;
mod tokcfg;
use super::{OAIChatLikeRequest, OAIPromptFormatter, PromptFormatter};
use tokcfg::{ChatTemplate, ChatTemplateValue};
pub use tokcfg::ChatTemplate;
use tokcfg::ChatTemplateValue;
impl PromptFormatter {
pub fn from_mdc(mdc: &ModelDeploymentCard) -> Result<PromptFormatter> {
......
......@@ -28,7 +28,7 @@ pub use etcd::EtcdStore;
mod file;
pub use file::FileStore;
const WATCH_SEND_TIMEOUT: Duration = Duration::from_millis(100);
const WATCH_SEND_TIMEOUT: Duration = Duration::from_millis(1000);
/// String we use as the Key in a key-value storage operation. Simple String wrapper
/// that can encode / decode a string.
......@@ -324,7 +324,7 @@ impl Manager {
tokio::sync::mpsc::Receiver<WatchEvent>,
) {
let bucket_name = bucket_name.to_string();
let (tx, rx) = tokio::sync::mpsc::channel(128);
let (tx, rx) = tokio::sync::mpsc::channel(1024);
let watch_task = tokio::spawn(async move {
// Start listening for changes but don't poll this yet
let bucket = self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment