offline_replay_bench.rs 5.9 KB
Newer Older
1
2
3
4
5
6
7
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

//! Rust-native offline replay benchmark entrypoint.
//!
//! Useful for profiling replay itself without the Python CLI wrapper. This keeps
//! the default mocker perf model unless CLI overrides are provided.
8
9
//!
//! Run with: cargo bench --package dynamo-bench --bench offline_replay_bench -- --help
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

use std::fs::File;
use std::path::PathBuf;
use std::time::Instant;

use anyhow::{Context, Result};
use clap::{Parser, ValueEnum};
use dynamo_mocker::common::protocols::MockEngineArgs;
use dynamo_mocker::replay::{ReplayRouterMode, simulate_trace_file_with_router_mode};

#[derive(Copy, Clone, Debug, Eq, PartialEq, ValueEnum)]
enum RouterModeArg {
    RoundRobin,
    KvRouter,
}

impl From<RouterModeArg> for ReplayRouterMode {
    fn from(value: RouterModeArg) -> Self {
        match value {
            RouterModeArg::RoundRobin => ReplayRouterMode::RoundRobin,
            RouterModeArg::KvRouter => ReplayRouterMode::KvRouter,
        }
    }
}

35
36
37
38
39
fn is_bench_harness_invocation() -> bool {
    let args: Vec<_> = std::env::args_os().skip(1).collect();
    args.is_empty() || args.iter().all(|arg| arg == "--bench")
}

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#[derive(Parser, Debug)]
#[command(name = "offline_replay_bench")]
#[command(about = "Run offline replay directly in Rust for benchmarking and profiling")]
struct Args {
    /// Mooncake trace JSONL file
    trace_file: PathBuf,

    /// Number of aggregated workers
    #[arg(long, default_value_t = 4)]
    num_workers: usize,

    /// Router mode for multi-worker replay
    #[arg(long, value_enum, default_value_t = RouterModeArg::KvRouter)]
    router_mode: RouterModeArg,

    /// Compress trace arrival timestamps by this factor
    #[arg(long, default_value_t = 4.0)]
    arrival_speedup_ratio: f64,

59
    /// Trace hash block size used to expand hash_ids into tokens
60
    #[arg(long, default_value_t = 512)]
61
62
63
64
    trace_block_size: usize,

    /// Engine/router block size used for replay hashing and mock execution
    #[arg(long, default_value_t = 64)]
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    block_size: usize,

    /// Override max running requests per worker
    #[arg(long)]
    max_num_seqs: Option<usize>,

    /// Override batched token budget per worker pass
    #[arg(long)]
    max_num_batched_tokens: Option<usize>,

    /// Global speedup multiplier for the default perf model
    #[arg(long)]
    speedup_ratio: Option<f64>,

    /// Additional decode-only speedup multiplier
    #[arg(long)]
    decode_speedup_ratio: Option<f64>,

    /// Explicit planner profile NPZ to use for perf-model timing
    #[arg(long)]
    planner_profile_data: Option<PathBuf>,

    /// Optional path to write the full replay report as pretty JSON
    #[arg(long)]
    report_json: Option<PathBuf>,

    /// Number of times to rerun the same replay in-process
    #[arg(long, default_value_t = 1)]
    iterations: usize,
94
95
96
97

    /// Ignored -- passed by cargo bench
    #[arg(long, hide = true)]
    bench: bool,
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
}

fn build_engine_args(args: &Args) -> Result<MockEngineArgs> {
    let mut builder = MockEngineArgs::builder();
    builder = builder.block_size(args.block_size);
    if let Some(max_num_seqs) = args.max_num_seqs {
        builder = builder.max_num_seqs(Some(max_num_seqs));
    }
    if let Some(max_num_batched_tokens) = args.max_num_batched_tokens {
        builder = builder.max_num_batched_tokens(Some(max_num_batched_tokens));
    }
    if let Some(speedup_ratio) = args.speedup_ratio {
        builder = builder.speedup_ratio(speedup_ratio);
    }
    if let Some(decode_speedup_ratio) = args.decode_speedup_ratio {
        builder = builder.decode_speedup_ratio(decode_speedup_ratio);
    }
    if let Some(planner_profile_data) = args.planner_profile_data.as_ref() {
        builder = builder.planner_profile_data(Some(planner_profile_data.clone()));
    }
    builder
        .build()
        .context("failed to build replay engine args")?
        .normalized()
}

fn main() -> Result<()> {
125
126
127
128
129
    if is_bench_harness_invocation() {
        eprintln!("offline_replay_bench: skipping no-arg harness invocation");
        return Ok(());
    }

130
131
132
133
134
135
136
137
    let args = Args::parse();
    let engine_args = build_engine_args(&args)?;
    let started_at = Instant::now();
    let mut last_report = None;
    for _ in 0..args.iterations {
        last_report = Some(simulate_trace_file_with_router_mode(
            engine_args.clone(),
            None,
138
            None,
139
            &args.trace_file,
140
            args.trace_block_size,
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
            args.num_workers,
            args.arrival_speedup_ratio,
            args.router_mode.into(),
        )?);
    }
    let report = last_report.expect("iterations must be at least 1");
    let process_wall_time_ms = started_at.elapsed().as_secs_f64() * 1000.0;

    if let Some(report_path) = args.report_json.as_ref() {
        let file = File::create(report_path)
            .with_context(|| format!("failed to create report file at {:?}", report_path))?;
        serde_json::to_writer_pretty(file, &report)
            .with_context(|| format!("failed to write report JSON to {:?}", report_path))?;
        println!("Saved report to {}", report_path.display());
    }

    println!("Offline replay report");
    println!(
        "  completed_requests: {}",
        report.request_counts.completed_requests
    );
    println!(
        "  request_throughput_rps: {:.6}",
        report.throughput.request_throughput_rps
    );
    println!(
        "  output_throughput_tok_s: {:.6}",
        report.throughput.output_throughput_tok_s
    );
    println!("  mean_ttft_ms: {:.6}", report.latency.ttft.mean_ms);
    println!("  mean_e2e_latency_ms: {:.6}", report.latency.e2e.mean_ms);
    println!(
        "  prefix_cache_reused_ratio: {:.6}",
        report.prefix_cache_reused_ratio
    );
    println!("  wall_time_ms: {:.6}", report.throughput.wall_time_ms);
    println!("  process_wall_time_ms: {:.6}", process_wall_time_ms);

    Ok(())
}