mooncake_bench.rs 37.8 KB
Newer Older
Yan Ru Pei's avatar
Yan Ru Pei committed
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
6
7
#[path = "common/mod.rs"]
mod common;
use common::*;

Yan Ru Pei's avatar
Yan Ru Pei committed
8
9
use clap::{Parser, Subcommand};
use dynamo_kv_router::LocalBlockHash;
10
11
12
use dynamo_kv_router::indexer::{
    KvIndexer, KvIndexerInterface, KvIndexerMetrics, ShardSizeSnapshot,
};
13
use dynamo_kv_router::protocols::{KvCacheEvent, KvCacheEventData, RouterEvent};
14
use dynamo_kv_router::{
15
16
    BranchShardedIndexer, ConcurrentRadixTree, ConcurrentRadixTreeCompressed, PositionalIndexer,
    ThreadPoolIndexer,
17
};
18
use dynamo_mocker::loadgen::Trace;
19
use serde::Serialize;
20
21
22
23
use std::sync::{
    Arc,
    atomic::{AtomicBool, Ordering},
};
Yan Ru Pei's avatar
Yan Ru Pei committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
use tokio::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;

/// Indexer backend selection and its backend-specific parameters.
#[derive(Subcommand, Debug, Clone)]
enum IndexerArgs {
    /// Single-threaded radix tree indexer.
    RadixTree {},

    /// Position-based nested map indexer with jump search.
    NestedMap {
        /// Number of positions to skip during jump search before scanning back.
        #[clap(long, default_value = "8")]
        jump_size: usize,

        /// Number of OS threads that consume and apply KV cache events.
        #[clap(long, default_value = "16")]
        num_event_workers: usize,
    },

    /// Lock-based concurrent radix tree indexer.
    ConcurrentRadixTree {
        /// Number of OS threads that consume and apply KV cache events.
        #[clap(long, default_value = "16")]
        num_event_workers: usize,
    },
50
51
52
53
54
55
56

    /// Compressed concurrent radix tree indexer (compressed edges).
    ConcurrentRadixTreeCompressed {
        /// Number of OS threads that consume and apply KV cache events.
        #[clap(long, default_value = "16")]
        num_event_workers: usize,
    },
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

    /// Branch-sharded CRTC: N independent CRTC shards assigned via an explicit routing
    /// table keyed on the first K block hashes. New branches are assigned to the
    /// least-loaded shard. find_matches touches exactly ONE shard (no scatter-gather).
    /// Unknown branch keys return empty scores immediately without any dispatch.
    BranchShardedCrtc {
        /// Number of independent CRTC shards.
        #[clap(long, default_value = "2")]
        num_shards: usize,

        /// Number of OS event-worker threads per shard.
        #[clap(long, default_value = "4")]
        num_event_workers_per_shard: usize,

        /// Number of prefix blocks hashed to identify a branch. K=2 is the
        /// recommended default: depth=1 often produces too few distinct branch
        /// keys, while depth=2 gives a much larger set of distinguishable branches.
        #[clap(long, default_value = "2")]
        prefix_depth: usize,

        /// Number of OS threads per shard dedicated to find_matches (read isolation).
        /// 0 (default): reads run inline on the calling tokio thread.
        #[clap(long, default_value = "0")]
        num_read_threads_per_shard: usize,
    },
Yan Ru Pei's avatar
Yan Ru Pei committed
82
83
84
85
}

impl IndexerArgs {
    /// Construct the concrete indexer from the parsed CLI args.
86
    fn build(self, block_size: u32) -> Arc<dyn KvIndexerInterface + Send + Sync> {
Yan Ru Pei's avatar
Yan Ru Pei committed
87
88
89
90
        let cancel_token = CancellationToken::new();
        let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
        match self {
            IndexerArgs::RadixTree {} => {
91
                Arc::new(KvIndexer::new(cancel_token, block_size, metrics))
Yan Ru Pei's avatar
Yan Ru Pei committed
92
93
94
95
96
97
98
            }
            IndexerArgs::NestedMap {
                jump_size,
                num_event_workers,
            } => Arc::new(ThreadPoolIndexer::new(
                PositionalIndexer::new(jump_size),
                num_event_workers,
99
                block_size,
Yan Ru Pei's avatar
Yan Ru Pei committed
100
            )),
101
102
103
            IndexerArgs::ConcurrentRadixTree { num_event_workers } => Arc::new(
                ThreadPoolIndexer::new(ConcurrentRadixTree::new(), num_event_workers, block_size),
            ),
104
105
106
107
108
109
110
            IndexerArgs::ConcurrentRadixTreeCompressed { num_event_workers } => {
                Arc::new(ThreadPoolIndexer::new(
                    ConcurrentRadixTreeCompressed::new(),
                    num_event_workers,
                    block_size,
                ))
            }
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
            IndexerArgs::BranchShardedCrtc {
                num_shards,
                num_event_workers_per_shard,
                prefix_depth,
                num_read_threads_per_shard: _,
            } => {
                let shards = (0..num_shards)
                    .map(|_| {
                        ThreadPoolIndexer::new(
                            ConcurrentRadixTreeCompressed::new(),
                            num_event_workers_per_shard,
                            block_size,
                        )
                    })
                    .collect();
                Arc::new(BranchShardedIndexer::new_with_options(
                    shards,
                    prefix_depth,
                    block_size,
                ))
            }
Yan Ru Pei's avatar
Yan Ru Pei committed
132
133
        }
    }
134

135
136
    fn supports_remove(_name: &str) -> bool {
        true
137
138
139
    }

    fn is_multi_threaded(name: &str) -> bool {
140
141
        matches!(
            name,
142
143
144
145
            "nested-map"
                | "concurrent-radix-tree"
                | "concurrent-radix-tree-compressed"
                | "branch-sharded-crtc"
146
        )
147
148
    }

149
    /// Construct an indexer from a short name string.
150
151
    fn from_name(
        name: &str,
152
153
        block_size: u32,
        num_event_workers: usize,
154
    ) -> anyhow::Result<Arc<dyn KvIndexerInterface + Send + Sync>> {
155
        let nw = num_event_workers;
156
157
158
159
160
161
162
163
164
        let indexer_args = match name {
            "radix-tree" => IndexerArgs::RadixTree {},
            "nested-map" => IndexerArgs::NestedMap {
                jump_size: 8,
                num_event_workers: nw,
            },
            "concurrent-radix-tree" => IndexerArgs::ConcurrentRadixTree {
                num_event_workers: nw,
            },
165
166
167
            "concurrent-radix-tree-compressed" => IndexerArgs::ConcurrentRadixTreeCompressed {
                num_event_workers: nw,
            },
168
169
170
171
172
173
            "branch-sharded-crtc" => IndexerArgs::BranchShardedCrtc {
                num_shards: 2,
                num_event_workers_per_shard: nw,
                prefix_depth: 2,
                num_read_threads_per_shard: 0,
            },
174
            _ => anyhow::bail!(
175
176
177
                "Unknown indexer '{}'. Valid names: radix-tree, radix-tree-sharded, \
                 nested-map, concurrent-radix-tree, concurrent-radix-tree-compressed, \
                 branch-sharded-crtc",
178
179
180
                name
            ),
        };
181
        Ok(indexer_args.build(block_size))
182
    }
Yan Ru Pei's avatar
Yan Ru Pei committed
183
184
185
186
187
}

#[derive(Parser, Debug)]
#[clap(version, about, long_about = None)]
struct Args {
188
189
    #[clap(flatten)]
    common: CommonArgs,
Yan Ru Pei's avatar
Yan Ru Pei committed
190

191
    /// Output path for the sweep plot SVG.
192
    #[clap(long, default_value = "sweep_plot.svg")]
193
194
195
196
    sweep_output: String,

    /// Comma-separated list of indexer names to benchmark and compare on the
    /// same plot. Overrides the subcommand indexer when present. Valid names:
197
    /// radix-tree, nested-map, concurrent-radix-tree,
198
    /// concurrent-radix-tree-compressed.
199
200
201
202
    #[clap(long, value_delimiter = ',')]
    compare: Vec<String>,

    /// Number of OS threads for event processing in compare mode. Applies to
203
204
205
    /// indexers that use a thread pool (nested-map, concurrent-radix-tree,
    /// concurrent-radix-tree-compressed).
    /// Ignored by radix-tree.
206
207
208
    #[clap(long, default_value = "16")]
    num_event_workers: usize,

209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    /// Number of additional concurrent tokio tasks that issue find_matches in a
    /// tight loop to stress the read path.  These tasks run alongside the normal
    /// trace-replay workers.  Set to 0 (default) to disable.
    #[clap(long, default_value = "0")]
    find_matches_concurrency: usize,

    /// Output path for the shard-size CSV produced when `shard-metrics` feature
    /// is enabled.  Rows: `elapsed_ms,shard_idx,worker_count,block_count,node_count`.
    /// An SVG plot is written alongside it (<path>.svg).
    /// Omit or leave empty to disable shard-size sampling.
    #[clap(long, default_value = "")]
    shard_metrics_csv: String,

    /// How often (ms) to sample shard sizes when `--shard-metrics-csv` is set.
    #[clap(long, default_value = "200")]
    shard_metrics_interval_ms: u64,

Yan Ru Pei's avatar
Yan Ru Pei committed
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    /// Indexer backend to benchmark (defaults to radix-tree if not specified).
    #[clap(subcommand)]
    indexer: Option<IndexerArgs>,
}

impl Args {
    /// Return the indexer config, falling back to RadixTree if none was specified.
    fn get_indexer(&self) -> IndexerArgs {
        self.indexer.clone().unwrap_or(IndexerArgs::RadixTree {})
    }
}

/// A single entry in a worker's merged benchmark timeline.
#[derive(Clone)]
enum WorkerTraceEntry {
    /// A find_matches request with pre-computed block hashes.
    Request(Vec<LocalBlockHash>),
    /// A KV cache event (store/remove/clear) to apply to the indexer.
    Event(KvCacheEvent),
}

/// A timestamped entry in a worker's benchmark trace, used to replay requests
/// and events at the correct relative timing.
#[derive(Clone)]
struct WorkerTrace {
    entry: WorkerTraceEntry,
    timestamp_us: u64,
}

/// Merge each worker's request trace and event trace into a single
/// time-ordered sequence of `WorkerTrace` entries suitable for benchmark
/// replay.
///
/// Timestamps are rescaled from the original trace / simulation durations
/// into the benchmark duration (microseconds).
fn prepare_worker_traces(
262
    artifacts: Vec<WorkerReplayArtifacts>,
263
    benchmark_duration_ms: u64,
Yan Ru Pei's avatar
Yan Ru Pei committed
264
) -> Vec<Vec<WorkerTrace>> {
265
    artifacts
Yan Ru Pei's avatar
Yan Ru Pei committed
266
        .into_iter()
267
268
269
        .map(|artifact| {
            let mut merged = artifact
                .requests
Yan Ru Pei's avatar
Yan Ru Pei committed
270
271
                .into_iter()
                .map(|request| WorkerTrace {
272
273
                    timestamp_us: request.timestamp_us,
                    entry: WorkerTraceEntry::Request(request.replay_hashes.local_block_hashes),
Yan Ru Pei's avatar
Yan Ru Pei committed
274
                })
275
276
277
278
279
                .chain(artifact.kv_events.into_iter().map(|event| WorkerTrace {
                    timestamp_us: event.timestamp_us,
                    entry: WorkerTraceEntry::Event(event.event),
                }))
                .collect::<Vec<_>>();
Yan Ru Pei's avatar
Yan Ru Pei committed
280
            merged.sort_by_key(|entry| entry.timestamp_us);
281
282
283
284
285
286
287
288
            let max_timestamp_us = merged.last().map(|entry| entry.timestamp_us).unwrap_or(0);
            for entry in &mut merged {
                entry.timestamp_us = if max_timestamp_us == 0 {
                    0
                } else {
                    entry.timestamp_us * benchmark_duration_ms * 1000 / max_timestamp_us
                };
            }
Yan Ru Pei's avatar
Yan Ru Pei committed
289
290
291
292
293
            merged
        })
        .collect()
}

294
295
296
297
298
299
300
#[derive(Serialize)]
struct SweepStepResult {
    duration_ms: u64,
    #[serde(flatten)]
    results: BenchmarkResults,
}

301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
// ---------------------------------------------------------------------------
// Shard-size sampling (always compiled; only called when a CSV path is given)
// ---------------------------------------------------------------------------

/// A single row in the shard-size time-series CSV.
#[derive(Clone)]
struct ShardSampleRow {
    elapsed_ms: u64,
    snapshot: ShardSizeSnapshot,
}

/// Spawn a background tokio task that samples `indexer.shard_sizes()` every
/// `interval_ms` milliseconds until `cancel` is triggered.
///
/// Returns a `JoinHandle` that resolves to all collected samples.
fn start_shard_sampler(
    indexer: Arc<dyn KvIndexerInterface + Send + Sync>,
    interval_ms: u64,
    cancel: tokio_util::sync::CancellationToken,
) -> tokio::task::JoinHandle<Vec<ShardSampleRow>> {
    tokio::spawn(async move {
        let mut rows = Vec::new();
        let start = Instant::now();
        let mut interval = tokio::time::interval(Duration::from_millis(interval_ms));
        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
        loop {
            tokio::select! {
                _ = interval.tick() => {
                    let elapsed_ms = start.elapsed().as_millis() as u64;
                    for snap in indexer.shard_sizes() {
                        rows.push(ShardSampleRow { elapsed_ms, snapshot: snap });
                    }
                }
                _ = cancel.cancelled() => break,
            }
        }
        rows
    })
}

/// Write the collected shard-size samples to a CSV file.
fn write_shard_metrics_csv(rows: &[ShardSampleRow], path: &str) -> anyhow::Result<()> {
    use std::io::Write;
    let mut f = std::fs::File::create(path)?;
    writeln!(
        f,
        "elapsed_ms,shard_idx,worker_count,block_count,node_count"
    )?;
    for r in rows {
        writeln!(
            f,
            "{},{},{},{},{}",
            r.elapsed_ms,
            r.snapshot.shard_idx,
            r.snapshot.worker_count,
            r.snapshot.block_count,
            r.snapshot.node_count,
        )?;
    }
    println!("Shard metrics CSV written to {path}");
    Ok(())
}

/// Plot per-shard `worker_count` and `block_count` over time and write an SVG.
///
/// Draws two panels stacked vertically:
/// - Top: workers per shard over time
/// - Bottom: blocks per shard over time
///
/// Each shard gets a distinct colour; shards are identified by their `shard_idx`.
fn plot_shard_metrics(rows: &[ShardSampleRow], svg_path: &str) -> anyhow::Result<()> {
    use plotters::prelude::*;

    if rows.is_empty() {
        return Ok(());
    }

    // Collect the set of shard indices present.
    let mut shard_indices: Vec<usize> = rows.iter().map(|r| r.snapshot.shard_idx).collect();
    shard_indices.sort_unstable();
    shard_indices.dedup();

    let max_elapsed = rows.iter().map(|r| r.elapsed_ms).max().unwrap_or(1);
    let max_workers = rows
        .iter()
        .map(|r| r.snapshot.worker_count)
        .max()
        .unwrap_or(1);
    let max_blocks = rows
        .iter()
        .map(|r| r.snapshot.block_count)
        .max()
        .unwrap_or(1);

    let colors: Vec<RGBColor> = vec![
        RGBColor(31, 119, 180),
        RGBColor(255, 127, 14),
        RGBColor(44, 160, 44),
        RGBColor(214, 39, 40),
        RGBColor(148, 103, 189),
        RGBColor(140, 86, 75),
    ];

    let root = SVGBackend::new(svg_path, (900, 700)).into_drawing_area();
    root.fill(&WHITE)?;

    let (upper, lower) = root.split_vertically(350);

    // --- Top panel: workers per shard ---
    let mut chart = ChartBuilder::on(&upper)
        .caption("Workers per shard over time", ("sans-serif", 18))
        .margin(15)
        .x_label_area_size(30)
        .y_label_area_size(60)
        .build_cartesian_2d(0u64..max_elapsed, 0usize..max_workers + 1)?;
    chart
        .configure_mesh()
        .x_desc("Elapsed (ms)")
        .y_desc("Workers")
        .draw()?;

    for (i, &shard_idx) in shard_indices.iter().enumerate() {
        let color = colors[i % colors.len()];
        let points: Vec<(u64, usize)> = rows
            .iter()
            .filter(|r| r.snapshot.shard_idx == shard_idx)
            .map(|r| (r.elapsed_ms, r.snapshot.worker_count))
            .collect();
        let label = format!("shard {shard_idx}");
        chart
            .draw_series(LineSeries::new(points, &color))?
            .label(label)
            .legend(move |(x, y)| {
                plotters::element::PathElement::new(
                    vec![(x, y), (x + 20, y)],
                    color.stroke_width(2),
                )
            });
    }
    chart
        .configure_series_labels()
        .background_style(WHITE.mix(0.8))
        .border_style(BLACK)
        .draw()?;

    // --- Bottom panel: blocks per shard ---
    let mut chart2 = ChartBuilder::on(&lower)
        .caption("Blocks per shard over time", ("sans-serif", 18))
        .margin(15)
        .x_label_area_size(30)
        .y_label_area_size(60)
        .build_cartesian_2d(0u64..max_elapsed, 0usize..max_blocks + 1)?;
    chart2
        .configure_mesh()
        .x_desc("Elapsed (ms)")
        .y_desc("Cached blocks")
        .draw()?;

    for (i, &shard_idx) in shard_indices.iter().enumerate() {
        let color = colors[i % colors.len()];
        let points: Vec<(u64, usize)> = rows
            .iter()
            .filter(|r| r.snapshot.shard_idx == shard_idx)
            .map(|r| (r.elapsed_ms, r.snapshot.block_count))
            .collect();
        let label = format!("shard {shard_idx}");
        chart2
            .draw_series(LineSeries::new(points, &color))?
            .label(label)
            .legend(move |(x, y)| {
                plotters::element::PathElement::new(
                    vec![(x, y), (x + 20, y)],
                    color.stroke_width(2),
                )
            });
    }
    chart2
        .configure_series_labels()
        .background_style(WHITE.mix(0.8))
        .border_style(BLACK)
        .draw()?;

    root.present()?;
    println!("Shard metrics plot written to {svg_path}");
    Ok(())
}

Yan Ru Pei's avatar
Yan Ru Pei committed
488
489
490
491
492
493
494
495
/// Run the benchmark: replay each worker's merged trace against the indexer,
/// measuring find_matches latency and event processing throughput.
///
/// Workers are spawned as tokio tasks, each replaying its trace at the
/// original inter-entry timing. After all workers finish, the event queue is
/// flushed and latency percentiles / throughput stats are printed.
async fn run_benchmark(
    indexer: Arc<dyn KvIndexerInterface + Send + Sync>,
496
    artifacts: Vec<WorkerReplayArtifacts>,
Yan Ru Pei's avatar
Yan Ru Pei committed
497
    args: &Args,
498
    benchmark_duration_ms: u64,
499
    count_events: bool,
500
    find_matches_concurrency: usize,
501
) -> anyhow::Result<BenchmarkResults> {
502
    let worker_traces = prepare_worker_traces(artifacts, benchmark_duration_ms);
503
    let worker_traces = worker_traces.into_iter().map(Arc::new).collect::<Vec<_>>();
Yan Ru Pei's avatar
Yan Ru Pei committed
504
505
506
507
508
509

    let progress = make_progress_bar(Some(
        worker_traces
            .iter()
            .map(|trace| trace.len() as u64)
            .sum::<u64>()
510
            * args.common.inference_worker_duplication_factor as u64,
Yan Ru Pei's avatar
Yan Ru Pei committed
511
512
513
    ));

    let mut tasks = Vec::new();
514
    for replica in 0..args.common.inference_worker_duplication_factor {
Yan Ru Pei's avatar
Yan Ru Pei committed
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
        for (worker_id, worker_trace) in worker_traces.iter().enumerate() {
            let indexer = indexer.clone();
            let trace = worker_trace.clone();
            let progress = progress.clone();
            let worker_id = worker_id + replica * worker_traces.len();
            tasks.push(tokio::spawn(async move {
                let mut request_latencies = Vec::with_capacity(trace.len());

                let submit = |entry: WorkerTrace| async {
                    match entry.entry {
                        WorkerTraceEntry::Request(request) => {
                            let start = minstant::Instant::now();
                            indexer.find_matches(request).await?;
                            Ok::<Option<u64>, anyhow::Error>(
                                Some(start.elapsed().as_nanos() as u64),
                            )
                        }
                        WorkerTraceEntry::Event(event) => {
                            indexer
534
                                .apply_event(RouterEvent::new(worker_id as u64, event))
Yan Ru Pei's avatar
Yan Ru Pei committed
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
                                .await;
                            Ok(None)
                        }
                    }
                };

                let mut target = Instant::now();

                let mut trace = trace.iter().peekable();

                let mut local_count = 0;

                while let Some(entry) = trace.next() {
                    let mut processed = 1;
                    let entry_timestamp_us = entry.timestamp_us;

                    if let Some(latency) = submit(entry.clone()).await? {
                        request_latencies.push(latency);
                    }

                    while let Some(next) = trace.peek() {
                        if next.timestamp_us == entry_timestamp_us {
                            if let Some(latency) = submit(trace.next().unwrap().clone()).await? {
                                request_latencies.push(latency);
                            }
                            processed += 1;
                        } else {
                            break;
                        }
                    }

                    if let Some(next) = trace.peek() {
                        target += Duration::from_micros(next.timestamp_us - entry_timestamp_us);
                    }

                    if target > Instant::now() {
                        tokio::time::sleep_until(target).await;
                    }

                    local_count += processed;

                    if local_count > 100 {
                        progress.inc(local_count);
                        local_count = 0;
                    }
                }

                progress.inc(local_count);

                Ok::<_, anyhow::Error>(request_latencies)
            }));
        }
    }

589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
    // Spawn additional concurrent find_matches callers if requested.
    // These tasks run alongside the trace-replay workers to stress the read path.
    let fm_stop = Arc::new(AtomicBool::new(false));
    let mut fm_tasks = Vec::new();
    if find_matches_concurrency > 0 {
        // Collect all request sequences as a pool for random selection.
        let seq_pool: Arc<Vec<Vec<LocalBlockHash>>> = Arc::new(
            worker_traces
                .iter()
                .flat_map(|t| t.iter())
                .filter_map(|entry| match &entry.entry {
                    WorkerTraceEntry::Request(hashes) => Some(hashes.clone()),
                    _ => None,
                })
                .collect(),
        );

        if !seq_pool.is_empty() {
            for task_id in 0..find_matches_concurrency {
                let indexer = indexer.clone();
                let pool = Arc::clone(&seq_pool);
                let stop = Arc::clone(&fm_stop);
                fm_tasks.push(tokio::spawn(async move {
                    let mut latencies = Vec::new();
                    let mut idx = task_id % pool.len();
                    while !stop.load(Ordering::Relaxed) {
                        let seq = pool[idx].clone();
                        let start = minstant::Instant::now();
                        let _ = indexer.find_matches(seq).await;
                        latencies.push(start.elapsed().as_nanos() as u64);
                        idx = (idx + 1) % pool.len();
                    }
                    latencies
                }));
            }
        }
    }

Yan Ru Pei's avatar
Yan Ru Pei committed
627
628
629
630
631
632
    let mut latencies = Vec::new();

    for task in tasks {
        latencies.extend(task.await??);
    }

633
634
635
636
637
638
639
640
    // Signal concurrent find_matches callers to stop and collect their latencies.
    fm_stop.store(true, Ordering::Relaxed);
    for task in fm_tasks {
        if let Ok(fm_latencies) = task.await {
            latencies.extend(fm_latencies);
        }
    }

641
    if progress.elapsed() > Duration::from_millis(benchmark_duration_ms * 11 / 10) {
Yan Ru Pei's avatar
Yan Ru Pei committed
642
643
644
645
646
        eprintln!(
            "WARNING: The benchmarker is unable to keep up with the request/event generation rate. Rerun with a larger --benchmark-duration-ms."
        )
    }

647
    let total_duration = progress.elapsed();
Yan Ru Pei's avatar
Yan Ru Pei committed
648
649
650
651
652
653
654
655
656
657

    let total_events = worker_traces
        .iter()
        .map(|trace| {
            trace
                .iter()
                .filter(|trace| matches!(trace.entry, WorkerTraceEntry::Event(_)))
                .count()
        })
        .sum::<usize>()
658
        * args.common.inference_worker_duplication_factor;
Yan Ru Pei's avatar
Yan Ru Pei committed
659
660

    let total_requests = worker_traces.iter().map(|trace| trace.len()).sum::<usize>()
661
        * args.common.inference_worker_duplication_factor
Yan Ru Pei's avatar
Yan Ru Pei committed
662
663
        - total_events;

664
665
666
667
668
669
670
671
    let total_request_blocks: usize = worker_traces
        .iter()
        .flat_map(|t| t.iter())
        .filter_map(|entry| match &entry.entry {
            WorkerTraceEntry::Request(hashes) => Some(hashes.len()),
            _ => None,
        })
        .sum::<usize>()
672
        * args.common.inference_worker_duplication_factor;
Yan Ru Pei's avatar
Yan Ru Pei committed
673

674
675
676
677
678
679
680
681
682
683
684
    let total_event_blocks: usize = worker_traces
        .iter()
        .flat_map(|t| t.iter())
        .filter_map(|entry| match &entry.entry {
            WorkerTraceEntry::Event(ev) => match &ev.data {
                KvCacheEventData::Stored(s) => Some(s.blocks.len()),
                _ => Some(0),
            },
            _ => None,
        })
        .sum::<usize>()
685
        * args.common.inference_worker_duplication_factor;
Yan Ru Pei's avatar
Yan Ru Pei committed
686

687
688
    let counted_events = if count_events { total_events } else { 0 };
    let counted_event_blocks = if count_events { total_event_blocks } else { 0 };
Yan Ru Pei's avatar
Yan Ru Pei committed
689

690
691
    let total_blocks = total_request_blocks + counted_event_blocks;
    let total_ops = total_requests + counted_events;
692
693
694
695
    let offered_ops_throughput = total_ops as f32 / benchmark_duration_ms as f32 * 1000.0;
    let ops_throughput = total_ops as f32 / total_duration.as_millis() as f32 * 1000.0;
    let offered_block_throughput = total_blocks as f32 / benchmark_duration_ms as f32 * 1000.0;
    let block_throughput = total_blocks as f32 / total_duration.as_millis() as f32 * 1000.0;
Yan Ru Pei's avatar
Yan Ru Pei committed
696
697

    latencies.sort_unstable();
698
699
700
701
702
    let latency_p99_us = if latencies.is_empty() {
        0.0
    } else {
        latencies[latencies.len() * 99 / 100] as f32 / 1000.0
    };
703

Yan Ru Pei's avatar
Yan Ru Pei committed
704
    println!(
705
706
707
708
709
710
        "Offered Ops Throughput: {} ops/s | Achieved: {} ops/s (requests + events)",
        offered_ops_throughput as u64, ops_throughput as u64,
    );
    println!(
        "Offered Block Throughput: {} block ops/s | Achieved: {} block ops/s",
        offered_block_throughput as u64, block_throughput as u64,
Yan Ru Pei's avatar
Yan Ru Pei committed
711
    );
712
713
714
715
716
717
718
719
720
721
    println!("Latency p99: {}us", latency_p99_us);

    Ok(BenchmarkResults {
        offered_ops_throughput,
        ops_throughput,
        offered_block_throughput,
        block_throughput,
        latency_p99_us,
    })
}
Yan Ru Pei's avatar
Yan Ru Pei committed
722

