Unverified Commit a2154ba5 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore(mocker): batch live output signal sends (#7647)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 06f17011
...@@ -16,7 +16,7 @@ use crate::scheduler::{Scheduler, SchedulerHandle, SglangScheduler}; ...@@ -16,7 +16,7 @@ use crate::scheduler::{Scheduler, SchedulerHandle, SglangScheduler};
pub fn create_engine( pub fn create_engine(
args: MockEngineArgs, args: MockEngineArgs,
dp_rank: u32, dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>, output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers, kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>, cancellation_token: Option<CancellationToken>,
) -> Box<dyn SchedulerHandle> { ) -> Box<dyn SchedulerHandle> {
......
...@@ -18,7 +18,7 @@ Offline replay starts in `lib/mocker/src/replay/offline/mod.rs`. ...@@ -18,7 +18,7 @@ Offline replay starts in `lib/mocker/src/replay/offline/mod.rs`.
`offline/mod.rs` chooses between three implementations: `offline/mod.rs` chooses between three implementations:
- `lib/mocker/src/replay/offline/single.rs` for the special case `num_workers == 1` with the vLLM engine - `lib/mocker/src/replay/offline/single.rs` for the special case `num_workers == 1` with the vLLM engine
- `lib/mocker/src/replay/offline/multi.rs` for everything else, including multi-worker replay and `kv_router` replay - `lib/mocker/src/replay/offline/agg.rs` for everything else, including aggregated multi-worker replay and `kv_router` replay
- `lib/mocker/src/replay/offline/disagg.rs` for offline disaggregated prefill/decode replay - `lib/mocker/src/replay/offline/disagg.rs` for offline disaggregated prefill/decode replay
## File Map ## File Map
...@@ -27,7 +27,7 @@ Offline replay starts in `lib/mocker/src/replay/offline/mod.rs`. ...@@ -27,7 +27,7 @@ Offline replay starts in `lib/mocker/src/replay/offline/mod.rs`.
Chooses single-worker fast path vs multi-worker harness. Chooses single-worker fast path vs multi-worker harness.
- `lib/mocker/src/replay/offline/single.rs` - `lib/mocker/src/replay/offline/single.rs`
Minimal replay loop for one vLLM worker. Minimal replay loop for one vLLM worker.
- `lib/mocker/src/replay/offline/multi.rs` - `lib/mocker/src/replay/offline/agg.rs`
General offline cluster simulator for multi-worker replay and KV-router replay. General offline cluster simulator for multi-worker replay and KV-router replay.
- `lib/mocker/src/replay/offline/disagg.rs` - `lib/mocker/src/replay/offline/disagg.rs`
Offline two-stage replay harness with separate prefill and decode pools. Offline two-stage replay harness with separate prefill and decode pools.
...@@ -75,7 +75,7 @@ Important details: ...@@ -75,7 +75,7 @@ Important details:
## Multi-Worker Harness ## Multi-Worker Harness
The general harness lives in `lib/mocker/src/replay/offline/multi.rs`. It models a cluster with: The general aggregated harness lives in `lib/mocker/src/replay/offline/agg.rs`. It models a cluster with:
- a logical clock `now_ms` - a logical clock `now_ms`
- a pending request queue - a pending request queue
...@@ -85,7 +85,7 @@ The general harness lives in `lib/mocker/src/replay/offline/multi.rs`. It models ...@@ -85,7 +85,7 @@ The general harness lives in `lib/mocker/src/replay/offline/multi.rs`. It models
### Main Loop ### Main Loop
The harness is event-driven. It does not sleep. Instead, `OfflineRuntime` repeatedly: The aggregated harness is event-driven. It does not sleep. Instead, `AggRuntime` repeatedly:
1. picks the next meaningful timestamp 1. picks the next meaningful timestamp
2. advances `now_ms` 2. advances `now_ms`
...@@ -164,7 +164,7 @@ flowchart LR ...@@ -164,7 +164,7 @@ flowchart LR
F -->|yes| G["dispatch to worker"] F -->|yes| G["dispatch to worker"]
F -->|no| H["store in router_pending"] F -->|no| H["store in router_pending"]
I["worker pass emits RouterEvent + OutputSignal"] --> J["OfflineRuntime::process_completed_pass"] I["worker pass emits RouterEvent + OutputSignal"] --> J["AggRuntime::process_completed_pass"]
J --> K["apply router events to sync indexer"] J --> K["apply router events to sync indexer"]
J --> L["mark_prefill_completed / free"] J --> L["mark_prefill_completed / free"]
L --> M["drain queued admissions"] L --> M["drain queued admissions"]
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::events::{SimulationEvent, SimulationWorkerStage}; use super::events::{SimulationEvent, SimulationWorkerStage};
use super::normalize_trace_requests;
use super::runtime_utils::{ use super::runtime_utils::{
WorkerCompletionPayload, next_timestamp as choose_next_timestamp, pop_next_concurrency_ready, WorkerCompletionPayload, next_timestamp as choose_next_timestamp, pop_next_concurrency_ready,
pop_next_trace_ready, pop_ready_worker_completion, push_worker_completion, pop_next_trace_ready, pop_ready_worker_completion, push_worker_completion,
...@@ -11,11 +10,11 @@ use super::runtime_utils::{ ...@@ -11,11 +10,11 @@ use super::runtime_utils::{
use super::state::OfflineWorkerSnapshot; use super::state::OfflineWorkerSnapshot;
use super::state::{AggRequestState, OfflineWorkerState}; use super::state::{AggRequestState, OfflineWorkerState};
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal}; use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal};
use crate::loadgen::{ReplayRequestHashes, Trace, WorkloadDriver}; use crate::loadgen::{ReplayRequestHashes, WorkloadDriver};
use crate::replay::router::OfflineReplayRouter; use crate::replay::router::OfflineReplayRouter;
#[cfg(test)] #[cfg(test)]
use crate::replay::router::OfflineRouterSnapshot; use crate::replay::router::OfflineRouterSnapshot;
use crate::replay::{ReplayRouterMode, TraceCollector, TraceSimulationReport}; use crate::replay::{ReplayRouterMode, TraceCollector};
use crate::scheduler::RouterEventVisibility; use crate::scheduler::RouterEventVisibility;
use anyhow::bail; use anyhow::bail;
use dynamo_kv_router::config::KvRouterConfig; use dynamo_kv_router::config::KvRouterConfig;
...@@ -24,7 +23,7 @@ use std::collections::{BinaryHeap, HashMap, VecDeque}; ...@@ -24,7 +23,7 @@ use std::collections::{BinaryHeap, HashMap, VecDeque};
use uuid::Uuid; use uuid::Uuid;
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
enum ReplayMode { pub(super) enum ReplayMode {
Trace, Trace,
Concurrency { max_in_flight: usize }, Concurrency { max_in_flight: usize },
} }
...@@ -36,19 +35,19 @@ enum AdmissionSource { ...@@ -36,19 +35,19 @@ enum AdmissionSource {
#[cfg(test)] #[cfg(test)]
#[derive(Debug, Default, Clone, PartialEq, Eq)] #[derive(Debug, Default, Clone, PartialEq, Eq)]
struct OfflineRuntimeStats { pub(super) struct AggRuntimeStats {
dispatch_history: Vec<usize>, dispatch_history: Vec<usize>,
dispatch_order: Vec<Uuid>, dispatch_order: Vec<Uuid>,
assigned_worker_by_uuid: HashMap<Uuid, usize>, assigned_worker_by_uuid: HashMap<Uuid, usize>,
max_in_flight_seen: usize, max_in_flight_seen: usize,
prefill_marked_count: usize, prefill_marked_count: usize,
freed_count: usize, router_freed_count: usize,
max_router_pending: usize, max_router_pending_count: usize,
} }
#[cfg(test)] #[cfg(test)]
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
struct OfflineRuntimeSnapshot { struct AggRuntimeSnapshot {
now_ms: f64, now_ms: f64,
worker_active_requests: Vec<Vec<Uuid>>, worker_active_requests: Vec<Vec<Uuid>>,
workers: Vec<OfflineWorkerSnapshot>, workers: Vec<OfflineWorkerSnapshot>,
...@@ -59,29 +58,29 @@ struct OfflineRuntimeSnapshot { ...@@ -59,29 +58,29 @@ struct OfflineRuntimeSnapshot {
#[cfg(not(test))] #[cfg(not(test))]
#[derive(Debug, Default, Clone, PartialEq, Eq)] #[derive(Debug, Default, Clone, PartialEq, Eq)]
struct OfflineRuntimeStats; pub(super) struct AggRuntimeStats;
struct OfflineRuntime { pub(super) struct AggRuntime {
now_ms: f64, now_ms: f64,
next_worker_idx: usize, next_worker_idx: usize,
next_event_seq: u64, next_event_seq: u64,
admission: AdmissionSource, admission: AdmissionSource,
requests: HashMap<Uuid, AggRequestState>, requests: HashMap<Uuid, AggRequestState>,
queued_requests: usize,
workers: Vec<OfflineWorkerState>, workers: Vec<OfflineWorkerState>,
collector: TraceCollector, collector: TraceCollector,
events: BinaryHeap<SimulationEvent>, events: BinaryHeap<SimulationEvent>,
mode: ReplayMode, mode: ReplayMode,
router: Option<OfflineReplayRouter>, router: Option<OfflineReplayRouter>,
stats: OfflineRuntimeStats, stats: AggRuntimeStats,
#[cfg(test)] #[cfg(test)]
worker_active_requests: Vec<Vec<Uuid>>, worker_active_requests: Vec<Vec<Uuid>>,
#[cfg(test)] #[cfg(test)]
stepped: bool, stepped: bool,
} }
impl OfflineRuntime { impl AggRuntime {
fn new( /// Create an aggregated offline runtime seeded from an explicit request queue.
pub(super) fn new(
args: &MockEngineArgs, args: &MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
pending: VecDeque<DirectRequest>, pending: VecDeque<DirectRequest>,
...@@ -99,7 +98,8 @@ impl OfflineRuntime { ...@@ -99,7 +98,8 @@ impl OfflineRuntime {
) )
} }
fn new_workload( /// Create an aggregated offline runtime whose admissions come from a workload driver.
pub(super) fn new_workload(
args: &MockEngineArgs, args: &MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
driver: WorkloadDriver, driver: WorkloadDriver,
...@@ -117,6 +117,7 @@ impl OfflineRuntime { ...@@ -117,6 +117,7 @@ impl OfflineRuntime {
) )
} }
/// Shared constructor for both raw-request and workload-driven admissions.
fn new_with_source( fn new_with_source(
args: &MockEngineArgs, args: &MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
...@@ -140,7 +141,6 @@ impl OfflineRuntime { ...@@ -140,7 +141,6 @@ impl OfflineRuntime {
next_event_seq: 0, next_event_seq: 0,
admission, admission,
requests: HashMap::new(), requests: HashMap::new(),
queued_requests: 0,
workers: (0..num_workers) workers: (0..num_workers)
.map(|worker_idx| { .map(|worker_idx| {
OfflineWorkerState::new(worker_idx, args.clone(), capture_kv_events) OfflineWorkerState::new(worker_idx, args.clone(), capture_kv_events)
...@@ -151,9 +151,9 @@ impl OfflineRuntime { ...@@ -151,9 +151,9 @@ impl OfflineRuntime {
mode, mode,
router, router,
#[cfg(test)] #[cfg(test)]
stats: OfflineRuntimeStats::default(), stats: AggRuntimeStats::default(),
#[cfg(not(test))] #[cfg(not(test))]
stats: OfflineRuntimeStats, stats: AggRuntimeStats,
#[cfg(test)] #[cfg(test)]
worker_active_requests: vec![Vec::new(); num_workers], worker_active_requests: vec![Vec::new(); num_workers],
#[cfg(test)] #[cfg(test)]
...@@ -161,14 +161,19 @@ impl OfflineRuntime { ...@@ -161,14 +161,19 @@ impl OfflineRuntime {
}) })
} }
/// Count all requests currently consuming cluster capacity, including router-queued ones.
fn cluster_in_flight(&self) -> usize { fn cluster_in_flight(&self) -> usize {
self.workers self.workers
.iter() .iter()
.map(OfflineWorkerState::in_flight) .map(OfflineWorkerState::in_flight)
.sum::<usize>() .sum::<usize>()
+ self.queued_requests + self
.router
.as_ref()
.map_or(0, OfflineReplayRouter::pending_count)
} }
/// Track the peak cluster occupancy seen during the replay.
fn record_in_flight_peak(&mut self) { fn record_in_flight_peak(&mut self) {
#[cfg(test)] #[cfg(test)]
{ {
...@@ -177,6 +182,7 @@ impl OfflineRuntime { ...@@ -177,6 +182,7 @@ impl OfflineRuntime {
} }
} }
/// Track the maximum number of requests parked in the offline router.
fn record_router_pending(&mut self) { fn record_router_pending(&mut self) {
#[cfg(test)] #[cfg(test)]
let Some(router) = self.router.as_ref() else { let Some(router) = self.router.as_ref() else {
...@@ -184,11 +190,21 @@ impl OfflineRuntime { ...@@ -184,11 +190,21 @@ impl OfflineRuntime {
}; };
#[cfg(test)] #[cfg(test)]
{ {
self.stats.max_router_pending = self.stats.max_router_pending_count = self
self.stats.max_router_pending.max(router.pending_count()); .stats
.max_router_pending_count
.max(router.pending_count());
} }
} }
/// Pick the next worker in round-robin order.
fn next_worker(&mut self) -> usize {
let worker_idx = self.next_worker_idx;
self.next_worker_idx = (self.next_worker_idx + 1) % self.workers.len();
worker_idx
}
/// Record which worker accepted a request and refresh in-flight stats.
fn record_dispatch(&mut self, _uuid: Uuid, _worker_idx: usize) { fn record_dispatch(&mut self, _uuid: Uuid, _worker_idx: usize) {
#[cfg(test)] #[cfg(test)]
{ {
...@@ -201,6 +217,7 @@ impl OfflineRuntime { ...@@ -201,6 +217,7 @@ impl OfflineRuntime {
self.record_in_flight_peak(); self.record_in_flight_peak();
} }
/// Fail fast if a router admission points at a worker that does not exist.
fn validate_worker_idx(&self, worker_idx: usize) -> anyhow::Result<()> { fn validate_worker_idx(&self, worker_idx: usize) -> anyhow::Result<()> {
if worker_idx >= self.workers.len() { if worker_idx >= self.workers.len() {
bail!("offline replay selected unknown worker index {worker_idx}"); bail!("offline replay selected unknown worker index {worker_idx}");
...@@ -208,6 +225,7 @@ impl OfflineRuntime { ...@@ -208,6 +225,7 @@ impl OfflineRuntime {
Ok(()) Ok(())
} }
/// Deliver a request to a worker and update the runtime's bookkeeping for that assignment.
fn dispatch_to_worker( fn dispatch_to_worker(
&mut self, &mut self,
request: DirectRequest, request: DirectRequest,
...@@ -222,6 +240,22 @@ impl OfflineRuntime { ...@@ -222,6 +240,22 @@ impl OfflineRuntime {
Ok(()) Ok(())
} }
/// Submit a request to the router and return an immediate admission when one is available.
fn submit_to_router(
&mut self,
request: &DirectRequest,
replay_hashes: Option<ReplayRequestHashes>,
) -> anyhow::Result<Option<usize>> {
let Some(router) = self.router.as_mut() else {
bail!("offline replay router submission requires an active router");
};
let maybe_worker_idx =
router.submit_request_with_hashes(request, replay_hashes, self.now_ms)?;
self.record_router_pending();
Ok(maybe_worker_idx)
}
/// Materialize router admissions into concrete worker dispatches.
fn dispatch_router_admissions(&mut self, admissions: Vec<(Uuid, usize)>) -> anyhow::Result<()> { fn dispatch_router_admissions(&mut self, admissions: Vec<(Uuid, usize)>) -> anyhow::Result<()> {
for (uuid, worker_idx) in admissions { for (uuid, worker_idx) in admissions {
let request = self let request = self
...@@ -231,12 +265,12 @@ impl OfflineRuntime { ...@@ -231,12 +265,12 @@ impl OfflineRuntime {
anyhow::anyhow!("offline replay missing queued request state for {uuid}") anyhow::anyhow!("offline replay missing queued request state for {uuid}")
})? })?
.take_queued_request(uuid)?; .take_queued_request(uuid)?;
self.queued_requests = self.queued_requests.saturating_sub(1);
self.dispatch_to_worker(request, uuid, worker_idx)?; self.dispatch_to_worker(request, uuid, worker_idx)?;
} }
Ok(()) Ok(())
} }
/// Admit one external request into the collector, optional router, and worker pool.
fn assign_request( fn assign_request(
&mut self, &mut self,
mut request: DirectRequest, mut request: DirectRequest,
...@@ -256,17 +290,13 @@ impl OfflineRuntime { ...@@ -256,17 +290,13 @@ impl OfflineRuntime {
request.max_output_tokens, request.max_output_tokens,
); );
let Some(router) = self.router.as_mut() else { if self.router.is_none() {
self.requests.insert(uuid, AggRequestState::new_running()); self.requests.insert(uuid, AggRequestState::new_running());
let worker_idx = self.next_worker_idx; let worker_idx = self.next_worker();
self.next_worker_idx = (self.next_worker_idx + 1) % self.workers.len();
self.dispatch_to_worker(request, uuid, worker_idx)?; self.dispatch_to_worker(request, uuid, worker_idx)?;
return Ok(uuid); return Ok(uuid);
}; }
let maybe_worker_idx = self.submit_to_router(&request, replay_hashes)?;
let maybe_worker_idx =
router.submit_request_with_hashes(&request, replay_hashes, self.now_ms)?;
self.record_router_pending();
if let Some(worker_idx) = maybe_worker_idx { if let Some(worker_idx) = maybe_worker_idx {
self.requests.insert(uuid, AggRequestState::new_running()); self.requests.insert(uuid, AggRequestState::new_running());
self.dispatch_to_worker(request, uuid, worker_idx)?; self.dispatch_to_worker(request, uuid, worker_idx)?;
...@@ -275,11 +305,11 @@ impl OfflineRuntime { ...@@ -275,11 +305,11 @@ impl OfflineRuntime {
self.requests self.requests
.insert(uuid, AggRequestState::new_queued(request)); .insert(uuid, AggRequestState::new_queued(request));
self.queued_requests += 1;
self.record_in_flight_peak(); self.record_in_flight_peak();
Ok(uuid) Ok(uuid)
} }
/// Return true once no workers, router queues, or admissions remain.
fn is_done(&self) -> bool { fn is_done(&self) -> bool {
self.events.is_empty() self.events.is_empty()
&& self.cluster_in_flight() == 0 && self.cluster_in_flight() == 0
...@@ -290,6 +320,7 @@ impl OfflineRuntime { ...@@ -290,6 +320,7 @@ impl OfflineRuntime {
&& self.workers.iter().all(OfflineWorkerState::is_drained) && self.workers.iter().all(OfflineWorkerState::is_drained)
} }
/// Pick the next logical timestamp from either arrivals or scheduled worker completions.
fn next_timestamp(&mut self) -> Option<f64> { fn next_timestamp(&mut self) -> Option<f64> {
let next_event_ms = self.events.peek().map(|event| event.at_ms); let next_event_ms = self.events.peek().map(|event| event.at_ms);
let cluster_in_flight = self.cluster_in_flight(); let cluster_in_flight = self.cluster_in_flight();
...@@ -311,10 +342,12 @@ impl OfflineRuntime { ...@@ -311,10 +342,12 @@ impl OfflineRuntime {
choose_next_timestamp(next_arrival_ms, next_event_ms) choose_next_timestamp(next_arrival_ms, next_event_ms)
} }
/// Release completed requests from worker-local accounting after a pass finishes.
fn apply_completed_requests(&mut self, worker_idx: usize, completed_requests: usize) { fn apply_completed_requests(&mut self, worker_idx: usize, completed_requests: usize) {
self.workers[worker_idx].mark_completed(completed_requests); self.workers[worker_idx].mark_completed(completed_requests);
} }
/// Apply router-visible KV events at the phase chosen by the scheduler core.
fn apply_router_events(&mut self, events: Vec<RouterEvent>) -> anyhow::Result<()> { fn apply_router_events(&mut self, events: Vec<RouterEvent>) -> anyhow::Result<()> {
let Some(router) = self.router.as_mut() else { let Some(router) = self.router.as_mut() else {
return Ok(()); return Ok(());
...@@ -325,6 +358,7 @@ impl OfflineRuntime { ...@@ -325,6 +358,7 @@ impl OfflineRuntime {
Ok(()) Ok(())
} }
/// Consume one output signal, updating router state, collector state, and completion counts.
fn process_output_signal(&mut self, signal: OutputSignal) -> anyhow::Result<()> { fn process_output_signal(&mut self, signal: OutputSignal) -> anyhow::Result<()> {
let mut admissions = Vec::new(); let mut admissions = Vec::new();
if signal.completed { if signal.completed {
...@@ -334,7 +368,7 @@ impl OfflineRuntime { ...@@ -334,7 +368,7 @@ impl OfflineRuntime {
admissions = router.free(signal.uuid)?; admissions = router.free(signal.uuid)?;
#[cfg(test)] #[cfg(test)]
{ {
self.stats.freed_count += 1; self.stats.router_freed_count += 1;
} }
self.record_router_pending(); self.record_router_pending();
} }
...@@ -379,6 +413,7 @@ impl OfflineRuntime { ...@@ -379,6 +413,7 @@ impl OfflineRuntime {
} }
#[cfg(test)] #[cfg(test)]
/// Remove a request from the test-only active-request tracking for its worker.
fn remove_active_request(&mut self, uuid: Uuid) { fn remove_active_request(&mut self, uuid: Uuid) {
for active_requests in &mut self.worker_active_requests { for active_requests in &mut self.worker_active_requests {
let Some(position) = active_requests let Some(position) = active_requests
...@@ -392,6 +427,7 @@ impl OfflineRuntime { ...@@ -392,6 +427,7 @@ impl OfflineRuntime {
} }
} }
/// Apply one completed pass: free request slots, publish KV events, and handle outputs.
fn process_completed_pass( fn process_completed_pass(
&mut self, &mut self,
worker_idx: usize, worker_idx: usize,
...@@ -407,6 +443,7 @@ impl OfflineRuntime { ...@@ -407,6 +443,7 @@ impl OfflineRuntime {
Ok(()) Ok(())
} }
/// Drain all worker-completion events scheduled for the current logical timestamp.
fn apply_worker_completions(&mut self) -> anyhow::Result<bool> { fn apply_worker_completions(&mut self) -> anyhow::Result<bool> {
let mut changed = false; let mut changed = false;
while let Some(WorkerCompletionPayload { while let Some(WorkerCompletionPayload {
...@@ -426,6 +463,7 @@ impl OfflineRuntime { ...@@ -426,6 +463,7 @@ impl OfflineRuntime {
Ok(changed) Ok(changed)
} }
/// Release every trace arrival whose timestamp is now visible to the global clock.
fn release_trace_arrivals(&mut self) -> anyhow::Result<bool> { fn release_trace_arrivals(&mut self) -> anyhow::Result<bool> {
let mut released_any = false; let mut released_any = false;
if matches!(self.admission, AdmissionSource::Requests(_)) { if matches!(self.admission, AdmissionSource::Requests(_)) {
...@@ -461,6 +499,7 @@ impl OfflineRuntime { ...@@ -461,6 +499,7 @@ impl OfflineRuntime {
Ok(released_any) Ok(released_any)
} }
/// Backfill closed-loop concurrency replay until the configured in-flight limit is reached.
fn top_off_concurrency(&mut self, max_in_flight: usize) -> anyhow::Result<bool> { fn top_off_concurrency(&mut self, max_in_flight: usize) -> anyhow::Result<bool> {
let mut released_any = false; let mut released_any = false;
if matches!(self.admission, AdmissionSource::Requests(_)) { if matches!(self.admission, AdmissionSource::Requests(_)) {
...@@ -501,6 +540,7 @@ impl OfflineRuntime { ...@@ -501,6 +540,7 @@ impl OfflineRuntime {
Ok(released_any) Ok(released_any)
} }
/// Start passes on every idle worker that can make progress at the current timestamp.
fn drive_ready_workers(&mut self) -> anyhow::Result<bool> { fn drive_ready_workers(&mut self) -> anyhow::Result<bool> {
let mut changed = false; let mut changed = false;
for worker_idx in 0..self.workers.len() { for worker_idx in 0..self.workers.len() {
...@@ -553,6 +593,7 @@ impl OfflineRuntime { ...@@ -553,6 +593,7 @@ impl OfflineRuntime {
Ok(changed) Ok(changed)
} }
/// Repeatedly process all work that becomes possible without advancing logical time.
fn drain_current_timestamp(&mut self) -> anyhow::Result<()> { fn drain_current_timestamp(&mut self) -> anyhow::Result<()> {
loop { loop {
let mut changed = self.apply_worker_completions()?; let mut changed = self.apply_worker_completions()?;
...@@ -574,7 +615,8 @@ impl OfflineRuntime { ...@@ -574,7 +615,8 @@ impl OfflineRuntime {
Ok(()) Ok(())
} }
fn run(mut self) -> anyhow::Result<(TraceCollector, OfflineRuntimeStats)> { /// Run the aggregated offline replay until all arrivals and worker work are exhausted.
pub(super) fn run(mut self) -> anyhow::Result<(TraceCollector, AggRuntimeStats)> {
self.drain_current_timestamp()?; self.drain_current_timestamp()?;
while !self.is_done() { while !self.is_done() {
...@@ -589,14 +631,11 @@ impl OfflineRuntime { ...@@ -589,14 +631,11 @@ impl OfflineRuntime {
self.drain_current_timestamp()?; self.drain_current_timestamp()?;
} }
if let Some(router) = self.router.as_mut() {
router.shutdown();
}
Ok((self.collector, self.stats)) Ok((self.collector, self.stats))
} }
#[cfg(test)] #[cfg(test)]
/// Test helper: advance exactly one logical timestamp worth of work.
fn advance_one_timestamp(&mut self) -> anyhow::Result<bool> { fn advance_one_timestamp(&mut self) -> anyhow::Result<bool> {
if self.is_done() { if self.is_done() {
return Ok(false); return Ok(false);
...@@ -621,7 +660,8 @@ impl OfflineRuntime { ...@@ -621,7 +660,8 @@ impl OfflineRuntime {
} }
#[cfg(test)] #[cfg(test)]
fn debug_snapshot(&self) -> OfflineRuntimeSnapshot { /// Test helper: snapshot the runtime's visible request, worker, and router state.
fn debug_snapshot(&self) -> AggRuntimeSnapshot {
let mut router_pending_request_ids = self let mut router_pending_request_ids = self
.requests .requests
.iter() .iter()
...@@ -637,7 +677,7 @@ impl OfflineRuntime { ...@@ -637,7 +677,7 @@ impl OfflineRuntime {
.collect::<Vec<_>>(); .collect::<Vec<_>>();
prefill_completed.sort_unstable(); prefill_completed.sort_unstable();
OfflineRuntimeSnapshot { AggRuntimeSnapshot {
now_ms: self.now_ms, now_ms: self.now_ms,
worker_active_requests: self.worker_active_requests.clone(), worker_active_requests: self.worker_active_requests.clone(),
workers: self workers: self
...@@ -655,182 +695,17 @@ impl OfflineRuntime { ...@@ -655,182 +695,17 @@ impl OfflineRuntime {
} }
} }
pub(crate) fn simulate_trace_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let (collector, _) = OfflineRuntime::new(
&args,
router_config,
pending,
num_workers,
ReplayMode::Trace,
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = VecDeque::from(requests);
let (collector, _) = OfflineRuntime::new(
&args,
router_config,
pending,
num_workers,
ReplayMode::Concurrency { max_in_flight },
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_trace_workload_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let driver = trace.into_trace_driver()?;
let (collector, _) = OfflineRuntime::new_workload(
&args,
router_config,
driver,
num_workers,
ReplayMode::Trace,
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_workload_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let driver = trace.into_concurrency_driver()?;
let (collector, _) = OfflineRuntime::new_workload(
&args,
router_config,
driver,
num_workers,
ReplayMode::Concurrency { max_in_flight },
router_mode,
)?
.run()?;
Ok(collector.finish())
}
#[cfg(test)]
fn run_trace_multi_collect_with_stats(
args: &MockEngineArgs,
requests: Vec<DirectRequest>,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, OfflineRuntimeStats) {
let pending = normalize_trace_requests(requests, 1.0).unwrap();
OfflineRuntime::new(
args,
None,
pending,
num_workers,
ReplayMode::Trace,
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
fn run_concurrency_multi_collect_with_stats(
args: &MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, OfflineRuntimeStats) {
OfflineRuntime::new(
args,
None,
VecDeque::from(requests),
num_workers,
ReplayMode::Concurrency { max_in_flight },
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
fn run_trace_workload_multi_collect_with_stats(
args: &MockEngineArgs,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, OfflineRuntimeStats) {
OfflineRuntime::new_workload(
args,
None,
trace.into_trace_driver().unwrap(),
num_workers,
ReplayMode::Trace,
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
fn run_concurrency_workload_multi_collect_with_stats(
args: &MockEngineArgs,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, OfflineRuntimeStats) {
OfflineRuntime::new_workload(
args,
None,
trace.into_concurrency_driver().unwrap(),
num_workers,
ReplayMode::Concurrency { max_in_flight },
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::super::single::{run_concurrency_single_collect, run_trace_single_collect}; use super::super::entrypoints::{
run_concurrency_multi_collect_with_stats, run_concurrency_single_collect,
run_concurrency_workload_multi_collect_with_stats, run_trace_multi_collect_with_stats,
run_trace_single_collect, run_trace_workload_multi_collect_with_stats,
};
use super::*; use super::*;
use crate::common::protocols::{EngineType, SglangArgs}; use crate::common::protocols::{EngineType, SglangArgs};
use crate::loadgen::{SessionTrace, TurnTrace}; use crate::loadgen::{SessionTrace, Trace, TurnTrace};
use crate::replay::normalize_trace_requests;
use dynamo_kv_router::config::RouterQueuePolicy; use dynamo_kv_router::config::RouterQueuePolicy;
fn replay_args(enable_prefix_caching: bool, enable_chunked_prefill: bool) -> MockEngineArgs { fn replay_args(enable_prefix_caching: bool, enable_chunked_prefill: bool) -> MockEngineArgs {
...@@ -1079,7 +954,7 @@ mod tests { ...@@ -1079,7 +954,7 @@ mod tests {
#[test] #[test]
fn test_multi_worker_trace_kv_router_debug_snapshot_tracks_queue_and_cached_dispatch() { fn test_multi_worker_trace_kv_router_debug_snapshot_tracks_queue_and_cached_dispatch() {
let args = queueing_router_args(RouterQueuePolicy::Fcfs); let args = queueing_router_args(RouterQueuePolicy::Fcfs);
let mut runtime = OfflineRuntime::new( let mut runtime = AggRuntime::new(
&args, &args,
None, None,
normalize_trace_requests( normalize_trace_requests(
...@@ -1459,8 +1334,8 @@ mod tests { ...@@ -1459,8 +1334,8 @@ mod tests {
); );
assert_eq!(stats.prefill_marked_count, 1); assert_eq!(stats.prefill_marked_count, 1);
assert_eq!(stats.freed_count, 2); assert_eq!(stats.router_freed_count, 2);
assert_eq!(stats.max_router_pending, 0); assert_eq!(stats.max_router_pending_count, 0);
} }
#[test] #[test]
...@@ -1503,7 +1378,7 @@ mod tests { ...@@ -1503,7 +1378,7 @@ mod tests {
.unwrap() .unwrap()
.min(request_2.first_token_ms.unwrap()); .min(request_2.first_token_ms.unwrap());
assert!(stats.max_router_pending > 0); assert!(stats.max_router_pending_count > 0);
assert!(request_3.first_admit_ms.unwrap() > request_3.arrival_time_ms); assert!(request_3.first_admit_ms.unwrap() > request_3.arrival_time_ms);
assert_eq!(request_3.first_admit_ms.unwrap(), first_unblock_ms); assert_eq!(request_3.first_admit_ms.unwrap(), first_unblock_ms);
assert!(request_3.first_admit_ms.unwrap() < request_1.last_token_ms.unwrap()); assert!(request_3.first_admit_ms.unwrap() < request_1.last_token_ms.unwrap());
...@@ -1556,8 +1431,8 @@ mod tests { ...@@ -1556,8 +1431,8 @@ mod tests {
ReplayRouterMode::KvRouter, ReplayRouterMode::KvRouter,
); );
assert!(fcfs_stats.max_router_pending > 0); assert!(fcfs_stats.max_router_pending_count > 0);
assert!(lcfs_stats.max_router_pending > 0); assert!(lcfs_stats.max_router_pending_count > 0);
assert_eq!( assert_eq!(
&fcfs_stats.dispatch_order[..2], &fcfs_stats.dispatch_order[..2],
&[Uuid::from_u128(10), Uuid::from_u128(20)] &[Uuid::from_u128(10), Uuid::from_u128(20)]
...@@ -1628,8 +1503,8 @@ mod tests { ...@@ -1628,8 +1503,8 @@ mod tests {
let lcfs_request_30 = lcfs_collector.snapshot(Uuid::from_u128(30)).unwrap(); let lcfs_request_30 = lcfs_collector.snapshot(Uuid::from_u128(30)).unwrap();
let lcfs_request_40 = lcfs_collector.snapshot(Uuid::from_u128(40)).unwrap(); let lcfs_request_40 = lcfs_collector.snapshot(Uuid::from_u128(40)).unwrap();
assert!(fcfs_stats.max_router_pending > 0); assert!(fcfs_stats.max_router_pending_count > 0);
assert!(lcfs_stats.max_router_pending > 0); assert!(lcfs_stats.max_router_pending_count > 0);
assert_eq!( assert_eq!(
&fcfs_stats.dispatch_order[2..4], &fcfs_stats.dispatch_order[2..4],
&[Uuid::from_u128(30), Uuid::from_u128(40)] &[Uuid::from_u128(30), Uuid::from_u128(40)]
...@@ -1683,7 +1558,7 @@ mod tests { ...@@ -1683,7 +1558,7 @@ mod tests {
); );
assert_eq!(stats.max_in_flight_seen, 3); assert_eq!(stats.max_in_flight_seen, 3);
assert!(stats.max_router_pending > 0); assert!(stats.max_router_pending_count > 0);
} }
#[test] #[test]
...@@ -1805,7 +1680,7 @@ mod tests { ...@@ -1805,7 +1680,7 @@ mod tests {
run_trace_multi_collect_with_stats(&args, requests, 1, ReplayRouterMode::KvRouter); run_trace_multi_collect_with_stats(&args, requests, 1, ReplayRouterMode::KvRouter);
assert_eq!(stats.dispatch_history, vec![0, 0, 0]); assert_eq!(stats.dispatch_history, vec![0, 0, 0]);
assert_eq!(stats.max_router_pending, 0); assert_eq!(stats.max_router_pending_count, 0);
for uuid in [11_u128, 22, 33] { for uuid in [11_u128, 22, 33] {
assert_eq!( assert_eq!(
multi.snapshot(Uuid::from_u128(uuid)), multi.snapshot(Uuid::from_u128(uuid)),
...@@ -1898,7 +1773,7 @@ mod tests { ...@@ -1898,7 +1773,7 @@ mod tests {
); );
assert_eq!(stats.dispatch_history, vec![0, 0, 0]); assert_eq!(stats.dispatch_history, vec![0, 0, 0]);
assert_eq!(stats.max_router_pending, 0); assert_eq!(stats.max_router_pending_count, 0);
for uuid in [11_u128, 22, 33] { for uuid in [11_u128, 22, 33] {
assert_eq!( assert_eq!(
multi.snapshot(Uuid::from_u128(uuid)), multi.snapshot(Uuid::from_u128(uuid)),
......
...@@ -9,7 +9,6 @@ use dynamo_kv_router::protocols::RouterEvent; ...@@ -9,7 +9,6 @@ use dynamo_kv_router::protocols::RouterEvent;
use uuid::Uuid; use uuid::Uuid;
use super::events::{SimulationEvent, SimulationWorkerStage}; use super::events::{SimulationEvent, SimulationWorkerStage};
use super::normalize_trace_requests;
use super::runtime_utils::{ use super::runtime_utils::{
WorkerCompletionPayload, next_timestamp as choose_next_timestamp, pop_next_concurrency_ready, WorkerCompletionPayload, next_timestamp as choose_next_timestamp, pop_next_concurrency_ready,
pop_next_trace_ready, pop_ready_decode_handoff, pop_ready_worker_completion, pop_next_trace_ready, pop_ready_decode_handoff, pop_ready_worker_completion,
...@@ -21,15 +20,13 @@ use super::state::DisaggPhase; ...@@ -21,15 +20,13 @@ use super::state::DisaggPhase;
use super::state::DisaggRequestSnapshot; use super::state::DisaggRequestSnapshot;
use super::state::{DisaggRequestState, OfflineWorkerState}; use super::state::{DisaggRequestState, OfflineWorkerState};
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal}; use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal};
use crate::loadgen::{ReplayRequestHashes, Trace, WorkloadDriver}; use crate::loadgen::{ReplayRequestHashes, WorkloadDriver};
use crate::replay::router::OfflineReplayRouter; use crate::replay::router::OfflineReplayRouter;
use crate::replay::{ use crate::replay::{OfflineDisaggReplayConfig, ReplayRouterMode, TraceCollector};
OfflineDisaggReplayConfig, ReplayRouterMode, TraceCollector, TraceSimulationReport,
};
use crate::scheduler::RouterEventVisibility; use crate::scheduler::RouterEventVisibility;
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
enum ReplayMode { pub(super) enum ReplayMode {
Trace, Trace,
Concurrency { max_in_flight: usize }, Concurrency { max_in_flight: usize },
} }
...@@ -41,23 +38,23 @@ enum AdmissionSource { ...@@ -41,23 +38,23 @@ enum AdmissionSource {
#[cfg(test)] #[cfg(test)]
#[derive(Debug, Default, Clone, PartialEq)] #[derive(Debug, Default, Clone, PartialEq)]
struct DisaggRuntimeStats { pub(super) struct DisaggRuntimeStats {
request_snapshots: HashMap<Uuid, DisaggRequestSnapshot>, request_snapshots: HashMap<Uuid, DisaggRequestSnapshot>,
prefill_assignments: HashMap<Uuid, usize>, prefill_assignments: HashMap<Uuid, usize>,
decode_assignments: HashMap<Uuid, usize>, decode_assignments: HashMap<Uuid, usize>,
handoff_ms: HashMap<Uuid, f64>, handoff_ms: HashMap<Uuid, f64>,
prefill_marked_count: usize, prefill_marked_count: usize,
prefill_freed_count: usize, prefill_router_freed_count: usize,
decode_freed_count: usize, decode_router_freed_count: usize,
max_prefill_router_pending: usize, max_prefill_router_pending_count: usize,
max_decode_router_pending: usize, max_decode_router_pending_count: usize,
} }
#[cfg(not(test))] #[cfg(not(test))]
#[derive(Debug, Default, Clone, PartialEq, Eq)] #[derive(Debug, Default, Clone, PartialEq, Eq)]
struct DisaggRuntimeStats; pub(super) struct DisaggRuntimeStats;
struct DisaggRuntime { pub(super) struct DisaggRuntime {
now_ms: f64, now_ms: f64,
next_prefill_worker_idx: usize, next_prefill_worker_idx: usize,
next_decode_worker_idx: usize, next_decode_worker_idx: usize,
...@@ -75,7 +72,8 @@ struct DisaggRuntime { ...@@ -75,7 +72,8 @@ struct DisaggRuntime {
} }
impl DisaggRuntime { impl DisaggRuntime {
fn new( /// Create a disaggregated offline runtime seeded from an explicit request queue.
pub(super) fn new(
config: &OfflineDisaggReplayConfig, config: &OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
pending: VecDeque<DirectRequest>, pending: VecDeque<DirectRequest>,
...@@ -91,7 +89,8 @@ impl DisaggRuntime { ...@@ -91,7 +89,8 @@ impl DisaggRuntime {
) )
} }
fn new_workload( /// Create a disaggregated offline runtime whose admissions come from a workload driver.
pub(super) fn new_workload(
config: &OfflineDisaggReplayConfig, config: &OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
driver: WorkloadDriver, driver: WorkloadDriver,
...@@ -107,6 +106,7 @@ impl DisaggRuntime { ...@@ -107,6 +106,7 @@ impl DisaggRuntime {
) )
} }
/// Shared constructor for both raw-request and workload-driven admissions.
fn new_with_source( fn new_with_source(
config: &OfflineDisaggReplayConfig, config: &OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
...@@ -169,6 +169,7 @@ impl DisaggRuntime { ...@@ -169,6 +169,7 @@ impl DisaggRuntime {
}) })
} }
/// Count all requests consuming cluster capacity across prefill, decode, and router queues.
fn cluster_in_flight(&self) -> usize { fn cluster_in_flight(&self) -> usize {
self.prefill_workers self.prefill_workers
.iter() .iter()
...@@ -189,6 +190,7 @@ impl DisaggRuntime { ...@@ -189,6 +190,7 @@ impl DisaggRuntime {
.map_or(0, OfflineReplayRouter::pending_count) .map_or(0, OfflineReplayRouter::pending_count)
} }
/// Pick the next prefill worker in round-robin order.
fn next_prefill_worker(&mut self) -> usize { fn next_prefill_worker(&mut self) -> usize {
let worker_idx = self.next_prefill_worker_idx; let worker_idx = self.next_prefill_worker_idx;
self.next_prefill_worker_idx = self.next_prefill_worker_idx =
...@@ -196,28 +198,33 @@ impl DisaggRuntime { ...@@ -196,28 +198,33 @@ impl DisaggRuntime {
worker_idx worker_idx
} }
/// Pick the next decode worker in round-robin order.
fn next_decode_worker(&mut self) -> usize { fn next_decode_worker(&mut self) -> usize {
let worker_idx = self.next_decode_worker_idx; let worker_idx = self.next_decode_worker_idx;
self.next_decode_worker_idx = (self.next_decode_worker_idx + 1) % self.decode_workers.len(); self.next_decode_worker_idx = (self.next_decode_worker_idx + 1) % self.decode_workers.len();
worker_idx worker_idx
} }
/// Track the peak number of requests parked in each stage router.
fn record_router_pending(&mut self) { fn record_router_pending(&mut self) {
#[cfg(test)] #[cfg(test)]
{ {
self.stats.max_prefill_router_pending = self.stats.max_prefill_router_pending.max( self.stats.max_prefill_router_pending_count =
self.prefill_router self.stats.max_prefill_router_pending_count.max(
.as_ref() self.prefill_router
.map_or(0, OfflineReplayRouter::pending_count), .as_ref()
); .map_or(0, OfflineReplayRouter::pending_count),
self.stats.max_decode_router_pending = self.stats.max_decode_router_pending.max( );
self.decode_router self.stats.max_decode_router_pending_count =
.as_ref() self.stats.max_decode_router_pending_count.max(
.map_or(0, OfflineReplayRouter::pending_count), self.decode_router
); .as_ref()
.map_or(0, OfflineReplayRouter::pending_count),
);
} }
} }
/// Fail fast if a router admission targets a non-existent stage worker.
fn validate_worker_idx(&self, stage: SimulationWorkerStage, worker_idx: usize) -> Result<()> { fn validate_worker_idx(&self, stage: SimulationWorkerStage, worker_idx: usize) -> Result<()> {
let worker_count = match stage { let worker_count = match stage {
SimulationWorkerStage::Prefill => self.prefill_workers.len(), SimulationWorkerStage::Prefill => self.prefill_workers.len(),
...@@ -230,18 +237,21 @@ impl DisaggRuntime { ...@@ -230,18 +237,21 @@ impl DisaggRuntime {
Ok(()) Ok(())
} }
/// Borrow immutable request state with a structured missing-request error.
fn state(&self, uuid: Uuid) -> Result<&DisaggRequestState> { fn state(&self, uuid: Uuid) -> Result<&DisaggRequestState> {
self.requests self.requests
.get(&uuid) .get(&uuid)
.ok_or_else(|| anyhow!("offline disagg replay missing request state for {uuid}")) .ok_or_else(|| anyhow!("offline disagg replay missing request state for {uuid}"))
} }
/// Borrow mutable request state with a structured missing-request error.
fn state_mut(&mut self, uuid: Uuid) -> Result<&mut DisaggRequestState> { fn state_mut(&mut self, uuid: Uuid) -> Result<&mut DisaggRequestState> {
self.requests self.requests
.get_mut(&uuid) .get_mut(&uuid)
.ok_or_else(|| anyhow!("offline disagg replay missing request state for {uuid}")) .ok_or_else(|| anyhow!("offline disagg replay missing request state for {uuid}"))
} }
/// Dispatch a request's prefill stage onto a specific prefill worker.
fn dispatch_prefill(&mut self, uuid: Uuid, worker_idx: usize) -> Result<()> { fn dispatch_prefill(&mut self, uuid: Uuid, worker_idx: usize) -> Result<()> {
self.validate_worker_idx(SimulationWorkerStage::Prefill, worker_idx)?; self.validate_worker_idx(SimulationWorkerStage::Prefill, worker_idx)?;
let request = self.state(uuid)?.build_prefill_request()?; let request = self.state(uuid)?.build_prefill_request()?;
...@@ -254,6 +264,7 @@ impl DisaggRuntime { ...@@ -254,6 +264,7 @@ impl DisaggRuntime {
Ok(()) Ok(())
} }
/// Dispatch a request's decode stage onto a specific decode worker.
fn dispatch_decode(&mut self, uuid: Uuid, worker_idx: usize) -> Result<()> { fn dispatch_decode(&mut self, uuid: Uuid, worker_idx: usize) -> Result<()> {
self.validate_worker_idx(SimulationWorkerStage::Decode, worker_idx)?; self.validate_worker_idx(SimulationWorkerStage::Decode, worker_idx)?;
let request = self.state(uuid)?.build_decode_request()?; let request = self.state(uuid)?.build_decode_request()?;
...@@ -266,6 +277,7 @@ impl DisaggRuntime { ...@@ -266,6 +277,7 @@ impl DisaggRuntime {
Ok(()) Ok(())
} }
/// Turn prefill router admissions into concrete worker dispatches.
fn dispatch_prefill_admissions(&mut self, admissions: Vec<(Uuid, usize)>) -> Result<()> { fn dispatch_prefill_admissions(&mut self, admissions: Vec<(Uuid, usize)>) -> Result<()> {
for (uuid, worker_idx) in admissions { for (uuid, worker_idx) in admissions {
if !self.state(uuid)?.is_queued_prefill() { if !self.state(uuid)?.is_queued_prefill() {
...@@ -276,6 +288,7 @@ impl DisaggRuntime { ...@@ -276,6 +288,7 @@ impl DisaggRuntime {
Ok(()) Ok(())
} }
/// Turn decode router admissions into concrete worker dispatches.
fn dispatch_decode_admissions(&mut self, admissions: Vec<(Uuid, usize)>) -> Result<()> { fn dispatch_decode_admissions(&mut self, admissions: Vec<(Uuid, usize)>) -> Result<()> {
for (uuid, worker_idx) in admissions { for (uuid, worker_idx) in admissions {
if !self.state(uuid)?.is_queued_decode() { if !self.state(uuid)?.is_queued_decode() {
...@@ -286,21 +299,42 @@ impl DisaggRuntime { ...@@ -286,21 +299,42 @@ impl DisaggRuntime {
Ok(()) Ok(())
} }
fn enqueue_decode(&mut self, uuid: Uuid) -> Result<()> { /// Submit a request to the prefill router and return an immediate admission when available.
fn submit_to_prefill_router(
&mut self,
uuid: Uuid,
replay_hashes: Option<ReplayRequestHashes>,
) -> Result<Option<usize>> {
let request = self.state(uuid)?.original_request()?.clone();
let Some(prefill_router) = self.prefill_router.as_mut() else {
bail!("offline disagg replay prefill submission requires an active router");
};
let maybe_worker_idx =
prefill_router.submit_request_with_hashes(&request, replay_hashes, self.now_ms)?;
self.record_router_pending();
Ok(maybe_worker_idx)
}
/// Submit a request to the decode router and return an immediate admission when available.
fn submit_to_decode_router(&mut self, uuid: Uuid) -> Result<Option<usize>> {
let request = self.state(uuid)?.original_request()?.clone();
let Some(decode_router) = self.decode_router.as_mut() else { let Some(decode_router) = self.decode_router.as_mut() else {
bail!("offline disagg replay decode submission requires an active router");
};
let maybe_worker_idx =
decode_router.submit_request_with_hashes(&request, None, self.now_ms)?;
self.record_router_pending();
Ok(maybe_worker_idx)
}
/// Queue or dispatch a request into decode, depending on whether a decode router is active.
fn enqueue_decode(&mut self, uuid: Uuid) -> Result<()> {
if self.decode_router.is_none() {
let worker_idx = self.next_decode_worker(); let worker_idx = self.next_decode_worker();
self.dispatch_decode(uuid, worker_idx)?; self.dispatch_decode(uuid, worker_idx)?;
return Ok(()); return Ok(());
}; }
let maybe_worker_idx = { let maybe_worker_idx = self.submit_to_decode_router(uuid)?;
let requests = &self.requests;
let request = requests
.get(&uuid)
.ok_or_else(|| anyhow!("offline disagg replay missing request state for {uuid}"))?
.original_request()?;
decode_router.submit_request_with_hashes(request, None, self.now_ms)?
};
self.record_router_pending();
#[cfg(test)] #[cfg(test)]
{ {
self.stats.handoff_ms.insert(uuid, self.now_ms); self.stats.handoff_ms.insert(uuid, self.now_ms);
...@@ -314,6 +348,7 @@ impl DisaggRuntime { ...@@ -314,6 +348,7 @@ impl DisaggRuntime {
Ok(()) Ok(())
} }
/// Admit one external request into prefill-side state, collector state, and optional router.
fn on_external_arrival( fn on_external_arrival(
&mut self, &mut self,
mut request: DirectRequest, mut request: DirectRequest,
...@@ -333,26 +368,19 @@ impl DisaggRuntime { ...@@ -333,26 +368,19 @@ impl DisaggRuntime {
self.requests self.requests
.insert(uuid, DisaggRequestState::new(request, arrival_time_ms)); .insert(uuid, DisaggRequestState::new(request, arrival_time_ms));
let Some(prefill_router) = self.prefill_router.as_mut() else { if self.prefill_router.is_none() {
let worker_idx = self.next_prefill_worker(); let worker_idx = self.next_prefill_worker();
self.dispatch_prefill(uuid, worker_idx)?; self.dispatch_prefill(uuid, worker_idx)?;
return Ok(uuid); return Ok(uuid);
}; }
let maybe_worker_idx = { let maybe_worker_idx = self.submit_to_prefill_router(uuid, replay_hashes)?;
let requests = &self.requests;
let request = requests
.get(&uuid)
.ok_or_else(|| anyhow!("offline disagg replay missing request state for {uuid}"))?
.original_request()?;
prefill_router.submit_request_with_hashes(request, replay_hashes, self.now_ms)?
};
self.record_router_pending();
if let Some(worker_idx) = maybe_worker_idx { if let Some(worker_idx) = maybe_worker_idx {
self.dispatch_prefill(uuid, worker_idx)?; self.dispatch_prefill(uuid, worker_idx)?;
} }
Ok(uuid) Ok(uuid)
} }
/// Return true once both stages, both routers, and all admissions are fully drained.
fn is_done(&self) -> bool { fn is_done(&self) -> bool {
self.events.is_empty() self.events.is_empty()
&& self.cluster_in_flight() == 0 && self.cluster_in_flight() == 0
...@@ -370,6 +398,7 @@ impl DisaggRuntime { ...@@ -370,6 +398,7 @@ impl DisaggRuntime {
.all(OfflineWorkerState::is_drained) .all(OfflineWorkerState::is_drained)
} }
/// Pick the next logical timestamp from arrivals, worker completions, or decode handoffs.
fn next_timestamp(&mut self) -> Option<f64> { fn next_timestamp(&mut self) -> Option<f64> {
let next_event_ms = self.events.peek().map(|event| event.at_ms); let next_event_ms = self.events.peek().map(|event| event.at_ms);
let cluster_in_flight = self.cluster_in_flight(); let cluster_in_flight = self.cluster_in_flight();
...@@ -391,6 +420,7 @@ impl DisaggRuntime { ...@@ -391,6 +420,7 @@ impl DisaggRuntime {
choose_next_timestamp(next_arrival_ms, next_event_ms) choose_next_timestamp(next_arrival_ms, next_event_ms)
} }
/// Apply prefill-side KV router events at the scheduler-selected visibility phase.
fn apply_prefill_router_events(&mut self, events: Vec<RouterEvent>) -> Result<()> { fn apply_prefill_router_events(&mut self, events: Vec<RouterEvent>) -> Result<()> {
let Some(prefill_router) = self.prefill_router.as_mut() else { let Some(prefill_router) = self.prefill_router.as_mut() else {
return Ok(()); return Ok(());
...@@ -401,6 +431,7 @@ impl DisaggRuntime { ...@@ -401,6 +431,7 @@ impl DisaggRuntime {
Ok(()) Ok(())
} }
/// Process one prefill output signal, including router updates and decode handoff scheduling.
fn process_prefill_signal(&mut self, signal: OutputSignal) -> Result<()> { fn process_prefill_signal(&mut self, signal: OutputSignal) -> Result<()> {
if !signal.completed { if !signal.completed {
return Ok(()); return Ok(());
...@@ -424,7 +455,7 @@ impl DisaggRuntime { ...@@ -424,7 +455,7 @@ impl DisaggRuntime {
}; };
#[cfg(test)] #[cfg(test)]
{ {
self.stats.prefill_freed_count += 1; self.stats.prefill_router_freed_count += 1;
} }
self.record_router_pending(); self.record_router_pending();
self.dispatch_prefill_admissions(admissions)?; self.dispatch_prefill_admissions(admissions)?;
...@@ -433,6 +464,7 @@ impl DisaggRuntime { ...@@ -433,6 +464,7 @@ impl DisaggRuntime {
self.enqueue_decode_after_handoff(signal.uuid, signal.handoff_delay_ms) self.enqueue_decode_after_handoff(signal.uuid, signal.handoff_delay_ms)
} }
/// Process one decode output signal, including decode router frees and request completion.
fn process_decode_signal(&mut self, signal: OutputSignal) -> Result<()> { fn process_decode_signal(&mut self, signal: OutputSignal) -> Result<()> {
if !signal.completed { if !signal.completed {
return Ok(()); return Ok(());
...@@ -442,7 +474,7 @@ impl DisaggRuntime { ...@@ -442,7 +474,7 @@ impl DisaggRuntime {
let admissions = decode_router.free(signal.uuid)?; let admissions = decode_router.free(signal.uuid)?;
#[cfg(test)] #[cfg(test)]
{ {
self.stats.decode_freed_count += 1; self.stats.decode_router_freed_count += 1;
} }
admissions admissions
} else { } else {
...@@ -457,6 +489,7 @@ impl DisaggRuntime { ...@@ -457,6 +489,7 @@ impl DisaggRuntime {
Ok(()) Ok(())
} }
/// Apply the side effects of a finished prefill pass.
fn process_prefill_pass( fn process_prefill_pass(
&mut self, &mut self,
worker_idx: usize, worker_idx: usize,
...@@ -472,6 +505,7 @@ impl DisaggRuntime { ...@@ -472,6 +505,7 @@ impl DisaggRuntime {
Ok(()) Ok(())
} }
/// Apply the side effects of a finished decode pass.
fn process_decode_pass( fn process_decode_pass(
&mut self, &mut self,
worker_idx: usize, worker_idx: usize,
...@@ -485,6 +519,7 @@ impl DisaggRuntime { ...@@ -485,6 +519,7 @@ impl DisaggRuntime {
Ok(()) Ok(())
} }
/// Drain all worker-completion events scheduled for the current logical timestamp.
fn apply_worker_completions(&mut self) -> Result<bool> { fn apply_worker_completions(&mut self) -> Result<bool> {
let mut changed = false; let mut changed = false;
while let Some(WorkerCompletionPayload { while let Some(WorkerCompletionPayload {
...@@ -518,6 +553,7 @@ impl DisaggRuntime { ...@@ -518,6 +553,7 @@ impl DisaggRuntime {
Ok(changed) Ok(changed)
} }
/// Drain all delayed decode handoff events scheduled for the current logical timestamp.
fn apply_decode_handoffs(&mut self) -> Result<bool> { fn apply_decode_handoffs(&mut self) -> Result<bool> {
let mut changed = false; let mut changed = false;
while let Some(uuid) = pop_ready_decode_handoff(&mut self.events, self.now_ms) { while let Some(uuid) = pop_ready_decode_handoff(&mut self.events, self.now_ms) {
...@@ -527,6 +563,7 @@ impl DisaggRuntime { ...@@ -527,6 +563,7 @@ impl DisaggRuntime {
Ok(changed) Ok(changed)
} }
/// Either enqueue decode immediately or schedule a delayed handoff event on the event heap.
fn enqueue_decode_after_handoff( fn enqueue_decode_after_handoff(
&mut self, &mut self,
uuid: Uuid, uuid: Uuid,
...@@ -547,6 +584,7 @@ impl DisaggRuntime { ...@@ -547,6 +584,7 @@ impl DisaggRuntime {
self.enqueue_decode(uuid) self.enqueue_decode(uuid)
} }
/// Release every trace arrival whose timestamp is now visible to the global clock.
fn release_trace_arrivals(&mut self) -> Result<bool> { fn release_trace_arrivals(&mut self) -> Result<bool> {
let mut released_any = false; let mut released_any = false;
if matches!(self.admission, AdmissionSource::Requests(_)) { if matches!(self.admission, AdmissionSource::Requests(_)) {
...@@ -581,6 +619,7 @@ impl DisaggRuntime { ...@@ -581,6 +619,7 @@ impl DisaggRuntime {
Ok(released_any) Ok(released_any)
} }
/// Backfill closed-loop concurrency replay until the staged cluster reaches its cap.
fn top_off_concurrency(&mut self, max_in_flight: usize) -> Result<bool> { fn top_off_concurrency(&mut self, max_in_flight: usize) -> Result<bool> {
let mut released_any = false; let mut released_any = false;
if matches!(self.admission, AdmissionSource::Requests(_)) { if matches!(self.admission, AdmissionSource::Requests(_)) {
...@@ -620,6 +659,7 @@ impl DisaggRuntime { ...@@ -620,6 +659,7 @@ impl DisaggRuntime {
Ok(released_any) Ok(released_any)
} }
/// Start passes on every idle prefill worker that can make progress at the current timestamp.
fn drive_prefill_workers(&mut self) -> Result<bool> { fn drive_prefill_workers(&mut self) -> Result<bool> {
let mut changed = false; let mut changed = false;
for worker_idx in 0..self.prefill_workers.len() { for worker_idx in 0..self.prefill_workers.len() {
...@@ -668,6 +708,7 @@ impl DisaggRuntime { ...@@ -668,6 +708,7 @@ impl DisaggRuntime {
Ok(changed) Ok(changed)
} }
/// Start passes on every idle decode worker that can make progress at the current timestamp.
fn drive_decode_workers(&mut self) -> Result<bool> { fn drive_decode_workers(&mut self) -> Result<bool> {
let mut changed = false; let mut changed = false;
for worker_idx in 0..self.decode_workers.len() { for worker_idx in 0..self.decode_workers.len() {
...@@ -710,6 +751,7 @@ impl DisaggRuntime { ...@@ -710,6 +751,7 @@ impl DisaggRuntime {
Ok(changed) Ok(changed)
} }
/// Repeatedly process all work that becomes possible without advancing logical time.
fn drain_current_timestamp(&mut self) -> Result<()> { fn drain_current_timestamp(&mut self) -> Result<()> {
loop { loop {
let mut changed = self.apply_worker_completions()?; let mut changed = self.apply_worker_completions()?;
...@@ -732,6 +774,7 @@ impl DisaggRuntime { ...@@ -732,6 +774,7 @@ impl DisaggRuntime {
Ok(()) Ok(())
} }
/// Finalize test-only request snapshots before returning.
fn finish_test_stats(&mut self) { fn finish_test_stats(&mut self) {
#[cfg(test)] #[cfg(test)]
{ {
...@@ -743,7 +786,8 @@ impl DisaggRuntime { ...@@ -743,7 +786,8 @@ impl DisaggRuntime {
} }
} }
fn run(mut self) -> Result<(TraceCollector, DisaggRuntimeStats)> { /// Run the staged offline replay until both prefill and decode pipelines are drained.
pub(super) fn run(mut self) -> Result<(TraceCollector, DisaggRuntimeStats)> {
self.drain_current_timestamp()?; self.drain_current_timestamp()?;
while !self.is_done() { while !self.is_done() {
...@@ -793,124 +837,9 @@ fn derive_decode_router_config( ...@@ -793,124 +837,9 @@ fn derive_decode_router_config(
config config
} }
pub(crate) fn simulate_trace_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let (collector, _) = DisaggRuntime::new(
&config,
router_config,
pending,
ReplayMode::Trace,
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let pending = VecDeque::from(requests);
let (collector, _) = DisaggRuntime::new(
&config,
router_config,
pending,
ReplayMode::Concurrency { max_in_flight },
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_trace_workload_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let driver = WorkloadDriver::new_trace(trace)?;
let (collector, _) = DisaggRuntime::new_workload(
&config,
router_config,
driver,
ReplayMode::Trace,
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_workload_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let driver = WorkloadDriver::new_concurrency(trace)?;
let (collector, _) = DisaggRuntime::new_workload(
&config,
router_config,
driver,
ReplayMode::Concurrency { max_in_flight },
router_mode,
)?
.run()?;
Ok(collector.finish())
}
#[cfg(test)]
fn run_trace_collect(
config: &OfflineDisaggReplayConfig,
requests: Vec<DirectRequest>,
router_config: Option<KvRouterConfig>,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> (TraceCollector, DisaggRuntimeStats) {
let pending = normalize_trace_requests(requests, arrival_speedup_ratio).unwrap();
DisaggRuntime::new(
config,
router_config,
pending,
ReplayMode::Trace,
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
fn run_concurrency_collect(
config: &OfflineDisaggReplayConfig,
requests: Vec<DirectRequest>,
router_config: Option<KvRouterConfig>,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, DisaggRuntimeStats) {
DisaggRuntime::new(
config,
router_config,
VecDeque::from(requests),
ReplayMode::Concurrency { max_in_flight },
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::super::entrypoints::{run_concurrency_collect, run_trace_collect};
use super::*; use super::*;
use crate::common::protocols::{MockEngineArgs, WorkerType}; use crate::common::protocols::{MockEngineArgs, WorkerType};
...@@ -1093,8 +1022,8 @@ mod tests { ...@@ -1093,8 +1022,8 @@ mod tests {
); );
assert_eq!(stats.prefill_marked_count, 1); assert_eq!(stats.prefill_marked_count, 1);
assert_eq!(stats.prefill_freed_count, 1); assert_eq!(stats.prefill_router_freed_count, 1);
assert_eq!(stats.decode_freed_count, 1); assert_eq!(stats.decode_router_freed_count, 1);
} }
#[test] #[test]
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::VecDeque;
use anyhow::Result;
use dynamo_kv_router::config::KvRouterConfig;
#[cfg(test)]
use super::agg::AggRuntimeStats;
use super::agg::{AggRuntime, ReplayMode as AggReplayMode};
#[cfg(test)]
use super::disagg::DisaggRuntimeStats;
use super::disagg::{DisaggRuntime, ReplayMode as DisaggReplayMode};
use super::normalize_trace_requests;
use super::single::{SingleReplayMode, SingleRuntime};
use crate::common::protocols::{DirectRequest, EngineType, MockEngineArgs};
use crate::loadgen::{Trace, WorkloadDriver};
use crate::replay::OfflineDisaggReplayConfig;
#[cfg(test)]
use crate::replay::TraceCollector;
use crate::replay::{ReplayRouterMode, TraceSimulationReport};
pub(crate) fn simulate_trace(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == EngineType::Vllm {
simulate_trace_single(args, requests, arrival_speedup_ratio)
} else {
simulate_trace_multi(
args,
router_config,
requests,
num_workers,
arrival_speedup_ratio,
router_mode,
)
}
}
pub(crate) fn simulate_concurrency(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == EngineType::Vllm {
simulate_concurrency_single(args, requests, max_in_flight)
} else {
simulate_concurrency_multi(
args,
router_config,
requests,
max_in_flight,
num_workers,
router_mode,
)
}
}
pub(crate) fn simulate_trace_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == EngineType::Vllm {
simulate_trace_workload_single(args, trace)
} else {
simulate_trace_workload_multi(args, router_config, trace, num_workers, router_mode)
}
}
pub(crate) fn simulate_concurrency_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == EngineType::Vllm {
simulate_concurrency_workload_single(args, trace, max_in_flight)
} else {
simulate_concurrency_workload_multi(
args,
router_config,
trace,
max_in_flight,
num_workers,
router_mode,
)
}
}
pub(crate) fn simulate_trace_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let (collector, _) = DisaggRuntime::new(
&config,
router_config,
pending,
DisaggReplayMode::Trace,
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let pending = VecDeque::from(requests);
let (collector, _) = DisaggRuntime::new(
&config,
router_config,
pending,
DisaggReplayMode::Concurrency { max_in_flight },
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_trace_workload_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let driver = WorkloadDriver::new_trace(trace)?;
let (collector, _) = DisaggRuntime::new_workload(
&config,
router_config,
driver,
DisaggReplayMode::Trace,
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_workload_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let driver = WorkloadDriver::new_concurrency(trace)?;
let (collector, _) = DisaggRuntime::new_workload(
&config,
router_config,
driver,
DisaggReplayMode::Concurrency { max_in_flight },
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_trace_single(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let collector = SingleRuntime::new(args, pending, SingleReplayMode::Trace).run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_single(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = VecDeque::from(requests);
let collector = SingleRuntime::new(
args,
pending,
SingleReplayMode::Concurrency { max_in_flight },
)
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_trace_workload_single(
args: MockEngineArgs,
trace: Trace,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let collector =
SingleRuntime::new_workload(args, trace.into_trace_driver()?, SingleReplayMode::Trace)
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_workload_single(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let collector = SingleRuntime::new_workload(
args,
trace.into_concurrency_driver()?,
SingleReplayMode::Concurrency { max_in_flight },
)
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_trace_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let (collector, _) = AggRuntime::new(
&args,
router_config,
pending,
num_workers,
AggReplayMode::Trace,
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = VecDeque::from(requests);
let (collector, _) = AggRuntime::new(
&args,
router_config,
pending,
num_workers,
AggReplayMode::Concurrency { max_in_flight },
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_trace_workload_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let (collector, _) = AggRuntime::new_workload(
&args,
router_config,
trace.into_trace_driver()?,
num_workers,
AggReplayMode::Trace,
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_workload_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let (collector, _) = AggRuntime::new_workload(
&args,
router_config,
trace.into_concurrency_driver()?,
num_workers,
AggReplayMode::Concurrency { max_in_flight },
router_mode,
)?
.run()?;
Ok(collector.finish())
}
#[cfg(test)]
pub(super) fn run_trace_single_collect(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
) -> TraceCollector {
let pending = normalize_trace_requests(requests, arrival_speedup_ratio).unwrap();
SingleRuntime::new(args, pending, SingleReplayMode::Trace)
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_concurrency_single_collect(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
) -> TraceCollector {
SingleRuntime::new(
args,
VecDeque::from(requests),
SingleReplayMode::Concurrency { max_in_flight },
)
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_trace_workload_single_collect(
args: MockEngineArgs,
trace: Trace,
) -> TraceCollector {
SingleRuntime::new_workload(
args,
trace.into_trace_driver().unwrap(),
SingleReplayMode::Trace,
)
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_concurrency_workload_single_collect(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
) -> TraceCollector {
SingleRuntime::new_workload(
args,
trace.into_concurrency_driver().unwrap(),
SingleReplayMode::Concurrency { max_in_flight },
)
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_trace_multi_collect_with_stats(
args: &MockEngineArgs,
requests: Vec<DirectRequest>,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, AggRuntimeStats) {
let pending = normalize_trace_requests(requests, 1.0).unwrap();
AggRuntime::new(
args,
None,
pending,
num_workers,
AggReplayMode::Trace,
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_concurrency_multi_collect_with_stats(
args: &MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, AggRuntimeStats) {
AggRuntime::new(
args,
None,
VecDeque::from(requests),
num_workers,
AggReplayMode::Concurrency { max_in_flight },
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_trace_workload_multi_collect_with_stats(
args: &MockEngineArgs,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, AggRuntimeStats) {
AggRuntime::new_workload(
args,
None,
trace.into_trace_driver().unwrap(),
num_workers,
AggReplayMode::Trace,
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_concurrency_workload_multi_collect_with_stats(
args: &MockEngineArgs,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, AggRuntimeStats) {
AggRuntime::new_workload(
args,
None,
trace.into_concurrency_driver().unwrap(),
num_workers,
AggReplayMode::Concurrency { max_in_flight },
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_trace_collect(
config: &OfflineDisaggReplayConfig,
requests: Vec<DirectRequest>,
router_config: Option<KvRouterConfig>,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> (TraceCollector, DisaggRuntimeStats) {
let pending = normalize_trace_requests(requests, arrival_speedup_ratio).unwrap();
DisaggRuntime::new(
config,
router_config,
pending,
DisaggReplayMode::Trace,
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_concurrency_collect(
config: &OfflineDisaggReplayConfig,
requests: Vec<DirectRequest>,
router_config: Option<KvRouterConfig>,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, DisaggRuntimeStats) {
DisaggRuntime::new(
config,
router_config,
VecDeque::from(requests),
DisaggReplayMode::Concurrency { max_in_flight },
router_mode,
)
.unwrap()
.run()
.unwrap()
}
...@@ -34,17 +34,9 @@ pub(crate) struct SimulationEvent { ...@@ -34,17 +34,9 @@ pub(crate) struct SimulationEvent {
pub(crate) kind: SimulationEventKind, pub(crate) kind: SimulationEventKind,
} }
impl SimulationEvent {
fn kind_priority(&self) -> u8 {
0
}
}
impl PartialEq for SimulationEvent { impl PartialEq for SimulationEvent {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.at_ms.to_bits() == other.at_ms.to_bits() self.at_ms.to_bits() == other.at_ms.to_bits() && self.seq_no == other.seq_no
&& self.seq_no == other.seq_no
&& self.kind_priority() == other.kind_priority()
} }
} }
...@@ -62,7 +54,6 @@ impl Ord for SimulationEvent { ...@@ -62,7 +54,6 @@ impl Ord for SimulationEvent {
.at_ms .at_ms
.partial_cmp(&self.at_ms) .partial_cmp(&self.at_ms)
.unwrap_or(Ordering::Equal) .unwrap_or(Ordering::Equal)
.then_with(|| self.kind_priority().cmp(&other.kind_priority()))
.then_with(|| other.seq_no.cmp(&self.seq_no)) .then_with(|| other.seq_no.cmp(&self.seq_no))
} }
} }
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::loadgen::Trace;
use crate::replay::OfflineDisaggReplayConfig;
pub(crate) use crate::replay::normalize_trace_requests; pub(crate) use crate::replay::normalize_trace_requests;
use crate::replay::{ReplayRouterMode, TraceSimulationReport};
use dynamo_kv_router::config::KvRouterConfig;
pub(crate) mod agg;
pub(crate) mod core; pub(crate) mod core;
pub(crate) mod disagg; pub(crate) mod disagg;
mod entrypoints;
pub(crate) mod events; pub(crate) mod events;
pub(crate) mod multi;
pub(crate) mod runtime_utils; pub(crate) mod runtime_utils;
pub(crate) mod single; pub(crate) mod single;
pub(crate) mod state; pub(crate) mod state;
pub(crate) fn simulate_trace( pub(crate) use entrypoints::{
args: MockEngineArgs, simulate_concurrency, simulate_concurrency_disagg, simulate_concurrency_workload,
router_config: Option<KvRouterConfig>, simulate_concurrency_workload_disagg, simulate_trace, simulate_trace_disagg,
requests: Vec<DirectRequest>, simulate_trace_workload, simulate_trace_workload_disagg,
num_workers: usize, };
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == crate::common::protocols::EngineType::Vllm {
single::simulate_trace_single(args, requests, arrival_speedup_ratio)
} else {
multi::simulate_trace_multi(
args,
router_config,
requests,
num_workers,
arrival_speedup_ratio,
router_mode,
)
}
}
pub(crate) fn simulate_concurrency(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == crate::common::protocols::EngineType::Vllm {
single::simulate_concurrency_single(args, requests, max_in_flight)
} else {
multi::simulate_concurrency_multi(
args,
router_config,
requests,
max_in_flight,
num_workers,
router_mode,
)
}
}
pub(crate) fn simulate_trace_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == crate::common::protocols::EngineType::Vllm {
single::simulate_trace_workload_single(args, trace)
} else {
multi::simulate_trace_workload_multi(args, router_config, trace, num_workers, router_mode)
}
}
pub(crate) fn simulate_concurrency_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == crate::common::protocols::EngineType::Vllm {
single::simulate_concurrency_workload_single(args, trace, max_in_flight)
} else {
multi::simulate_concurrency_workload_multi(
args,
router_config,
trace,
max_in_flight,
num_workers,
router_mode,
)
}
}
pub(crate) fn simulate_trace_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
disagg::simulate_trace_disagg(
config,
router_config,
requests,
arrival_speedup_ratio,
router_mode,
)
}
pub(crate) fn simulate_concurrency_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
disagg::simulate_concurrency_disagg(config, router_config, requests, max_in_flight, router_mode)
}
pub(crate) fn simulate_trace_workload_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
disagg::simulate_trace_workload_disagg(config, router_config, trace, router_mode)
}
pub(crate) fn simulate_concurrency_workload_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
disagg::simulate_concurrency_workload_disagg(
config,
router_config,
trace,
max_in_flight,
router_mode,
)
}
...@@ -2,16 +2,15 @@ ...@@ -2,16 +2,15 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::core::ReplayWorkerCore; use super::core::ReplayWorkerCore;
use super::normalize_trace_requests;
use crate::common::protocols::{DirectRequest, MockEngineArgs}; use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::loadgen::{Trace, WorkloadDriver}; use crate::loadgen::WorkloadDriver;
use crate::replay::{TraceCollector, TraceSimulationReport}; use crate::replay::TraceCollector;
use anyhow::bail; use anyhow::bail;
use std::collections::VecDeque; use std::collections::VecDeque;
use uuid::Uuid; use uuid::Uuid;
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
enum SingleReplayMode { pub(super) enum SingleReplayMode {
Trace, Trace,
Concurrency { max_in_flight: usize }, Concurrency { max_in_flight: usize },
} }
...@@ -21,7 +20,7 @@ enum AdmissionSource { ...@@ -21,7 +20,7 @@ enum AdmissionSource {
Workload(WorkloadDriver), Workload(WorkloadDriver),
} }
struct SingleRuntime { pub(super) struct SingleRuntime {
current_time_ms: f64, current_time_ms: f64,
admission: AdmissionSource, admission: AdmissionSource,
worker: ReplayWorkerCore, worker: ReplayWorkerCore,
...@@ -30,11 +29,19 @@ struct SingleRuntime { ...@@ -30,11 +29,19 @@ struct SingleRuntime {
} }
impl SingleRuntime { impl SingleRuntime {
fn new(args: MockEngineArgs, pending: VecDeque<DirectRequest>, mode: SingleReplayMode) -> Self { pub(super) fn new(
args: MockEngineArgs,
pending: VecDeque<DirectRequest>,
mode: SingleReplayMode,
) -> Self {
Self::new_with_source(args, AdmissionSource::Requests(pending), mode) Self::new_with_source(args, AdmissionSource::Requests(pending), mode)
} }
fn new_workload(args: MockEngineArgs, driver: WorkloadDriver, mode: SingleReplayMode) -> Self { pub(super) fn new_workload(
args: MockEngineArgs,
driver: WorkloadDriver,
mode: SingleReplayMode,
) -> Self {
Self::new_with_source(args, AdmissionSource::Workload(driver), mode) Self::new_with_source(args, AdmissionSource::Workload(driver), mode)
} }
...@@ -166,7 +173,7 @@ impl SingleRuntime { ...@@ -166,7 +173,7 @@ impl SingleRuntime {
} }
} }
fn run(mut self) -> anyhow::Result<TraceCollector> { pub(super) fn run(mut self) -> anyhow::Result<TraceCollector> {
while !self.is_done() { while !self.is_done() {
match self.mode { match self.mode {
SingleReplayMode::Trace => { SingleReplayMode::Trace => {
...@@ -196,119 +203,14 @@ impl SingleRuntime { ...@@ -196,119 +203,14 @@ impl SingleRuntime {
} }
} }
pub(crate) fn simulate_trace_single(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let collector = SingleRuntime::new(args, pending, SingleReplayMode::Trace).run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_single(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = VecDeque::from(requests);
let collector = SingleRuntime::new(
args,
pending,
SingleReplayMode::Concurrency { max_in_flight },
)
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_trace_workload_single(
args: MockEngineArgs,
trace: Trace,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let collector =
SingleRuntime::new_workload(args, trace.into_trace_driver()?, SingleReplayMode::Trace)
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_workload_single(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let collector = SingleRuntime::new_workload(
args,
trace.into_concurrency_driver()?,
SingleReplayMode::Concurrency { max_in_flight },
)
.run()?;
Ok(collector.finish())
}
#[cfg(test)]
pub(super) fn run_trace_single_collect(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
) -> TraceCollector {
let pending = normalize_trace_requests(requests, arrival_speedup_ratio).unwrap();
SingleRuntime::new(args, pending, SingleReplayMode::Trace)
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_concurrency_single_collect(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
) -> TraceCollector {
SingleRuntime::new(
args,
VecDeque::from(requests),
SingleReplayMode::Concurrency { max_in_flight },
)
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_trace_workload_single_collect(
args: MockEngineArgs,
trace: Trace,
) -> TraceCollector {
SingleRuntime::new_workload(
args,
trace.into_trace_driver().unwrap(),
SingleReplayMode::Trace,
)
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_concurrency_workload_single_collect(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
) -> TraceCollector {
SingleRuntime::new_workload(
args,
trace.into_concurrency_driver().unwrap(),
SingleReplayMode::Concurrency { max_in_flight },
)
.run()
.unwrap()
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::super::entrypoints::{
run_concurrency_workload_single_collect, run_trace_workload_single_collect,
simulate_concurrency_single, simulate_trace_single,
};
use super::*; use super::*;
use crate::loadgen::{SessionTrace, TurnTrace}; use crate::loadgen::{SessionTrace, Trace, TurnTrace};
use crate::replay::{TraceRequestStatsSnapshot, TraceSimulationReport}; use crate::replay::{TraceRequestStatsSnapshot, TraceSimulationReport};
use rstest::rstest; use rstest::rstest;
use std::collections::{HashMap, VecDeque}; use std::collections::{HashMap, VecDeque};
......
...@@ -13,11 +13,53 @@ use crate::scheduler::AdmissionEvent; ...@@ -13,11 +13,53 @@ use crate::scheduler::AdmissionEvent;
use super::state::{ArrivalEvent, RequestRegistry, SharedLiveRuntimeStats, now_ms}; use super::state::{ArrivalEvent, RequestRegistry, SharedLiveRuntimeStats, now_ms};
async fn process_output_signal(
output: OutputSignal,
batch_time_ms: f64,
collector: &mut TraceCollector,
requests: &RequestRegistry,
router: &ReplayRouter,
stats: &SharedLiveRuntimeStats,
) {
collector.on_token(output.uuid, batch_time_ms);
let Some(state) = requests.get(&output.uuid) else {
return;
};
if state.mark_first_token_once() {
match router.on_first_token(output.uuid).await {
Ok(true) => stats.record_prefill_marked(),
Ok(false) => {}
Err(error) => tracing::warn!(
uuid = %output.uuid,
error = %error,
"online replay failed to mark prefill completed"
),
}
}
if !output.completed || !state.mark_completed_once() {
return;
}
match router.on_complete(output.uuid).await {
Ok(true) => stats.record_freed(),
Ok(false) => {}
Err(error) => tracing::warn!(
uuid = %output.uuid,
error = %error,
"online replay failed to free completed request"
),
}
state.notify_completion();
}
pub(super) async fn run_demux( pub(super) async fn run_demux(
start: Instant, start: Instant,
mut arrival_rx: mpsc::UnboundedReceiver<ArrivalEvent>, mut arrival_rx: mpsc::UnboundedReceiver<ArrivalEvent>,
mut admission_rx: mpsc::UnboundedReceiver<AdmissionEvent>, mut admission_rx: mpsc::UnboundedReceiver<AdmissionEvent>,
mut output_rx: mpsc::UnboundedReceiver<OutputSignal>, mut output_rx: mpsc::UnboundedReceiver<Vec<OutputSignal>>,
requests: RequestRegistry, requests: RequestRegistry,
router: Arc<ReplayRouter>, router: Arc<ReplayRouter>,
stats: Arc<SharedLiveRuntimeStats>, stats: Arc<SharedLiveRuntimeStats>,
...@@ -55,33 +97,18 @@ pub(super) async fn run_demux( ...@@ -55,33 +97,18 @@ pub(super) async fn run_demux(
} }
output = output_rx.recv(), if outputs_open => { output = output_rx.recv(), if outputs_open => {
match output { match output {
Some(output) => { Some(output_batch) => {
collector.on_token(output.uuid, now_ms(start)); let batch_time_ms = now_ms(start);
if let Some(state) = requests.get(&output.uuid) { for output in output_batch {
if state.mark_first_token_once() { process_output_signal(
match router.on_first_token(output.uuid).await { output,
Ok(true) => stats.record_prefill_marked(), batch_time_ms,
Ok(false) => {} &mut collector,
Err(error) => tracing::warn!( &requests,
uuid = %output.uuid, &router,
error = %error, &stats,
"online replay failed to mark prefill completed" )
), .await;
}
}
if output.completed && state.mark_completed_once() {
match router.on_complete(output.uuid).await {
Ok(true) => stats.record_freed(),
Ok(false) => {}
Err(error) => tracing::warn!(
uuid = %output.uuid,
error = %error,
"online replay failed to free completed request"
),
}
state.notify_completion();
}
} }
} }
None => outputs_open = false, None => outputs_open = false,
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::VecDeque;
use anyhow::{Result, anyhow, bail};
use dynamo_kv_router::config::KvRouterConfig;
use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::loadgen::{Trace, WorkloadDriver};
use crate::replay::{ReplayRouterMode, TraceSimulationReport, normalize_trace_requests};
use super::live_runtime::LiveRuntime;
use super::state::{LiveReplayMode, LiveRuntimeStats};
fn total_turns(trace: &Trace) -> usize {
trace
.sessions
.iter()
.map(|session| session.turns.len())
.sum()
}
fn run_live_runtime(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
pending: VecDeque<DirectRequest>,
num_workers: usize,
mode: LiveReplayMode,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.map_err(|e| anyhow!("failed to create online replay runtime: {e}"))?;
runtime.block_on(async move {
LiveRuntime::new(args, router_config, pending, num_workers, mode, router_mode)?
.run()
.await
})
}
fn run_live_workload_runtime(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
driver: WorkloadDriver,
total_turns: usize,
num_workers: usize,
mode: LiveReplayMode,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.map_err(|e| anyhow!("failed to create online replay runtime: {e}"))?;
runtime.block_on(async move {
LiveRuntime::new(
args,
router_config,
VecDeque::new(),
num_workers,
mode,
router_mode,
)?
.run_workload(driver, total_turns)
.await
})
}
pub(crate) fn simulate_trace_requests(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let (report, _) = run_live_runtime(
args,
router_config,
pending,
num_workers,
LiveReplayMode::Trace,
router_mode,
)?;
Ok(report)
}
pub(crate) fn simulate_concurrency_requests(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
if requests.is_empty() {
bail!("online concurrency replay requires at least one request");
}
let pending = VecDeque::from(requests);
let (report, _) = run_live_runtime(
args,
router_config,
pending,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)?;
Ok(report)
}
pub(crate) fn simulate_trace_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let total_turns = total_turns(&trace);
let (report, _) = run_live_workload_runtime(
args,
router_config,
trace.into_trace_driver()?,
total_turns,
num_workers,
LiveReplayMode::Trace,
router_mode,
)?;
Ok(report)
}
pub(crate) fn simulate_concurrency_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let total_turns = total_turns(&trace);
let (report, _) = run_live_workload_runtime(
args,
router_config,
trace.into_concurrency_driver()?,
total_turns,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)?;
Ok(report)
}
#[cfg(test)]
pub(super) fn simulate_trace_requests_with_stats(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
run_live_runtime(
args,
None,
pending,
num_workers,
LiveReplayMode::Trace,
router_mode,
)
}
#[cfg(test)]
pub(super) fn simulate_concurrency_requests_with_stats(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let pending = VecDeque::from(requests);
run_live_runtime(
args,
None,
pending,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)
}
#[cfg(test)]
pub(super) fn simulate_trace_workload_with_stats(
args: MockEngineArgs,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let total_turns = total_turns(&trace);
run_live_workload_runtime(
args,
None,
trace.into_trace_driver()?,
total_turns,
num_workers,
LiveReplayMode::Trace,
router_mode,
)
}
#[cfg(test)]
pub(super) fn simulate_concurrency_workload_with_stats(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let total_turns = total_turns(&trace);
run_live_workload_runtime(
args,
None,
trace.into_concurrency_driver()?,
total_turns,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::collections::VecDeque;
use std::sync::Arc; use std::sync::Arc;
use anyhow::{Result, anyhow, bail}; use anyhow::{Result, anyhow};
use dashmap::DashMap; use dashmap::DashMap;
use dynamo_kv_router::config::KvRouterConfig; use dynamo_kv_router::config::KvRouterConfig;
use tokio::sync::{Notify, Semaphore, mpsc}; use tokio::sync::{Notify, Semaphore, mpsc};
...@@ -13,9 +12,9 @@ use tokio::time::Instant; ...@@ -13,9 +12,9 @@ use tokio::time::Instant;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal}; use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal};
use crate::loadgen::{Trace, WorkloadDriver}; use crate::loadgen::WorkloadDriver;
use crate::replay::router::ReplayRouter; use crate::replay::router::ReplayRouter;
use crate::replay::{ReplayRouterMode, TraceSimulationReport, normalize_trace_requests}; use crate::replay::{ReplayRouterMode, TraceSimulationReport};
use crate::scheduler::{AdmissionEvent, EngineScheduler, SchedulerHandle}; use crate::scheduler::{AdmissionEvent, EngineScheduler, SchedulerHandle};
use super::demux::run_demux; use super::demux::run_demux;
...@@ -25,11 +24,11 @@ use super::state::{ ...@@ -25,11 +24,11 @@ use super::state::{
}; };
use super::task::{RequestTaskContext, run_request_task, wait_for_workload_progress}; use super::task::{RequestTaskContext, run_request_task, wait_for_workload_progress};
struct LiveRuntime { pub(super) struct LiveRuntime {
pending: VecDeque<DirectRequest>, pending: std::collections::VecDeque<DirectRequest>,
senders: Arc<[mpsc::UnboundedSender<DirectRequest>]>, senders: Arc<[mpsc::UnboundedSender<DirectRequest>]>,
schedulers: Vec<EngineScheduler>, schedulers: Vec<EngineScheduler>,
output_rx: mpsc::UnboundedReceiver<OutputSignal>, output_rx: mpsc::UnboundedReceiver<Vec<OutputSignal>>,
admission_rx: mpsc::UnboundedReceiver<AdmissionEvent>, admission_rx: mpsc::UnboundedReceiver<AdmissionEvent>,
cancel_token: CancellationToken, cancel_token: CancellationToken,
start: Instant, start: Instant,
...@@ -38,16 +37,17 @@ struct LiveRuntime { ...@@ -38,16 +37,17 @@ struct LiveRuntime {
} }
impl LiveRuntime { impl LiveRuntime {
fn new( /// Build the shared router, worker schedulers, and demux inputs for one live replay run.
pub(super) fn new(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
pending: VecDeque<DirectRequest>, pending: std::collections::VecDeque<DirectRequest>,
num_workers: usize, num_workers: usize,
mode: LiveReplayMode, mode: LiveReplayMode,
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
) -> Result<Self> { ) -> Result<Self> {
let cancel_token = CancellationToken::new(); let cancel_token = CancellationToken::new();
let (output_tx, output_rx) = mpsc::unbounded_channel(); let (output_tx, output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let (admission_tx, admission_rx) = mpsc::unbounded_channel(); let (admission_tx, admission_rx) = mpsc::unbounded_channel();
let router = Arc::new(ReplayRouter::new( let router = Arc::new(ReplayRouter::new(
router_mode, router_mode,
...@@ -70,10 +70,6 @@ impl LiveRuntime { ...@@ -70,10 +70,6 @@ impl LiveRuntime {
senders.push(scheduler.request_sender()); senders.push(scheduler.request_sender());
schedulers.push(scheduler); schedulers.push(scheduler);
} }
drop(output_tx);
drop(admission_tx);
Ok(Self { Ok(Self {
pending, pending,
senders: Arc::from(senders), senders: Arc::from(senders),
...@@ -87,7 +83,8 @@ impl LiveRuntime { ...@@ -87,7 +83,8 @@ impl LiveRuntime {
}) })
} }
async fn run(mut self) -> Result<(TraceSimulationReport, LiveRuntimeStats)> { /// Replay a finite queue of requests and return the final trace report plus debug stats.
pub(super) async fn run(mut self) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let requests = Arc::new(DashMap::with_capacity(self.pending.len())); let requests = Arc::new(DashMap::with_capacity(self.pending.len()));
let stats = Arc::new(SharedLiveRuntimeStats::default()); let stats = Arc::new(SharedLiveRuntimeStats::default());
let (arrival_tx, arrival_rx) = mpsc::unbounded_channel(); let (arrival_tx, arrival_rx) = mpsc::unbounded_channel();
...@@ -160,7 +157,8 @@ impl LiveRuntime { ...@@ -160,7 +157,8 @@ impl LiveRuntime {
Ok((report, stats.snapshot())) Ok((report, stats.snapshot()))
} }
async fn run_workload( /// Drive a multi-turn workload driver until it is drained and all spawned request tasks finish.
pub(super) async fn run_workload(
mut self, mut self,
driver: WorkloadDriver, driver: WorkloadDriver,
total_turns: usize, total_turns: usize,
...@@ -283,237 +281,3 @@ impl LiveRuntime { ...@@ -283,237 +281,3 @@ impl LiveRuntime {
Ok((report, stats.snapshot())) Ok((report, stats.snapshot()))
} }
} }
fn run_live_runtime(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
pending: VecDeque<DirectRequest>,
num_workers: usize,
mode: LiveReplayMode,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.map_err(|e| anyhow!("failed to create online replay runtime: {e}"))?;
runtime.block_on(async move {
LiveRuntime::new(args, router_config, pending, num_workers, mode, router_mode)?
.run()
.await
})
}
fn run_live_workload_runtime(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
driver: WorkloadDriver,
total_turns: usize,
num_workers: usize,
mode: LiveReplayMode,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.map_err(|e| anyhow!("failed to create online replay runtime: {e}"))?;
runtime.block_on(async move {
LiveRuntime::new(
args,
router_config,
VecDeque::new(),
num_workers,
mode,
router_mode,
)?
.run_workload(driver, total_turns)
.await
})
}
pub(crate) fn simulate_trace_requests(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let (report, _) = run_live_runtime(
args,
router_config,
pending,
num_workers,
LiveReplayMode::Trace,
router_mode,
)?;
Ok(report)
}
pub(crate) fn simulate_concurrency_requests(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
if requests.is_empty() {
bail!("online concurrency replay requires at least one request");
}
let pending = VecDeque::from(requests);
let (report, _) = run_live_runtime(
args,
router_config,
pending,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)?;
Ok(report)
}
pub(crate) fn simulate_trace_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let total_turns = trace
.sessions
.iter()
.map(|session| session.turns.len())
.sum();
let (report, _) = run_live_workload_runtime(
args,
router_config,
trace.into_trace_driver()?,
total_turns,
num_workers,
LiveReplayMode::Trace,
router_mode,
)?;
Ok(report)
}
pub(crate) fn simulate_concurrency_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let total_turns = trace
.sessions
.iter()
.map(|session| session.turns.len())
.sum();
let (report, _) = run_live_workload_runtime(
args,
router_config,
trace.into_concurrency_driver()?,
total_turns,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)?;
Ok(report)
}
#[cfg(test)]
pub(super) fn simulate_trace_requests_with_stats(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
run_live_runtime(
args,
None,
pending,
num_workers,
LiveReplayMode::Trace,
router_mode,
)
}
#[cfg(test)]
pub(super) fn simulate_concurrency_requests_with_stats(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let pending = VecDeque::from(requests);
run_live_runtime(
args,
None,
pending,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)
}
#[cfg(test)]
pub(super) fn simulate_trace_workload_with_stats(
args: MockEngineArgs,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let total_turns = trace
.sessions
.iter()
.map(|session| session.turns.len())
.sum();
run_live_workload_runtime(
args,
None,
trace.into_trace_driver()?,
total_turns,
num_workers,
LiveReplayMode::Trace,
router_mode,
)
}
#[cfg(test)]
pub(super) fn simulate_concurrency_workload_with_stats(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let total_turns = trace
.sessions
.iter()
.map(|session| session.turns.len())
.sum();
run_live_workload_runtime(
args,
None,
trace.into_concurrency_driver()?,
total_turns,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)
}
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
mod demux; mod demux;
mod entrypoints;
mod live_runtime; mod live_runtime;
mod state; mod state;
mod task; mod task;
...@@ -9,7 +10,7 @@ mod task; ...@@ -9,7 +10,7 @@ mod task;
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;
pub(crate) use live_runtime::{ pub(crate) use entrypoints::{
simulate_concurrency_requests, simulate_concurrency_workload, simulate_trace_requests, simulate_concurrency_requests, simulate_concurrency_workload, simulate_trace_requests,
simulate_trace_workload, simulate_trace_workload,
}; };
...@@ -16,7 +16,7 @@ use crate::loadgen::{SessionTrace, Trace, TurnTrace}; ...@@ -16,7 +16,7 @@ use crate::loadgen::{SessionTrace, Trace, TurnTrace};
use crate::replay::ReplayRouterMode; use crate::replay::ReplayRouterMode;
use crate::replay::router::ReplayRouter; use crate::replay::router::ReplayRouter;
use super::live_runtime::{ use super::entrypoints::{
simulate_concurrency_requests_with_stats, simulate_concurrency_workload_with_stats, simulate_concurrency_requests_with_stats, simulate_concurrency_workload_with_stats,
simulate_trace_requests, simulate_trace_requests_with_stats, simulate_trace_requests, simulate_trace_requests_with_stats,
simulate_trace_workload_with_stats, simulate_trace_workload_with_stats,
......
...@@ -248,14 +248,14 @@ impl OfflineReplayRouter { ...@@ -248,14 +248,14 @@ impl OfflineReplayRouter {
pub(crate) fn mark_prefill_completed(&mut self, uuid: Uuid) -> Result<Vec<(Uuid, usize)>> { pub(crate) fn mark_prefill_completed(&mut self, uuid: Uuid) -> Result<Vec<(Uuid, usize)>> {
self.slots self.slots
.mark_prefill_completed_sync(&uuid.to_string()) .mark_prefill_completed(&uuid.to_string())
.map_err(anyhow::Error::from)?; .map_err(anyhow::Error::from)?;
self.drain_pending() self.drain_pending()
} }
pub(crate) fn free(&mut self, uuid: Uuid) -> Result<Vec<(Uuid, usize)>> { pub(crate) fn free(&mut self, uuid: Uuid) -> Result<Vec<(Uuid, usize)>> {
self.slots self.slots
.free_sync(&uuid.to_string()) .free(&uuid.to_string())
.map_err(anyhow::Error::from)?; .map_err(anyhow::Error::from)?;
self.drain_pending() self.drain_pending()
} }
...@@ -316,8 +316,6 @@ impl OfflineReplayRouter { ...@@ -316,8 +316,6 @@ impl OfflineReplayRouter {
} }
} }
pub(crate) fn shutdown(&mut self) {}
fn enqueue_key(&self, now_ms: f64, request: &PendingRequest) -> ReplayQueueKey { fn enqueue_key(&self, now_ms: f64, request: &PendingRequest) -> ReplayQueueKey {
let arrival_offset = Duration::from_secs_f64((now_ms.max(0.0)) / 1000.0); let arrival_offset = Duration::from_secs_f64((now_ms.max(0.0)) / 1000.0);
self.policy.enqueue_key( self.policy.enqueue_key(
...@@ -400,7 +398,7 @@ impl OfflineReplayRouter { ...@@ -400,7 +398,7 @@ impl OfflineReplayRouter {
let request_id = request.request_id(); let request_id = request.request_id();
self.slots self.slots
.add_request_sync(SequenceRequest { .add_request(SequenceRequest {
request_id, request_id,
token_sequence: request.token_seq, token_sequence: request.token_seq,
isl: request.isl_tokens, isl: request.isl_tokens,
......
...@@ -108,7 +108,7 @@ impl EngineScheduler { ...@@ -108,7 +108,7 @@ impl EngineScheduler {
pub(crate) fn new_with_admission( pub(crate) fn new_with_admission(
args: crate::common::protocols::MockEngineArgs, args: crate::common::protocols::MockEngineArgs,
dp_rank: u32, dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>, output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers, kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>, cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>, admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
......
...@@ -36,7 +36,7 @@ impl SglangScheduler { ...@@ -36,7 +36,7 @@ impl SglangScheduler {
pub fn new( pub fn new(
args: MockEngineArgs, args: MockEngineArgs,
dp_rank: u32, dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>, output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers, kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>, cancellation_token: Option<CancellationToken>,
) -> Self { ) -> Self {
...@@ -53,7 +53,7 @@ impl SglangScheduler { ...@@ -53,7 +53,7 @@ impl SglangScheduler {
pub(crate) fn new_with_admission( pub(crate) fn new_with_admission(
args: MockEngineArgs, args: MockEngineArgs,
dp_rank: u32, dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>, output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers, kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>, cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>, admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
...@@ -71,7 +71,7 @@ impl SglangScheduler { ...@@ -71,7 +71,7 @@ impl SglangScheduler {
fn new_internal( fn new_internal(
args: MockEngineArgs, args: MockEngineArgs,
dp_rank: u32, dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>, output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers, kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>, cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>, admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
...@@ -121,11 +121,12 @@ impl SglangScheduler { ...@@ -121,11 +121,12 @@ impl SglangScheduler {
if pass.router_event_visibility == RouterEventVisibility::PassEnd { if pass.router_event_visibility == RouterEventVisibility::PassEnd {
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain()); publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
} }
flush_output_signals(&output_tx, &pass.output_signals); let active_decode_blocks = pass.active_decode_blocks;
flush_output_signals(&output_tx, pass.output_signals);
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain()); publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
let _ = metrics_tx.send(MockerMetrics::new( let _ = metrics_tx.send(MockerMetrics::new(
dp_rank, dp_rank,
pass.active_decode_blocks, active_decode_blocks,
total_blocks, total_blocks,
)); ));
} }
...@@ -182,14 +183,16 @@ async fn receive_requests( ...@@ -182,14 +183,16 @@ async fn receive_requests(
} }
fn flush_output_signals( fn flush_output_signals(
output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>, output_tx: &Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
output_signals: &[OutputSignal], output_signals: Vec<OutputSignal>,
) { ) {
let Some(tx) = output_tx.as_ref() else { let Some(tx) = output_tx.as_ref() else {
return; return;
}; };
for signal in output_signals { if output_signals.is_empty() {
let _ = tx.send(signal.clone()); return;
} }
let _ = tx.send(output_signals);
} }
...@@ -94,7 +94,7 @@ mod scheduling { ...@@ -94,7 +94,7 @@ mod scheduling {
.build() .build()
.unwrap(); .unwrap();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>(); let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let scheduler = let scheduler =
SglangScheduler::new(args, 0, Some(output_tx), KvEventPublishers::default(), None); SglangScheduler::new(args, 0, Some(output_tx), KvEventPublishers::default(), None);
...@@ -117,8 +117,8 @@ mod scheduling { ...@@ -117,8 +117,8 @@ mod scheduling {
loop { loop {
tokio::select! { tokio::select! {
Some(_) = output_rx.recv() => { Some(output_batch) = output_rx.recv() => {
received += 1; received += output_batch.len();
if received >= expected_signals { if received >= expected_signals {
break; break;
} }
...@@ -535,7 +535,7 @@ mod core_behavior { ...@@ -535,7 +535,7 @@ mod core_behavior {
async fn assert_sglang_scheduler_completes_all( async fn assert_sglang_scheduler_completes_all(
scheduler: &SglangScheduler, scheduler: &SglangScheduler,
output_rx: &mut mpsc::UnboundedReceiver<OutputSignal>, output_rx: &mut mpsc::UnboundedReceiver<Vec<OutputSignal>>,
num_requests: usize, num_requests: usize,
prompt_len: usize, prompt_len: usize,
max_output_tokens: usize, max_output_tokens: usize,
...@@ -567,8 +567,8 @@ async fn assert_sglang_scheduler_completes_all( ...@@ -567,8 +567,8 @@ async fn assert_sglang_scheduler_completes_all(
loop { loop {
tokio::select! { tokio::select! {
biased; biased;
Some(_) = output_rx.recv() => { Some(output_batch) = output_rx.recv() => {
received_tokens += 1; received_tokens += output_batch.len();
if received_tokens >= expected_tokens { if received_tokens >= expected_tokens {
break; break;
} }
...@@ -604,7 +604,7 @@ mod router_events { ...@@ -604,7 +604,7 @@ mod router_events {
#[case] schedule_policy: &str, #[case] schedule_policy: &str,
#[case] page_size: usize, #[case] page_size: usize,
) { ) {
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>(); let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let args = MockEngineArgs::builder() let args = MockEngineArgs::builder()
.num_gpu_blocks(500) .num_gpu_blocks(500)
.block_size(64) .block_size(64)
...@@ -818,7 +818,7 @@ mod router_events { ...@@ -818,7 +818,7 @@ mod router_events {
let harness = RouterIndexerHarness::new(4, ROUTER_TEST_WORKER_ID); let harness = RouterIndexerHarness::new(4, ROUTER_TEST_WORKER_ID);
let (sink, forward_task) = harness.spawn_forwarder(); let (sink, forward_task) = harness.spawn_forwarder();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>(); let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let scheduler = SglangScheduler::new( let scheduler = SglangScheduler::new(
MockEngineArgs::builder() MockEngineArgs::builder()
.engine_type(EngineType::Sglang) .engine_type(EngineType::Sglang)
...@@ -849,8 +849,8 @@ mod router_events { ...@@ -849,8 +849,8 @@ mod router_events {
loop { loop {
tokio::select! { tokio::select! {
Some(_) = output_rx.recv() => { Some(output_batch) = output_rx.recv() => {
seen += 1; seen += output_batch.len();
if seen == expected { if seen == expected {
break; break;
} }
......
...@@ -215,7 +215,7 @@ pub(crate) fn removed_event_count(events: &[RouterEvent]) -> usize { ...@@ -215,7 +215,7 @@ pub(crate) fn removed_event_count(events: &[RouterEvent]) -> usize {
/// common prefix to exercise prefix caching / radix tree reuse. /// common prefix to exercise prefix caching / radix tree reuse.
pub(crate) async fn assert_scheduler_completes_all( pub(crate) async fn assert_scheduler_completes_all(
scheduler: &dyn SchedulerHandle, scheduler: &dyn SchedulerHandle,
output_rx: &mut mpsc::UnboundedReceiver<OutputSignal>, output_rx: &mut mpsc::UnboundedReceiver<Vec<OutputSignal>>,
num_requests: usize, num_requests: usize,
input_len: usize, input_len: usize,
max_output_tokens: usize, max_output_tokens: usize,
...@@ -260,8 +260,8 @@ pub(crate) async fn assert_scheduler_completes_all( ...@@ -260,8 +260,8 @@ pub(crate) async fn assert_scheduler_completes_all(
loop { loop {
tokio::select! { tokio::select! {
biased; biased;
Some(_) = output_rx.recv() => { Some(output_batch) = output_rx.recv() => {
received_tokens += 1; received_tokens += output_batch.len();
if received_tokens >= expected_tokens { if received_tokens >= expected_tokens {
break; break;
} }
......
...@@ -63,7 +63,7 @@ impl Scheduler { ...@@ -63,7 +63,7 @@ impl Scheduler {
pub fn new( pub fn new(
args: MockEngineArgs, args: MockEngineArgs,
dp_rank: u32, dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>, output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers, kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>, cancellation_token: Option<CancellationToken>,
) -> Self { ) -> Self {
...@@ -80,7 +80,7 @@ impl Scheduler { ...@@ -80,7 +80,7 @@ impl Scheduler {
pub(crate) fn new_with_admission( pub(crate) fn new_with_admission(
args: MockEngineArgs, args: MockEngineArgs,
dp_rank: u32, dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>, output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers, kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>, cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>, admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
...@@ -98,7 +98,7 @@ impl Scheduler { ...@@ -98,7 +98,7 @@ impl Scheduler {
fn new_internal( fn new_internal(
args: MockEngineArgs, args: MockEngineArgs,
dp_rank: u32, dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>, output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers, kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>, cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>, admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
...@@ -137,7 +137,7 @@ impl Scheduler { ...@@ -137,7 +137,7 @@ impl Scheduler {
if pass.router_event_visibility == RouterEventVisibility::PassEnd { if pass.router_event_visibility == RouterEventVisibility::PassEnd {
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain()); publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
} }
flush_output_signals(&mut core, &output_tx, &pass.output_signals); flush_output_signals(&mut core, &output_tx, pass.output_signals);
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain()); publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
let _ = metrics_tx.send(MockerMetrics::new( let _ = metrics_tx.send(MockerMetrics::new(
dp_rank, dp_rank,
...@@ -198,17 +198,20 @@ async fn receive_requests( ...@@ -198,17 +198,20 @@ async fn receive_requests(
fn flush_output_signals( fn flush_output_signals(
core: &mut VllmCore, core: &mut VllmCore,
output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>, output_tx: &Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
output_signals: &[OutputSignal], output_signals: Vec<OutputSignal>,
) { ) {
let Some(tx) = output_tx.as_ref() else { let Some(tx) = output_tx.as_ref() else {
return; return;
}; };
for signal in output_signals { if output_signals.is_empty() {
if tx.send(signal.clone()).is_ok() { return;
continue; }
if let Err(error) = tx.send(output_signals) {
for signal in error.0 {
core.drop_request(signal.uuid);
} }
core.drop_request(signal.uuid);
} }
} }
...@@ -506,7 +506,7 @@ mod live_scheduler { ...@@ -506,7 +506,7 @@ mod live_scheduler {
#[case] enable_prefix_caching: bool, #[case] enable_prefix_caching: bool,
#[case] enable_chunked_prefill: bool, #[case] enable_chunked_prefill: bool,
) { ) {
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>(); let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let args = MockEngineArgs::builder() let args = MockEngineArgs::builder()
.num_gpu_blocks(500) .num_gpu_blocks(500)
...@@ -539,7 +539,7 @@ mod live_scheduler { ...@@ -539,7 +539,7 @@ mod live_scheduler {
let num_requests = 10; let num_requests = 10;
let token_length = 65; let token_length = 65;
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>(); let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let args = MockEngineArgs::builder() let args = MockEngineArgs::builder()
.num_gpu_blocks(100) .num_gpu_blocks(100)
...@@ -576,8 +576,8 @@ mod live_scheduler { ...@@ -576,8 +576,8 @@ mod live_scheduler {
let _metrics = metrics_rx.borrow().clone(); let _metrics = metrics_rx.borrow().clone();
tracing::debug!("Forward Pass Metrics: {_metrics:#?}"); tracing::debug!("Forward Pass Metrics: {_metrics:#?}");
} }
Some(_signal) = output_rx.recv() => { Some(output_batch) = output_rx.recv() => {
received_tokens += 1; received_tokens += output_batch.len();
timeout.set(tokio::time::sleep(Duration::from_millis(500))); timeout.set(tokio::time::sleep(Duration::from_millis(500)));
} }
_ = &mut timeout => break, _ = &mut timeout => break,
...@@ -592,7 +592,7 @@ mod live_scheduler { ...@@ -592,7 +592,7 @@ mod live_scheduler {
#[tokio::test] #[tokio::test]
async fn test_receiver_drop_cleans_up_resources() { async fn test_receiver_drop_cleans_up_resources() {
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>(); let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let args = MockEngineArgs::builder() let args = MockEngineArgs::builder()
.num_gpu_blocks(10) .num_gpu_blocks(10)
.block_size(64) .block_size(64)
...@@ -612,8 +612,8 @@ mod live_scheduler { ...@@ -612,8 +612,8 @@ mod live_scheduler {
let mut received_count = 0; let mut received_count = 0;
while received_count < 129 { while received_count < 129 {
if output_rx.recv().await.is_some() { if let Some(output_batch) = output_rx.recv().await {
received_count += 1; received_count += output_batch.len();
continue; continue;
} }
panic!("Channel closed before receiving 129 tokens"); panic!("Channel closed before receiving 129 tokens");
...@@ -639,7 +639,7 @@ mod live_scheduler { ...@@ -639,7 +639,7 @@ mod live_scheduler {
#[tokio::test] #[tokio::test]
async fn test_live_scheduler_forwards_buffered_kv_token_ids() { async fn test_live_scheduler_forwards_buffered_kv_token_ids() {
let sink = Arc::new(CapturingKvSink::default()); let sink = Arc::new(CapturingKvSink::default());
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>(); let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let args = MockEngineArgs::builder() let args = MockEngineArgs::builder()
.block_size(4) .block_size(4)
.num_gpu_blocks(12) .num_gpu_blocks(12)
...@@ -667,10 +667,14 @@ mod live_scheduler { ...@@ -667,10 +667,14 @@ mod live_scheduler {
arrival_timestamp_ms: None, arrival_timestamp_ms: None,
}); });
let signal = tokio::time::timeout(Duration::from_secs(2), output_rx.recv()) let output_batch = tokio::time::timeout(Duration::from_secs(2), output_rx.recv())
.await .await
.expect("scheduler should emit output") .expect("scheduler should emit output")
.expect("output channel should stay open"); .expect("output channel should stay open");
let signal = output_batch
.into_iter()
.next()
.expect("live scheduler should emit one output signal");
assert!(signal.completed); assert!(signal.completed);
tokio::time::sleep(Duration::from_millis(50)).await; tokio::time::sleep(Duration::from_millis(50)).await;
...@@ -691,7 +695,7 @@ mod live_scheduler { ...@@ -691,7 +695,7 @@ mod live_scheduler {
let harness = RouterIndexerHarness::new(4, ROUTER_TEST_WORKER_ID); let harness = RouterIndexerHarness::new(4, ROUTER_TEST_WORKER_ID);
let (sink, forward_task) = harness.spawn_forwarder(); let (sink, forward_task) = harness.spawn_forwarder();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>(); let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let scheduler = Scheduler::new( let scheduler = Scheduler::new(
MockEngineArgs::builder() MockEngineArgs::builder()
.block_size(4) .block_size(4)
...@@ -726,8 +730,8 @@ mod live_scheduler { ...@@ -726,8 +730,8 @@ mod live_scheduler {
loop { loop {
tokio::select! { tokio::select! {
Some(_) = output_rx.recv() => { Some(output_batch) = output_rx.recv() => {
seen += 1; seen += output_batch.len();
if seen == expected { if seen == expected {
break; break;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment