// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use std::sync::Arc; use anyhow::{Result, anyhow}; use dashmap::DashMap; use dynamo_kv_router::config::KvRouterConfig; use tokio::sync::{Notify, Semaphore, mpsc}; use tokio::task::JoinSet; use tokio::time::Instant; use tokio_util::sync::CancellationToken; use crate::common::protocols::{DirectRequest, FpmPublisher, MockEngineArgs, OutputSignal}; use crate::loadgen::WorkloadDriver; use crate::replay::{ReplayPrefillLoadEstimator, ReplayRouterMode, TraceSimulationReport}; use crate::scheduler::{AdmissionEvent, EngineScheduler, SchedulerHandle}; use super::ReplayRouter; use super::demux::run_demux; use super::state::{ LiveReplayMode, LiveRuntimeStats, SharedLiveRuntimeStats, WorkloadDispatchState, now_ms, record_arrival, }; use super::task::{ InFlightGuard, RequestTaskContext, run_request_task, wait_for_workload_progress, }; pub(super) struct LiveRuntime { pending: std::collections::VecDeque, senders: Arc<[mpsc::UnboundedSender]>, schedulers: Vec, output_rx: mpsc::UnboundedReceiver>, admission_rx: mpsc::UnboundedReceiver, cancel_token: CancellationToken, start: Instant, mode: LiveReplayMode, router: Arc, } impl LiveRuntime { /// Build the shared router, worker schedulers, and demux inputs for one live replay run. pub(super) fn new( args: MockEngineArgs, router_config: Option, prefill_load_estimator: Option, pending: std::collections::VecDeque, num_workers: usize, mode: LiveReplayMode, router_mode: ReplayRouterMode, ) -> Result { let cancel_token = CancellationToken::new(); let (output_tx, output_rx) = mpsc::unbounded_channel::>(); let (admission_tx, admission_rx) = mpsc::unbounded_channel(); let router = Arc::new(ReplayRouter::new( router_mode, &args, router_config, prefill_load_estimator, num_workers, )); let mut schedulers = Vec::with_capacity(num_workers); let mut senders = Vec::with_capacity(num_workers); for worker_idx in 0..num_workers { let scheduler = EngineScheduler::new_with_admission( args.clone(), 0, Some(output_tx.clone()), router.sink(worker_idx as _), Some(cancel_token.clone()), Some(admission_tx.clone()), FpmPublisher::default(), ); senders.push(scheduler.request_sender()); schedulers.push(scheduler); } Ok(Self { pending, senders: Arc::from(senders), schedulers, output_rx, admission_rx, cancel_token, start: Instant::now(), mode, router, }) } /// Replay a finite queue of requests and return the final trace report plus debug stats. pub(super) async fn run(mut self) -> Result<(TraceSimulationReport, LiveRuntimeStats)> { let requests = Arc::new(DashMap::with_capacity(self.pending.len())); let stats = Arc::new(SharedLiveRuntimeStats::default()); let (arrival_tx, arrival_rx) = mpsc::unbounded_channel(); let demux_requests = Arc::clone(&requests); let start = self.start; let router = Arc::clone(&self.router); let senders = Arc::clone(&self.senders); let output_rx = self.output_rx; let admission_rx = self.admission_rx; let demux_stats = Arc::clone(&stats); let demux_router = Arc::clone(&router); let demux_task = tokio::spawn(async move { run_demux( start, arrival_rx, admission_rx, output_rx, demux_requests, demux_router, demux_stats, ) .await }); let mut tasks = JoinSet::new(); let task_ctx = RequestTaskContext { senders, router: Arc::clone(&self.router), requests: Arc::clone(&requests), stats: Arc::clone(&stats), workload: None, }; match self.mode { LiveReplayMode::Trace => { while let Some(request) = self.pending.pop_front() { let arrival_ms = request.arrival_timestamp_ms.unwrap_or(0.0); let deadline = start + tokio::time::Duration::from_secs_f64(arrival_ms / 1000.0); tokio::time::sleep_until(deadline).await; record_arrival(&arrival_tx, &request, arrival_ms)?; tasks.spawn(run_request_task(task_ctx.clone(), request, None)); } } LiveReplayMode::Concurrency { max_in_flight } => { let semaphore = Arc::new(Semaphore::new(max_in_flight)); while let Some(request) = self.pending.pop_front() { let permit = semaphore .clone() .acquire_owned() .await .map_err(|_| anyhow!("online replay concurrency semaphore closed"))?; record_arrival(&arrival_tx, &request, now_ms(start))?; let task_ctx = task_ctx.clone(); tasks.spawn(async move { let _permit = permit; run_request_task(task_ctx, request, None).await }); } } } while let Some(result) = tasks.join_next().await { result.map_err(|e| anyhow!("online replay request task failed: {e}"))??; } drop(arrival_tx); self.cancel_token.cancel(); self.schedulers.clear(); let report = demux_task .await .map_err(|e| anyhow!("online replay demux task failed: {e}"))?; router.shutdown().await?; Ok((report, stats.snapshot())) } /// Drive a multi-turn workload driver until it is drained and all spawned request tasks finish. pub(super) async fn run_workload( mut self, mut driver: WorkloadDriver, total_turns: usize, ) -> Result<(TraceSimulationReport, LiveRuntimeStats)> { let requests = Arc::new(DashMap::with_capacity(total_turns.max(1))); let stats = Arc::new(SharedLiveRuntimeStats::default()); let (arrival_tx, arrival_rx) = mpsc::unbounded_channel(); let demux_requests = Arc::clone(&requests); let start = self.start; let router = Arc::clone(&self.router); let senders = Arc::clone(&self.senders); let output_rx = self.output_rx; let admission_rx = self.admission_rx; let demux_stats = Arc::clone(&stats); let demux_router = Arc::clone(&router); let demux_task = tokio::spawn(async move { run_demux( start, arrival_rx, admission_rx, output_rx, demux_requests, demux_router, demux_stats, ) .await }); let cap_enabled = match self.mode { LiveReplayMode::Trace => false, LiveReplayMode::Concurrency { max_in_flight } => { driver.set_max_in_flight(max_in_flight); true } }; let workload = Arc::new(WorkloadDispatchState { driver: std::sync::Mutex::new(driver), wakeup: Notify::new(), start, }); let mut tasks = JoinSet::new(); let task_ctx = RequestTaskContext { senders, router: Arc::clone(&self.router), requests: Arc::clone(&requests), stats: Arc::clone(&stats), workload: Some(Arc::clone(&workload)), }; loop { let now = now_ms(start); let ready_turns = workload.driver.lock().unwrap().pop_ready(now, usize::MAX); if !ready_turns.is_empty() { for ready_turn in ready_turns { let guard = cap_enabled.then(|| { InFlightGuard::new(Arc::clone(&workload), ready_turn.request_uuid) }); let arrival_at_ms = match self.mode { LiveReplayMode::Trace => ready_turn.scheduled_ready_at_ms, LiveReplayMode::Concurrency { .. } => now_ms(start), }; record_arrival(&arrival_tx, &ready_turn.request, arrival_at_ms)?; tasks.spawn(run_request_task( task_ctx.clone(), ready_turn.request, guard, )); } continue; } let wake = workload.wakeup.notified(); tokio::pin!(wake); let (is_drained, next_ready_ms) = { let mut driver = workload.driver.lock().unwrap(); (driver.is_drained(), driver.next_ready_time_ms()) }; if is_drained { break; } wait_for_workload_progress(next_ready_ms, start, wake.as_mut()).await; } while let Some(result) = tasks.join_next().await { result.map_err(|e| anyhow!("online replay request task failed: {e}"))??; } drop(arrival_tx); self.cancel_token.cancel(); self.schedulers.clear(); let report = demux_task .await .map_err(|e| anyhow!("online replay demux task failed: {e}"))?; router.shutdown().await?; Ok((report, stats.snapshot())) } }