723
async fn run_tests() -> anyhow::Result<()> {
724
    use std::collections::HashSet;
725
    use std::fs::File;
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
    use std::io::Write;

    let path =
        std::env::temp_dir().join(format!("mooncake_bench_test_{}.jsonl", std::process::id()));
    {
        let mut f = File::create(&path)?;
        for (i, (hash_ids, output_length)) in
            [(&[0u64, 1, 2] as &[u64], 10u64), (&[0, 1, 3, 4], 10)]
                .iter()
                .enumerate()
        {
            writeln!(
                f,
                "{}",
                serde_json::json!({
                    "timestamp": i as u64,
742
                    "input_length": hash_ids.len(),
743
744
745
746
747
748
749
                    "hash_ids": hash_ids,
                    "output_length": output_length,
                })
            )?;
        }
    }

750
    let traces = process_mooncake_trace(path.to_str().unwrap(), 512, 2, 2, 2, 42)?;
751
752
753
754
    std::fs::remove_file(&path).ok();

    let mut all_hashes: Vec<Vec<u64>> = traces
        .into_iter()
755
756
        .flat_map(|worker| worker.sessions.into_iter())
        .flat_map(|session| session.turns.into_iter().map(|turn| turn.hash_ids))
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
        .collect();
    all_hashes.sort();

    // expand(2): [0,1,2] → [0,1,2,3,4,5], [0,1,3,4] → [0,1,2,3,6,7,8,9]
    // duplicate(2): max=9, offset=10
    let mut expected = vec![
        vec![0, 1, 2, 3, 4, 5],
        vec![10, 11, 12, 13, 14, 15],
        vec![0, 1, 2, 3, 6, 7, 8, 9],
        vec![10, 11, 12, 13, 16, 17, 18, 19],
    ];
    expected.sort();
    assert_eq!(all_hashes, expected, "hash_ids mismatch");

    // Verify prefix structure within each copy.
    let copy0: Vec<&Vec<u64>> = all_hashes.iter().filter(|h| h[0] == 0).collect();
    let copy1: Vec<&Vec<u64>> = all_hashes.iter().filter(|h| h[0] == 10).collect();
    assert_eq!(copy0.len(), 2);
    assert_eq!(copy1.len(), 2);
    assert_eq!(copy0[0][..4], copy0[1][..4], "copy 0 shared prefix broken");
    assert_eq!(copy1[0][..4], copy1[1][..4], "copy 1 shared prefix broken");

    // Verify disjointness between copies.
    let set0: HashSet<u64> = copy0.iter().flat_map(|h| h.iter().copied()).collect();
    let set1: HashSet<u64> = copy1.iter().flat_map(|h| h.iter().copied()).collect();
    assert!(set0.is_disjoint(&set1), "copies are not hash-disjoint");

784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
    let replay_trace = Trace {
        block_size: 2,
        sessions: vec![dynamo_mocker::loadgen::SessionTrace {
            session_id: "session-a".to_string(),
            first_arrival_timestamp_ms: Some(0.0),
            turns: vec![
                dynamo_mocker::loadgen::TurnTrace {
                    input_length: 4,
                    max_output_tokens: 2,
                    hash_ids: vec![1, 2],
                    delay_after_previous_ms: 0.0,
                },
                dynamo_mocker::loadgen::TurnTrace {
                    input_length: 4,
                    max_output_tokens: 2,
                    hash_ids: vec![3, 4],
                    delay_after_previous_ms: 5.0,
                },
            ],
        }],
    };
    let artifacts = generate_replay_artifacts(&[replay_trace], 1024, 2, 5).await?;
    assert_eq!(artifacts.len(), 1);
    assert_eq!(artifacts[0].requests.len(), 2);
    let first_uuid = artifacts[0].requests[0].uuid;
    let first_completion_ms = artifacts[0]
        .output_signals
        .iter()
        .find(|signal| signal.signal.uuid == first_uuid && signal.signal.completed)
        .expect("first request must complete")
        .timestamp_us as f64
        / 1000.0;
    assert!(
        artifacts[0].requests[1].scheduled_ready_at_ms + 0.1 >= first_completion_ms + 5.0,
        "expected second request to wait for completion plus delay"
    );

821
822
823
824
    println!("All tests passed.");
    Ok(())
}

Yan Ru Pei's avatar
Yan Ru Pei committed
825
826
827
828
#[tokio::main]
async fn main() -> anyhow::Result<()> {
    let args = Args::parse();

829
    if args.common.test {
830
        return run_tests().await;
831
832
    }

833
834
835
836
837
838
839
    let path = match args.common.mooncake_trace_path.as_deref() {
        Some(p) => p,
        None => {
            eprintln!("No mooncake_trace_path provided, skipping benchmark");
            return Ok(());
        }
    };
840
841
    let traces = process_mooncake_trace(
        path,
842
        args.common.block_size,
843
844
845
846
847
        args.common.trace_length_factor,
        args.common.trace_duplication_factor,
        args.common.num_unique_inference_workers,
        args.common.seed,
    )?;
848
    let artifacts = generate_replay_artifacts(
849
850
851
852
853
854
        &traces,
        args.common.num_gpu_blocks,
        args.common.block_size,
        args.common.trace_simulation_duration_ms,
    )
    .await?;
Yan Ru Pei's avatar
Yan Ru Pei committed
855

856
857
858
859
860
    let indexer_names: Vec<String> = if args.compare.is_empty() {
        let name = match args.get_indexer() {
            IndexerArgs::RadixTree {} => "radix-tree",
            IndexerArgs::NestedMap { .. } => "nested-map",
            IndexerArgs::ConcurrentRadixTree { .. } => "concurrent-radix-tree",
861
            IndexerArgs::ConcurrentRadixTreeCompressed { .. } => "concurrent-radix-tree-compressed",
862
            IndexerArgs::BranchShardedCrtc { .. } => "branch-sharded-crtc",
863
864
865
866
867
868
        };
        vec![name.to_string()]
    } else {
        args.compare.clone()
    };

869
870
871
872
873
874
    if args.common.sweep {
        let durations_low_to_high = compute_sweep_durations(
            args.common.sweep_min_ms,
            args.common.sweep_max_ms,
            args.common.sweep_steps,
        );
875
        let durations_high_to_low: Vec<u64> = durations_low_to_high.iter().copied().rev().collect();
876
877
878
879
880
881
882
883

        let mut all_results: Vec<(&str, Vec<(u64, BenchmarkResults)>)> = Vec::new();

        for name in &indexer_names {
            println!("\n{}", "=".repeat(60));
            println!("Benchmarking indexer: {}", name);
            println!("{}", "=".repeat(60));

884
885
886
887
888
889
890
            let multi_threaded = IndexerArgs::is_multi_threaded(name);
            let durations = if multi_threaded {
                &durations_high_to_low
            } else {
                &durations_low_to_high
            };

891
            let mut results: Vec<(u64, BenchmarkResults)> = Vec::new();
892
            let mut consecutive_keeping_up = 0u32;
893

894
            for &dur_ms in durations {
895
896
                println!("\n=== Sweep: benchmark_duration_ms = {} ===", dur_ms);
                let indexer = if args.compare.is_empty() {
897
                    args.get_indexer().build(args.common.block_size)
898
                } else {
899
                    IndexerArgs::from_name(name, args.common.block_size, args.num_event_workers)?
900
                };
901
                let count_events = IndexerArgs::supports_remove(name);
902
903
904
905
906
907
908
909
910
                let result = run_benchmark(
                    indexer,
                    artifacts.clone(),
                    &args,
                    dur_ms,
                    count_events,
                    args.find_matches_concurrency,
                )
                .await?;
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930

                if multi_threaded {
                    if result.block_throughput >= result.offered_block_throughput * 0.95 {
                        consecutive_keeping_up += 1;
                    } else {
                        consecutive_keeping_up = 0;
                    }
                    results.push((dur_ms, result));
                    if consecutive_keeping_up >= 5 {
                        println!("Early stop: achieved >= 95% offered for 5 consecutive steps");
                        break;
                    }
                } else {
                    let saturated = result.offered_block_throughput > result.block_throughput * 5.0;
                    results.push((dur_ms, result));
                    if saturated {
                        println!("Early stop: offered throughput >5x achieved throughput");
                        break;
                    }
                }
931
932
            }

933
            results.sort_by_key(|(dur, _)| std::cmp::Reverse(*dur));
934
            print_sweep_summary(name, &results);
Yan Ru Pei's avatar
Yan Ru Pei committed
935

936
937
938
939
            all_results.push((name, results));
        }

        plot_sweep(&all_results, &args.sweep_output)?;
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965

        let json_path = args
            .sweep_output
            .replace(".png", ".json")
            .replace(".svg", ".json");
        let json_map: std::collections::BTreeMap<&str, Vec<SweepStepResult>> = all_results
            .iter()
            .map(|(name, results)| {
                let steps = results
                    .iter()
                    .map(|(dur, r)| SweepStepResult {
                        duration_ms: *dur,
                        results: BenchmarkResults {
                            offered_ops_throughput: r.offered_ops_throughput,
                            ops_throughput: r.ops_throughput,
                            offered_block_throughput: r.offered_block_throughput,
                            block_throughput: r.block_throughput,
                            latency_p99_us: r.latency_p99_us,
                        },
                    })
                    .collect();
                (*name, steps)
            })
            .collect();
        std::fs::write(&json_path, serde_json::to_string_pretty(&json_map)?)?;
        println!("Sweep results saved to {}", json_path);
966
    } else {
967
968
        drop(traces);

969
970
971
        for name in &indexer_names {
            println!("\nBenchmarking indexer: {}", name);
            let indexer = if args.compare.is_empty() {
972
                args.get_indexer().build(args.common.block_size)
973
            } else {
974
                IndexerArgs::from_name(name, args.common.block_size, args.num_event_workers)?
975
            };
976
            let count_events = IndexerArgs::supports_remove(name);
977
978
979
980
981
982
983
984
985
986
987
988
989

            // Start shard-size sampler if a CSV path was provided.
            let shard_cancel = CancellationToken::new();
            let shard_sampler = if !args.shard_metrics_csv.is_empty() {
                Some(start_shard_sampler(
                    indexer.clone(),
                    args.shard_metrics_interval_ms,
                    shard_cancel.clone(),
                ))
            } else {
                None
            };

990
            run_benchmark(
991
                indexer.clone(),
992
                artifacts.clone(),
993
                &args,
994
                args.common.benchmark_duration_ms,
995
                count_events,
996
                args.find_matches_concurrency,
997
998
            )
            .await?;
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052

            // Stop sampler and write CSV + plot.
            shard_cancel.cancel();
            if let Some(handle) = shard_sampler {
                let rows = handle.await?;
                // In compare mode, prefix the indexer name to distinguish outputs.
                let csv_path = if args.compare.len() > 1 {
                    let stem = args.shard_metrics_csv.trim_end_matches(".csv");
                    format!("{stem}_{name}.csv")
                } else {
                    args.shard_metrics_csv.clone()
                };
                write_shard_metrics_csv(&rows, &csv_path)?;
                let svg = format!("{}.svg", csv_path.trim_end_matches(".csv"));
                plot_shard_metrics(&rows, &svg)?;
            }

            let report = indexer.timing_report();
            if !report.is_empty() {
                println!("{}", report);
            }
            let sizes = indexer.shard_sizes();
            if sizes.len() > 1 {
                let total_blocks: usize = sizes.iter().map(|s| s.block_count).sum();
                let total_nodes: usize = sizes.iter().map(|s| s.node_count).sum();
                println!("Shard block distribution:");
                for s in &sizes {
                    let pct = if total_blocks > 0 {
                        100.0 * s.block_count as f64 / total_blocks as f64
                    } else {
                        0.0
                    };
                    println!(
                        "  shard {}: {} blocks ({:.1}%), {} workers, {} nodes",
                        s.shard_idx, s.block_count, pct, s.worker_count, s.node_count
                    );
                }
                if total_nodes > 0 {
                    println!("  total nodes across shards: {}", total_nodes);
                }
            }

            let mut edge_lengths = indexer.node_edge_lengths();
            if !edge_lengths.is_empty() {
                let avg = edge_lengths.iter().sum::<usize>() as f64 / edge_lengths.len() as f64;
                edge_lengths.sort_unstable();
                let p99 = edge_lengths[edge_lengths.len() * 99 / 100];
                println!(
                    "Node edge lengths ({} nodes): avg={:.1} hashes/node, p99={} hashes/node",
                    edge_lengths.len(),
                    avg,
                    p99,
                );
            }
1053
1054
        }
    }
Yan Ru Pei's avatar
Yan Ru Pei committed
1055
1056
1057

    Ok(())
}