offline_replay_bench.rs 5.67 KB
Newer Older
1
2
3
4
5
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

//! Rust-native offline replay benchmark entrypoint.
//!
6
7
8
//! Useful for profiling replay itself without the Python CLI wrapper. This
//! bench intentionally uses the mocker's internal polynomial perf model so the
//! measurements stay focused on replay and router overhead.
9
10
//!
//! Run with: cargo bench --package dynamo-bench --bench offline_replay_bench -- --help
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

use std::fs::File;
use std::path::PathBuf;
use std::time::Instant;

use anyhow::{Context, Result};
use clap::{Parser, ValueEnum};
use dynamo_mocker::common::protocols::MockEngineArgs;
use dynamo_mocker::replay::{ReplayRouterMode, simulate_trace_file_with_router_mode};

#[derive(Copy, Clone, Debug, Eq, PartialEq, ValueEnum)]
enum RouterModeArg {
    RoundRobin,
    KvRouter,
}

impl From<RouterModeArg> for ReplayRouterMode {
    fn from(value: RouterModeArg) -> Self {
        match value {
            RouterModeArg::RoundRobin => ReplayRouterMode::RoundRobin,
            RouterModeArg::KvRouter => ReplayRouterMode::KvRouter,
        }
    }
}

36
37
38
39
40
fn is_bench_harness_invocation() -> bool {
    let args: Vec<_> = std::env::args_os().skip(1).collect();
    args.is_empty() || args.iter().all(|arg| arg == "--bench")
}

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#[derive(Parser, Debug)]
#[command(name = "offline_replay_bench")]
#[command(about = "Run offline replay directly in Rust for benchmarking and profiling")]
struct Args {
    /// Mooncake trace JSONL file
    trace_file: PathBuf,

    /// Number of aggregated workers
    #[arg(long, default_value_t = 4)]
    num_workers: usize,

    /// Router mode for multi-worker replay
    #[arg(long, value_enum, default_value_t = RouterModeArg::KvRouter)]
    router_mode: RouterModeArg,

    /// Compress trace arrival timestamps by this factor
    #[arg(long, default_value_t = 4.0)]
    arrival_speedup_ratio: f64,

60
    /// Trace hash block size used to expand hash_ids into tokens
61
    #[arg(long, default_value_t = 512)]
62
63
64
65
    trace_block_size: usize,

    /// Engine/router block size used for replay hashing and mock execution
    #[arg(long, default_value_t = 64)]
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    block_size: usize,

    /// Override max running requests per worker
    #[arg(long)]
    max_num_seqs: Option<usize>,

    /// Override batched token budget per worker pass
    #[arg(long)]
    max_num_batched_tokens: Option<usize>,

    /// Global speedup multiplier for the default perf model
    #[arg(long)]
    speedup_ratio: Option<f64>,

    /// Additional decode-only speedup multiplier
    #[arg(long)]
    decode_speedup_ratio: Option<f64>,

    /// Optional path to write the full replay report as pretty JSON
    #[arg(long)]
    report_json: Option<PathBuf>,

    /// Number of times to rerun the same replay in-process
    #[arg(long, default_value_t = 1)]
    iterations: usize,
91
92
93
94

    /// Ignored -- passed by cargo bench
    #[arg(long, hide = true)]
    bench: bool,
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
}

fn build_engine_args(args: &Args) -> Result<MockEngineArgs> {
    let mut builder = MockEngineArgs::builder();
    builder = builder.block_size(args.block_size);
    if let Some(max_num_seqs) = args.max_num_seqs {
        builder = builder.max_num_seqs(Some(max_num_seqs));
    }
    if let Some(max_num_batched_tokens) = args.max_num_batched_tokens {
        builder = builder.max_num_batched_tokens(Some(max_num_batched_tokens));
    }
    if let Some(speedup_ratio) = args.speedup_ratio {
        builder = builder.speedup_ratio(speedup_ratio);
    }
    if let Some(decode_speedup_ratio) = args.decode_speedup_ratio {
        builder = builder.decode_speedup_ratio(decode_speedup_ratio);
    }
    builder
        .build()
        .context("failed to build replay engine args")?
        .normalized()
}

fn main() -> Result<()> {
119
120
121
122
123
    if is_bench_harness_invocation() {
        eprintln!("offline_replay_bench: skipping no-arg harness invocation");
        return Ok(());
    }

124
125
126
127
128
129
130
131
    let args = Args::parse();
    let engine_args = build_engine_args(&args)?;
    let started_at = Instant::now();
    let mut last_report = None;
    for _ in 0..args.iterations {
        last_report = Some(simulate_trace_file_with_router_mode(
            engine_args.clone(),
            None,
132
            None,
133
            &args.trace_file,
134
            args.trace_block_size,
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
            args.num_workers,
            args.arrival_speedup_ratio,
            args.router_mode.into(),
        )?);
    }
    let report = last_report.expect("iterations must be at least 1");
    let process_wall_time_ms = started_at.elapsed().as_secs_f64() * 1000.0;

    if let Some(report_path) = args.report_json.as_ref() {
        let file = File::create(report_path)
            .with_context(|| format!("failed to create report file at {:?}", report_path))?;
        serde_json::to_writer_pretty(file, &report)
            .with_context(|| format!("failed to write report JSON to {:?}", report_path))?;
        println!("Saved report to {}", report_path.display());
    }

    println!("Offline replay report");
    println!(
        "  completed_requests: {}",
        report.request_counts.completed_requests
    );
    println!(
        "  request_throughput_rps: {:.6}",
        report.throughput.request_throughput_rps
    );
    println!(
        "  output_throughput_tok_s: {:.6}",
        report.throughput.output_throughput_tok_s
    );
    println!("  mean_ttft_ms: {:.6}", report.latency.ttft.mean_ms);
    println!("  mean_e2e_latency_ms: {:.6}", report.latency.e2e.mean_ms);
    println!(
        "  prefix_cache_reused_ratio: {:.6}",
        report.prefix_cache_reused_ratio
    );
    println!("  wall_time_ms: {:.6}", report.throughput.wall_time_ms);
    println!("  process_wall_time_ms: {:.6}", process_wall_time_ms);

    Ok(())
}