// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use super::core::ReplayWorkerCore; use super::normalize_trace_requests; use crate::common::protocols::{DirectRequest, MockEngineArgs}; use crate::replay::{TraceCollector, TraceSimulationReport}; use anyhow::bail; use std::collections::VecDeque; use uuid::Uuid; #[derive(Debug, Clone, Copy)] enum SingleReplayMode { Trace, Concurrency { max_in_flight: usize }, } struct SingleRuntime { current_time_ms: f64, pending: VecDeque, worker: ReplayWorkerCore, collector: TraceCollector, mode: SingleReplayMode, } impl SingleRuntime { fn new(args: MockEngineArgs, pending: VecDeque, mode: SingleReplayMode) -> Self { Self { current_time_ms: 0.0, pending, worker: ReplayWorkerCore::new(args), collector: TraceCollector::default(), mode, } } fn enqueue_trace_arrivals(&mut self) { loop { let Some(next_arrival_ms) = self .pending .front() .and_then(|request| request.arrival_timestamp_ms) else { break; }; if next_arrival_ms > self.current_time_ms { break; } let request = self .pending .pop_front() .expect("front request must exist when arrival is available"); let arrival_ms = request .arrival_timestamp_ms .expect("trace replay requests must have an arrival timestamp"); self.record_arrival(request, arrival_ms); } } fn enqueue_concurrency_arrivals(&mut self, max_in_flight: usize) { while self.worker.num_requests() < max_in_flight { let Some(mut request) = self.pending.pop_front() else { break; }; request.arrival_timestamp_ms = Some(self.current_time_ms); self.record_arrival(request, self.current_time_ms); } } fn record_arrival(&mut self, request: DirectRequest, arrival_ms: f64) -> Uuid { let input_length = request.tokens.len(); let output_length = request.max_output_tokens; let uuid = self.worker.receive(request); self.collector .on_arrival(uuid, arrival_ms, input_length, output_length); uuid } fn is_done(&self) -> bool { self.pending.is_empty() && self.worker.is_empty() } fn advance_to_next_trace_arrival(&mut self) -> anyhow::Result<()> { let Some(next_arrival_ms) = self .pending .front() .and_then(|request| request.arrival_timestamp_ms) else { bail!("trace replay reached an idle state without a pending arrival"); }; self.current_time_ms = next_arrival_ms; Ok(()) } fn drive_worker(&mut self, admit_arrivals_between_steps: bool) { let pass = self .worker .execute_pass(&mut self.collector, self.current_time_ms); self.current_time_ms = pass.end_ms; if admit_arrivals_between_steps { self.enqueue_trace_arrivals(); } } fn run(mut self) -> anyhow::Result { while !self.is_done() { match self.mode { SingleReplayMode::Trace => { self.enqueue_trace_arrivals(); if self.worker.is_empty() { self.advance_to_next_trace_arrival()?; self.enqueue_trace_arrivals(); continue; } self.drive_worker(true); } SingleReplayMode::Concurrency { max_in_flight } => { self.enqueue_concurrency_arrivals(max_in_flight); if self.worker.is_empty() { break; } self.drive_worker(false); } } } Ok(self.collector) } } pub(crate) fn simulate_trace_single( args: MockEngineArgs, requests: Vec, arrival_speedup_ratio: f64, ) -> anyhow::Result { let args = args.normalized()?; let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?; let collector = SingleRuntime::new(args, pending, SingleReplayMode::Trace).run()?; Ok(collector.finish()) } pub(crate) fn simulate_concurrency_single( args: MockEngineArgs, requests: Vec, max_in_flight: usize, ) -> anyhow::Result { let args = args.normalized()?; let pending = VecDeque::from(requests); let collector = SingleRuntime::new( args, pending, SingleReplayMode::Concurrency { max_in_flight }, ) .run()?; Ok(collector.finish()) } #[cfg(test)] pub(super) fn run_trace_single_collect( args: MockEngineArgs, requests: Vec, arrival_speedup_ratio: f64, ) -> TraceCollector { let pending = normalize_trace_requests(requests, arrival_speedup_ratio).unwrap(); SingleRuntime::new(args, pending, SingleReplayMode::Trace) .run() .unwrap() } #[cfg(test)] pub(super) fn run_concurrency_single_collect( args: MockEngineArgs, requests: Vec, max_in_flight: usize, ) -> TraceCollector { SingleRuntime::new( args, VecDeque::from(requests), SingleReplayMode::Concurrency { max_in_flight }, ) .run() .unwrap() } #[cfg(test)] mod tests { use super::*; use crate::replay::{TraceRequestStatsSnapshot, TraceSimulationReport}; use rstest::rstest; use std::collections::{HashMap, VecDeque}; use uuid::Uuid; #[derive(Debug)] struct ManualReplayResult { report: TraceSimulationReport, snapshots: HashMap, idle_jump_ms: f64, first_decode_end_ms: f64, } #[derive(Debug)] struct ManualConcurrencyResult { report: TraceSimulationReport, snapshots: HashMap, } fn enqueue_trace_arrivals_manual( pending: &mut VecDeque, worker: &mut ReplayWorkerCore, collector: &mut TraceCollector, current_time_ms: f64, ) { loop { let Some(next_arrival_ms) = pending .front() .and_then(|request| request.arrival_timestamp_ms) else { break; }; if next_arrival_ms > current_time_ms { break; } let request = pending .pop_front() .expect("front request must exist when arrival is available"); let arrival_ms = request .arrival_timestamp_ms .expect("trace replay requests must have an arrival timestamp"); let input_length = request.tokens.len(); let output_length = request.max_output_tokens; let uuid = worker.receive(request); collector.on_arrival(uuid, arrival_ms, input_length, output_length); } } fn enqueue_concurrency_arrivals_manual( pending: &mut VecDeque, worker: &mut ReplayWorkerCore, collector: &mut TraceCollector, current_time_ms: f64, max_in_flight: usize, ) { while worker.num_requests() < max_in_flight { let Some(mut request) = pending.pop_front() else { break; }; request.arrival_timestamp_ms = Some(current_time_ms); let input_length = request.tokens.len(); let output_length = request.max_output_tokens; let uuid = worker.receive(request); collector.on_arrival(uuid, current_time_ms, input_length, output_length); } } fn replay_args(enable_prefix_caching: bool, enable_chunked_prefill: bool) -> MockEngineArgs { MockEngineArgs::builder() .block_size(4) .num_gpu_blocks(32) .max_num_batched_tokens(Some(8)) .max_num_seqs(Some(2)) .enable_prefix_caching(enable_prefix_caching) .enable_chunked_prefill(enable_chunked_prefill) .speedup_ratio(0.0) .build() .unwrap() } fn replay_fixture() -> Vec { vec![ DirectRequest { tokens: vec![1, 1, 1, 1, 2, 2, 2, 2], max_output_tokens: 2, uuid: Some(Uuid::from_u128(11)), dp_rank: 0, arrival_timestamp_ms: Some(100.0), }, DirectRequest { tokens: vec![1, 1, 1, 1, 2, 2, 2, 2], max_output_tokens: 2, uuid: Some(Uuid::from_u128(22)), dp_rank: 0, arrival_timestamp_ms: Some(101.0), }, DirectRequest { tokens: vec![9, 9, 9, 9, 8, 8, 8, 8], max_output_tokens: 2, uuid: Some(Uuid::from_u128(33)), dp_rank: 0, arrival_timestamp_ms: Some(500.0), }, ] } fn run_trace_manually( args: &MockEngineArgs, requests: Vec, ) -> ManualReplayResult { let mut requests = requests; requests.sort_by(|left, right| { let left_ts = left.arrival_timestamp_ms.unwrap(); let right_ts = right.arrival_timestamp_ms.unwrap(); left_ts.total_cmp(&right_ts) }); let first_arrival_ms = requests.first().unwrap().arrival_timestamp_ms.unwrap(); let mut pending = VecDeque::from( requests .into_iter() .map(|mut request| { request.arrival_timestamp_ms = Some(request.arrival_timestamp_ms.unwrap() - first_arrival_ms); request }) .collect::>(), ); let mut worker = ReplayWorkerCore::new(args.clone()); let mut collector = TraceCollector::default(); let mut current_time_ms = 0.0; let mut idle_jump_ms = 0.0; let mut first_decode_end_ms = 0.0; while !pending.is_empty() || !worker.is_empty() { enqueue_trace_arrivals_manual( &mut pending, &mut worker, &mut collector, current_time_ms, ); if worker.is_empty() { let next_arrival_ms = pending.front().unwrap().arrival_timestamp_ms.unwrap(); current_time_ms = next_arrival_ms; if idle_jump_ms == 0.0 && current_time_ms > 0.0 { idle_jump_ms = current_time_ms; } enqueue_trace_arrivals_manual( &mut pending, &mut worker, &mut collector, current_time_ms, ); continue; } let pass = worker.execute_pass(&mut collector, current_time_ms); if first_decode_end_ms == 0.0 && !pass.output_signals.is_empty() { first_decode_end_ms = pass.end_ms; } current_time_ms = pass.end_ms; enqueue_trace_arrivals_manual( &mut pending, &mut worker, &mut collector, current_time_ms, ); } let snapshots = [ Uuid::from_u128(11), Uuid::from_u128(22), Uuid::from_u128(33), ] .into_iter() .map(|uuid| (uuid, collector.snapshot(uuid).unwrap())) .collect(); ManualReplayResult { report: collector.finish(), snapshots, idle_jump_ms, first_decode_end_ms, } } fn run_concurrency_manually( args: &MockEngineArgs, requests: Vec, max_in_flight: usize, ) -> ManualConcurrencyResult { let mut pending = VecDeque::from(requests); let mut worker = ReplayWorkerCore::new(args.clone()); let mut collector = TraceCollector::default(); let mut current_time_ms = 0.0; while !pending.is_empty() || !worker.is_empty() { enqueue_concurrency_arrivals_manual( &mut pending, &mut worker, &mut collector, current_time_ms, max_in_flight, ); if worker.is_empty() { break; } let pass = worker.execute_pass(&mut collector, current_time_ms); current_time_ms = pass.end_ms; } let snapshots = [ Uuid::from_u128(11), Uuid::from_u128(22), Uuid::from_u128(33), ] .into_iter() .map(|uuid| (uuid, collector.snapshot(uuid).unwrap())) .collect(); ManualConcurrencyResult { report: collector.finish(), snapshots, } } fn assert_report_close(left: &TraceSimulationReport, right: &TraceSimulationReport) { let epsilon = 1e-9; assert_eq!( left.request_counts.num_requests, right.request_counts.num_requests ); assert_eq!( left.request_counts.completed_requests, right.request_counts.completed_requests ); assert_eq!( left.request_counts.total_input_tokens, right.request_counts.total_input_tokens ); assert_eq!( left.request_counts.total_output_tokens, right.request_counts.total_output_tokens ); assert!((left.throughput.duration_ms - right.throughput.duration_ms).abs() <= epsilon); assert!( (left.throughput.request_throughput_rps - right.throughput.request_throughput_rps) .abs() <= epsilon ); assert!( (left.throughput.input_throughput_tok_s - right.throughput.input_throughput_tok_s) .abs() <= epsilon ); assert!( (left.throughput.output_throughput_tok_s - right.throughput.output_throughput_tok_s) .abs() <= epsilon ); assert!( (left.throughput.total_throughput_tok_s - right.throughput.total_throughput_tok_s) .abs() <= epsilon ); assert!( (left.prefix_cache_reused_ratio - right.prefix_cache_reused_ratio).abs() <= epsilon ); assert!((left.latency.ttft.mean_ms - right.latency.ttft.mean_ms).abs() <= epsilon); assert!((left.latency.ttft.min_ms - right.latency.ttft.min_ms).abs() <= epsilon); assert!((left.latency.ttft.max_ms - right.latency.ttft.max_ms).abs() <= epsilon); assert!((left.latency.ttft.median_ms - right.latency.ttft.median_ms).abs() <= epsilon); assert!((left.latency.ttft.p75_ms - right.latency.ttft.p75_ms).abs() <= epsilon); assert!((left.latency.ttft.p90_ms - right.latency.ttft.p90_ms).abs() <= epsilon); assert!((left.latency.ttft.p95_ms - right.latency.ttft.p95_ms).abs() <= epsilon); assert!((left.latency.ttft.p99_ms - right.latency.ttft.p99_ms).abs() <= epsilon); assert!((left.latency.ttft.std_ms - right.latency.ttft.std_ms).abs() <= epsilon); assert!((left.latency.ttst.mean_ms - right.latency.ttst.mean_ms).abs() <= epsilon); assert!((left.latency.ttst.min_ms - right.latency.ttst.min_ms).abs() <= epsilon); assert!((left.latency.ttst.max_ms - right.latency.ttst.max_ms).abs() <= epsilon); assert!((left.latency.ttst.median_ms - right.latency.ttst.median_ms).abs() <= epsilon); assert!((left.latency.ttst.p75_ms - right.latency.ttst.p75_ms).abs() <= epsilon); assert!((left.latency.ttst.p90_ms - right.latency.ttst.p90_ms).abs() <= epsilon); assert!((left.latency.ttst.p95_ms - right.latency.ttst.p95_ms).abs() <= epsilon); assert!((left.latency.ttst.p99_ms - right.latency.ttst.p99_ms).abs() <= epsilon); assert!((left.latency.ttst.std_ms - right.latency.ttst.std_ms).abs() <= epsilon); assert!((left.latency.tpot.mean_ms - right.latency.tpot.mean_ms).abs() <= epsilon); assert!((left.latency.tpot.min_ms - right.latency.tpot.min_ms).abs() <= epsilon); assert!((left.latency.tpot.max_ms - right.latency.tpot.max_ms).abs() <= epsilon); assert!((left.latency.tpot.median_ms - right.latency.tpot.median_ms).abs() <= epsilon); assert!((left.latency.tpot.p75_ms - right.latency.tpot.p75_ms).abs() <= epsilon); assert!((left.latency.tpot.p90_ms - right.latency.tpot.p90_ms).abs() <= epsilon); assert!((left.latency.tpot.p95_ms - right.latency.tpot.p95_ms).abs() <= epsilon); assert!((left.latency.tpot.p99_ms - right.latency.tpot.p99_ms).abs() <= epsilon); assert!((left.latency.tpot.std_ms - right.latency.tpot.std_ms).abs() <= epsilon); assert!( (left.latency.itl.distribution.mean_ms - right.latency.itl.distribution.mean_ms).abs() <= epsilon ); assert!( (left.latency.itl.distribution.min_ms - right.latency.itl.distribution.min_ms).abs() <= epsilon ); assert!( (left.latency.itl.distribution.max_ms - right.latency.itl.distribution.max_ms).abs() <= epsilon ); assert!( (left.latency.itl.distribution.median_ms - right.latency.itl.distribution.median_ms) .abs() <= epsilon ); assert!( (left.latency.itl.distribution.p75_ms - right.latency.itl.distribution.p75_ms).abs() <= epsilon ); assert!( (left.latency.itl.distribution.p90_ms - right.latency.itl.distribution.p90_ms).abs() <= epsilon ); assert!( (left.latency.itl.distribution.p95_ms - right.latency.itl.distribution.p95_ms).abs() <= epsilon ); assert!( (left.latency.itl.distribution.p99_ms - right.latency.itl.distribution.p99_ms).abs() <= epsilon ); assert!( (left.latency.itl.distribution.std_ms - right.latency.itl.distribution.std_ms).abs() <= epsilon ); assert!((left.latency.itl.max_ms - right.latency.itl.max_ms).abs() <= epsilon); assert!((left.latency.e2e.mean_ms - right.latency.e2e.mean_ms).abs() <= epsilon); assert!((left.latency.e2e.min_ms - right.latency.e2e.min_ms).abs() <= epsilon); assert!((left.latency.e2e.max_ms - right.latency.e2e.max_ms).abs() <= epsilon); assert!((left.latency.e2e.median_ms - right.latency.e2e.median_ms).abs() <= epsilon); assert!((left.latency.e2e.p75_ms - right.latency.e2e.p75_ms).abs() <= epsilon); assert!((left.latency.e2e.p90_ms - right.latency.e2e.p90_ms).abs() <= epsilon); assert!((left.latency.e2e.p95_ms - right.latency.e2e.p95_ms).abs() <= epsilon); assert!((left.latency.e2e.p99_ms - right.latency.e2e.p99_ms).abs() <= epsilon); assert!((left.latency.e2e.std_ms - right.latency.e2e.std_ms).abs() <= epsilon); assert!( (left.latency.output_token_throughput_per_user.mean_ms - right.latency.output_token_throughput_per_user.mean_ms) .abs() <= epsilon ); assert!( (left.latency.output_token_throughput_per_user.min_ms - right.latency.output_token_throughput_per_user.min_ms) .abs() <= epsilon ); assert!( (left.latency.output_token_throughput_per_user.max_ms - right.latency.output_token_throughput_per_user.max_ms) .abs() <= epsilon ); assert!( (left.latency.output_token_throughput_per_user.median_ms - right.latency.output_token_throughput_per_user.median_ms) .abs() <= epsilon ); assert!( (left.latency.output_token_throughput_per_user.p75_ms - right.latency.output_token_throughput_per_user.p75_ms) .abs() <= epsilon ); assert!( (left.latency.output_token_throughput_per_user.p90_ms - right.latency.output_token_throughput_per_user.p90_ms) .abs() <= epsilon ); assert!( (left.latency.output_token_throughput_per_user.p95_ms - right.latency.output_token_throughput_per_user.p95_ms) .abs() <= epsilon ); assert!( (left.latency.output_token_throughput_per_user.p99_ms - right.latency.output_token_throughput_per_user.p99_ms) .abs() <= epsilon ); assert!( (left.latency.output_token_throughput_per_user.std_ms - right.latency.output_token_throughput_per_user.std_ms) .abs() <= epsilon ); } #[rstest] #[case(false, false)] #[case(false, true)] #[case(true, false)] #[case(true, true)] fn test_trace_replay_matches_manual_steps( #[case] enable_prefix_caching: bool, #[case] enable_chunked_prefill: bool, ) { let args = replay_args(enable_prefix_caching, enable_chunked_prefill); let manual = run_trace_manually(&args, replay_fixture()); let replay_report = simulate_trace_single(args, replay_fixture(), 1.0).unwrap(); let request_1 = manual.snapshots.get(&Uuid::from_u128(11)).unwrap(); let request_2 = manual.snapshots.get(&Uuid::from_u128(22)).unwrap(); let request_3 = manual.snapshots.get(&Uuid::from_u128(33)).unwrap(); assert_eq!(request_1.arrival_time_ms, 0.0); assert_eq!(request_2.arrival_time_ms, 1.0); assert_eq!(request_3.arrival_time_ms, 400.0); assert_eq!(manual.idle_jump_ms, 400.0); assert_eq!( request_1.first_token_ms.unwrap(), manual.first_decode_end_ms ); assert!(request_2.first_admit_ms.unwrap() >= request_2.arrival_time_ms); assert!(request_3.first_admit_ms.unwrap() >= request_3.arrival_time_ms); assert!(manual.report.latency.e2e.mean_ms >= manual.report.latency.ttft.mean_ms); if enable_prefix_caching { assert!(request_2.reused_input_tokens > 0); assert!(manual.report.prefix_cache_reused_ratio > 0.0); } else { assert_eq!(request_2.reused_input_tokens, 0); assert_eq!(manual.report.prefix_cache_reused_ratio, 0.0); } assert_report_close(&replay_report, &manual.report); } #[test] fn test_concurrency_replay_matches_manual_steps() { let args = replay_args(false, false); let requests = vec![ DirectRequest { tokens: vec![1, 2, 3, 4, 5, 6, 7, 8], max_output_tokens: 2, uuid: Some(Uuid::from_u128(11)), dp_rank: 0, arrival_timestamp_ms: Some(900.0), }, DirectRequest { tokens: vec![1, 2, 3, 4, 5, 9, 10, 11], max_output_tokens: 2, uuid: Some(Uuid::from_u128(22)), dp_rank: 0, arrival_timestamp_ms: Some(1000.0), }, DirectRequest { tokens: vec![12, 13, 14, 15, 16, 17, 18, 19], max_output_tokens: 2, uuid: Some(Uuid::from_u128(33)), dp_rank: 0, arrival_timestamp_ms: Some(100.0), }, ]; let manual = run_concurrency_manually(&args, requests.clone(), 2); let replay_report = simulate_concurrency_single(args, requests, 2).unwrap(); let request_1 = manual.snapshots.get(&Uuid::from_u128(11)).unwrap(); let request_2 = manual.snapshots.get(&Uuid::from_u128(22)).unwrap(); let request_3 = manual.snapshots.get(&Uuid::from_u128(33)).unwrap(); assert_eq!(request_1.arrival_time_ms, 0.0); assert_eq!(request_2.arrival_time_ms, 0.0); assert_eq!(request_3.arrival_time_ms, request_1.last_token_ms.unwrap()); assert!(request_3.arrival_time_ms < request_2.last_token_ms.unwrap()); assert_eq!(manual.report.request_counts.completed_requests, 3); assert_eq!(manual.report.request_counts.total_input_tokens, 24); assert_eq!(manual.report.request_counts.total_output_tokens, 6); assert_report_close(&replay_report, &manual.report); } }