Unverified Commit b2c59aa4 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(replay): add shared loadgen workload paths [DYN-2510] (#7593)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 2b36b175
......@@ -360,6 +360,22 @@ impl TraceCollector {
reused_input_tokens: stats.reused_input_tokens,
})
}
#[cfg(test)]
pub(crate) fn snapshots(&self) -> Vec<TraceRequestStatsSnapshot> {
self.requests
.values()
.map(|stats| TraceRequestStatsSnapshot {
arrival_time_ms: stats.arrival_time_ms,
first_admit_ms: stats.first_admit_ms,
first_token_ms: stats.first_token_ms(),
last_token_ms: stats.last_token_ms(),
input_length: stats.input_length,
output_length: stats.output_length,
reused_input_tokens: stats.reused_input_tokens,
})
.collect()
}
}
fn mean(values: &[f64]) -> f64 {
......
......@@ -7,7 +7,6 @@ use std::time::Instant;
use anyhow::{Result, bail};
use dynamo_kv_router::config::KvRouterConfig;
use super::loader::load_trace_requests;
use super::online;
use super::validate::{
validate_offline_concurrency_args, validate_offline_replay_args,
......@@ -15,6 +14,7 @@ use super::validate::{
};
use super::{ReplayRouterMode, TraceSimulationReport};
use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::loadgen::Trace;
pub fn simulate_trace_file(
args: MockEngineArgs,
......@@ -42,14 +42,15 @@ pub fn simulate_trace_file_with_router_mode(
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
validate_offline_replay_args(&args, num_workers, router_mode)?;
let requests = load_trace_requests(trace_path, args.block_size, true)?;
let trace = Trace::from_mooncake(trace_path, args.block_size)?
.normalize_session_starts()?
.speed_up_timing(arrival_speedup_ratio)?;
let started_at = Instant::now();
let report = crate::replay::offline::simulate_trace(
let report = crate::replay::offline::simulate_trace_workload(
args,
router_config,
requests,
trace,
num_workers,
arrival_speedup_ratio,
router_mode,
)?;
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
......@@ -81,15 +82,10 @@ pub fn simulate_trace_live_file_with_router_mode(
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
validate_online_replay_args(&args, num_workers)?;
let requests = load_trace_requests(trace_path, args.block_size, true)?;
online::simulate_trace_requests(
args,
router_config,
requests,
num_workers,
arrival_speedup_ratio,
router_mode,
)
let trace = Trace::from_mooncake(trace_path, args.block_size)?
.normalize_session_starts()?
.speed_up_timing(arrival_speedup_ratio)?;
online::simulate_trace_workload(args, router_config, trace, num_workers, router_mode)
}
pub fn simulate_trace_requests(
......@@ -199,12 +195,13 @@ pub fn simulate_concurrency_file_with_router_mode(
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let requests = load_trace_requests(trace_path, args.block_size, false)?;
validate_offline_concurrency_args(&args, num_workers, max_in_flight, router_mode)?;
let trace = Trace::from_mooncake(trace_path, args.block_size)?;
let started_at = Instant::now();
let report = simulate_concurrency_requests_with_router_mode(
let report = simulate_concurrency_workload_with_router_mode(
args,
router_config,
requests,
trace,
max_in_flight,
num_workers,
router_mode,
......@@ -238,11 +235,11 @@ pub fn simulate_concurrency_live_file_with_router_mode(
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
validate_online_concurrency_args(&args, num_workers, max_in_flight)?;
let requests = load_trace_requests(trace_path, args.block_size, false)?;
online::simulate_concurrency_requests(
let trace = Trace::from_mooncake(trace_path, args.block_size)?;
online::simulate_concurrency_workload(
args,
router_config,
requests,
trace,
max_in_flight,
num_workers,
router_mode,
......@@ -328,3 +325,135 @@ pub fn simulate_concurrency_requests_with_router_mode(
router_mode,
)
}
pub fn simulate_trace_workload(
args: MockEngineArgs,
trace: Trace,
num_workers: usize,
) -> Result<TraceSimulationReport> {
simulate_trace_workload_with_router_mode(
args,
None,
trace,
num_workers,
ReplayRouterMode::RoundRobin,
)
}
pub fn simulate_trace_workload_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
validate_offline_replay_args(&args, num_workers, router_mode)?;
let started_at = Instant::now();
let report = crate::replay::offline::simulate_trace_workload(
args,
router_config,
trace,
num_workers,
router_mode,
)?;
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
}
pub fn simulate_trace_live_workload(
args: MockEngineArgs,
trace: Trace,
num_workers: usize,
) -> Result<TraceSimulationReport> {
simulate_trace_live_workload_with_router_mode(
args,
None,
trace,
num_workers,
ReplayRouterMode::RoundRobin,
)
}
pub fn simulate_trace_live_workload_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
validate_online_replay_args(&args, num_workers)?;
online::simulate_trace_workload(args, router_config, trace, num_workers, router_mode)
}
pub fn simulate_concurrency_workload(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
) -> Result<TraceSimulationReport> {
simulate_concurrency_workload_with_router_mode(
args,
None,
trace,
max_in_flight,
num_workers,
ReplayRouterMode::RoundRobin,
)
}
pub fn simulate_concurrency_workload_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
validate_offline_concurrency_args(&args, num_workers, max_in_flight, router_mode)?;
crate::replay::offline::simulate_concurrency_workload(
args,
router_config,
trace,
max_in_flight,
num_workers,
router_mode,
)
}
pub fn simulate_concurrency_live_workload(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
) -> Result<TraceSimulationReport> {
simulate_concurrency_live_workload_with_router_mode(
args,
None,
trace,
max_in_flight,
num_workers,
ReplayRouterMode::RoundRobin,
)
}
pub fn simulate_concurrency_live_workload_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
validate_online_concurrency_args(&args, num_workers, max_in_flight)?;
online::simulate_concurrency_workload(
args,
router_config,
trace,
max_in_flight,
num_workers,
router_mode,
)
}
......@@ -3,7 +3,6 @@
mod collector;
mod entrypoints;
mod loader;
pub(crate) mod offline;
mod online;
mod router;
......@@ -30,11 +29,15 @@ pub use entrypoints::{
simulate_concurrency_file, simulate_concurrency_file_with_router_mode,
simulate_concurrency_live_file, simulate_concurrency_live_file_with_router_mode,
simulate_concurrency_live_requests, simulate_concurrency_live_requests_with_router_mode,
simulate_concurrency_live_workload, simulate_concurrency_live_workload_with_router_mode,
simulate_concurrency_requests, simulate_concurrency_requests_with_router_mode,
simulate_concurrency_workload, simulate_concurrency_workload_with_router_mode,
simulate_trace_file, simulate_trace_file_with_router_mode, simulate_trace_live_file,
simulate_trace_live_file_with_router_mode, simulate_trace_live_requests,
simulate_trace_live_requests_with_router_mode, simulate_trace_requests,
simulate_trace_requests_with_router_mode,
simulate_trace_live_requests_with_router_mode, simulate_trace_live_workload,
simulate_trace_live_workload_with_router_mode, simulate_trace_requests,
simulate_trace_requests_with_router_mode, simulate_trace_workload,
simulate_trace_workload_with_router_mode,
};
pub(crate) fn normalize_trace_requests(
......
......@@ -9,7 +9,7 @@ The goal is to simulate trace execution without spinning up async runtimes, netw
The public replay entrypoints live one level up in `lib/mocker/src/replay/entrypoints.rs`. They:
- normalize `MockEngineArgs`
- load or accept `DirectRequest`s
- load or accept `DirectRequest`s or `loadgen::Trace` workloads
- validate replay arguments
- dispatch to offline or online replay
......@@ -42,7 +42,10 @@ The single-worker path is intentionally simple and only used when:
- `num_workers == 1`
- engine type is `vllm`
That path avoids the cluster event queue and router machinery entirely.
That path avoids the cluster event queue and router machinery entirely, but it now supports both:
- flat request replay
- workload-driven replay through `WorkloadDriver` for multi-turn/session traces
```mermaid
flowchart TD
......@@ -63,6 +66,8 @@ Important details:
- Trace mode uses `normalize_trace_requests` in `lib/mocker/src/replay/mod.rs` so the first request starts at `0 ms`, then applies `arrival_speedup_ratio`.
- Concurrency mode ignores original arrival spacing and keeps the worker filled up to `max_in_flight`.
- Workload trace mode honors first-turn timestamps and inter-turn delays.
- Workload concurrency mode ignores first-turn timestamps but still enforces inter-turn delays after completion.
- The worker itself is still the real mocker engine core; only the scheduling loop is simplified.
## Multi-Worker Harness
......@@ -178,13 +183,15 @@ In round-robin mode, this capture is skipped because nothing consumes those even
Both single and multi harnesses support two admission modes:
- Trace mode
- respects input arrival timestamps
- timestamps are normalized so the first request starts at `0 ms`
- `arrival_speedup_ratio` compresses or stretches inter-arrival gaps
- for flat requests, respects input arrival timestamps
- for workloads, respects first-turn timestamps and inter-turn delays
- timestamps are normalized so the first request or first session starts at `0 ms`
- `arrival_speedup_ratio` compresses or stretches inter-arrival gaps and inter-turn delays
- Concurrency mode
- ignores original spacing
- ignores original first-turn spacing
- keeps up to `max_in_flight` requests resident in the cluster
- for workloads, still unlocks follow-up turns only after completion plus inter-turn delay
- stamps synthetic arrival times as requests are admitted
This split is why `lib/mocker/src/replay/offline/mod.rs` exposes both:
......
......@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::loadgen::Trace;
pub(crate) use crate::replay::normalize_trace_requests;
use crate::replay::{ReplayRouterMode, TraceSimulationReport};
use dynamo_kv_router::config::KvRouterConfig;
......@@ -55,3 +56,39 @@ pub(crate) fn simulate_concurrency(
)
}
}
pub(crate) fn simulate_trace_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == crate::common::protocols::EngineType::Vllm {
single::simulate_trace_workload_single(args, trace)
} else {
multi::simulate_trace_workload_multi(args, router_config, trace, num_workers, router_mode)
}
}
pub(crate) fn simulate_concurrency_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == crate::common::protocols::EngineType::Vllm {
single::simulate_concurrency_workload_single(args, trace, max_in_flight)
} else {
multi::simulate_concurrency_workload_multi(
args,
router_config,
trace,
max_in_flight,
num_workers,
router_mode,
)
}
}
This diff is collapsed.
......@@ -4,6 +4,7 @@
use super::core::ReplayWorkerCore;
use super::normalize_trace_requests;
use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::loadgen::{Trace, WorkloadDriver};
use crate::replay::{TraceCollector, TraceSimulationReport};
use anyhow::bail;
use std::collections::VecDeque;
......@@ -15,9 +16,14 @@ enum SingleReplayMode {
Concurrency { max_in_flight: usize },
}
enum AdmissionSource {
Requests(VecDeque<DirectRequest>),
Workload(WorkloadDriver),
}
struct SingleRuntime {
current_time_ms: f64,
pending: VecDeque<DirectRequest>,
admission: AdmissionSource,
worker: ReplayWorkerCore,
collector: TraceCollector,
mode: SingleReplayMode,
......@@ -25,9 +31,21 @@ struct SingleRuntime {
impl SingleRuntime {
fn new(args: MockEngineArgs, pending: VecDeque<DirectRequest>, mode: SingleReplayMode) -> Self {
Self::new_with_source(args, AdmissionSource::Requests(pending), mode)
}
fn new_workload(args: MockEngineArgs, driver: WorkloadDriver, mode: SingleReplayMode) -> Self {
Self::new_with_source(args, AdmissionSource::Workload(driver), mode)
}
fn new_with_source(
args: MockEngineArgs,
admission: AdmissionSource,
mode: SingleReplayMode,
) -> Self {
Self {
current_time_ms: 0.0,
pending,
admission,
worker: ReplayWorkerCore::new(args),
collector: TraceCollector::default(),
mode,
......@@ -35,36 +53,67 @@ impl SingleRuntime {
}
fn enqueue_trace_arrivals(&mut self) {
loop {
let Some(next_arrival_ms) = self
.pending
.front()
.and_then(|request| request.arrival_timestamp_ms)
else {
break;
};
if next_arrival_ms > self.current_time_ms {
break;
let mut ready_requests = Vec::new();
match &mut self.admission {
AdmissionSource::Requests(pending) => loop {
let Some(next_arrival_ms) = pending
.front()
.and_then(|request| request.arrival_timestamp_ms)
else {
break;
};
if next_arrival_ms > self.current_time_ms {
break;
}
let request = pending
.pop_front()
.expect("front request must exist when arrival is available");
let arrival_ms = request
.arrival_timestamp_ms
.expect("trace replay requests must have an arrival timestamp");
ready_requests.push((request, arrival_ms));
},
AdmissionSource::Workload(driver) => {
ready_requests.extend(
driver
.pop_ready(self.current_time_ms, usize::MAX)
.into_iter()
.map(|ready| (ready.request, ready.scheduled_ready_at_ms)),
);
}
}
let request = self
.pending
.pop_front()
.expect("front request must exist when arrival is available");
let arrival_ms = request
.arrival_timestamp_ms
.expect("trace replay requests must have an arrival timestamp");
for (request, arrival_ms) in ready_requests {
self.record_arrival(request, arrival_ms);
}
}
fn enqueue_concurrency_arrivals(&mut self, max_in_flight: usize) {
while self.worker.num_requests() < max_in_flight {
let Some(mut request) = self.pending.pop_front() else {
break;
};
let available = max_in_flight.saturating_sub(self.worker.num_requests());
let mut ready_requests = Vec::new();
request.arrival_timestamp_ms = Some(self.current_time_ms);
match &mut self.admission {
AdmissionSource::Requests(pending) => {
for _ in 0..available {
let Some(mut request) = pending.pop_front() else {
break;
};
request.arrival_timestamp_ms = Some(self.current_time_ms);
ready_requests.push(request);
}
}
AdmissionSource::Workload(driver) => {
ready_requests.extend(
driver
.pop_ready(self.current_time_ms, available)
.into_iter()
.map(|ready| ready.request),
);
}
}
for request in ready_requests {
self.record_arrival(request, self.current_time_ms);
}
}
......@@ -79,15 +128,21 @@ impl SingleRuntime {
}
fn is_done(&self) -> bool {
self.pending.is_empty() && self.worker.is_empty()
self.worker.is_empty()
&& match &self.admission {
AdmissionSource::Requests(pending) => pending.is_empty(),
AdmissionSource::Workload(driver) => driver.is_drained(),
}
}
fn advance_to_next_trace_arrival(&mut self) -> anyhow::Result<()> {
let Some(next_arrival_ms) = self
.pending
.front()
.and_then(|request| request.arrival_timestamp_ms)
else {
let next_arrival_ms = match &mut self.admission {
AdmissionSource::Requests(pending) => pending
.front()
.and_then(|request| request.arrival_timestamp_ms),
AdmissionSource::Workload(driver) => driver.next_ready_time_ms(),
};
let Some(next_arrival_ms) = next_arrival_ms else {
bail!("trace replay reached an idle state without a pending arrival");
};
self.current_time_ms = next_arrival_ms;
......@@ -99,6 +154,13 @@ impl SingleRuntime {
.worker
.execute_pass(&mut self.collector, self.current_time_ms);
self.current_time_ms = pass.end_ms;
if let AdmissionSource::Workload(driver) = &mut self.admission {
for signal in pass.output_signals.iter().filter(|signal| signal.completed) {
driver
.on_complete(signal.uuid, self.current_time_ms)
.expect("completed workload request must belong to a session");
}
}
if admit_arrivals_between_steps {
self.enqueue_trace_arrivals();
}
......@@ -119,7 +181,11 @@ impl SingleRuntime {
SingleReplayMode::Concurrency { max_in_flight } => {
self.enqueue_concurrency_arrivals(max_in_flight);
if self.worker.is_empty() {
break;
if self.is_done() {
break;
}
self.advance_to_next_trace_arrival()?;
continue;
}
self.drive_worker(false);
}
......@@ -157,6 +223,32 @@ pub(crate) fn simulate_concurrency_single(
Ok(collector.finish())
}
pub(crate) fn simulate_trace_workload_single(
args: MockEngineArgs,
trace: Trace,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let collector =
SingleRuntime::new_workload(args, trace.into_trace_driver()?, SingleReplayMode::Trace)
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_workload_single(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let collector = SingleRuntime::new_workload(
args,
trace.into_concurrency_driver()?,
SingleReplayMode::Concurrency { max_in_flight },
)
.run()?;
Ok(collector.finish())
}
#[cfg(test)]
pub(super) fn run_trace_single_collect(
args: MockEngineArgs,
......@@ -184,9 +276,39 @@ pub(super) fn run_concurrency_single_collect(
.unwrap()
}
#[cfg(test)]
pub(super) fn run_trace_workload_single_collect(
args: MockEngineArgs,
trace: Trace,
) -> TraceCollector {
SingleRuntime::new_workload(
args,
trace.into_trace_driver().unwrap(),
SingleReplayMode::Trace,
)
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_concurrency_workload_single_collect(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
) -> TraceCollector {
SingleRuntime::new_workload(
args,
trace.into_concurrency_driver().unwrap(),
SingleReplayMode::Concurrency { max_in_flight },
)
.run()
.unwrap()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::loadgen::{SessionTrace, TurnTrace};
use crate::replay::{TraceRequestStatsSnapshot, TraceSimulationReport};
use rstest::rstest;
use std::collections::{HashMap, VecDeque};
......@@ -295,6 +417,42 @@ mod tests {
]
}
fn multiturn_trace_fixture() -> Trace {
Trace {
block_size: 1,
sessions: vec![
SessionTrace {
session_id: "session-a".to_string(),
first_arrival_timestamp_ms: Some(0.0),
turns: vec![
TurnTrace {
input_length: 3,
max_output_tokens: 2,
hash_ids: vec![1, 2, 3],
delay_after_previous_ms: 0.0,
},
TurnTrace {
input_length: 5,
max_output_tokens: 2,
hash_ids: vec![4, 5, 6, 7, 8],
delay_after_previous_ms: 5.0,
},
],
},
SessionTrace {
session_id: "session-b".to_string(),
first_arrival_timestamp_ms: Some(1.0),
turns: vec![TurnTrace {
input_length: 4,
max_output_tokens: 2,
hash_ids: vec![9, 10, 11, 12],
delay_after_previous_ms: 0.0,
}],
},
],
}
}
fn run_trace_manually(
args: &MockEngineArgs,
requests: Vec<DirectRequest>,
......@@ -674,4 +832,44 @@ mod tests {
assert_report_close(&replay_report, &manual.report);
}
#[test]
fn test_trace_workload_single_unlocks_follow_up_turn_after_completion() {
let args = replay_args(false, true);
let collector = run_trace_workload_single_collect(args, multiturn_trace_fixture());
let snapshots = collector.snapshots();
let first = snapshots
.iter()
.find(|stats| stats.input_length == 3)
.unwrap();
let second = snapshots
.iter()
.find(|stats| stats.input_length == 5)
.unwrap();
let other = snapshots
.iter()
.find(|stats| stats.input_length == 4)
.unwrap();
assert_eq!(first.arrival_time_ms, 0.0);
assert_eq!(other.arrival_time_ms, 1.0);
assert!(second.arrival_time_ms >= first.last_token_ms.unwrap() + 5.0);
}
#[test]
fn test_concurrency_workload_single_ignores_first_turn_timestamps_but_keeps_delay() {
let args = replay_args(false, true);
let collector = run_concurrency_workload_single_collect(args, multiturn_trace_fixture(), 1);
let arrival_times = collector
.snapshots()
.into_iter()
.map(|stats| stats.arrival_time_ms)
.collect::<Vec<_>>();
let report = collector.finish();
assert!(arrival_times.contains(&0.0));
assert!(arrival_times.iter().all(|arrival| *arrival >= 0.0));
assert_eq!(report.request_counts.completed_requests, 3);
}
}
......@@ -12,6 +12,15 @@ pub(crate) struct OfflineWorkerState {
in_flight: usize,
}
#[cfg(test)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct OfflineWorkerSnapshot {
pub(crate) busy: bool,
pub(crate) in_flight: usize,
pub(crate) ready: bool,
pub(crate) drained: bool,
}
impl OfflineWorkerState {
pub(crate) fn new(worker_idx: usize, args: MockEngineArgs, capture_kv_events: bool) -> Self {
let core = match args.engine_type {
......@@ -81,4 +90,14 @@ impl OfflineWorkerState {
) -> EnginePassResult {
self.core.execute_pass(collector, now_ms)
}
#[cfg(test)]
pub(crate) fn debug_snapshot(&self) -> OfflineWorkerSnapshot {
OfflineWorkerSnapshot {
busy: self.busy,
in_flight: self.in_flight,
ready: self.is_ready(),
drained: self.is_drained(),
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::time::Instant;
use crate::common::protocols::OutputSignal;
use crate::replay::router::ReplayRouter;
use crate::replay::{TraceCollector, TraceSimulationReport};
use crate::scheduler::AdmissionEvent;
use super::state::{ArrivalEvent, RequestRegistry, SharedLiveRuntimeStats, now_ms};
pub(super) async fn run_demux(
start: Instant,
mut arrival_rx: mpsc::UnboundedReceiver<ArrivalEvent>,
mut admission_rx: mpsc::UnboundedReceiver<AdmissionEvent>,
mut output_rx: mpsc::UnboundedReceiver<OutputSignal>,
requests: RequestRegistry,
router: Arc<ReplayRouter>,
stats: Arc<SharedLiveRuntimeStats>,
) -> TraceSimulationReport {
let mut collector = TraceCollector::default();
let mut arrivals_open = true;
let mut admissions_open = true;
let mut outputs_open = true;
loop {
if !arrivals_open && !admissions_open && !outputs_open {
break;
}
tokio::select! {
biased;
arrival = arrival_rx.recv(), if arrivals_open => {
match arrival {
Some(arrival) => collector.on_arrival(
arrival.uuid,
arrival.at_ms,
arrival.input_tokens,
arrival.output_tokens,
),
None => arrivals_open = false,
}
}
admission = admission_rx.recv(), if admissions_open => {
match admission {
Some(admission) => {
collector.on_admit(admission.uuid, now_ms(start), admission.reused_input_tokens);
}
None => admissions_open = false,
}
}
output = output_rx.recv(), if outputs_open => {
match output {
Some(output) => {
collector.on_token(output.uuid, now_ms(start));
if let Some(state) = requests.get(&output.uuid) {
if state.mark_first_token_once() {
match router.on_first_token(output.uuid).await {
Ok(true) => stats.record_prefill_marked(),
Ok(false) => {}
Err(error) => tracing::warn!(
uuid = %output.uuid,
error = %error,
"online replay failed to mark prefill completed"
),
}
}
if output.completed && state.mark_completed_once() {
match router.on_complete(output.uuid).await {
Ok(true) => stats.record_freed(),
Ok(false) => {}
Err(error) => tracing::warn!(
uuid = %output.uuid,
error = %error,
"online replay failed to free completed request"
),
}
state.notify_completion();
}
}
}
None => outputs_open = false,
}
}
}
}
collector.finish().with_wall_time_ms(now_ms(start))
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
mod runtime;
mod demux;
mod live_runtime;
mod state;
mod task;
pub(crate) use runtime::{simulate_concurrency_requests, simulate_trace_requests};
#[cfg(test)]
mod tests;
pub(crate) use live_runtime::{
simulate_concurrency_requests, simulate_concurrency_workload, simulate_trace_requests,
simulate_trace_workload,
};
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use anyhow::{Result, anyhow};
use dashmap::DashMap;
use tokio::sync::{Notify, mpsc};
use tokio::time::Instant;
use uuid::Uuid;
use crate::common::protocols::DirectRequest;
use crate::loadgen::WorkloadDriver;
#[derive(Clone, Copy, Debug)]
pub(super) enum LiveReplayMode {
Trace,
Concurrency { max_in_flight: usize },
}
#[derive(Debug, Default, PartialEq, Eq)]
pub(super) struct LiveRuntimeStats {
pub(super) dispatch_history: Vec<usize>,
pub(super) max_in_flight_seen: usize,
pub(super) prefill_marked_count: usize,
pub(super) freed_count: usize,
}
#[derive(Default)]
pub(super) struct SharedLiveRuntimeStats {
dispatch_history: Mutex<Vec<usize>>,
current_in_flight: AtomicUsize,
max_in_flight_seen: AtomicUsize,
prefill_marked_count: AtomicUsize,
freed_count: AtomicUsize,
}
impl SharedLiveRuntimeStats {
pub(super) fn record_dispatch(&self, worker_idx: usize) {
self.dispatch_history.lock().unwrap().push(worker_idx);
let current = self.current_in_flight.fetch_add(1, Ordering::AcqRel) + 1;
self.max_in_flight_seen.fetch_max(current, Ordering::AcqRel);
}
pub(super) fn record_completion(&self) {
self.current_in_flight.fetch_sub(1, Ordering::AcqRel);
}
pub(super) fn record_prefill_marked(&self) {
self.prefill_marked_count.fetch_add(1, Ordering::AcqRel);
}
pub(super) fn record_freed(&self) {
self.freed_count.fetch_add(1, Ordering::AcqRel);
}
pub(super) fn snapshot(&self) -> LiveRuntimeStats {
LiveRuntimeStats {
dispatch_history: self.dispatch_history.lock().unwrap().clone(),
max_in_flight_seen: self.max_in_flight_seen.load(Ordering::Acquire),
prefill_marked_count: self.prefill_marked_count.load(Ordering::Acquire),
freed_count: self.freed_count.load(Ordering::Acquire),
}
}
}
#[derive(Default)]
pub(super) struct RequestState {
first_token_seen: AtomicBool,
completed_seen: AtomicBool,
completion_notify: Notify,
}
impl RequestState {
pub(super) fn mark_first_token_once(&self) -> bool {
!self.first_token_seen.swap(true, Ordering::AcqRel)
}
pub(super) fn mark_completed_once(&self) -> bool {
!self.completed_seen.swap(true, Ordering::AcqRel)
}
pub(super) fn notify_completion(&self) {
self.completion_notify.notify_waiters();
}
pub(super) async fn wait_for_completion(&self) {
loop {
let notified = self.completion_notify.notified();
if self.completed_seen.load(Ordering::Acquire) {
return;
}
notified.await;
}
}
}
#[derive(Clone, Copy)]
pub(super) struct ArrivalEvent {
pub(super) uuid: Uuid,
pub(super) at_ms: f64,
pub(super) input_tokens: usize,
pub(super) output_tokens: usize,
}
pub(super) type RequestRegistry = Arc<DashMap<Uuid, Arc<RequestState>>>;
pub(super) struct WorkloadDispatchState {
pub(super) driver: Mutex<WorkloadDriver>,
pub(super) wakeup: Notify,
pub(super) start: Instant,
}
pub(super) fn now_ms(start: Instant) -> f64 {
start.elapsed().as_secs_f64() * 1000.0
}
pub(super) fn request_uuid(request: &DirectRequest) -> Result<Uuid> {
request
.uuid
.ok_or_else(|| anyhow!("online replay requires requests to have stable UUIDs"))
}
pub(super) fn record_arrival(
arrival_tx: &mpsc::UnboundedSender<ArrivalEvent>,
request: &DirectRequest,
arrival_at_ms: f64,
) -> Result<Uuid> {
let uuid = request_uuid(request)?;
let input_tokens = request.tokens.len();
let output_tokens = request.max_output_tokens;
arrival_tx
.send(ArrivalEvent {
uuid,
at_ms: arrival_at_ms,
input_tokens,
output_tokens,
})
.map_err(|_| anyhow!("online replay arrival channel closed"))?;
Ok(uuid)
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use anyhow::{Result, anyhow, bail};
use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc};
use tokio::time::Instant;
use crate::common::protocols::DirectRequest;
use crate::replay::router::ReplayRouter;
use super::state::{
LiveReplayMode, RequestRegistry, RequestState, SharedLiveRuntimeStats, WorkloadDispatchState,
now_ms, request_uuid,
};
#[derive(Clone)]
pub(super) struct RequestTaskContext {
pub(super) senders: Arc<[mpsc::UnboundedSender<DirectRequest>]>,
pub(super) router: Arc<ReplayRouter>,
pub(super) requests: RequestRegistry,
pub(super) stats: Arc<SharedLiveRuntimeStats>,
pub(super) workload: Option<Arc<WorkloadDispatchState>>,
}
pub(super) async fn wait_for_workload_progress<F>(
mode: LiveReplayMode,
semaphore: Option<&Semaphore>,
next_ready_ms: Option<f64>,
start: Instant,
mut wake: Pin<&mut F>,
) where
F: Future<Output = ()>,
{
match (mode, semaphore, next_ready_ms) {
(LiveReplayMode::Trace, _, Some(next_ready_ms)) => {
let deadline = start + tokio::time::Duration::from_secs_f64(next_ready_ms / 1000.0);
tokio::select! {
_ = tokio::time::sleep_until(deadline) => {}
_ = wake.as_mut() => {}
}
}
(LiveReplayMode::Trace, _, None) => {
wake.as_mut().await;
}
(LiveReplayMode::Concurrency { .. }, Some(semaphore), Some(next_ready_ms)) => {
if semaphore.available_permits() == 0 {
wake.as_mut().await;
} else {
let deadline = start + tokio::time::Duration::from_secs_f64(next_ready_ms / 1000.0);
tokio::select! {
_ = tokio::time::sleep_until(deadline) => {}
_ = wake.as_mut() => {}
}
}
}
(LiveReplayMode::Concurrency { .. }, Some(_semaphore), None) => {
wake.as_mut().await;
}
(LiveReplayMode::Concurrency { .. }, None, _) => {
unreachable!("concurrency mode must have a semaphore");
}
}
}
pub(super) async fn run_request_task(
ctx: RequestTaskContext,
request: DirectRequest,
permit: Option<OwnedSemaphorePermit>,
) -> Result<()> {
let uuid = request_uuid(&request)?;
let worker_idx = ctx
.router
.select_worker(&request, ctx.senders.len())
.await?;
if worker_idx >= ctx.senders.len() {
bail!("online replay selected unknown worker index {worker_idx}");
}
let state = Arc::new(RequestState::default());
ctx.requests.insert(uuid, Arc::clone(&state));
if let Err(error) = ctx.senders[worker_idx].send(request) {
ctx.requests.remove(&uuid);
return Err(anyhow!(
"online replay failed to dispatch request to worker {worker_idx}: {error}"
));
}
ctx.stats.record_dispatch(worker_idx);
state.wait_for_completion().await;
ctx.stats.record_completion();
ctx.requests.remove(&uuid);
if let Some(workload) = ctx.workload.as_ref() {
let completion_ms = now_ms(workload.start);
workload
.driver
.lock()
.unwrap()
.on_complete(uuid, completion_ms)?;
workload.wakeup.notify_waiters();
}
drop(permit);
Ok(())
}
This diff is collapsed.
......@@ -6,4 +6,6 @@ mod online;
mod shared;
pub(crate) use offline::OfflineReplayRouter;
#[cfg(test)]
pub(crate) use offline::OfflineRouterSnapshot;
pub(crate) use online::ReplayRouter;
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment