Unverified Commit ef2583a9 authored by Janelle Cai's avatar Janelle Cai Committed by GitHub
Browse files

feat: add prefix groups to rust mooncake bench (#6308)


Signed-off-by: default avatarJanelle Cai <jcai18@mit.edu>
Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarYan Ru Pei <yanrpei@gmail.com>
parent 0a266653
...@@ -108,7 +108,6 @@ def validate_args(args): ...@@ -108,7 +108,6 @@ def validate_args(args):
and args.cuda_version in valid_inputs[args.framework]["cuda_version"] and args.cuda_version in valid_inputs[args.framework]["cuda_version"]
): ):
return return
else:
raise ValueError( raise ValueError(
f"Invalid input combination: [framework={args.framework},target={args.target},cuda_version={args.cuda_version}]" f"Invalid input combination: [framework={args.framework},target={args.target},cuda_version={args.cuda_version}]"
) )
...@@ -116,7 +115,6 @@ def validate_args(args): ...@@ -116,7 +115,6 @@ def validate_args(args):
raise ValueError( raise ValueError(
f"Invalid input combination: [framework={args.framework},target={args.target},cuda_version={args.cuda_version}]" f"Invalid input combination: [framework={args.framework},target={args.target},cuda_version={args.cuda_version}]"
) )
return
def render(args, context, script_dir): def render(args, context, script_dir):
......
...@@ -6,8 +6,7 @@ use dynamo_kv_router::LocalBlockHash; ...@@ -6,8 +6,7 @@ use dynamo_kv_router::LocalBlockHash;
use dynamo_kv_router::indexer::{ use dynamo_kv_router::indexer::{
KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvIndexerSharded, KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvIndexerSharded,
}; };
use dynamo_kv_router::protocols::RouterEvent; use dynamo_kv_router::protocols::{RouterEvent, XXH3_SEED};
use dynamo_kv_router::protocols::XXH3_SEED;
use dynamo_kv_router::{ConcurrentRadixTree, PositionalIndexer, ThreadPoolIndexer}; use dynamo_kv_router::{ConcurrentRadixTree, PositionalIndexer, ThreadPoolIndexer};
use dynamo_tokens::compute_hash_v2; use dynamo_tokens::compute_hash_v2;
use rand::prelude::*; use rand::prelude::*;
...@@ -98,7 +97,12 @@ impl IndexerArgs { ...@@ -98,7 +97,12 @@ impl IndexerArgs {
struct Args { struct Args {
/// Path to a JSONL mooncake trace file. Each line is a JSON object with /// Path to a JSONL mooncake trace file. Each line is a JSON object with
/// fields: uuid, timestamp, hash_ids, output_length. /// fields: uuid, timestamp, hash_ids, output_length.
mooncake_trace_path: String, /// Required unless --test is passed.
mooncake_trace_path: Option<String>,
/// Run built-in self-tests instead of the benchmark.
#[clap(long)]
test: bool,
/// Number of GPU blocks available in the mock engine's KV cache. /// Number of GPU blocks available in the mock engine's KV cache.
/// Smaller values force more evictions and produce more remove events. /// Smaller values force more evictions and produce more remove events.
...@@ -131,6 +135,19 @@ struct Args { ...@@ -131,6 +135,19 @@ struct Args {
#[clap(short = 'd', long, default_value = "1")] #[clap(short = 'd', long, default_value = "1")]
inference_worker_duplication_factor: usize, inference_worker_duplication_factor: usize,
/// Factor by which to stretch each request's hash sequence length.
/// Each original hash block becomes `factor` consecutive blocks.
/// Applied before event generation and before trace duplication.
#[clap(long, default_value = "1")]
trace_length_factor: usize,
/// How many times to duplicate the raw trace data with offset hash_ids
/// before event generation. Each copy is a structurally identical prefix
/// tree with disjoint hash values, increasing the number of unique
/// prefix groups and workers.
#[clap(long, default_value = "1")]
trace_duplication_factor: usize,
/// RNG seed for reproducible worker-to-trace assignment. /// RNG seed for reproducible worker-to-trace assignment.
#[clap(long, default_value = "42")] #[clap(long, default_value = "42")]
seed: u64, seed: u64,
...@@ -207,30 +224,51 @@ struct WorkerTrace { ...@@ -207,30 +224,51 @@ struct WorkerTrace {
timestamp_us: u64, timestamp_us: u64,
} }
/// Load the mooncake trace from disk and randomly partition requests across /// Load the mooncake trace from disk into a flat list of requests.
/// `num_unique_inference_workers` worker buckets using the configured seed. fn load_mooncake_trace(path: &str) -> anyhow::Result<Vec<MooncakeRequest>> {
fn process_mooncake_trace(args: &Args) -> anyhow::Result<Vec<Vec<MooncakeRequest>>> { let file = File::open(path)?;
let mut traces: Vec<Vec<MooncakeRequest>> = Vec::new();
for _ in 0..args.num_unique_inference_workers {
traces.push(Vec::new());
}
let mut rng = StdRng::seed_from_u64(args.seed);
let file = File::open(&args.mooncake_trace_path)?;
let reader = BufReader::new(file); let reader = BufReader::new(file);
println!("Loading trace..."); println!("Loading trace...");
let progress = make_progress_bar(None); let progress = make_progress_bar(None);
let mut requests = Vec::new();
for line in reader.lines() { for line in reader.lines() {
let request = serde_json::from_str::<MooncakeRequest>(&line?)?; requests.push(serde_json::from_str::<MooncakeRequest>(&line?)?);
traces[rng.random_range(0..args.num_unique_inference_workers)].push(request);
progress.inc(1); progress.inc(1);
} }
Ok(traces) Ok(requests)
}
/// Load, transform, and partition the mooncake trace into per-worker request lists.
fn process_mooncake_trace(args: &Args) -> anyhow::Result<Vec<Vec<MooncakeRequest>>> {
let path = args
.mooncake_trace_path
.as_deref()
.ok_or_else(|| anyhow::anyhow!("mooncake_trace_path is required for benchmarking"))?;
let requests = load_mooncake_trace(path)?;
let requests = expand_trace_lengths(requests, args.trace_length_factor);
let requests = duplicate_traces(requests, args.trace_duplication_factor);
Ok(partition_trace(
requests,
args.num_unique_inference_workers,
args.seed,
))
}
/// Randomly partition a flat request list across `num_workers` worker buckets.
fn partition_trace(
requests: Vec<MooncakeRequest>,
num_workers: usize,
seed: u64,
) -> Vec<Vec<MooncakeRequest>> {
let mut rng = StdRng::seed_from_u64(seed);
let mut traces: Vec<Vec<MooncakeRequest>> = (0..num_workers).map(|_| Vec::new()).collect();
for request in requests {
traces[rng.random_range(0..num_workers)].push(request);
}
traces
} }
/// Linearly rescale all timestamps in a worker's trace so the total span equals /// Linearly rescale all timestamps in a worker's trace so the total span equals
...@@ -246,6 +284,70 @@ fn scale_mooncake_trace(trace: &Vec<MooncakeRequest>, duration: u64) -> Vec<Moon ...@@ -246,6 +284,70 @@ fn scale_mooncake_trace(trace: &Vec<MooncakeRequest>, duration: u64) -> Vec<Moon
.collect::<Vec<MooncakeRequest>>() .collect::<Vec<MooncakeRequest>>()
} }
/// Stretch each request's hash sequence by the given factor, simulating longer
/// prefix chains with the same tree structure.
///
/// Each hash `h` becomes `factor` consecutive hashes:
/// `h * factor`, `h * factor + 1`, ..., `h * factor + (factor - 1)`.
/// Two sequences that shared a k-block prefix now share a k*factor-block prefix.
fn expand_trace_lengths(requests: Vec<MooncakeRequest>, factor: usize) -> Vec<MooncakeRequest> {
if factor <= 1 {
return requests;
}
println!("Expanding trace lengths by {}x", factor);
requests
.into_iter()
.map(|mut request| {
request.hash_ids = request
.hash_ids
.iter()
.flat_map(|&h| {
let base = h * factor as u64;
(0..factor as u64).map(move |offset| base + offset)
})
.collect();
request
})
.collect()
}
/// Duplicate all worker traces with offset hash_ids, creating `factor`
/// structurally identical copies of the prefix tree with disjoint hash spaces.
///
/// Copy `d` (1-indexed) offsets every hash_id by `(max_hash_id + 1) * d`.
/// The original traces (copy 0) are kept as-is.
fn duplicate_traces(requests: Vec<MooncakeRequest>, factor: usize) -> Vec<MooncakeRequest> {
if factor <= 1 {
return requests;
}
let max_hash_id = requests
.iter()
.flat_map(|r| r.hash_ids.iter().copied())
.max()
.unwrap_or(0);
let offset_base = max_hash_id + 1;
println!(
"Duplicating traces: {}x (hash offset base: {})",
factor, offset_base
);
let mut out = Vec::with_capacity(requests.len() * factor);
for r in &requests {
for d in 0..factor {
let offset = offset_base * d as u64;
out.push(MooncakeRequest {
hash_ids: r.hash_ids.iter().map(|&h| h + offset).collect(),
..r.clone()
});
}
}
out
}
/// Expand a request's block-level hash_ids into per-token IDs by repeating each /// Expand a request's block-level hash_ids into per-token IDs by repeating each
/// hash_id `block_size` times. /// hash_id `block_size` times.
fn tokens_from_request(request: &MooncakeRequest, block_size: u32) -> Vec<u32> { fn tokens_from_request(request: &MooncakeRequest, block_size: u32) -> Vec<u32> {
...@@ -654,10 +756,90 @@ async fn run_benchmark( ...@@ -654,10 +756,90 @@ async fn run_benchmark(
Ok(()) Ok(())
} }
fn run_tests() -> anyhow::Result<()> {
use std::collections::HashSet;
use std::io::Write;
let path =
std::env::temp_dir().join(format!("mooncake_bench_test_{}.jsonl", std::process::id()));
{
let mut f = File::create(&path)?;
for (i, (hash_ids, output_length)) in
[(&[0u64, 1, 2] as &[u64], 10u64), (&[0, 1, 3, 4], 10)]
.iter()
.enumerate()
{
writeln!(
f,
"{}",
serde_json::json!({
"timestamp": i as u64,
"hash_ids": hash_ids,
"output_length": output_length,
})
)?;
}
}
let args = Args::parse_from([
"test",
"--test",
path.to_str().unwrap(),
"--num-unique-inference-workers",
"2",
"--trace-length-factor",
"2",
"--trace-duplication-factor",
"2",
"--seed",
"42",
]);
let traces = process_mooncake_trace(&args)?;
std::fs::remove_file(&path).ok();
let mut all_hashes: Vec<Vec<u64>> = traces
.into_iter()
.flat_map(|w| w.into_iter().map(|r| r.hash_ids))
.collect();
all_hashes.sort();
// expand(2): [0,1,2] → [0,1,2,3,4,5], [0,1,3,4] → [0,1,2,3,6,7,8,9]
// duplicate(2): max=9, offset=10
let mut expected = vec![
vec![0, 1, 2, 3, 4, 5],
vec![10, 11, 12, 13, 14, 15],
vec![0, 1, 2, 3, 6, 7, 8, 9],
vec![10, 11, 12, 13, 16, 17, 18, 19],
];
expected.sort();
assert_eq!(all_hashes, expected, "hash_ids mismatch");
// Verify prefix structure within each copy.
let copy0: Vec<&Vec<u64>> = all_hashes.iter().filter(|h| h[0] == 0).collect();
let copy1: Vec<&Vec<u64>> = all_hashes.iter().filter(|h| h[0] == 10).collect();
assert_eq!(copy0.len(), 2);
assert_eq!(copy1.len(), 2);
assert_eq!(copy0[0][..4], copy0[1][..4], "copy 0 shared prefix broken");
assert_eq!(copy1[0][..4], copy1[1][..4], "copy 1 shared prefix broken");
// Verify disjointness between copies.
let set0: HashSet<u64> = copy0.iter().flat_map(|h| h.iter().copied()).collect();
let set1: HashSet<u64> = copy1.iter().flat_map(|h| h.iter().copied()).collect();
assert!(set0.is_disjoint(&set1), "copies are not hash-disjoint");
println!("All tests passed.");
Ok(())
}
#[tokio::main] #[tokio::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
let args = Args::parse(); let args = Args::parse();
if args.test {
return run_tests();
}
let traces = process_mooncake_trace(&args)?; let traces = process_mooncake_trace(&args)?;
let events = generate_events(&traces, &args).await?; let events = generate_events(&traces, &args).await?;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment