// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use std::sync::Arc; use std::sync::Mutex; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use anyhow::{Result, anyhow}; use dashmap::DashMap; use tokio::sync::{Notify, mpsc}; use tokio::time::Instant; use uuid::Uuid; use crate::common::protocols::DirectRequest; use crate::loadgen::WorkloadDriver; #[derive(Clone, Copy, Debug)] pub(super) enum LiveReplayMode { Trace, Concurrency { max_in_flight: usize }, } #[derive(Debug, Default, PartialEq, Eq)] pub(super) struct LiveRuntimeStats { pub(super) dispatch_history: Vec, pub(super) max_in_flight_seen: usize, pub(super) prefill_marked_count: usize, pub(super) freed_count: usize, } #[derive(Default)] pub(super) struct SharedLiveRuntimeStats { dispatch_history: Mutex>, current_in_flight: AtomicUsize, max_in_flight_seen: AtomicUsize, prefill_marked_count: AtomicUsize, freed_count: AtomicUsize, } impl SharedLiveRuntimeStats { pub(super) fn record_dispatch(&self, worker_idx: usize) { self.dispatch_history.lock().unwrap().push(worker_idx); let current = self.current_in_flight.fetch_add(1, Ordering::AcqRel) + 1; self.max_in_flight_seen.fetch_max(current, Ordering::AcqRel); } pub(super) fn record_completion(&self) { self.current_in_flight.fetch_sub(1, Ordering::AcqRel); } pub(super) fn record_prefill_marked(&self) { self.prefill_marked_count.fetch_add(1, Ordering::AcqRel); } pub(super) fn record_freed(&self) { self.freed_count.fetch_add(1, Ordering::AcqRel); } pub(super) fn snapshot(&self) -> LiveRuntimeStats { LiveRuntimeStats { dispatch_history: self.dispatch_history.lock().unwrap().clone(), max_in_flight_seen: self.max_in_flight_seen.load(Ordering::Acquire), prefill_marked_count: self.prefill_marked_count.load(Ordering::Acquire), freed_count: self.freed_count.load(Ordering::Acquire), } } } #[derive(Default)] pub(super) struct RequestState { first_token_seen: AtomicBool, completed_seen: AtomicBool, completion_notify: Notify, } impl RequestState { pub(super) fn mark_first_token_once(&self) -> bool { !self.first_token_seen.swap(true, Ordering::AcqRel) } pub(super) fn mark_completed_once(&self) -> bool { !self.completed_seen.swap(true, Ordering::AcqRel) } pub(super) fn notify_completion(&self) { self.completion_notify.notify_waiters(); } pub(super) async fn wait_for_completion(&self) { loop { let notified = self.completion_notify.notified(); if self.completed_seen.load(Ordering::Acquire) { return; } notified.await; } } } #[derive(Clone, Copy)] pub(super) struct ArrivalEvent { pub(super) uuid: Uuid, pub(super) at_ms: f64, pub(super) input_tokens: usize, pub(super) output_tokens: usize, } pub(super) type RequestRegistry = Arc>>; pub(super) struct WorkloadDispatchState { pub(super) driver: Mutex, pub(super) wakeup: Notify, pub(super) start: Instant, } pub(super) fn now_ms(start: Instant) -> f64 { start.elapsed().as_secs_f64() * 1000.0 } pub(super) fn request_uuid(request: &DirectRequest) -> Result { request .uuid .ok_or_else(|| anyhow!("online replay requires requests to have stable UUIDs")) } pub(super) fn record_arrival( arrival_tx: &mpsc::UnboundedSender, request: &DirectRequest, arrival_at_ms: f64, ) -> Result { let uuid = request_uuid(request)?; let input_tokens = request.tokens.len(); let output_tokens = request.max_output_tokens; arrival_tx .send(ArrivalEvent { uuid, at_ms: arrival_at_ms, input_tokens, output_tokens, }) .map_err(|_| anyhow!("online replay arrival channel closed"))?; Ok(uuid) }