"vscode:/vscode.git/clone" did not exist on "d4cb783c10ffc091af7f09a3b052dceadc06d075"
Unverified Commit 73a9a53f authored by Janelle Cai's avatar Janelle Cai Committed by GitHub
Browse files

feat(router): branch sharded kv indexer (#7859)


Signed-off-by: default avatarHannah Zhang <hannahz@nvidia.com>
Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarHannah Zhang <hannahz@nvidia.com>
parent af32579e
...@@ -50,7 +50,7 @@ dynamo-mocker = { workspace = true } ...@@ -50,7 +50,7 @@ dynamo-mocker = { workspace = true }
[dev-dependencies] [dev-dependencies]
async-trait = { workspace = true } async-trait = { workspace = true }
dynamo-kv-router = { workspace = true, features = ["bench"] } dynamo-kv-router = { workspace = true, features = ["bench", "shard-metrics"] }
dynamo-tokens = { workspace = true } dynamo-tokens = { workspace = true }
minstant = "0.1.7" minstant = "0.1.7"
plotters = { version = "0.3", default-features = false, features = ["svg_backend", "line_series", "point_series", "full_palette"] } plotters = { version = "0.3", default-features = false, features = ["svg_backend", "line_series", "point_series", "full_palette"] }
......
...@@ -42,11 +42,11 @@ pub struct CommonArgs { ...@@ -42,11 +42,11 @@ pub struct CommonArgs {
pub test: bool, pub test: bool,
/// Number of GPU blocks available in the mock engine's KV cache. /// Number of GPU blocks available in the mock engine's KV cache.
#[clap(long, default_value = "1048576")] #[clap(long, default_value = "16384")]
pub num_gpu_blocks: usize, pub num_gpu_blocks: usize,
/// Number of tokens per KV cache block. /// Number of tokens per KV cache block.
#[clap(long, default_value = "512")] #[clap(long, default_value = "128")]
pub block_size: u32, pub block_size: u32,
/// Wall-clock duration (ms) over which the trace is replayed during event generation. /// Wall-clock duration (ms) over which the trace is replayed during event generation.
...@@ -58,7 +58,7 @@ pub struct CommonArgs { ...@@ -58,7 +58,7 @@ pub struct CommonArgs {
pub benchmark_duration_ms: u64, pub benchmark_duration_ms: u64,
/// Number of unique simulated inference workers. /// Number of unique simulated inference workers.
#[clap(short, long, default_value = "256")] #[clap(short, long, default_value = "1000")]
pub num_unique_inference_workers: usize, pub num_unique_inference_workers: usize,
/// How many times to duplicate unique workers during the benchmark phase. /// How many times to duplicate unique workers during the benchmark phase.
...@@ -124,10 +124,28 @@ pub struct MooncakeRequest { ...@@ -124,10 +124,28 @@ pub struct MooncakeRequest {
#[serde(default)] #[serde(default)]
pub input_length: usize, pub input_length: usize,
pub hash_ids: Vec<u64>, pub hash_ids: Vec<u64>,
#[serde(alias = "output_length", alias = "osl")]
pub output_length: u64, pub output_length: u64,
} }
#[derive(Deserialize)]
struct RawMooncakeRecord {
#[serde(default)]
timestamp: Option<f64>,
#[serde(default)]
delay: Option<f64>,
hash_ids: Vec<u64>,
#[serde(alias = "output_length", alias = "osl")]
output_length: u64,
}
/// Load the mooncake trace from disk into a flat list of requests. /// Load the mooncake trace from disk into a flat list of requests.
///
/// Supports two JSONL formats:
/// - Legacy: every record has an integer `timestamp` field (absolute ms).
/// - aiperf: first record has `timestamp` (float), subsequent records have
/// `delay` (float ms since previous). Absolute timestamps are reconstructed
/// by accumulating delays.
pub fn load_mooncake_trace(path: &str) -> anyhow::Result<Vec<MooncakeRequest>> { pub fn load_mooncake_trace(path: &str) -> anyhow::Result<Vec<MooncakeRequest>> {
let file = File::open(path)?; let file = File::open(path)?;
let reader = BufReader::new(file); let reader = BufReader::new(file);
...@@ -136,8 +154,24 @@ pub fn load_mooncake_trace(path: &str) -> anyhow::Result<Vec<MooncakeRequest>> { ...@@ -136,8 +154,24 @@ pub fn load_mooncake_trace(path: &str) -> anyhow::Result<Vec<MooncakeRequest>> {
let progress = make_progress_bar(None); let progress = make_progress_bar(None);
let mut requests = Vec::new(); let mut requests = Vec::new();
let mut cursor_ms: f64 = 0.0;
for line in reader.lines() { for line in reader.lines() {
requests.push(serde_json::from_str::<MooncakeRequest>(&line?)?); let raw: RawMooncakeRecord = serde_json::from_str(&line?)?;
if let Some(ts) = raw.timestamp {
cursor_ms = ts;
} else if let Some(d) = raw.delay {
cursor_ms += d;
}
requests.push(MooncakeRequest {
uuid: Uuid::new_v4(),
timestamp: cursor_ms as u64,
input_length: 0,
hash_ids: raw.hash_ids,
output_length: raw.output_length,
});
progress.inc(1); progress.inc(1);
} }
...@@ -155,6 +189,14 @@ pub fn partition_trace( ...@@ -155,6 +189,14 @@ pub fn partition_trace(
for request in requests { for request in requests {
traces[rng.random_range(0..num_workers)].push(request); traces[rng.random_range(0..num_workers)].push(request);
} }
// Sort each worker's trace by timestamp so that scale_mooncake_trace and
// generate_kv_events see monotonically increasing timestamps. Without this,
// mixing requests from multiple sessions (each starting at timestamp=0) into
// one worker produces non-monotonic sequences; u64 underflow in the delta
// computation then creates sleep durations measured in centuries.
for trace in &mut traces {
trace.sort_by_key(|r| r.timestamp);
}
traces traces
} }
......
This diff is collapsed.
...@@ -17,7 +17,9 @@ default = [] ...@@ -17,7 +17,9 @@ default = []
metrics = ["dep:prometheus"] metrics = ["dep:prometheus"]
runtime-protocols = ["dep:dynamo-runtime"] runtime-protocols = ["dep:dynamo-runtime"]
bench = [] bench = []
shard-metrics = []
standalone-indexer = ["dep:axum", "dep:serde_json", "dep:reqwest", "dep:zmq"] standalone-indexer = ["dep:axum", "dep:serde_json", "dep:reqwest", "dep:zmq"]
indexer-runtime = ["metrics", "runtime-protocols", "standalone-indexer"]
[dependencies] [dependencies]
# repo # repo
......
This diff is collapsed.
...@@ -31,6 +31,8 @@ ...@@ -31,6 +31,8 @@
//! //!
//! This module provides a scalable and efficient way to manage and retrieve data blocks for LLM inference, leveraging a global KV cache to optimize performance. //! This module provides a scalable and efficient way to manage and retrieve data blocks for LLM inference, leveraging a global KV cache to optimize performance.
mod branch_sharded;
fn warn_on_unit_block_size(indexer_type: &'static str, kv_block_size: u32) { fn warn_on_unit_block_size(indexer_type: &'static str, kv_block_size: u32) {
if kv_block_size == 1 { if kv_block_size == 1 {
tracing::warn!( tracing::warn!(
...@@ -40,7 +42,6 @@ fn warn_on_unit_block_size(indexer_type: &'static str, kv_block_size: u32) { ...@@ -40,7 +42,6 @@ fn warn_on_unit_block_size(indexer_type: &'static str, kv_block_size: u32) {
); );
} }
} }
mod kv_indexer; mod kv_indexer;
mod local; mod local;
mod metrics; mod metrics;
...@@ -58,6 +59,7 @@ pub mod radix_tree; ...@@ -58,6 +59,7 @@ pub mod radix_tree;
mod tests; mod tests;
// Re-export everything that was public in the old single-file module. // Re-export everything that was public in the old single-file module.
pub use branch_sharded::*;
pub use kv_indexer::*; pub use kv_indexer::*;
pub use local::*; pub use local::*;
pub use metrics::*; pub use metrics::*;
......
...@@ -12,7 +12,9 @@ use dashmap::DashMap; ...@@ -12,7 +12,9 @@ use dashmap::DashMap;
use rustc_hash::FxBuildHasher; use rustc_hash::FxBuildHasher;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use super::{KvIndexerInterface, KvIndexerMetrics, KvRouterError, SyncIndexer, WorkerTask}; use super::{
KvIndexerInterface, KvIndexerMetrics, KvRouterError, ShardSizeSnapshot, SyncIndexer, WorkerTask,
};
use crate::protocols::*; use crate::protocols::*;
/// Generic wrapper that provides [`KvIndexerInterface`] for any [`SyncIndexer`] backend. /// Generic wrapper that provides [`KvIndexerInterface`] for any [`SyncIndexer`] backend.
...@@ -133,6 +135,15 @@ impl<T: SyncIndexer> ThreadPoolIndexer<T> { ...@@ -133,6 +135,15 @@ impl<T: SyncIndexer> ThreadPoolIndexer<T> {
&self.backend &self.backend
} }
/// Get a cloned `Arc` to the underlying backend.
///
/// Useful when a caller needs to hand off an owned `Arc<T>` to a blocking
/// task (e.g. `tokio::task::spawn_blocking`) without cloning the backend
/// itself.
pub fn backend_arc(&self) -> Arc<T> {
Arc::clone(&self.backend)
}
/// Wait for all worker channels to drain. /// Wait for all worker channels to drain.
/// ///
/// Used primarily for testing and benchmarking to ensure all queued events /// Used primarily for testing and benchmarking to ensure all queued events
...@@ -365,4 +376,17 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> { ...@@ -365,4 +376,17 @@ impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
} }
curr_size curr_size
} }
fn shard_sizes(&self) -> Vec<ShardSizeSnapshot> {
vec![ShardSizeSnapshot {
shard_idx: 0,
worker_count: self.backend.worker_count(),
block_count: self.backend.block_count(),
node_count: self.backend.node_count(),
}]
}
fn node_edge_lengths(&self) -> Vec<usize> {
self.backend.node_edge_lengths()
}
} }
...@@ -8,6 +8,23 @@ use std::sync::Arc; ...@@ -8,6 +8,23 @@ use std::sync::Arc;
use super::{KvIndexerMetrics, KvRouterError, WorkerTask}; use super::{KvIndexerMetrics, KvRouterError, WorkerTask};
use crate::protocols::*; use crate::protocols::*;
/// Per-shard size snapshot returned by [`KvIndexerInterface::shard_sizes`].
///
/// `worker_count` and `block_count` are always populated.
/// `node_count` is populated only when the `shard-metrics` feature is enabled
/// on the `dynamo-kv-router` crate; otherwise it is `0`.
#[derive(Debug, Clone)]
pub struct ShardSizeSnapshot {
/// Zero-based shard index.
pub shard_idx: usize,
/// Distinct `(worker_id, dp_rank)` pairs stored in this shard.
pub worker_count: usize,
/// Total cached blocks across all workers in this shard.
pub block_count: usize,
/// Radix-tree node count (only non-zero with `shard-metrics` feature).
pub node_count: usize,
}
#[async_trait] #[async_trait]
pub trait KvIndexerInterface { pub trait KvIndexerInterface {
/// Find matches for a given sequence of `LocalBlockHash`es. /// Find matches for a given sequence of `LocalBlockHash`es.
...@@ -93,6 +110,32 @@ pub trait KvIndexerInterface { ...@@ -93,6 +110,32 @@ pub trait KvIndexerInterface {
/// Returns the amount of events still in the queue at the time of the flush. /// Returns the amount of events still in the queue at the time of the flush.
/// Used primarily for debugging. /// Used primarily for debugging.
async fn flush(&self) -> usize; async fn flush(&self) -> usize;
/// Return a human-readable timing breakdown of `find_matches` overhead.
///
/// Implementations that track per-phase timing (e.g. scatter/gather overhead
/// vs. actual shard work) override this to return a multi-line report string.
/// The default returns an empty string so callers can skip printing it.
fn timing_report(&self) -> String {
String::new()
}
/// Return a size snapshot for each shard.
///
/// Single-shard indexers return one entry (shard 0). Multi-shard indexers
/// return one entry per shard. Non-sharded indexers (and implementations
/// that don't override this) return an empty `Vec`.
///
/// See [`ShardSizeSnapshot`] for the fields exposed per shard.
fn shard_sizes(&self) -> Vec<ShardSizeSnapshot> {
vec![]
}
/// Edge lengths (hashes per node) for every non-root node.
/// Returns an empty vec for backends that don't support this.
fn node_edge_lengths(&self) -> Vec<usize> {
vec![]
}
} }
// ============================================================================ // ============================================================================
...@@ -136,4 +179,26 @@ pub trait SyncIndexer: Send + Sync + 'static { ...@@ -136,4 +179,26 @@ pub trait SyncIndexer: Send + Sync + 'static {
fn dump_events(&self) -> Option<Vec<RouterEvent>> { fn dump_events(&self) -> Option<Vec<RouterEvent>> {
None None
} }
/// Number of distinct workers registered in this backend.
fn worker_count(&self) -> usize {
0
}
/// Total cached blocks across all workers.
fn block_count(&self) -> usize {
0
}
/// Number of radix-tree nodes created since construction.
/// Only meaningful when the `shard-metrics` feature is enabled; returns 0 otherwise.
fn node_count(&self) -> usize {
0
}
/// Edge lengths (hashes per node) for every non-root node in the tree.
/// Returns an empty vec for backends that don't support this.
fn node_edge_lengths(&self) -> Vec<usize> {
vec![]
}
} }
...@@ -44,7 +44,7 @@ pub use self::sequence::{ActiveSequences, RequestId}; ...@@ -44,7 +44,7 @@ pub use self::sequence::{ActiveSequences, RequestId};
pub use concurrent_radix_tree::ConcurrentRadixTree; pub use concurrent_radix_tree::ConcurrentRadixTree;
pub use concurrent_radix_tree_compressed::ConcurrentRadixTreeCompressed; pub use concurrent_radix_tree_compressed::ConcurrentRadixTreeCompressed;
pub use config::{KvRouterConfig, RouterConfigOverride, RouterPrefillLoadModel, RouterQueuePolicy}; pub use config::{KvRouterConfig, RouterConfigOverride, RouterPrefillLoadModel, RouterQueuePolicy};
pub use indexer::{MaybeError, SyncIndexer, ThreadPoolIndexer}; pub use indexer::{BranchShardedIndexer, MaybeError, SyncIndexer, ThreadPoolIndexer};
pub use nested_map::PositionalIndexer; pub use nested_map::PositionalIndexer;
pub use protocols::{ pub use protocols::{
KvCacheEventError, LocalBlockHash, OverlapScores, RouterEvent, RouterEventSink, KvCacheEventError, LocalBlockHash, OverlapScores, RouterEvent, RouterEventSink,
......
...@@ -162,9 +162,20 @@ impl Trace { ...@@ -162,9 +162,20 @@ impl Trace {
let hash_ids = raw let hash_ids = raw
.hash_ids .hash_ids
.ok_or_else(|| anyhow!("trace line {} is missing hash_ids", line_idx + 1))?; .ok_or_else(|| anyhow!("trace line {} is missing hash_ids", line_idx + 1))?;
// Clamp input_length to the synthesizable capacity: in the mooncake
// trace format, input_length is the full prompt token count which may
// exceed hash_ids.len() * block_size (cached portion only).
let synthesizable_capacity =
hash_ids
.len()
.checked_mul(trace_block_size)
.ok_or_else(|| {
anyhow!("trace line {} synthesized capacity overflow", line_idx + 1)
})?;
let input_length = raw let input_length = raw
.input_length .input_length
.unwrap_or(hash_ids.len() * trace_block_size); .unwrap_or(synthesizable_capacity)
.min(synthesizable_capacity);
let output_length = raw let output_length = raw
.output_length .output_length
.ok_or_else(|| anyhow!("trace line {} is missing output_length", line_idx + 1))?; .ok_or_else(|| anyhow!("trace line {} is missing output_length", line_idx + 1))?;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment