"lib/bindings/vscode:/vscode.git/clone" did not exist on "f70dd6638bad4f745996be1e63be1ec186395c11"
Unverified Commit a2154ba5 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore(mocker): batch live output signal sends (#7647)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 06f17011
......@@ -16,7 +16,7 @@ use crate::scheduler::{Scheduler, SchedulerHandle, SglangScheduler};
pub fn create_engine(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
) -> Box<dyn SchedulerHandle> {
......
......@@ -18,7 +18,7 @@ Offline replay starts in `lib/mocker/src/replay/offline/mod.rs`.
`offline/mod.rs` chooses between three implementations:
- `lib/mocker/src/replay/offline/single.rs` for the special case `num_workers == 1` with the vLLM engine
- `lib/mocker/src/replay/offline/multi.rs` for everything else, including multi-worker replay and `kv_router` replay
- `lib/mocker/src/replay/offline/agg.rs` for everything else, including aggregated multi-worker replay and `kv_router` replay
- `lib/mocker/src/replay/offline/disagg.rs` for offline disaggregated prefill/decode replay
## File Map
......@@ -27,7 +27,7 @@ Offline replay starts in `lib/mocker/src/replay/offline/mod.rs`.
Chooses single-worker fast path vs multi-worker harness.
- `lib/mocker/src/replay/offline/single.rs`
Minimal replay loop for one vLLM worker.
- `lib/mocker/src/replay/offline/multi.rs`
- `lib/mocker/src/replay/offline/agg.rs`
General offline cluster simulator for multi-worker replay and KV-router replay.
- `lib/mocker/src/replay/offline/disagg.rs`
Offline two-stage replay harness with separate prefill and decode pools.
......@@ -75,7 +75,7 @@ Important details:
## Multi-Worker Harness
The general harness lives in `lib/mocker/src/replay/offline/multi.rs`. It models a cluster with:
The general aggregated harness lives in `lib/mocker/src/replay/offline/agg.rs`. It models a cluster with:
- a logical clock `now_ms`
- a pending request queue
......@@ -85,7 +85,7 @@ The general harness lives in `lib/mocker/src/replay/offline/multi.rs`. It models
### Main Loop
The harness is event-driven. It does not sleep. Instead, `OfflineRuntime` repeatedly:
The aggregated harness is event-driven. It does not sleep. Instead, `AggRuntime` repeatedly:
1. picks the next meaningful timestamp
2. advances `now_ms`
......@@ -164,7 +164,7 @@ flowchart LR
F -->|yes| G["dispatch to worker"]
F -->|no| H["store in router_pending"]
I["worker pass emits RouterEvent + OutputSignal"] --> J["OfflineRuntime::process_completed_pass"]
I["worker pass emits RouterEvent + OutputSignal"] --> J["AggRuntime::process_completed_pass"]
J --> K["apply router events to sync indexer"]
J --> L["mark_prefill_completed / free"]
L --> M["drain queued admissions"]
......
......@@ -2,7 +2,6 @@
// SPDX-License-Identifier: Apache-2.0
use super::events::{SimulationEvent, SimulationWorkerStage};
use super::normalize_trace_requests;
use super::runtime_utils::{
WorkerCompletionPayload, next_timestamp as choose_next_timestamp, pop_next_concurrency_ready,
pop_next_trace_ready, pop_ready_worker_completion, push_worker_completion,
......@@ -11,11 +10,11 @@ use super::runtime_utils::{
use super::state::OfflineWorkerSnapshot;
use super::state::{AggRequestState, OfflineWorkerState};
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal};
use crate::loadgen::{ReplayRequestHashes, Trace, WorkloadDriver};
use crate::loadgen::{ReplayRequestHashes, WorkloadDriver};
use crate::replay::router::OfflineReplayRouter;
#[cfg(test)]
use crate::replay::router::OfflineRouterSnapshot;
use crate::replay::{ReplayRouterMode, TraceCollector, TraceSimulationReport};
use crate::replay::{ReplayRouterMode, TraceCollector};
use crate::scheduler::RouterEventVisibility;
use anyhow::bail;
use dynamo_kv_router::config::KvRouterConfig;
......@@ -24,7 +23,7 @@ use std::collections::{BinaryHeap, HashMap, VecDeque};
use uuid::Uuid;
#[derive(Debug, Clone, Copy)]
enum ReplayMode {
pub(super) enum ReplayMode {
Trace,
Concurrency { max_in_flight: usize },
}
......@@ -36,19 +35,19 @@ enum AdmissionSource {
#[cfg(test)]
#[derive(Debug, Default, Clone, PartialEq, Eq)]
struct OfflineRuntimeStats {
pub(super) struct AggRuntimeStats {
dispatch_history: Vec<usize>,
dispatch_order: Vec<Uuid>,
assigned_worker_by_uuid: HashMap<Uuid, usize>,
max_in_flight_seen: usize,
prefill_marked_count: usize,
freed_count: usize,
max_router_pending: usize,
router_freed_count: usize,
max_router_pending_count: usize,
}
#[cfg(test)]
#[derive(Debug, Clone, PartialEq)]
struct OfflineRuntimeSnapshot {
struct AggRuntimeSnapshot {
now_ms: f64,
worker_active_requests: Vec<Vec<Uuid>>,
workers: Vec<OfflineWorkerSnapshot>,
......@@ -59,29 +58,29 @@ struct OfflineRuntimeSnapshot {
#[cfg(not(test))]
#[derive(Debug, Default, Clone, PartialEq, Eq)]
struct OfflineRuntimeStats;
pub(super) struct AggRuntimeStats;
struct OfflineRuntime {
pub(super) struct AggRuntime {
now_ms: f64,
next_worker_idx: usize,
next_event_seq: u64,
admission: AdmissionSource,
requests: HashMap<Uuid, AggRequestState>,
queued_requests: usize,
workers: Vec<OfflineWorkerState>,
collector: TraceCollector,
events: BinaryHeap<SimulationEvent>,
mode: ReplayMode,
router: Option<OfflineReplayRouter>,
stats: OfflineRuntimeStats,
stats: AggRuntimeStats,
#[cfg(test)]
worker_active_requests: Vec<Vec<Uuid>>,
#[cfg(test)]
stepped: bool,
}
impl OfflineRuntime {
fn new(
impl AggRuntime {
/// Create an aggregated offline runtime seeded from an explicit request queue.
pub(super) fn new(
args: &MockEngineArgs,
router_config: Option<KvRouterConfig>,
pending: VecDeque<DirectRequest>,
......@@ -99,7 +98,8 @@ impl OfflineRuntime {
)
}
fn new_workload(
/// Create an aggregated offline runtime whose admissions come from a workload driver.
pub(super) fn new_workload(
args: &MockEngineArgs,
router_config: Option<KvRouterConfig>,
driver: WorkloadDriver,
......@@ -117,6 +117,7 @@ impl OfflineRuntime {
)
}
/// Shared constructor for both raw-request and workload-driven admissions.
fn new_with_source(
args: &MockEngineArgs,
router_config: Option<KvRouterConfig>,
......@@ -140,7 +141,6 @@ impl OfflineRuntime {
next_event_seq: 0,
admission,
requests: HashMap::new(),
queued_requests: 0,
workers: (0..num_workers)
.map(|worker_idx| {
OfflineWorkerState::new(worker_idx, args.clone(), capture_kv_events)
......@@ -151,9 +151,9 @@ impl OfflineRuntime {
mode,
router,
#[cfg(test)]
stats: OfflineRuntimeStats::default(),
stats: AggRuntimeStats::default(),
#[cfg(not(test))]
stats: OfflineRuntimeStats,
stats: AggRuntimeStats,
#[cfg(test)]
worker_active_requests: vec![Vec::new(); num_workers],
#[cfg(test)]
......@@ -161,14 +161,19 @@ impl OfflineRuntime {
})
}
/// Count all requests currently consuming cluster capacity, including router-queued ones.
fn cluster_in_flight(&self) -> usize {
self.workers
.iter()
.map(OfflineWorkerState::in_flight)
.sum::<usize>()
+ self.queued_requests
+ self
.router
.as_ref()
.map_or(0, OfflineReplayRouter::pending_count)
}
/// Track the peak cluster occupancy seen during the replay.
fn record_in_flight_peak(&mut self) {
#[cfg(test)]
{
......@@ -177,6 +182,7 @@ impl OfflineRuntime {
}
}
/// Track the maximum number of requests parked in the offline router.
fn record_router_pending(&mut self) {
#[cfg(test)]
let Some(router) = self.router.as_ref() else {
......@@ -184,11 +190,21 @@ impl OfflineRuntime {
};
#[cfg(test)]
{
self.stats.max_router_pending =
self.stats.max_router_pending.max(router.pending_count());
self.stats.max_router_pending_count = self
.stats
.max_router_pending_count
.max(router.pending_count());
}
}
/// Pick the next worker in round-robin order.
fn next_worker(&mut self) -> usize {
let worker_idx = self.next_worker_idx;
self.next_worker_idx = (self.next_worker_idx + 1) % self.workers.len();
worker_idx
}
/// Record which worker accepted a request and refresh in-flight stats.
fn record_dispatch(&mut self, _uuid: Uuid, _worker_idx: usize) {
#[cfg(test)]
{
......@@ -201,6 +217,7 @@ impl OfflineRuntime {
self.record_in_flight_peak();
}
/// Fail fast if a router admission points at a worker that does not exist.
fn validate_worker_idx(&self, worker_idx: usize) -> anyhow::Result<()> {
if worker_idx >= self.workers.len() {
bail!("offline replay selected unknown worker index {worker_idx}");
......@@ -208,6 +225,7 @@ impl OfflineRuntime {
Ok(())
}
/// Deliver a request to a worker and update the runtime's bookkeeping for that assignment.
fn dispatch_to_worker(
&mut self,
request: DirectRequest,
......@@ -222,6 +240,22 @@ impl OfflineRuntime {
Ok(())
}
/// Submit a request to the router and return an immediate admission when one is available.
fn submit_to_router(
&mut self,
request: &DirectRequest,
replay_hashes: Option<ReplayRequestHashes>,
) -> anyhow::Result<Option<usize>> {
let Some(router) = self.router.as_mut() else {
bail!("offline replay router submission requires an active router");
};
let maybe_worker_idx =
router.submit_request_with_hashes(request, replay_hashes, self.now_ms)?;
self.record_router_pending();
Ok(maybe_worker_idx)
}
/// Materialize router admissions into concrete worker dispatches.
fn dispatch_router_admissions(&mut self, admissions: Vec<(Uuid, usize)>) -> anyhow::Result<()> {
for (uuid, worker_idx) in admissions {
let request = self
......@@ -231,12 +265,12 @@ impl OfflineRuntime {
anyhow::anyhow!("offline replay missing queued request state for {uuid}")
})?
.take_queued_request(uuid)?;
self.queued_requests = self.queued_requests.saturating_sub(1);
self.dispatch_to_worker(request, uuid, worker_idx)?;
}
Ok(())
}
/// Admit one external request into the collector, optional router, and worker pool.
fn assign_request(
&mut self,
mut request: DirectRequest,
......@@ -256,17 +290,13 @@ impl OfflineRuntime {
request.max_output_tokens,
);
let Some(router) = self.router.as_mut() else {
if self.router.is_none() {
self.requests.insert(uuid, AggRequestState::new_running());
let worker_idx = self.next_worker_idx;
self.next_worker_idx = (self.next_worker_idx + 1) % self.workers.len();
let worker_idx = self.next_worker();
self.dispatch_to_worker(request, uuid, worker_idx)?;
return Ok(uuid);
};
let maybe_worker_idx =
router.submit_request_with_hashes(&request, replay_hashes, self.now_ms)?;
self.record_router_pending();
}
let maybe_worker_idx = self.submit_to_router(&request, replay_hashes)?;
if let Some(worker_idx) = maybe_worker_idx {
self.requests.insert(uuid, AggRequestState::new_running());
self.dispatch_to_worker(request, uuid, worker_idx)?;
......@@ -275,11 +305,11 @@ impl OfflineRuntime {
self.requests
.insert(uuid, AggRequestState::new_queued(request));
self.queued_requests += 1;
self.record_in_flight_peak();
Ok(uuid)
}
/// Return true once no workers, router queues, or admissions remain.
fn is_done(&self) -> bool {
self.events.is_empty()
&& self.cluster_in_flight() == 0
......@@ -290,6 +320,7 @@ impl OfflineRuntime {
&& self.workers.iter().all(OfflineWorkerState::is_drained)
}
/// Pick the next logical timestamp from either arrivals or scheduled worker completions.
fn next_timestamp(&mut self) -> Option<f64> {
let next_event_ms = self.events.peek().map(|event| event.at_ms);
let cluster_in_flight = self.cluster_in_flight();
......@@ -311,10 +342,12 @@ impl OfflineRuntime {
choose_next_timestamp(next_arrival_ms, next_event_ms)
}
/// Release completed requests from worker-local accounting after a pass finishes.
fn apply_completed_requests(&mut self, worker_idx: usize, completed_requests: usize) {
self.workers[worker_idx].mark_completed(completed_requests);
}
/// Apply router-visible KV events at the phase chosen by the scheduler core.
fn apply_router_events(&mut self, events: Vec<RouterEvent>) -> anyhow::Result<()> {
let Some(router) = self.router.as_mut() else {
return Ok(());
......@@ -325,6 +358,7 @@ impl OfflineRuntime {
Ok(())
}
/// Consume one output signal, updating router state, collector state, and completion counts.
fn process_output_signal(&mut self, signal: OutputSignal) -> anyhow::Result<()> {
let mut admissions = Vec::new();
if signal.completed {
......@@ -334,7 +368,7 @@ impl OfflineRuntime {
admissions = router.free(signal.uuid)?;
#[cfg(test)]
{
self.stats.freed_count += 1;
self.stats.router_freed_count += 1;
}
self.record_router_pending();
}
......@@ -379,6 +413,7 @@ impl OfflineRuntime {
}
#[cfg(test)]
/// Remove a request from the test-only active-request tracking for its worker.
fn remove_active_request(&mut self, uuid: Uuid) {
for active_requests in &mut self.worker_active_requests {
let Some(position) = active_requests
......@@ -392,6 +427,7 @@ impl OfflineRuntime {
}
}
/// Apply one completed pass: free request slots, publish KV events, and handle outputs.
fn process_completed_pass(
&mut self,
worker_idx: usize,
......@@ -407,6 +443,7 @@ impl OfflineRuntime {
Ok(())
}
/// Drain all worker-completion events scheduled for the current logical timestamp.
fn apply_worker_completions(&mut self) -> anyhow::Result<bool> {
let mut changed = false;
while let Some(WorkerCompletionPayload {
......@@ -426,6 +463,7 @@ impl OfflineRuntime {
Ok(changed)
}
/// Release every trace arrival whose timestamp is now visible to the global clock.
fn release_trace_arrivals(&mut self) -> anyhow::Result<bool> {
let mut released_any = false;
if matches!(self.admission, AdmissionSource::Requests(_)) {
......@@ -461,6 +499,7 @@ impl OfflineRuntime {
Ok(released_any)
}
/// Backfill closed-loop concurrency replay until the configured in-flight limit is reached.
fn top_off_concurrency(&mut self, max_in_flight: usize) -> anyhow::Result<bool> {
let mut released_any = false;
if matches!(self.admission, AdmissionSource::Requests(_)) {
......@@ -501,6 +540,7 @@ impl OfflineRuntime {
Ok(released_any)
}
/// Start passes on every idle worker that can make progress at the current timestamp.
fn drive_ready_workers(&mut self) -> anyhow::Result<bool> {
let mut changed = false;
for worker_idx in 0..self.workers.len() {
......@@ -553,6 +593,7 @@ impl OfflineRuntime {
Ok(changed)
}
/// Repeatedly process all work that becomes possible without advancing logical time.
fn drain_current_timestamp(&mut self) -> anyhow::Result<()> {
loop {
let mut changed = self.apply_worker_completions()?;
......@@ -574,7 +615,8 @@ impl OfflineRuntime {
Ok(())
}
fn run(mut self) -> anyhow::Result<(TraceCollector, OfflineRuntimeStats)> {
/// Run the aggregated offline replay until all arrivals and worker work are exhausted.
pub(super) fn run(mut self) -> anyhow::Result<(TraceCollector, AggRuntimeStats)> {
self.drain_current_timestamp()?;
while !self.is_done() {
......@@ -589,14 +631,11 @@ impl OfflineRuntime {
self.drain_current_timestamp()?;
}
if let Some(router) = self.router.as_mut() {
router.shutdown();
}
Ok((self.collector, self.stats))
}
#[cfg(test)]
/// Test helper: advance exactly one logical timestamp worth of work.
fn advance_one_timestamp(&mut self) -> anyhow::Result<bool> {
if self.is_done() {
return Ok(false);
......@@ -621,7 +660,8 @@ impl OfflineRuntime {
}
#[cfg(test)]
fn debug_snapshot(&self) -> OfflineRuntimeSnapshot {
/// Test helper: snapshot the runtime's visible request, worker, and router state.
fn debug_snapshot(&self) -> AggRuntimeSnapshot {
let mut router_pending_request_ids = self
.requests
.iter()
......@@ -637,7 +677,7 @@ impl OfflineRuntime {
.collect::<Vec<_>>();
prefill_completed.sort_unstable();
OfflineRuntimeSnapshot {
AggRuntimeSnapshot {
now_ms: self.now_ms,
worker_active_requests: self.worker_active_requests.clone(),
workers: self
......@@ -655,182 +695,17 @@ impl OfflineRuntime {
}
}
pub(crate) fn simulate_trace_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let (collector, _) = OfflineRuntime::new(
&args,
router_config,
pending,
num_workers,
ReplayMode::Trace,
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = VecDeque::from(requests);
let (collector, _) = OfflineRuntime::new(
&args,
router_config,
pending,
num_workers,
ReplayMode::Concurrency { max_in_flight },
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_trace_workload_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let driver = trace.into_trace_driver()?;
let (collector, _) = OfflineRuntime::new_workload(
&args,
router_config,
driver,
num_workers,
ReplayMode::Trace,
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_workload_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let driver = trace.into_concurrency_driver()?;
let (collector, _) = OfflineRuntime::new_workload(
&args,
router_config,
driver,
num_workers,
ReplayMode::Concurrency { max_in_flight },
router_mode,
)?
.run()?;
Ok(collector.finish())
}
#[cfg(test)]
fn run_trace_multi_collect_with_stats(
args: &MockEngineArgs,
requests: Vec<DirectRequest>,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, OfflineRuntimeStats) {
let pending = normalize_trace_requests(requests, 1.0).unwrap();
OfflineRuntime::new(
args,
None,
pending,
num_workers,
ReplayMode::Trace,
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
fn run_concurrency_multi_collect_with_stats(
args: &MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, OfflineRuntimeStats) {
OfflineRuntime::new(
args,
None,
VecDeque::from(requests),
num_workers,
ReplayMode::Concurrency { max_in_flight },
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
fn run_trace_workload_multi_collect_with_stats(
args: &MockEngineArgs,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, OfflineRuntimeStats) {
OfflineRuntime::new_workload(
args,
None,
trace.into_trace_driver().unwrap(),
num_workers,
ReplayMode::Trace,
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
fn run_concurrency_workload_multi_collect_with_stats(
args: &MockEngineArgs,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, OfflineRuntimeStats) {
OfflineRuntime::new_workload(
args,
None,
trace.into_concurrency_driver().unwrap(),
num_workers,
ReplayMode::Concurrency { max_in_flight },
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
mod tests {
use super::super::single::{run_concurrency_single_collect, run_trace_single_collect};
use super::super::entrypoints::{
run_concurrency_multi_collect_with_stats, run_concurrency_single_collect,
run_concurrency_workload_multi_collect_with_stats, run_trace_multi_collect_with_stats,
run_trace_single_collect, run_trace_workload_multi_collect_with_stats,
};
use super::*;
use crate::common::protocols::{EngineType, SglangArgs};
use crate::loadgen::{SessionTrace, TurnTrace};
use crate::loadgen::{SessionTrace, Trace, TurnTrace};
use crate::replay::normalize_trace_requests;
use dynamo_kv_router::config::RouterQueuePolicy;
fn replay_args(enable_prefix_caching: bool, enable_chunked_prefill: bool) -> MockEngineArgs {
......@@ -1079,7 +954,7 @@ mod tests {
#[test]
fn test_multi_worker_trace_kv_router_debug_snapshot_tracks_queue_and_cached_dispatch() {
let args = queueing_router_args(RouterQueuePolicy::Fcfs);
let mut runtime = OfflineRuntime::new(
let mut runtime = AggRuntime::new(
&args,
None,
normalize_trace_requests(
......@@ -1459,8 +1334,8 @@ mod tests {
);
assert_eq!(stats.prefill_marked_count, 1);
assert_eq!(stats.freed_count, 2);
assert_eq!(stats.max_router_pending, 0);
assert_eq!(stats.router_freed_count, 2);
assert_eq!(stats.max_router_pending_count, 0);
}
#[test]
......@@ -1503,7 +1378,7 @@ mod tests {
.unwrap()
.min(request_2.first_token_ms.unwrap());
assert!(stats.max_router_pending > 0);
assert!(stats.max_router_pending_count > 0);
assert!(request_3.first_admit_ms.unwrap() > request_3.arrival_time_ms);
assert_eq!(request_3.first_admit_ms.unwrap(), first_unblock_ms);
assert!(request_3.first_admit_ms.unwrap() < request_1.last_token_ms.unwrap());
......@@ -1556,8 +1431,8 @@ mod tests {
ReplayRouterMode::KvRouter,
);
assert!(fcfs_stats.max_router_pending > 0);
assert!(lcfs_stats.max_router_pending > 0);
assert!(fcfs_stats.max_router_pending_count > 0);
assert!(lcfs_stats.max_router_pending_count > 0);
assert_eq!(
&fcfs_stats.dispatch_order[..2],
&[Uuid::from_u128(10), Uuid::from_u128(20)]
......@@ -1628,8 +1503,8 @@ mod tests {
let lcfs_request_30 = lcfs_collector.snapshot(Uuid::from_u128(30)).unwrap();
let lcfs_request_40 = lcfs_collector.snapshot(Uuid::from_u128(40)).unwrap();
assert!(fcfs_stats.max_router_pending > 0);
assert!(lcfs_stats.max_router_pending > 0);
assert!(fcfs_stats.max_router_pending_count > 0);
assert!(lcfs_stats.max_router_pending_count > 0);
assert_eq!(
&fcfs_stats.dispatch_order[2..4],
&[Uuid::from_u128(30), Uuid::from_u128(40)]
......@@ -1683,7 +1558,7 @@ mod tests {
);
assert_eq!(stats.max_in_flight_seen, 3);
assert!(stats.max_router_pending > 0);
assert!(stats.max_router_pending_count > 0);
}
#[test]
......@@ -1805,7 +1680,7 @@ mod tests {
run_trace_multi_collect_with_stats(&args, requests, 1, ReplayRouterMode::KvRouter);
assert_eq!(stats.dispatch_history, vec![0, 0, 0]);
assert_eq!(stats.max_router_pending, 0);
assert_eq!(stats.max_router_pending_count, 0);
for uuid in [11_u128, 22, 33] {
assert_eq!(
multi.snapshot(Uuid::from_u128(uuid)),
......@@ -1898,7 +1773,7 @@ mod tests {
);
assert_eq!(stats.dispatch_history, vec![0, 0, 0]);
assert_eq!(stats.max_router_pending, 0);
assert_eq!(stats.max_router_pending_count, 0);
for uuid in [11_u128, 22, 33] {
assert_eq!(
multi.snapshot(Uuid::from_u128(uuid)),
......
......@@ -9,7 +9,6 @@ use dynamo_kv_router::protocols::RouterEvent;
use uuid::Uuid;
use super::events::{SimulationEvent, SimulationWorkerStage};
use super::normalize_trace_requests;
use super::runtime_utils::{
WorkerCompletionPayload, next_timestamp as choose_next_timestamp, pop_next_concurrency_ready,
pop_next_trace_ready, pop_ready_decode_handoff, pop_ready_worker_completion,
......@@ -21,15 +20,13 @@ use super::state::DisaggPhase;
use super::state::DisaggRequestSnapshot;
use super::state::{DisaggRequestState, OfflineWorkerState};
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal};
use crate::loadgen::{ReplayRequestHashes, Trace, WorkloadDriver};
use crate::loadgen::{ReplayRequestHashes, WorkloadDriver};
use crate::replay::router::OfflineReplayRouter;
use crate::replay::{
OfflineDisaggReplayConfig, ReplayRouterMode, TraceCollector, TraceSimulationReport,
};
use crate::replay::{OfflineDisaggReplayConfig, ReplayRouterMode, TraceCollector};
use crate::scheduler::RouterEventVisibility;
#[derive(Debug, Clone, Copy)]
enum ReplayMode {
pub(super) enum ReplayMode {
Trace,
Concurrency { max_in_flight: usize },
}
......@@ -41,23 +38,23 @@ enum AdmissionSource {
#[cfg(test)]
#[derive(Debug, Default, Clone, PartialEq)]
struct DisaggRuntimeStats {
pub(super) struct DisaggRuntimeStats {
request_snapshots: HashMap<Uuid, DisaggRequestSnapshot>,
prefill_assignments: HashMap<Uuid, usize>,
decode_assignments: HashMap<Uuid, usize>,
handoff_ms: HashMap<Uuid, f64>,
prefill_marked_count: usize,
prefill_freed_count: usize,
decode_freed_count: usize,
max_prefill_router_pending: usize,
max_decode_router_pending: usize,
prefill_router_freed_count: usize,
decode_router_freed_count: usize,
max_prefill_router_pending_count: usize,
max_decode_router_pending_count: usize,
}
#[cfg(not(test))]
#[derive(Debug, Default, Clone, PartialEq, Eq)]
struct DisaggRuntimeStats;
pub(super) struct DisaggRuntimeStats;
struct DisaggRuntime {
pub(super) struct DisaggRuntime {
now_ms: f64,
next_prefill_worker_idx: usize,
next_decode_worker_idx: usize,
......@@ -75,7 +72,8 @@ struct DisaggRuntime {
}
impl DisaggRuntime {
fn new(
/// Create a disaggregated offline runtime seeded from an explicit request queue.
pub(super) fn new(
config: &OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
pending: VecDeque<DirectRequest>,
......@@ -91,7 +89,8 @@ impl DisaggRuntime {
)
}
fn new_workload(
/// Create a disaggregated offline runtime whose admissions come from a workload driver.
pub(super) fn new_workload(
config: &OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
driver: WorkloadDriver,
......@@ -107,6 +106,7 @@ impl DisaggRuntime {
)
}
/// Shared constructor for both raw-request and workload-driven admissions.
fn new_with_source(
config: &OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
......@@ -169,6 +169,7 @@ impl DisaggRuntime {
})
}
/// Count all requests consuming cluster capacity across prefill, decode, and router queues.
fn cluster_in_flight(&self) -> usize {
self.prefill_workers
.iter()
......@@ -189,6 +190,7 @@ impl DisaggRuntime {
.map_or(0, OfflineReplayRouter::pending_count)
}
/// Pick the next prefill worker in round-robin order.
fn next_prefill_worker(&mut self) -> usize {
let worker_idx = self.next_prefill_worker_idx;
self.next_prefill_worker_idx =
......@@ -196,28 +198,33 @@ impl DisaggRuntime {
worker_idx
}
/// Pick the next decode worker in round-robin order.
fn next_decode_worker(&mut self) -> usize {
let worker_idx = self.next_decode_worker_idx;
self.next_decode_worker_idx = (self.next_decode_worker_idx + 1) % self.decode_workers.len();
worker_idx
}
/// Track the peak number of requests parked in each stage router.
fn record_router_pending(&mut self) {
#[cfg(test)]
{
self.stats.max_prefill_router_pending = self.stats.max_prefill_router_pending.max(
self.prefill_router
.as_ref()
.map_or(0, OfflineReplayRouter::pending_count),
);
self.stats.max_decode_router_pending = self.stats.max_decode_router_pending.max(
self.decode_router
.as_ref()
.map_or(0, OfflineReplayRouter::pending_count),
);
self.stats.max_prefill_router_pending_count =
self.stats.max_prefill_router_pending_count.max(
self.prefill_router
.as_ref()
.map_or(0, OfflineReplayRouter::pending_count),
);
self.stats.max_decode_router_pending_count =
self.stats.max_decode_router_pending_count.max(
self.decode_router
.as_ref()
.map_or(0, OfflineReplayRouter::pending_count),
);
}
}
/// Fail fast if a router admission targets a non-existent stage worker.
fn validate_worker_idx(&self, stage: SimulationWorkerStage, worker_idx: usize) -> Result<()> {
let worker_count = match stage {
SimulationWorkerStage::Prefill => self.prefill_workers.len(),
......@@ -230,18 +237,21 @@ impl DisaggRuntime {
Ok(())
}
/// Borrow immutable request state with a structured missing-request error.
fn state(&self, uuid: Uuid) -> Result<&DisaggRequestState> {
self.requests
.get(&uuid)
.ok_or_else(|| anyhow!("offline disagg replay missing request state for {uuid}"))
}
/// Borrow mutable request state with a structured missing-request error.
fn state_mut(&mut self, uuid: Uuid) -> Result<&mut DisaggRequestState> {
self.requests
.get_mut(&uuid)
.ok_or_else(|| anyhow!("offline disagg replay missing request state for {uuid}"))
}
/// Dispatch a request's prefill stage onto a specific prefill worker.
fn dispatch_prefill(&mut self, uuid: Uuid, worker_idx: usize) -> Result<()> {
self.validate_worker_idx(SimulationWorkerStage::Prefill, worker_idx)?;
let request = self.state(uuid)?.build_prefill_request()?;
......@@ -254,6 +264,7 @@ impl DisaggRuntime {
Ok(())
}
/// Dispatch a request's decode stage onto a specific decode worker.
fn dispatch_decode(&mut self, uuid: Uuid, worker_idx: usize) -> Result<()> {
self.validate_worker_idx(SimulationWorkerStage::Decode, worker_idx)?;
let request = self.state(uuid)?.build_decode_request()?;
......@@ -266,6 +277,7 @@ impl DisaggRuntime {
Ok(())
}
/// Turn prefill router admissions into concrete worker dispatches.
fn dispatch_prefill_admissions(&mut self, admissions: Vec<(Uuid, usize)>) -> Result<()> {
for (uuid, worker_idx) in admissions {
if !self.state(uuid)?.is_queued_prefill() {
......@@ -276,6 +288,7 @@ impl DisaggRuntime {
Ok(())
}
/// Turn decode router admissions into concrete worker dispatches.
fn dispatch_decode_admissions(&mut self, admissions: Vec<(Uuid, usize)>) -> Result<()> {
for (uuid, worker_idx) in admissions {
if !self.state(uuid)?.is_queued_decode() {
......@@ -286,21 +299,42 @@ impl DisaggRuntime {
Ok(())
}
fn enqueue_decode(&mut self, uuid: Uuid) -> Result<()> {
/// Submit a request to the prefill router and return an immediate admission when available.
fn submit_to_prefill_router(
&mut self,
uuid: Uuid,
replay_hashes: Option<ReplayRequestHashes>,
) -> Result<Option<usize>> {
let request = self.state(uuid)?.original_request()?.clone();
let Some(prefill_router) = self.prefill_router.as_mut() else {
bail!("offline disagg replay prefill submission requires an active router");
};
let maybe_worker_idx =
prefill_router.submit_request_with_hashes(&request, replay_hashes, self.now_ms)?;
self.record_router_pending();
Ok(maybe_worker_idx)
}
/// Submit a request to the decode router and return an immediate admission when available.
fn submit_to_decode_router(&mut self, uuid: Uuid) -> Result<Option<usize>> {
let request = self.state(uuid)?.original_request()?.clone();
let Some(decode_router) = self.decode_router.as_mut() else {
bail!("offline disagg replay decode submission requires an active router");
};
let maybe_worker_idx =
decode_router.submit_request_with_hashes(&request, None, self.now_ms)?;
self.record_router_pending();
Ok(maybe_worker_idx)
}
/// Queue or dispatch a request into decode, depending on whether a decode router is active.
fn enqueue_decode(&mut self, uuid: Uuid) -> Result<()> {
if self.decode_router.is_none() {
let worker_idx = self.next_decode_worker();
self.dispatch_decode(uuid, worker_idx)?;
return Ok(());
};
let maybe_worker_idx = {
let requests = &self.requests;
let request = requests
.get(&uuid)
.ok_or_else(|| anyhow!("offline disagg replay missing request state for {uuid}"))?
.original_request()?;
decode_router.submit_request_with_hashes(request, None, self.now_ms)?
};
self.record_router_pending();
}
let maybe_worker_idx = self.submit_to_decode_router(uuid)?;
#[cfg(test)]
{
self.stats.handoff_ms.insert(uuid, self.now_ms);
......@@ -314,6 +348,7 @@ impl DisaggRuntime {
Ok(())
}
/// Admit one external request into prefill-side state, collector state, and optional router.
fn on_external_arrival(
&mut self,
mut request: DirectRequest,
......@@ -333,26 +368,19 @@ impl DisaggRuntime {
self.requests
.insert(uuid, DisaggRequestState::new(request, arrival_time_ms));
let Some(prefill_router) = self.prefill_router.as_mut() else {
if self.prefill_router.is_none() {
let worker_idx = self.next_prefill_worker();
self.dispatch_prefill(uuid, worker_idx)?;
return Ok(uuid);
};
let maybe_worker_idx = {
let requests = &self.requests;
let request = requests
.get(&uuid)
.ok_or_else(|| anyhow!("offline disagg replay missing request state for {uuid}"))?
.original_request()?;
prefill_router.submit_request_with_hashes(request, replay_hashes, self.now_ms)?
};
self.record_router_pending();
}
let maybe_worker_idx = self.submit_to_prefill_router(uuid, replay_hashes)?;
if let Some(worker_idx) = maybe_worker_idx {
self.dispatch_prefill(uuid, worker_idx)?;
}
Ok(uuid)
}
/// Return true once both stages, both routers, and all admissions are fully drained.
fn is_done(&self) -> bool {
self.events.is_empty()
&& self.cluster_in_flight() == 0
......@@ -370,6 +398,7 @@ impl DisaggRuntime {
.all(OfflineWorkerState::is_drained)
}
/// Pick the next logical timestamp from arrivals, worker completions, or decode handoffs.
fn next_timestamp(&mut self) -> Option<f64> {
let next_event_ms = self.events.peek().map(|event| event.at_ms);
let cluster_in_flight = self.cluster_in_flight();
......@@ -391,6 +420,7 @@ impl DisaggRuntime {
choose_next_timestamp(next_arrival_ms, next_event_ms)
}
/// Apply prefill-side KV router events at the scheduler-selected visibility phase.
fn apply_prefill_router_events(&mut self, events: Vec<RouterEvent>) -> Result<()> {
let Some(prefill_router) = self.prefill_router.as_mut() else {
return Ok(());
......@@ -401,6 +431,7 @@ impl DisaggRuntime {
Ok(())
}
/// Process one prefill output signal, including router updates and decode handoff scheduling.
fn process_prefill_signal(&mut self, signal: OutputSignal) -> Result<()> {
if !signal.completed {
return Ok(());
......@@ -424,7 +455,7 @@ impl DisaggRuntime {
};
#[cfg(test)]
{
self.stats.prefill_freed_count += 1;
self.stats.prefill_router_freed_count += 1;
}
self.record_router_pending();
self.dispatch_prefill_admissions(admissions)?;
......@@ -433,6 +464,7 @@ impl DisaggRuntime {
self.enqueue_decode_after_handoff(signal.uuid, signal.handoff_delay_ms)
}
/// Process one decode output signal, including decode router frees and request completion.
fn process_decode_signal(&mut self, signal: OutputSignal) -> Result<()> {
if !signal.completed {
return Ok(());
......@@ -442,7 +474,7 @@ impl DisaggRuntime {
let admissions = decode_router.free(signal.uuid)?;
#[cfg(test)]
{
self.stats.decode_freed_count += 1;
self.stats.decode_router_freed_count += 1;
}
admissions
} else {
......@@ -457,6 +489,7 @@ impl DisaggRuntime {
Ok(())
}
/// Apply the side effects of a finished prefill pass.
fn process_prefill_pass(
&mut self,
worker_idx: usize,
......@@ -472,6 +505,7 @@ impl DisaggRuntime {
Ok(())
}
/// Apply the side effects of a finished decode pass.
fn process_decode_pass(
&mut self,
worker_idx: usize,
......@@ -485,6 +519,7 @@ impl DisaggRuntime {
Ok(())
}
/// Drain all worker-completion events scheduled for the current logical timestamp.
fn apply_worker_completions(&mut self) -> Result<bool> {
let mut changed = false;
while let Some(WorkerCompletionPayload {
......@@ -518,6 +553,7 @@ impl DisaggRuntime {
Ok(changed)
}
/// Drain all delayed decode handoff events scheduled for the current logical timestamp.
fn apply_decode_handoffs(&mut self) -> Result<bool> {
let mut changed = false;
while let Some(uuid) = pop_ready_decode_handoff(&mut self.events, self.now_ms) {
......@@ -527,6 +563,7 @@ impl DisaggRuntime {
Ok(changed)
}
/// Either enqueue decode immediately or schedule a delayed handoff event on the event heap.
fn enqueue_decode_after_handoff(
&mut self,
uuid: Uuid,
......@@ -547,6 +584,7 @@ impl DisaggRuntime {
self.enqueue_decode(uuid)
}
/// Release every trace arrival whose timestamp is now visible to the global clock.
fn release_trace_arrivals(&mut self) -> Result<bool> {
let mut released_any = false;
if matches!(self.admission, AdmissionSource::Requests(_)) {
......@@ -581,6 +619,7 @@ impl DisaggRuntime {
Ok(released_any)
}
/// Backfill closed-loop concurrency replay until the staged cluster reaches its cap.
fn top_off_concurrency(&mut self, max_in_flight: usize) -> Result<bool> {
let mut released_any = false;
if matches!(self.admission, AdmissionSource::Requests(_)) {
......@@ -620,6 +659,7 @@ impl DisaggRuntime {
Ok(released_any)
}
/// Start passes on every idle prefill worker that can make progress at the current timestamp.
fn drive_prefill_workers(&mut self) -> Result<bool> {
let mut changed = false;
for worker_idx in 0..self.prefill_workers.len() {
......@@ -668,6 +708,7 @@ impl DisaggRuntime {
Ok(changed)
}
/// Start passes on every idle decode worker that can make progress at the current timestamp.
fn drive_decode_workers(&mut self) -> Result<bool> {
let mut changed = false;
for worker_idx in 0..self.decode_workers.len() {
......@@ -710,6 +751,7 @@ impl DisaggRuntime {
Ok(changed)
}
/// Repeatedly process all work that becomes possible without advancing logical time.
fn drain_current_timestamp(&mut self) -> Result<()> {
loop {
let mut changed = self.apply_worker_completions()?;
......@@ -732,6 +774,7 @@ impl DisaggRuntime {
Ok(())
}
/// Finalize test-only request snapshots before returning.
fn finish_test_stats(&mut self) {
#[cfg(test)]
{
......@@ -743,7 +786,8 @@ impl DisaggRuntime {
}
}
fn run(mut self) -> Result<(TraceCollector, DisaggRuntimeStats)> {
/// Run the staged offline replay until both prefill and decode pipelines are drained.
pub(super) fn run(mut self) -> Result<(TraceCollector, DisaggRuntimeStats)> {
self.drain_current_timestamp()?;
while !self.is_done() {
......@@ -793,124 +837,9 @@ fn derive_decode_router_config(
config
}
pub(crate) fn simulate_trace_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let (collector, _) = DisaggRuntime::new(
&config,
router_config,
pending,
ReplayMode::Trace,
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let pending = VecDeque::from(requests);
let (collector, _) = DisaggRuntime::new(
&config,
router_config,
pending,
ReplayMode::Concurrency { max_in_flight },
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_trace_workload_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let driver = WorkloadDriver::new_trace(trace)?;
let (collector, _) = DisaggRuntime::new_workload(
&config,
router_config,
driver,
ReplayMode::Trace,
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_workload_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let driver = WorkloadDriver::new_concurrency(trace)?;
let (collector, _) = DisaggRuntime::new_workload(
&config,
router_config,
driver,
ReplayMode::Concurrency { max_in_flight },
router_mode,
)?
.run()?;
Ok(collector.finish())
}
#[cfg(test)]
fn run_trace_collect(
config: &OfflineDisaggReplayConfig,
requests: Vec<DirectRequest>,
router_config: Option<KvRouterConfig>,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> (TraceCollector, DisaggRuntimeStats) {
let pending = normalize_trace_requests(requests, arrival_speedup_ratio).unwrap();
DisaggRuntime::new(
config,
router_config,
pending,
ReplayMode::Trace,
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
fn run_concurrency_collect(
config: &OfflineDisaggReplayConfig,
requests: Vec<DirectRequest>,
router_config: Option<KvRouterConfig>,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, DisaggRuntimeStats) {
DisaggRuntime::new(
config,
router_config,
VecDeque::from(requests),
ReplayMode::Concurrency { max_in_flight },
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
mod tests {
use super::super::entrypoints::{run_concurrency_collect, run_trace_collect};
use super::*;
use crate::common::protocols::{MockEngineArgs, WorkerType};
......@@ -1093,8 +1022,8 @@ mod tests {
);
assert_eq!(stats.prefill_marked_count, 1);
assert_eq!(stats.prefill_freed_count, 1);
assert_eq!(stats.decode_freed_count, 1);
assert_eq!(stats.prefill_router_freed_count, 1);
assert_eq!(stats.decode_router_freed_count, 1);
}
#[test]
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::VecDeque;
use anyhow::Result;
use dynamo_kv_router::config::KvRouterConfig;
#[cfg(test)]
use super::agg::AggRuntimeStats;
use super::agg::{AggRuntime, ReplayMode as AggReplayMode};
#[cfg(test)]
use super::disagg::DisaggRuntimeStats;
use super::disagg::{DisaggRuntime, ReplayMode as DisaggReplayMode};
use super::normalize_trace_requests;
use super::single::{SingleReplayMode, SingleRuntime};
use crate::common::protocols::{DirectRequest, EngineType, MockEngineArgs};
use crate::loadgen::{Trace, WorkloadDriver};
use crate::replay::OfflineDisaggReplayConfig;
#[cfg(test)]
use crate::replay::TraceCollector;
use crate::replay::{ReplayRouterMode, TraceSimulationReport};
pub(crate) fn simulate_trace(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == EngineType::Vllm {
simulate_trace_single(args, requests, arrival_speedup_ratio)
} else {
simulate_trace_multi(
args,
router_config,
requests,
num_workers,
arrival_speedup_ratio,
router_mode,
)
}
}
pub(crate) fn simulate_concurrency(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == EngineType::Vllm {
simulate_concurrency_single(args, requests, max_in_flight)
} else {
simulate_concurrency_multi(
args,
router_config,
requests,
max_in_flight,
num_workers,
router_mode,
)
}
}
pub(crate) fn simulate_trace_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == EngineType::Vllm {
simulate_trace_workload_single(args, trace)
} else {
simulate_trace_workload_multi(args, router_config, trace, num_workers, router_mode)
}
}
pub(crate) fn simulate_concurrency_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == EngineType::Vllm {
simulate_concurrency_workload_single(args, trace, max_in_flight)
} else {
simulate_concurrency_workload_multi(
args,
router_config,
trace,
max_in_flight,
num_workers,
router_mode,
)
}
}
pub(crate) fn simulate_trace_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let (collector, _) = DisaggRuntime::new(
&config,
router_config,
pending,
DisaggReplayMode::Trace,
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let pending = VecDeque::from(requests);
let (collector, _) = DisaggRuntime::new(
&config,
router_config,
pending,
DisaggReplayMode::Concurrency { max_in_flight },
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_trace_workload_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let driver = WorkloadDriver::new_trace(trace)?;
let (collector, _) = DisaggRuntime::new_workload(
&config,
router_config,
driver,
DisaggReplayMode::Trace,
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_workload_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let driver = WorkloadDriver::new_concurrency(trace)?;
let (collector, _) = DisaggRuntime::new_workload(
&config,
router_config,
driver,
DisaggReplayMode::Concurrency { max_in_flight },
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_trace_single(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let collector = SingleRuntime::new(args, pending, SingleReplayMode::Trace).run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_single(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = VecDeque::from(requests);
let collector = SingleRuntime::new(
args,
pending,
SingleReplayMode::Concurrency { max_in_flight },
)
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_trace_workload_single(
args: MockEngineArgs,
trace: Trace,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let collector =
SingleRuntime::new_workload(args, trace.into_trace_driver()?, SingleReplayMode::Trace)
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_workload_single(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let collector = SingleRuntime::new_workload(
args,
trace.into_concurrency_driver()?,
SingleReplayMode::Concurrency { max_in_flight },
)
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_trace_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let (collector, _) = AggRuntime::new(
&args,
router_config,
pending,
num_workers,
AggReplayMode::Trace,
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = VecDeque::from(requests);
let (collector, _) = AggRuntime::new(
&args,
router_config,
pending,
num_workers,
AggReplayMode::Concurrency { max_in_flight },
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_trace_workload_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let (collector, _) = AggRuntime::new_workload(
&args,
router_config,
trace.into_trace_driver()?,
num_workers,
AggReplayMode::Trace,
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_workload_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let (collector, _) = AggRuntime::new_workload(
&args,
router_config,
trace.into_concurrency_driver()?,
num_workers,
AggReplayMode::Concurrency { max_in_flight },
router_mode,
)?
.run()?;
Ok(collector.finish())
}
#[cfg(test)]
pub(super) fn run_trace_single_collect(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
) -> TraceCollector {
let pending = normalize_trace_requests(requests, arrival_speedup_ratio).unwrap();
SingleRuntime::new(args, pending, SingleReplayMode::Trace)
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_concurrency_single_collect(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
) -> TraceCollector {
SingleRuntime::new(
args,
VecDeque::from(requests),
SingleReplayMode::Concurrency { max_in_flight },
)
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_trace_workload_single_collect(
args: MockEngineArgs,
trace: Trace,
) -> TraceCollector {
SingleRuntime::new_workload(
args,
trace.into_trace_driver().unwrap(),
SingleReplayMode::Trace,
)
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_concurrency_workload_single_collect(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
) -> TraceCollector {
SingleRuntime::new_workload(
args,
trace.into_concurrency_driver().unwrap(),
SingleReplayMode::Concurrency { max_in_flight },
)
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_trace_multi_collect_with_stats(
args: &MockEngineArgs,
requests: Vec<DirectRequest>,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, AggRuntimeStats) {
let pending = normalize_trace_requests(requests, 1.0).unwrap();
AggRuntime::new(
args,
None,
pending,
num_workers,
AggReplayMode::Trace,
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_concurrency_multi_collect_with_stats(
args: &MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, AggRuntimeStats) {
AggRuntime::new(
args,
None,
VecDeque::from(requests),
num_workers,
AggReplayMode::Concurrency { max_in_flight },
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_trace_workload_multi_collect_with_stats(
args: &MockEngineArgs,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, AggRuntimeStats) {
AggRuntime::new_workload(
args,
None,
trace.into_trace_driver().unwrap(),
num_workers,
AggReplayMode::Trace,
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_concurrency_workload_multi_collect_with_stats(
args: &MockEngineArgs,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, AggRuntimeStats) {
AggRuntime::new_workload(
args,
None,
trace.into_concurrency_driver().unwrap(),
num_workers,
AggReplayMode::Concurrency { max_in_flight },
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_trace_collect(
config: &OfflineDisaggReplayConfig,
requests: Vec<DirectRequest>,
router_config: Option<KvRouterConfig>,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> (TraceCollector, DisaggRuntimeStats) {
let pending = normalize_trace_requests(requests, arrival_speedup_ratio).unwrap();
DisaggRuntime::new(
config,
router_config,
pending,
DisaggReplayMode::Trace,
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_concurrency_collect(
config: &OfflineDisaggReplayConfig,
requests: Vec<DirectRequest>,
router_config: Option<KvRouterConfig>,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, DisaggRuntimeStats) {
DisaggRuntime::new(
config,
router_config,
VecDeque::from(requests),
DisaggReplayMode::Concurrency { max_in_flight },
router_mode,
)
.unwrap()
.run()
.unwrap()
}
......@@ -34,17 +34,9 @@ pub(crate) struct SimulationEvent {
pub(crate) kind: SimulationEventKind,
}
impl SimulationEvent {
fn kind_priority(&self) -> u8 {
0
}
}
impl PartialEq for SimulationEvent {
fn eq(&self, other: &Self) -> bool {
self.at_ms.to_bits() == other.at_ms.to_bits()
&& self.seq_no == other.seq_no
&& self.kind_priority() == other.kind_priority()
self.at_ms.to_bits() == other.at_ms.to_bits() && self.seq_no == other.seq_no
}
}
......@@ -62,7 +54,6 @@ impl Ord for SimulationEvent {
.at_ms
.partial_cmp(&self.at_ms)
.unwrap_or(Ordering::Equal)
.then_with(|| self.kind_priority().cmp(&other.kind_priority()))
.then_with(|| other.seq_no.cmp(&self.seq_no))
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::loadgen::Trace;
use crate::replay::OfflineDisaggReplayConfig;
pub(crate) use crate::replay::normalize_trace_requests;
use crate::replay::{ReplayRouterMode, TraceSimulationReport};
use dynamo_kv_router::config::KvRouterConfig;
pub(crate) mod agg;
pub(crate) mod core;
pub(crate) mod disagg;
mod entrypoints;
pub(crate) mod events;
pub(crate) mod multi;
pub(crate) mod runtime_utils;
pub(crate) mod single;
pub(crate) mod state;
pub(crate) fn simulate_trace(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == crate::common::protocols::EngineType::Vllm {
single::simulate_trace_single(args, requests, arrival_speedup_ratio)
} else {
multi::simulate_trace_multi(
args,
router_config,
requests,
num_workers,
arrival_speedup_ratio,
router_mode,
)
}
}
pub(crate) fn simulate_concurrency(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == crate::common::protocols::EngineType::Vllm {
single::simulate_concurrency_single(args, requests, max_in_flight)
} else {
multi::simulate_concurrency_multi(
args,
router_config,
requests,
max_in_flight,
num_workers,
router_mode,
)
}
}
pub(crate) fn simulate_trace_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == crate::common::protocols::EngineType::Vllm {
single::simulate_trace_workload_single(args, trace)
} else {
multi::simulate_trace_workload_multi(args, router_config, trace, num_workers, router_mode)
}
}
pub(crate) fn simulate_concurrency_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == crate::common::protocols::EngineType::Vllm {
single::simulate_concurrency_workload_single(args, trace, max_in_flight)
} else {
multi::simulate_concurrency_workload_multi(
args,
router_config,
trace,
max_in_flight,
num_workers,
router_mode,
)
}
}
pub(crate) fn simulate_trace_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
disagg::simulate_trace_disagg(
config,
router_config,
requests,
arrival_speedup_ratio,
router_mode,
)
}
pub(crate) fn simulate_concurrency_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
disagg::simulate_concurrency_disagg(config, router_config, requests, max_in_flight, router_mode)
}
pub(crate) fn simulate_trace_workload_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
disagg::simulate_trace_workload_disagg(config, router_config, trace, router_mode)
}
pub(crate) fn simulate_concurrency_workload_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
disagg::simulate_concurrency_workload_disagg(
config,
router_config,
trace,
max_in_flight,
router_mode,
)
}
pub(crate) use entrypoints::{
simulate_concurrency, simulate_concurrency_disagg, simulate_concurrency_workload,
simulate_concurrency_workload_disagg, simulate_trace, simulate_trace_disagg,
simulate_trace_workload, simulate_trace_workload_disagg,
};
......@@ -2,16 +2,15 @@
// SPDX-License-Identifier: Apache-2.0
use super::core::ReplayWorkerCore;
use super::normalize_trace_requests;
use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::loadgen::{Trace, WorkloadDriver};
use crate::replay::{TraceCollector, TraceSimulationReport};
use crate::loadgen::WorkloadDriver;
use crate::replay::TraceCollector;
use anyhow::bail;
use std::collections::VecDeque;
use uuid::Uuid;
#[derive(Debug, Clone, Copy)]
enum SingleReplayMode {
pub(super) enum SingleReplayMode {
Trace,
Concurrency { max_in_flight: usize },
}
......@@ -21,7 +20,7 @@ enum AdmissionSource {
Workload(WorkloadDriver),
}
struct SingleRuntime {
pub(super) struct SingleRuntime {
current_time_ms: f64,
admission: AdmissionSource,
worker: ReplayWorkerCore,
......@@ -30,11 +29,19 @@ struct SingleRuntime {
}
impl SingleRuntime {
fn new(args: MockEngineArgs, pending: VecDeque<DirectRequest>, mode: SingleReplayMode) -> Self {
pub(super) fn new(
args: MockEngineArgs,
pending: VecDeque<DirectRequest>,
mode: SingleReplayMode,
) -> Self {
Self::new_with_source(args, AdmissionSource::Requests(pending), mode)
}
fn new_workload(args: MockEngineArgs, driver: WorkloadDriver, mode: SingleReplayMode) -> Self {
pub(super) fn new_workload(
args: MockEngineArgs,
driver: WorkloadDriver,
mode: SingleReplayMode,
) -> Self {
Self::new_with_source(args, AdmissionSource::Workload(driver), mode)
}
......@@ -166,7 +173,7 @@ impl SingleRuntime {
}
}
fn run(mut self) -> anyhow::Result<TraceCollector> {
pub(super) fn run(mut self) -> anyhow::Result<TraceCollector> {
while !self.is_done() {
match self.mode {
SingleReplayMode::Trace => {
......@@ -196,119 +203,14 @@ impl SingleRuntime {
}
}
pub(crate) fn simulate_trace_single(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let collector = SingleRuntime::new(args, pending, SingleReplayMode::Trace).run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_single(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = VecDeque::from(requests);
let collector = SingleRuntime::new(
args,
pending,
SingleReplayMode::Concurrency { max_in_flight },
)
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_trace_workload_single(
args: MockEngineArgs,
trace: Trace,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let collector =
SingleRuntime::new_workload(args, trace.into_trace_driver()?, SingleReplayMode::Trace)
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_workload_single(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let collector = SingleRuntime::new_workload(
args,
trace.into_concurrency_driver()?,
SingleReplayMode::Concurrency { max_in_flight },
)
.run()?;
Ok(collector.finish())
}
#[cfg(test)]
pub(super) fn run_trace_single_collect(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
) -> TraceCollector {
let pending = normalize_trace_requests(requests, arrival_speedup_ratio).unwrap();
SingleRuntime::new(args, pending, SingleReplayMode::Trace)
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_concurrency_single_collect(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
) -> TraceCollector {
SingleRuntime::new(
args,
VecDeque::from(requests),
SingleReplayMode::Concurrency { max_in_flight },
)
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_trace_workload_single_collect(
args: MockEngineArgs,
trace: Trace,
) -> TraceCollector {
SingleRuntime::new_workload(
args,
trace.into_trace_driver().unwrap(),
SingleReplayMode::Trace,
)
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_concurrency_workload_single_collect(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
) -> TraceCollector {
SingleRuntime::new_workload(
args,
trace.into_concurrency_driver().unwrap(),
SingleReplayMode::Concurrency { max_in_flight },
)
.run()
.unwrap()
}
#[cfg(test)]
mod tests {
use super::super::entrypoints::{
run_concurrency_workload_single_collect, run_trace_workload_single_collect,
simulate_concurrency_single, simulate_trace_single,
};
use super::*;
use crate::loadgen::{SessionTrace, TurnTrace};
use crate::loadgen::{SessionTrace, Trace, TurnTrace};
use crate::replay::{TraceRequestStatsSnapshot, TraceSimulationReport};
use rstest::rstest;
use std::collections::{HashMap, VecDeque};
......
......@@ -13,11 +13,53 @@ use crate::scheduler::AdmissionEvent;
use super::state::{ArrivalEvent, RequestRegistry, SharedLiveRuntimeStats, now_ms};
async fn process_output_signal(
output: OutputSignal,
batch_time_ms: f64,
collector: &mut TraceCollector,
requests: &RequestRegistry,
router: &ReplayRouter,
stats: &SharedLiveRuntimeStats,
) {
collector.on_token(output.uuid, batch_time_ms);
let Some(state) = requests.get(&output.uuid) else {
return;
};
if state.mark_first_token_once() {
match router.on_first_token(output.uuid).await {
Ok(true) => stats.record_prefill_marked(),
Ok(false) => {}
Err(error) => tracing::warn!(
uuid = %output.uuid,
error = %error,
"online replay failed to mark prefill completed"
),
}
}
if !output.completed || !state.mark_completed_once() {
return;
}
match router.on_complete(output.uuid).await {
Ok(true) => stats.record_freed(),
Ok(false) => {}
Err(error) => tracing::warn!(
uuid = %output.uuid,
error = %error,
"online replay failed to free completed request"
),
}
state.notify_completion();
}
pub(super) async fn run_demux(
start: Instant,
mut arrival_rx: mpsc::UnboundedReceiver<ArrivalEvent>,
mut admission_rx: mpsc::UnboundedReceiver<AdmissionEvent>,
mut output_rx: mpsc::UnboundedReceiver<OutputSignal>,
mut output_rx: mpsc::UnboundedReceiver<Vec<OutputSignal>>,
requests: RequestRegistry,
router: Arc<ReplayRouter>,
stats: Arc<SharedLiveRuntimeStats>,
......@@ -55,33 +97,18 @@ pub(super) async fn run_demux(
}
output = output_rx.recv(), if outputs_open => {
match output {
Some(output) => {
collector.on_token(output.uuid, now_ms(start));
if let Some(state) = requests.get(&output.uuid) {
if state.mark_first_token_once() {
match router.on_first_token(output.uuid).await {
Ok(true) => stats.record_prefill_marked(),
Ok(false) => {}
Err(error) => tracing::warn!(
uuid = %output.uuid,
error = %error,
"online replay failed to mark prefill completed"
),
}
}
if output.completed && state.mark_completed_once() {
match router.on_complete(output.uuid).await {
Ok(true) => stats.record_freed(),
Ok(false) => {}
Err(error) => tracing::warn!(
uuid = %output.uuid,
error = %error,
"online replay failed to free completed request"
),
}
state.notify_completion();
}
Some(output_batch) => {
let batch_time_ms = now_ms(start);
for output in output_batch {
process_output_signal(
output,
batch_time_ms,
&mut collector,
&requests,
&router,
&stats,
)
.await;
}
}
None => outputs_open = false,
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::VecDeque;
use anyhow::{Result, anyhow, bail};
use dynamo_kv_router::config::KvRouterConfig;
use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::loadgen::{Trace, WorkloadDriver};
use crate::replay::{ReplayRouterMode, TraceSimulationReport, normalize_trace_requests};
use super::live_runtime::LiveRuntime;
use super::state::{LiveReplayMode, LiveRuntimeStats};
fn total_turns(trace: &Trace) -> usize {
trace
.sessions
.iter()
.map(|session| session.turns.len())
.sum()
}
fn run_live_runtime(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
pending: VecDeque<DirectRequest>,
num_workers: usize,
mode: LiveReplayMode,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.map_err(|e| anyhow!("failed to create online replay runtime: {e}"))?;
runtime.block_on(async move {
LiveRuntime::new(args, router_config, pending, num_workers, mode, router_mode)?
.run()
.await
})
}
fn run_live_workload_runtime(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
driver: WorkloadDriver,
total_turns: usize,
num_workers: usize,
mode: LiveReplayMode,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.map_err(|e| anyhow!("failed to create online replay runtime: {e}"))?;
runtime.block_on(async move {
LiveRuntime::new(
args,
router_config,
VecDeque::new(),
num_workers,
mode,
router_mode,
)?
.run_workload(driver, total_turns)
.await
})
}
pub(crate) fn simulate_trace_requests(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let (report, _) = run_live_runtime(
args,
router_config,
pending,
num_workers,
LiveReplayMode::Trace,
router_mode,
)?;
Ok(report)
}
pub(crate) fn simulate_concurrency_requests(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
if requests.is_empty() {
bail!("online concurrency replay requires at least one request");
}
let pending = VecDeque::from(requests);
let (report, _) = run_live_runtime(
args,
router_config,
pending,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)?;
Ok(report)
}
pub(crate) fn simulate_trace_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let total_turns = total_turns(&trace);
let (report, _) = run_live_workload_runtime(
args,
router_config,
trace.into_trace_driver()?,
total_turns,
num_workers,
LiveReplayMode::Trace,
router_mode,
)?;
Ok(report)
}
pub(crate) fn simulate_concurrency_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let total_turns = total_turns(&trace);
let (report, _) = run_live_workload_runtime(
args,
router_config,
trace.into_concurrency_driver()?,
total_turns,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)?;
Ok(report)
}
#[cfg(test)]
pub(super) fn simulate_trace_requests_with_stats(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
run_live_runtime(
args,
None,
pending,
num_workers,
LiveReplayMode::Trace,
router_mode,
)
}
#[cfg(test)]
pub(super) fn simulate_concurrency_requests_with_stats(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let pending = VecDeque::from(requests);
run_live_runtime(
args,
None,
pending,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)
}
#[cfg(test)]
pub(super) fn simulate_trace_workload_with_stats(
args: MockEngineArgs,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let total_turns = total_turns(&trace);
run_live_workload_runtime(
args,
None,
trace.into_trace_driver()?,
total_turns,
num_workers,
LiveReplayMode::Trace,
router_mode,
)
}
#[cfg(test)]
pub(super) fn simulate_concurrency_workload_with_stats(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let total_turns = total_turns(&trace);
run_live_workload_runtime(
args,
None,
trace.into_concurrency_driver()?,
total_turns,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::VecDeque;
use std::sync::Arc;
use anyhow::{Result, anyhow, bail};
use anyhow::{Result, anyhow};
use dashmap::DashMap;
use dynamo_kv_router::config::KvRouterConfig;
use tokio::sync::{Notify, Semaphore, mpsc};
......@@ -13,9 +12,9 @@ use tokio::time::Instant;
use tokio_util::sync::CancellationToken;
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal};
use crate::loadgen::{Trace, WorkloadDriver};
use crate::loadgen::WorkloadDriver;
use crate::replay::router::ReplayRouter;
use crate::replay::{ReplayRouterMode, TraceSimulationReport, normalize_trace_requests};
use crate::replay::{ReplayRouterMode, TraceSimulationReport};
use crate::scheduler::{AdmissionEvent, EngineScheduler, SchedulerHandle};
use super::demux::run_demux;
......@@ -25,11 +24,11 @@ use super::state::{
};
use super::task::{RequestTaskContext, run_request_task, wait_for_workload_progress};
struct LiveRuntime {
pending: VecDeque<DirectRequest>,
pub(super) struct LiveRuntime {
pending: std::collections::VecDeque<DirectRequest>,
senders: Arc<[mpsc::UnboundedSender<DirectRequest>]>,
schedulers: Vec<EngineScheduler>,
output_rx: mpsc::UnboundedReceiver<OutputSignal>,
output_rx: mpsc::UnboundedReceiver<Vec<OutputSignal>>,
admission_rx: mpsc::UnboundedReceiver<AdmissionEvent>,
cancel_token: CancellationToken,
start: Instant,
......@@ -38,16 +37,17 @@ struct LiveRuntime {
}
impl LiveRuntime {
fn new(
/// Build the shared router, worker schedulers, and demux inputs for one live replay run.
pub(super) fn new(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
pending: VecDeque<DirectRequest>,
pending: std::collections::VecDeque<DirectRequest>,
num_workers: usize,
mode: LiveReplayMode,
router_mode: ReplayRouterMode,
) -> Result<Self> {
let cancel_token = CancellationToken::new();
let (output_tx, output_rx) = mpsc::unbounded_channel();
let (output_tx, output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let (admission_tx, admission_rx) = mpsc::unbounded_channel();
let router = Arc::new(ReplayRouter::new(
router_mode,
......@@ -70,10 +70,6 @@ impl LiveRuntime {
senders.push(scheduler.request_sender());
schedulers.push(scheduler);
}
drop(output_tx);
drop(admission_tx);
Ok(Self {
pending,
senders: Arc::from(senders),
......@@ -87,7 +83,8 @@ impl LiveRuntime {
})
}
async fn run(mut self) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
/// Replay a finite queue of requests and return the final trace report plus debug stats.
pub(super) async fn run(mut self) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let requests = Arc::new(DashMap::with_capacity(self.pending.len()));
let stats = Arc::new(SharedLiveRuntimeStats::default());
let (arrival_tx, arrival_rx) = mpsc::unbounded_channel();
......@@ -160,7 +157,8 @@ impl LiveRuntime {
Ok((report, stats.snapshot()))
}
async fn run_workload(
/// Drive a multi-turn workload driver until it is drained and all spawned request tasks finish.
pub(super) async fn run_workload(
mut self,
driver: WorkloadDriver,
total_turns: usize,
......@@ -283,237 +281,3 @@ impl LiveRuntime {
Ok((report, stats.snapshot()))
}
}
fn run_live_runtime(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
pending: VecDeque<DirectRequest>,
num_workers: usize,
mode: LiveReplayMode,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.map_err(|e| anyhow!("failed to create online replay runtime: {e}"))?;
runtime.block_on(async move {
LiveRuntime::new(args, router_config, pending, num_workers, mode, router_mode)?
.run()
.await
})
}
fn run_live_workload_runtime(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
driver: WorkloadDriver,
total_turns: usize,
num_workers: usize,
mode: LiveReplayMode,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.map_err(|e| anyhow!("failed to create online replay runtime: {e}"))?;
runtime.block_on(async move {
LiveRuntime::new(
args,
router_config,
VecDeque::new(),
num_workers,
mode,
router_mode,
)?
.run_workload(driver, total_turns)
.await
})
}
pub(crate) fn simulate_trace_requests(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let (report, _) = run_live_runtime(
args,
router_config,
pending,
num_workers,
LiveReplayMode::Trace,
router_mode,
)?;
Ok(report)
}
pub(crate) fn simulate_concurrency_requests(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
if requests.is_empty() {
bail!("online concurrency replay requires at least one request");
}
let pending = VecDeque::from(requests);
let (report, _) = run_live_runtime(
args,
router_config,
pending,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)?;
Ok(report)
}
pub(crate) fn simulate_trace_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let total_turns = trace
.sessions
.iter()
.map(|session| session.turns.len())
.sum();
let (report, _) = run_live_workload_runtime(
args,
router_config,
trace.into_trace_driver()?,
total_turns,
num_workers,
LiveReplayMode::Trace,
router_mode,
)?;
Ok(report)
}
pub(crate) fn simulate_concurrency_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let total_turns = trace
.sessions
.iter()
.map(|session| session.turns.len())
.sum();
let (report, _) = run_live_workload_runtime(
args,
router_config,
trace.into_concurrency_driver()?,
total_turns,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)?;
Ok(report)
}
#[cfg(test)]
pub(super) fn simulate_trace_requests_with_stats(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
run_live_runtime(
args,
None,
pending,
num_workers,
LiveReplayMode::Trace,
router_mode,
)
}
#[cfg(test)]
pub(super) fn simulate_concurrency_requests_with_stats(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let pending = VecDeque::from(requests);
run_live_runtime(
args,
None,
pending,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)
}
#[cfg(test)]
pub(super) fn simulate_trace_workload_with_stats(
args: MockEngineArgs,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let total_turns = trace
.sessions
.iter()
.map(|session| session.turns.len())
.sum();
run_live_workload_runtime(
args,
None,
trace.into_trace_driver()?,
total_turns,
num_workers,
LiveReplayMode::Trace,
router_mode,
)
}
#[cfg(test)]
pub(super) fn simulate_concurrency_workload_with_stats(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let total_turns = trace
.sessions
.iter()
.map(|session| session.turns.len())
.sum();
run_live_workload_runtime(
args,
None,
trace.into_concurrency_driver()?,
total_turns,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)
}
......@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
mod demux;
mod entrypoints;
mod live_runtime;
mod state;
mod task;
......@@ -9,7 +10,7 @@ mod task;
#[cfg(test)]
mod tests;
pub(crate) use live_runtime::{
pub(crate) use entrypoints::{
simulate_concurrency_requests, simulate_concurrency_workload, simulate_trace_requests,
simulate_trace_workload,
};
......@@ -16,7 +16,7 @@ use crate::loadgen::{SessionTrace, Trace, TurnTrace};
use crate::replay::ReplayRouterMode;
use crate::replay::router::ReplayRouter;
use super::live_runtime::{
use super::entrypoints::{
simulate_concurrency_requests_with_stats, simulate_concurrency_workload_with_stats,
simulate_trace_requests, simulate_trace_requests_with_stats,
simulate_trace_workload_with_stats,
......
......@@ -248,14 +248,14 @@ impl OfflineReplayRouter {
pub(crate) fn mark_prefill_completed(&mut self, uuid: Uuid) -> Result<Vec<(Uuid, usize)>> {
self.slots
.mark_prefill_completed_sync(&uuid.to_string())
.mark_prefill_completed(&uuid.to_string())
.map_err(anyhow::Error::from)?;
self.drain_pending()
}
pub(crate) fn free(&mut self, uuid: Uuid) -> Result<Vec<(Uuid, usize)>> {
self.slots
.free_sync(&uuid.to_string())
.free(&uuid.to_string())
.map_err(anyhow::Error::from)?;
self.drain_pending()
}
......@@ -316,8 +316,6 @@ impl OfflineReplayRouter {
}
}
pub(crate) fn shutdown(&mut self) {}
fn enqueue_key(&self, now_ms: f64, request: &PendingRequest) -> ReplayQueueKey {
let arrival_offset = Duration::from_secs_f64((now_ms.max(0.0)) / 1000.0);
self.policy.enqueue_key(
......@@ -400,7 +398,7 @@ impl OfflineReplayRouter {
let request_id = request.request_id();
self.slots
.add_request_sync(SequenceRequest {
.add_request(SequenceRequest {
request_id,
token_sequence: request.token_seq,
isl: request.isl_tokens,
......
......@@ -108,7 +108,7 @@ impl EngineScheduler {
pub(crate) fn new_with_admission(
args: crate::common::protocols::MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
......
......@@ -36,7 +36,7 @@ impl SglangScheduler {
pub fn new(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
) -> Self {
......@@ -53,7 +53,7 @@ impl SglangScheduler {
pub(crate) fn new_with_admission(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
......@@ -71,7 +71,7 @@ impl SglangScheduler {
fn new_internal(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
......@@ -121,11 +121,12 @@ impl SglangScheduler {
if pass.router_event_visibility == RouterEventVisibility::PassEnd {
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
}
flush_output_signals(&output_tx, &pass.output_signals);
let active_decode_blocks = pass.active_decode_blocks;
flush_output_signals(&output_tx, pass.output_signals);
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
let _ = metrics_tx.send(MockerMetrics::new(
dp_rank,
pass.active_decode_blocks,
active_decode_blocks,
total_blocks,
));
}
......@@ -182,14 +183,16 @@ async fn receive_requests(
}
fn flush_output_signals(
output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>,
output_signals: &[OutputSignal],
output_tx: &Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
output_signals: Vec<OutputSignal>,
) {
let Some(tx) = output_tx.as_ref() else {
return;
};
for signal in output_signals {
let _ = tx.send(signal.clone());
if output_signals.is_empty() {
return;
}
let _ = tx.send(output_signals);
}
......@@ -94,7 +94,7 @@ mod scheduling {
.build()
.unwrap();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let scheduler =
SglangScheduler::new(args, 0, Some(output_tx), KvEventPublishers::default(), None);
......@@ -117,8 +117,8 @@ mod scheduling {
loop {
tokio::select! {
Some(_) = output_rx.recv() => {
received += 1;
Some(output_batch) = output_rx.recv() => {
received += output_batch.len();
if received >= expected_signals {
break;
}
......@@ -535,7 +535,7 @@ mod core_behavior {
async fn assert_sglang_scheduler_completes_all(
scheduler: &SglangScheduler,
output_rx: &mut mpsc::UnboundedReceiver<OutputSignal>,
output_rx: &mut mpsc::UnboundedReceiver<Vec<OutputSignal>>,
num_requests: usize,
prompt_len: usize,
max_output_tokens: usize,
......@@ -567,8 +567,8 @@ async fn assert_sglang_scheduler_completes_all(
loop {
tokio::select! {
biased;
Some(_) = output_rx.recv() => {
received_tokens += 1;
Some(output_batch) = output_rx.recv() => {
received_tokens += output_batch.len();
if received_tokens >= expected_tokens {
break;
}
......@@ -604,7 +604,7 @@ mod router_events {
#[case] schedule_policy: &str,
#[case] page_size: usize,
) {
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let args = MockEngineArgs::builder()
.num_gpu_blocks(500)
.block_size(64)
......@@ -818,7 +818,7 @@ mod router_events {
let harness = RouterIndexerHarness::new(4, ROUTER_TEST_WORKER_ID);
let (sink, forward_task) = harness.spawn_forwarder();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let scheduler = SglangScheduler::new(
MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
......@@ -849,8 +849,8 @@ mod router_events {
loop {
tokio::select! {
Some(_) = output_rx.recv() => {
seen += 1;
Some(output_batch) = output_rx.recv() => {
seen += output_batch.len();
if seen == expected {
break;
}
......
......@@ -215,7 +215,7 @@ pub(crate) fn removed_event_count(events: &[RouterEvent]) -> usize {
/// common prefix to exercise prefix caching / radix tree reuse.
pub(crate) async fn assert_scheduler_completes_all(
scheduler: &dyn SchedulerHandle,
output_rx: &mut mpsc::UnboundedReceiver<OutputSignal>,
output_rx: &mut mpsc::UnboundedReceiver<Vec<OutputSignal>>,
num_requests: usize,
input_len: usize,
max_output_tokens: usize,
......@@ -260,8 +260,8 @@ pub(crate) async fn assert_scheduler_completes_all(
loop {
tokio::select! {
biased;
Some(_) = output_rx.recv() => {
received_tokens += 1;
Some(output_batch) = output_rx.recv() => {
received_tokens += output_batch.len();
if received_tokens >= expected_tokens {
break;
}
......
......@@ -63,7 +63,7 @@ impl Scheduler {
pub fn new(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
) -> Self {
......@@ -80,7 +80,7 @@ impl Scheduler {
pub(crate) fn new_with_admission(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
......@@ -98,7 +98,7 @@ impl Scheduler {
fn new_internal(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
......@@ -137,7 +137,7 @@ impl Scheduler {
if pass.router_event_visibility == RouterEventVisibility::PassEnd {
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
}
flush_output_signals(&mut core, &output_tx, &pass.output_signals);
flush_output_signals(&mut core, &output_tx, pass.output_signals);
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
let _ = metrics_tx.send(MockerMetrics::new(
dp_rank,
......@@ -198,17 +198,20 @@ async fn receive_requests(
fn flush_output_signals(
core: &mut VllmCore,
output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>,
output_signals: &[OutputSignal],
output_tx: &Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
output_signals: Vec<OutputSignal>,
) {
let Some(tx) = output_tx.as_ref() else {
return;
};
for signal in output_signals {
if tx.send(signal.clone()).is_ok() {
continue;
if output_signals.is_empty() {
return;
}
if let Err(error) = tx.send(output_signals) {
for signal in error.0 {
core.drop_request(signal.uuid);
}
core.drop_request(signal.uuid);
}
}
......@@ -506,7 +506,7 @@ mod live_scheduler {
#[case] enable_prefix_caching: bool,
#[case] enable_chunked_prefill: bool,
) {
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let args = MockEngineArgs::builder()
.num_gpu_blocks(500)
......@@ -539,7 +539,7 @@ mod live_scheduler {
let num_requests = 10;
let token_length = 65;
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let args = MockEngineArgs::builder()
.num_gpu_blocks(100)
......@@ -576,8 +576,8 @@ mod live_scheduler {
let _metrics = metrics_rx.borrow().clone();
tracing::debug!("Forward Pass Metrics: {_metrics:#?}");
}
Some(_signal) = output_rx.recv() => {
received_tokens += 1;
Some(output_batch) = output_rx.recv() => {
received_tokens += output_batch.len();
timeout.set(tokio::time::sleep(Duration::from_millis(500)));
}
_ = &mut timeout => break,
......@@ -592,7 +592,7 @@ mod live_scheduler {
#[tokio::test]
async fn test_receiver_drop_cleans_up_resources() {
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let args = MockEngineArgs::builder()
.num_gpu_blocks(10)
.block_size(64)
......@@ -612,8 +612,8 @@ mod live_scheduler {
let mut received_count = 0;
while received_count < 129 {
if output_rx.recv().await.is_some() {
received_count += 1;
if let Some(output_batch) = output_rx.recv().await {
received_count += output_batch.len();
continue;
}
panic!("Channel closed before receiving 129 tokens");
......@@ -639,7 +639,7 @@ mod live_scheduler {
#[tokio::test]
async fn test_live_scheduler_forwards_buffered_kv_token_ids() {
let sink = Arc::new(CapturingKvSink::default());
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let args = MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(12)
......@@ -667,10 +667,14 @@ mod live_scheduler {
arrival_timestamp_ms: None,
});
let signal = tokio::time::timeout(Duration::from_secs(2), output_rx.recv())
let output_batch = tokio::time::timeout(Duration::from_secs(2), output_rx.recv())
.await
.expect("scheduler should emit output")
.expect("output channel should stay open");
let signal = output_batch
.into_iter()
.next()
.expect("live scheduler should emit one output signal");
assert!(signal.completed);
tokio::time::sleep(Duration::from_millis(50)).await;
......@@ -691,7 +695,7 @@ mod live_scheduler {
let harness = RouterIndexerHarness::new(4, ROUTER_TEST_WORKER_ID);
let (sink, forward_task) = harness.spawn_forwarder();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let scheduler = Scheduler::new(
MockEngineArgs::builder()
.block_size(4)
......@@ -726,8 +730,8 @@ mod live_scheduler {
loop {
tokio::select! {
Some(_) = output_rx.recv() => {
seen += 1;
Some(output_batch) = output_rx.recv() => {
seen += output_batch.len();
if seen == expected {
break;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment