Unverified Commit 95a750f4 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore(replay): refactor offline components into cleaner lanes (#7866)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 210bbf5d
...@@ -6,12 +6,14 @@ mod collector; ...@@ -6,12 +6,14 @@ mod collector;
mod entrypoints; mod entrypoints;
pub(crate) mod offline; pub(crate) mod offline;
mod online; mod online;
mod router; mod router_shared;
mod validate; mod validate;
use std::collections::VecDeque; use std::collections::VecDeque;
use std::sync::Arc;
use crate::common::protocols::{DirectRequest, MockEngineArgs}; use crate::common::protocols::{DirectRequest, MockEngineArgs};
use dynamo_kv_router::PrefillLoadEstimator;
pub use artifacts::{ pub use artifacts::{
ReplayTimedKvEvent, ReplayTimedOutputSignal, ReplayTimedRequest, ReplayWorkerArtifacts, ReplayTimedKvEvent, ReplayTimedOutputSignal, ReplayTimedRequest, ReplayWorkerArtifacts,
...@@ -35,6 +37,8 @@ pub enum ReplayArgsMode { ...@@ -35,6 +37,8 @@ pub enum ReplayArgsMode {
Disagg, Disagg,
} }
pub type ReplayPrefillLoadEstimator = Arc<dyn PrefillLoadEstimator>;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct OfflineDisaggReplayConfig { pub struct OfflineDisaggReplayConfig {
pub prefill_args: MockEngineArgs, pub prefill_args: MockEngineArgs,
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
#[cfg(test)]
use super::components::OfflineRouterSnapshot;
pub(super) use super::components::ReplayMode;
use super::events::{SimulationEvent, SimulationWorkerStage}; use super::events::{SimulationEvent, SimulationWorkerStage};
use super::progress::ReplayProgress;
use super::runtime_utils::{ use super::runtime_utils::{
WorkerCompletionPayload, next_timestamp as choose_next_timestamp, pop_next_concurrency_ready, next_timestamp as choose_next_timestamp, pop_ready_worker_completion, push_worker_completion,
pop_next_trace_ready, pop_ready_worker_completion, push_worker_completion,
}; };
#[cfg(test)] #[cfg(test)]
use super::state::AggRequestPhase;
#[cfg(test)]
use super::state::OfflineWorkerSnapshot; use super::state::OfflineWorkerSnapshot;
use super::state::{AggRequestState, OfflineWorkerState}; use super::{
components::{
AdmissionQueue, EngineComponent, EngineEffects, EnginePassMode, OfflineReplayRouter,
ScheduledWorkerCompletion, WorkerAdmission,
},
state::AggRequestState,
};
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal}; use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal};
use crate::loadgen::{ReplayRequestHashes, WorkloadDriver}; use crate::loadgen::{ReplayRequestHashes, WorkloadDriver};
use crate::replay::router::OfflineReplayRouter; use crate::replay::{ReplayPrefillLoadEstimator, ReplayRouterMode, TraceCollector};
#[cfg(test)]
use crate::replay::router::OfflineRouterSnapshot;
use crate::replay::{ReplayRouterMode, TraceCollector};
use crate::scheduler::RouterEventVisibility;
use anyhow::bail; use anyhow::bail;
use dynamo_kv_router::config::KvRouterConfig; use dynamo_kv_router::config::KvRouterConfig;
use dynamo_kv_router::protocols::RouterEvent; use dynamo_kv_router::protocols::RouterEvent;
...@@ -25,17 +32,6 @@ use std::collections::HashMap; ...@@ -25,17 +32,6 @@ use std::collections::HashMap;
use std::collections::{BinaryHeap, VecDeque}; use std::collections::{BinaryHeap, VecDeque};
use uuid::Uuid; use uuid::Uuid;
#[derive(Debug, Clone, Copy)]
pub(super) enum ReplayMode {
Trace,
Concurrency { max_in_flight: usize },
}
enum AdmissionSource {
Requests(VecDeque<DirectRequest>),
Workload(WorkloadDriver),
}
#[cfg(test)] #[cfg(test)]
#[derive(Debug, Default, Clone, PartialEq, Eq)] #[derive(Debug, Default, Clone, PartialEq, Eq)]
pub(super) struct AggRuntimeStats { pub(super) struct AggRuntimeStats {
...@@ -67,13 +63,13 @@ pub(super) struct AggRuntime { ...@@ -67,13 +63,13 @@ pub(super) struct AggRuntime {
now_ms: f64, now_ms: f64,
next_worker_idx: usize, next_worker_idx: usize,
next_event_seq: u64, next_event_seq: u64,
admission: AdmissionSource, admission: AdmissionQueue,
requests: FxHashMap<Uuid, AggRequestState>, requests: FxHashMap<Uuid, AggRequestState>,
workers: Vec<OfflineWorkerState>, engine: EngineComponent,
collector: TraceCollector, collector: TraceCollector,
events: BinaryHeap<SimulationEvent>, events: BinaryHeap<SimulationEvent>,
mode: ReplayMode,
router: Option<OfflineReplayRouter>, router: Option<OfflineReplayRouter>,
progress: ReplayProgress,
stats: AggRuntimeStats, stats: AggRuntimeStats,
#[cfg(test)] #[cfg(test)]
worker_active_requests: Vec<Vec<Uuid>>, worker_active_requests: Vec<Vec<Uuid>>,
...@@ -86,6 +82,7 @@ impl AggRuntime { ...@@ -86,6 +82,7 @@ impl AggRuntime {
pub(super) fn new( pub(super) fn new(
args: &MockEngineArgs, args: &MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
pending: VecDeque<DirectRequest>, pending: VecDeque<DirectRequest>,
num_workers: usize, num_workers: usize,
mode: ReplayMode, mode: ReplayMode,
...@@ -94,9 +91,9 @@ impl AggRuntime { ...@@ -94,9 +91,9 @@ impl AggRuntime {
Self::new_with_source( Self::new_with_source(
args, args,
router_config, router_config,
AdmissionSource::Requests(pending), prefill_load_estimator,
AdmissionQueue::new_requests(pending, mode),
num_workers, num_workers,
mode,
router_mode, router_mode,
) )
} }
...@@ -105,6 +102,7 @@ impl AggRuntime { ...@@ -105,6 +102,7 @@ impl AggRuntime {
pub(super) fn new_workload( pub(super) fn new_workload(
args: &MockEngineArgs, args: &MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
driver: WorkloadDriver, driver: WorkloadDriver,
num_workers: usize, num_workers: usize,
mode: ReplayMode, mode: ReplayMode,
...@@ -113,9 +111,9 @@ impl AggRuntime { ...@@ -113,9 +111,9 @@ impl AggRuntime {
Self::new_with_source( Self::new_with_source(
args, args,
router_config, router_config,
AdmissionSource::Workload(driver), prefill_load_estimator,
AdmissionQueue::new_workload(driver, mode),
num_workers, num_workers,
mode,
router_mode, router_mode,
) )
} }
...@@ -124,19 +122,36 @@ impl AggRuntime { ...@@ -124,19 +122,36 @@ impl AggRuntime {
fn new_with_source( fn new_with_source(
args: &MockEngineArgs, args: &MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
admission: AdmissionSource, prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
admission: AdmissionQueue,
num_workers: usize, num_workers: usize,
mode: ReplayMode,
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
) -> anyhow::Result<Self> { ) -> anyhow::Result<Self> {
let args = args.clone().normalized()?; let args = args.clone().normalized()?;
let progress = ReplayProgress::new(admission.total_requests(), "offline replay");
let router = match router_mode { let router = match router_mode {
ReplayRouterMode::RoundRobin => None, ReplayRouterMode::RoundRobin => None,
ReplayRouterMode::KvRouter => { ReplayRouterMode::KvRouter => Some(OfflineReplayRouter::new(
Some(OfflineReplayRouter::new(&args, router_config, num_workers)?) &args,
} router_config,
prefill_load_estimator,
num_workers,
)?),
}; };
let capture_kv_events = router.is_some(); let capture_kv_events = router.is_some();
let engine = EngineComponent::new(
SimulationWorkerStage::Aggregated,
EnginePassMode::Visible,
(0..num_workers)
.map(|worker_idx| {
super::state::OfflineWorkerState::new(
worker_idx,
args.clone(),
capture_kv_events,
)
})
.collect(),
);
Ok(Self { Ok(Self {
now_ms: 0.0, now_ms: 0.0,
...@@ -144,15 +159,11 @@ impl AggRuntime { ...@@ -144,15 +159,11 @@ impl AggRuntime {
next_event_seq: 0, next_event_seq: 0,
admission, admission,
requests: FxHashMap::default(), requests: FxHashMap::default(),
workers: (0..num_workers) engine,
.map(|worker_idx| {
OfflineWorkerState::new(worker_idx, args.clone(), capture_kv_events)
})
.collect(),
collector: TraceCollector::default(), collector: TraceCollector::default(),
events: BinaryHeap::new(), events: BinaryHeap::new(),
mode,
router, router,
progress,
#[cfg(test)] #[cfg(test)]
stats: AggRuntimeStats::default(), stats: AggRuntimeStats::default(),
#[cfg(not(test))] #[cfg(not(test))]
...@@ -166,10 +177,7 @@ impl AggRuntime { ...@@ -166,10 +177,7 @@ impl AggRuntime {
/// Count all requests currently consuming cluster capacity, including router-queued ones. /// Count all requests currently consuming cluster capacity, including router-queued ones.
fn cluster_in_flight(&self) -> usize { fn cluster_in_flight(&self) -> usize {
self.workers self.engine.in_flight()
.iter()
.map(OfflineWorkerState::in_flight)
.sum::<usize>()
+ self + self
.router .router
.as_ref() .as_ref()
...@@ -203,7 +211,7 @@ impl AggRuntime { ...@@ -203,7 +211,7 @@ impl AggRuntime {
/// Pick the next worker in round-robin order. /// Pick the next worker in round-robin order.
fn next_worker(&mut self) -> usize { fn next_worker(&mut self) -> usize {
let worker_idx = self.next_worker_idx; let worker_idx = self.next_worker_idx;
self.next_worker_idx = (self.next_worker_idx + 1) % self.workers.len(); self.next_worker_idx = (self.next_worker_idx + 1) % self.engine.worker_count();
worker_idx worker_idx
} }
...@@ -220,14 +228,6 @@ impl AggRuntime { ...@@ -220,14 +228,6 @@ impl AggRuntime {
self.record_in_flight_peak(); self.record_in_flight_peak();
} }
/// Fail fast if a router admission points at a worker that does not exist.
fn validate_worker_idx(&self, worker_idx: usize) -> anyhow::Result<()> {
if worker_idx >= self.workers.len() {
bail!("offline replay selected unknown worker index {worker_idx}");
}
Ok(())
}
/// Deliver a request to a worker and update the runtime's bookkeeping for that assignment. /// Deliver a request to a worker and update the runtime's bookkeeping for that assignment.
fn dispatch_to_worker( fn dispatch_to_worker(
&mut self, &mut self,
...@@ -235,32 +235,19 @@ impl AggRuntime { ...@@ -235,32 +235,19 @@ impl AggRuntime {
uuid: Uuid, uuid: Uuid,
worker_idx: usize, worker_idx: usize,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
self.validate_worker_idx(worker_idx)?; self.engine.dispatch(worker_idx, request)?;
self.workers[worker_idx].receive_request(request);
self.record_dispatch(uuid, worker_idx); self.record_dispatch(uuid, worker_idx);
#[cfg(test)] #[cfg(test)]
self.worker_active_requests[worker_idx].push(uuid); self.worker_active_requests[worker_idx].push(uuid);
Ok(()) Ok(())
} }
/// Submit a request to the router and return an immediate admission when one is available.
fn submit_to_router(
&mut self,
request: &DirectRequest,
replay_hashes: Option<ReplayRequestHashes>,
) -> anyhow::Result<Option<usize>> {
let Some(router) = self.router.as_mut() else {
bail!("offline replay router submission requires an active router");
};
let maybe_worker_idx =
router.submit_request_with_hashes(request, replay_hashes, self.now_ms)?;
self.record_router_pending();
Ok(maybe_worker_idx)
}
/// Materialize router admissions into concrete worker dispatches. /// Materialize router admissions into concrete worker dispatches.
fn dispatch_router_admissions(&mut self, admissions: Vec<(Uuid, usize)>) -> anyhow::Result<()> { fn dispatch_router_admissions(
for (uuid, worker_idx) in admissions { &mut self,
admissions: Vec<WorkerAdmission>,
) -> anyhow::Result<()> {
for WorkerAdmission { uuid, worker_idx } in admissions {
let request = self let request = self
.requests .requests
.get_mut(&uuid) .get_mut(&uuid)
...@@ -282,7 +269,7 @@ impl AggRuntime { ...@@ -282,7 +269,7 @@ impl AggRuntime {
) -> anyhow::Result<Uuid> { ) -> anyhow::Result<Uuid> {
let uuid = request.uuid.unwrap_or_else(Uuid::new_v4); let uuid = request.uuid.unwrap_or_else(Uuid::new_v4);
request.uuid = Some(uuid); request.uuid = Some(uuid);
if matches!(self.mode, ReplayMode::Concurrency { .. }) { if matches!(self.admission.mode(), ReplayMode::Concurrency { .. }) {
request.arrival_timestamp_ms = Some(arrival_time_ms); request.arrival_timestamp_ms = Some(arrival_time_ms);
} }
...@@ -299,15 +286,17 @@ impl AggRuntime { ...@@ -299,15 +286,17 @@ impl AggRuntime {
self.dispatch_to_worker(request, uuid, worker_idx)?; self.dispatch_to_worker(request, uuid, worker_idx)?;
return Ok(uuid); return Ok(uuid);
} }
let maybe_worker_idx = self.submit_to_router(&request, replay_hashes)?; let queued_request = request.clone();
if let Some(worker_idx) = maybe_worker_idx {
self.requests.insert(uuid, AggRequestState::new_running());
self.dispatch_to_worker(request, uuid, worker_idx)?;
return Ok(uuid);
}
self.requests self.requests
.insert(uuid, AggRequestState::new_queued(request)); .insert(uuid, AggRequestState::new_queued(request));
let admissions = {
let router = self.router.as_mut().expect("router presence checked above");
router
.on_request_arrival(&queued_request, replay_hashes, self.now_ms)?
.admissions
};
self.record_router_pending();
self.dispatch_router_admissions(admissions)?;
self.record_in_flight_peak(); self.record_in_flight_peak();
Ok(uuid) Ok(uuid)
} }
...@@ -316,38 +305,17 @@ impl AggRuntime { ...@@ -316,38 +305,17 @@ impl AggRuntime {
fn is_done(&self) -> bool { fn is_done(&self) -> bool {
self.events.is_empty() self.events.is_empty()
&& self.cluster_in_flight() == 0 && self.cluster_in_flight() == 0
&& match &self.admission { && self.admission.is_drained()
AdmissionSource::Requests(pending) => pending.is_empty(), && self.engine.is_drained()
AdmissionSource::Workload(driver) => driver.is_drained(),
}
&& self.workers.iter().all(OfflineWorkerState::is_drained)
} }
/// Pick the next logical timestamp from either arrivals or scheduled worker completions. /// Pick the next logical timestamp from either arrivals or scheduled worker completions.
fn next_timestamp(&mut self) -> Option<f64> { fn next_timestamp(&mut self) -> Option<f64> {
let next_event_ms = self.events.peek().map(|event| event.at_ms); let next_event_ms = self.events.peek().map(|event| event.at_ms);
let cluster_in_flight = self.cluster_in_flight(); choose_next_timestamp(
let next_arrival_ms = match (&self.mode, &mut self.admission) { self.admission.next_ready_time_ms(self.cluster_in_flight()),
(ReplayMode::Trace, AdmissionSource::Requests(pending)) => pending next_event_ms,
.front() )
.and_then(|request| request.arrival_timestamp_ms),
(ReplayMode::Trace, AdmissionSource::Workload(driver)) => driver.next_ready_time_ms(),
(ReplayMode::Concurrency { max_in_flight }, AdmissionSource::Workload(driver)) => {
if cluster_in_flight < *max_in_flight {
driver.next_ready_time_ms()
} else {
None
}
}
(ReplayMode::Concurrency { .. }, AdmissionSource::Requests(_)) => None,
};
choose_next_timestamp(next_arrival_ms, next_event_ms)
}
/// Release completed requests from worker-local accounting after a pass finishes.
fn apply_completed_requests(&mut self, worker_idx: usize, completed_requests: usize) {
self.workers[worker_idx].mark_completed(completed_requests);
} }
/// Apply router-visible KV events at the phase chosen by the scheduler core. /// Apply router-visible KV events at the phase chosen by the scheduler core.
...@@ -355,8 +323,9 @@ impl AggRuntime { ...@@ -355,8 +323,9 @@ impl AggRuntime {
let Some(router) = self.router.as_mut() else { let Some(router) = self.router.as_mut() else {
return Ok(()); return Ok(());
}; };
for event in events { let effects = router.on_kv_events(events)?;
router.apply_event(event)?; if !effects.admissions.is_empty() {
bail!("offline replay router KV event application must not admit requests");
} }
Ok(()) Ok(())
} }
...@@ -368,7 +337,9 @@ impl AggRuntime { ...@@ -368,7 +337,9 @@ impl AggRuntime {
#[cfg(test)] #[cfg(test)]
self.remove_active_request(signal.uuid); self.remove_active_request(signal.uuid);
if let Some(router) = self.router.as_mut() { if let Some(router) = self.router.as_mut() {
admissions = router.free(signal.uuid)?; admissions = router
.on_request_completed(signal.uuid, self.now_ms)?
.admissions;
#[cfg(test)] #[cfg(test)]
{ {
self.stats.router_freed_count += 1; self.stats.router_freed_count += 1;
...@@ -378,9 +349,9 @@ impl AggRuntime { ...@@ -378,9 +349,9 @@ impl AggRuntime {
self.requests.remove(&signal.uuid).ok_or_else(|| { self.requests.remove(&signal.uuid).ok_or_else(|| {
anyhow::anyhow!("offline replay missing request state for {}", signal.uuid) anyhow::anyhow!("offline replay missing request state for {}", signal.uuid)
})?; })?;
if let AdmissionSource::Workload(driver) = &mut self.admission { self.admission
driver.on_complete(signal.uuid, self.now_ms)?; .on_request_completed(signal.uuid, self.now_ms)?;
} self.progress.inc_completed();
self.dispatch_router_admissions(admissions)?; self.dispatch_router_admissions(admissions)?;
return Ok(()); return Ok(());
} }
...@@ -391,7 +362,7 @@ impl AggRuntime { ...@@ -391,7 +362,7 @@ impl AggRuntime {
.ok_or_else(|| { .ok_or_else(|| {
anyhow::anyhow!("offline replay missing request state for {}", signal.uuid) anyhow::anyhow!("offline replay missing request state for {}", signal.uuid)
})? })?
.prefill_completed(); .prefill_completed;
if already_marked { if already_marked {
return Ok(()); return Ok(());
} }
...@@ -401,9 +372,11 @@ impl AggRuntime { ...@@ -401,9 +372,11 @@ impl AggRuntime {
.ok_or_else(|| { .ok_or_else(|| {
anyhow::anyhow!("offline replay missing request state for {}", signal.uuid) anyhow::anyhow!("offline replay missing request state for {}", signal.uuid)
})? })?
.mark_prefill_completed(); .prefill_completed = true;
if let Some(router) = self.router.as_mut() { if let Some(router) = self.router.as_mut() {
admissions = router.mark_prefill_completed(signal.uuid)?; admissions = router
.on_prefill_completed(signal.uuid, self.now_ms)?
.admissions;
#[cfg(test)] #[cfg(test)]
{ {
self.stats.prefill_marked_count += 1; self.stats.prefill_marked_count += 1;
...@@ -433,12 +406,11 @@ impl AggRuntime { ...@@ -433,12 +406,11 @@ impl AggRuntime {
/// Apply one completed pass: free request slots, publish KV events, and handle outputs. /// Apply one completed pass: free request slots, publish KV events, and handle outputs.
fn process_completed_pass( fn process_completed_pass(
&mut self, &mut self,
worker_idx: usize, _worker_idx: usize,
completed_requests: usize, _completed_requests: usize,
output_signals: Vec<OutputSignal>, output_signals: Vec<OutputSignal>,
kv_events: Vec<RouterEvent>, kv_events: Vec<RouterEvent>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
self.apply_completed_requests(worker_idx, completed_requests);
self.apply_router_events(kv_events)?; self.apply_router_events(kv_events)?;
for signal in output_signals { for signal in output_signals {
self.process_output_signal(signal)?; self.process_output_signal(signal)?;
...@@ -449,165 +421,71 @@ impl AggRuntime { ...@@ -449,165 +421,71 @@ impl AggRuntime {
/// Drain all worker-completion events scheduled for the current logical timestamp. /// Drain all worker-completion events scheduled for the current logical timestamp.
fn apply_worker_completions(&mut self) -> anyhow::Result<bool> { fn apply_worker_completions(&mut self) -> anyhow::Result<bool> {
let mut changed = false; let mut changed = false;
while let Some(WorkerCompletionPayload { while let Some(payload) = pop_ready_worker_completion(&mut self.events, self.now_ms) {
stage, debug_assert_eq!(payload.stage, SimulationWorkerStage::Aggregated);
worker_idx, let payload = self.engine.on_scheduled_completion(payload)?;
completed_requests, self.process_completed_pass(
output_signals, payload.worker_idx,
kv_events, payload.completed_requests,
}) = pop_ready_worker_completion(&mut self.events, self.now_ms) payload.output_signals,
{ payload.kv_events,
debug_assert_eq!(stage, SimulationWorkerStage::Aggregated); )?;
self.workers[worker_idx].mark_idle();
self.process_completed_pass(worker_idx, completed_requests, output_signals, kv_events)?;
changed = true; changed = true;
} }
Ok(changed) Ok(changed)
} }
/// Release every trace arrival whose timestamp is now visible to the global clock. /// Release every admission made ready by the shared admission queue.
fn release_trace_arrivals(&mut self) -> anyhow::Result<bool> { fn release_ready_arrivals(&mut self) -> anyhow::Result<bool> {
let mut released_any = false;
if matches!(self.admission, AdmissionSource::Requests(_)) {
loop {
let next_ready = match &mut self.admission {
AdmissionSource::Requests(pending) => {
pop_next_trace_ready(pending, self.now_ms)
}
AdmissionSource::Workload(_) => unreachable!(),
};
let Some((request, arrival_ms)) = next_ready else {
break;
};
self.assign_request(request, arrival_ms, None)?;
released_any = true;
}
return Ok(released_any);
}
let ready_requests = match &mut self.admission {
AdmissionSource::Requests(_) => unreachable!(),
AdmissionSource::Workload(driver) => driver.pop_ready(self.now_ms, usize::MAX),
};
for ready in ready_requests {
self.assign_request(
ready.request,
ready.scheduled_ready_at_ms,
ready.replay_hashes,
)?;
released_any = true;
}
Ok(released_any)
}
/// Backfill closed-loop concurrency replay until the configured in-flight limit is reached.
fn top_off_concurrency(&mut self, max_in_flight: usize) -> anyhow::Result<bool> {
let mut released_any = false; let mut released_any = false;
if matches!(self.admission, AdmissionSource::Requests(_)) { for ready in self
loop { .admission
let cluster_in_flight = self.cluster_in_flight(); .drain_ready(self.now_ms, self.cluster_in_flight())?
let next_ready = match &mut self.admission { {
AdmissionSource::Requests(pending) => pop_next_concurrency_ready( self.assign_request(ready.request, ready.arrival_time_ms, ready.replay_hashes)?;
pending,
self.now_ms,
cluster_in_flight,
max_in_flight,
),
AdmissionSource::Workload(_) => unreachable!(),
};
let Some((request, arrival_ms)) = next_ready else {
break;
};
self.assign_request(request, arrival_ms, None)?;
released_any = true;
}
return Ok(released_any);
}
let available = max_in_flight.saturating_sub(self.cluster_in_flight());
if available == 0 {
return Ok(false);
}
let ready_requests = match &mut self.admission {
AdmissionSource::Requests(_) => unreachable!(),
AdmissionSource::Workload(driver) => driver.pop_ready(self.now_ms, available),
};
for ready in ready_requests {
self.assign_request(ready.request, self.now_ms, ready.replay_hashes)?;
released_any = true; released_any = true;
} }
Ok(released_any) Ok(released_any)
} }
/// Start passes on every idle worker that can make progress at the current timestamp. /// Start passes on every idle worker that can make progress at the current timestamp.
fn drive_ready_workers(&mut self) -> anyhow::Result<bool> { fn drive_ready_workers(&mut self) -> anyhow::Result<bool> {
let mut changed = false; let mut changed = false;
for worker_idx in 0..self.workers.len() { loop {
loop { let effects = self
if !self.workers[worker_idx].is_ready() { .engine
break; .drive_ready(self.now_ms, Some(&mut self.collector))?;
} if effects.is_empty() {
return Ok(changed);
let executed = {
let (workers, collector) = (&mut self.workers, &mut self.collector);
workers[worker_idx].execute_pass(collector, self.now_ms)
};
changed = true;
let completion_kv_events =
if executed.router_event_visibility == RouterEventVisibility::PassStart {
self.apply_router_events(executed.kv_events)?;
Vec::new()
} else {
executed.kv_events
};
if executed.end_ms == self.now_ms {
self.process_completed_pass(
worker_idx,
executed.completed_requests,
executed.output_signals,
completion_kv_events,
)?;
continue;
}
self.workers[worker_idx].mark_busy();
push_worker_completion(
&mut self.events,
&mut self.next_event_seq,
executed.end_ms,
WorkerCompletionPayload {
stage: SimulationWorkerStage::Aggregated,
worker_idx,
completed_requests: executed.completed_requests,
output_signals: executed.output_signals,
kv_events: completion_kv_events,
},
);
break;
} }
changed = true;
self.handle_engine_effects(effects)?;
} }
}
Ok(changed) fn handle_engine_effects(&mut self, effects: EngineEffects) -> anyhow::Result<()> {
self.apply_router_events(effects.pass_start_kv_events)?;
for payload in effects.immediate_completions {
let payload = self.engine.on_scheduled_completion(payload)?;
self.process_completed_pass(
payload.worker_idx,
payload.completed_requests,
payload.output_signals,
payload.kv_events,
)?;
}
for ScheduledWorkerCompletion { at_ms, payload } in effects.scheduled_completions {
push_worker_completion(&mut self.events, &mut self.next_event_seq, at_ms, payload);
}
Ok(())
} }
/// Repeatedly process all work that becomes possible without advancing logical time. /// Repeatedly process all work that becomes possible without advancing logical time.
fn drain_current_timestamp(&mut self) -> anyhow::Result<()> { fn drain_current_timestamp(&mut self) -> anyhow::Result<()> {
loop { loop {
let mut changed = self.apply_worker_completions()?; let mut changed = self.apply_worker_completions()?;
changed |= self.release_ready_arrivals()?;
changed |= match self.mode {
ReplayMode::Trace => self.release_trace_arrivals()?,
ReplayMode::Concurrency { max_in_flight } => {
self.top_off_concurrency(max_in_flight)?
}
};
changed |= self.drive_ready_workers()?; changed |= self.drive_ready_workers()?;
if !changed { if !changed {
...@@ -634,6 +512,7 @@ impl AggRuntime { ...@@ -634,6 +512,7 @@ impl AggRuntime {
self.drain_current_timestamp()?; self.drain_current_timestamp()?;
} }
self.progress.finish();
Ok((self.collector, self.stats)) Ok((self.collector, self.stats))
} }
...@@ -668,14 +547,14 @@ impl AggRuntime { ...@@ -668,14 +547,14 @@ impl AggRuntime {
let mut router_pending_request_ids = self let mut router_pending_request_ids = self
.requests .requests
.iter() .iter()
.filter(|(_, state)| state.is_queued_at_router()) .filter(|(_, state)| state.phase == AggRequestPhase::QueuedAtRouter)
.map(|(uuid, _)| *uuid) .map(|(uuid, _)| *uuid)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
router_pending_request_ids.sort_unstable(); router_pending_request_ids.sort_unstable();
let mut prefill_completed = self let mut prefill_completed = self
.requests .requests
.iter() .iter()
.filter(|(_, state)| state.prefill_completed()) .filter(|(_, state)| state.prefill_completed)
.map(|(uuid, _)| *uuid) .map(|(uuid, _)| *uuid)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
prefill_completed.sort_unstable(); prefill_completed.sort_unstable();
...@@ -683,17 +562,13 @@ impl AggRuntime { ...@@ -683,17 +562,13 @@ impl AggRuntime {
AggRuntimeSnapshot { AggRuntimeSnapshot {
now_ms: self.now_ms, now_ms: self.now_ms,
worker_active_requests: self.worker_active_requests.clone(), worker_active_requests: self.worker_active_requests.clone(),
workers: self workers: self.engine.debug_snapshots(),
.workers
.iter()
.map(OfflineWorkerState::debug_snapshot)
.collect(),
router_pending_request_ids, router_pending_request_ids,
prefill_completed, prefill_completed,
router: self router: self
.router .router
.as_ref() .as_ref()
.map(OfflineReplayRouter::debug_snapshot), .map(|router| router.debug_snapshot(self.now_ms)),
} }
} }
} }
...@@ -960,6 +835,7 @@ mod tests { ...@@ -960,6 +835,7 @@ mod tests {
let mut runtime = AggRuntime::new( let mut runtime = AggRuntime::new(
&args, &args,
None, None,
None,
normalize_trace_requests( normalize_trace_requests(
vec![ vec![
DirectRequest { DirectRequest {
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::VecDeque;
use anyhow::Result;
use uuid::Uuid;
use super::{ReadyArrival, ReplayMode};
use crate::common::protocols::DirectRequest;
use crate::loadgen::WorkloadDriver;
enum AdmissionSource {
Requests(VecDeque<DirectRequest>),
Workload(WorkloadDriver),
}
pub(in crate::replay::offline) struct AdmissionQueue {
source: AdmissionSource,
mode: ReplayMode,
}
impl AdmissionQueue {
pub(in crate::replay::offline) fn new_requests(
source: VecDeque<DirectRequest>,
mode: ReplayMode,
) -> Self {
Self {
source: AdmissionSource::Requests(source),
mode,
}
}
pub(in crate::replay::offline) fn new_workload(
driver: WorkloadDriver,
mode: ReplayMode,
) -> Self {
Self {
source: AdmissionSource::Workload(driver),
mode,
}
}
pub(in crate::replay::offline) fn mode(&self) -> ReplayMode {
self.mode
}
pub(in crate::replay::offline) fn next_ready_time_ms(
&mut self,
cluster_in_flight: usize,
) -> Option<f64> {
match (&self.mode, &mut self.source) {
(ReplayMode::Trace, AdmissionSource::Requests(pending)) => pending
.front()
.and_then(|request| request.arrival_timestamp_ms),
(ReplayMode::Trace, AdmissionSource::Workload(driver)) => driver.next_ready_time_ms(),
(ReplayMode::Concurrency { max_in_flight }, AdmissionSource::Workload(driver)) => {
if cluster_in_flight < *max_in_flight {
driver.next_ready_time_ms()
} else {
None
}
}
(ReplayMode::Concurrency { .. }, AdmissionSource::Requests(_)) => None,
}
}
pub(in crate::replay::offline) fn drain_ready(
&mut self,
now_ms: f64,
cluster_in_flight: usize,
) -> Result<Vec<ReadyArrival>> {
match (&self.mode, &mut self.source) {
(ReplayMode::Trace, AdmissionSource::Requests(pending)) => {
let mut ready = Vec::new();
loop {
let arrival_ms = pending
.front()
.and_then(|request| request.arrival_timestamp_ms)
.filter(|arrival_ms| *arrival_ms <= now_ms);
let Some(arrival_time_ms) = arrival_ms else {
break;
};
let request = pending
.pop_front()
.expect("front request must exist when arrival is ready");
ready.push(ReadyArrival {
request,
arrival_time_ms,
replay_hashes: None,
});
}
Ok(ready)
}
(ReplayMode::Trace, AdmissionSource::Workload(driver)) => Ok(driver
.pop_ready(now_ms, usize::MAX)
.into_iter()
.map(|ready| ReadyArrival {
request: ready.request,
arrival_time_ms: ready.scheduled_ready_at_ms,
replay_hashes: ready.replay_hashes,
})
.collect()),
(ReplayMode::Concurrency { max_in_flight }, AdmissionSource::Requests(pending)) => {
let mut ready = Vec::new();
let mut simulated_in_flight = cluster_in_flight;
while simulated_in_flight < *max_in_flight {
let Some(request) = pending.pop_front() else {
break;
};
ready.push(ReadyArrival {
request,
arrival_time_ms: now_ms,
replay_hashes: None,
});
simulated_in_flight += 1;
}
Ok(ready)
}
(ReplayMode::Concurrency { max_in_flight }, AdmissionSource::Workload(driver)) => {
let available = max_in_flight.saturating_sub(cluster_in_flight);
if available == 0 {
return Ok(Vec::new());
}
Ok(driver
.pop_ready(now_ms, available)
.into_iter()
.map(|ready| ReadyArrival {
request: ready.request,
arrival_time_ms: now_ms,
replay_hashes: ready.replay_hashes,
})
.collect())
}
}
}
pub(in crate::replay::offline) fn on_request_completed(
&mut self,
uuid: Uuid,
now_ms: f64,
) -> Result<()> {
let AdmissionSource::Workload(driver) = &mut self.source else {
return Ok(());
};
driver.on_complete(uuid, now_ms)
}
pub(in crate::replay::offline) fn is_drained(&self) -> bool {
match &self.source {
AdmissionSource::Requests(pending) => pending.is_empty(),
AdmissionSource::Workload(driver) => driver.is_drained(),
}
}
#[cfg(test)]
pub(crate) fn is_workload(&self) -> bool {
matches!(self.source, AdmissionSource::Workload(_))
}
pub(in crate::replay::offline) fn total_requests(&self) -> usize {
match &self.source {
AdmissionSource::Requests(pending) => pending.len(),
AdmissionSource::Workload(driver) => driver.total_turns(),
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use anyhow::bail;
use super::super::events::SimulationWorkerStage;
use super::super::runtime_utils::WorkerCompletionPayload;
#[cfg(test)]
use super::super::state::OfflineWorkerSnapshot;
use super::super::state::OfflineWorkerState;
use super::{EngineEffects, EnginePassMode, ScheduledWorkerCompletion};
use crate::common::protocols::DirectRequest;
use crate::replay::TraceCollector;
use crate::scheduler::RouterEventVisibility;
pub(in crate::replay::offline) struct EngineComponent {
stage: SimulationWorkerStage,
pass_mode: EnginePassMode,
workers: Vec<OfflineWorkerState>,
}
impl EngineComponent {
pub(in crate::replay::offline) fn new(
stage: SimulationWorkerStage,
pass_mode: EnginePassMode,
workers: Vec<OfflineWorkerState>,
) -> Self {
Self {
stage,
pass_mode,
workers,
}
}
pub(in crate::replay::offline) fn dispatch(
&mut self,
worker_idx: usize,
request: DirectRequest,
) -> anyhow::Result<()> {
self.validate_worker_idx(worker_idx)?;
self.workers[worker_idx].receive_request(request);
Ok(())
}
pub(in crate::replay::offline) fn drive_ready(
&mut self,
now_ms: f64,
mut collector: Option<&mut TraceCollector>,
) -> anyhow::Result<EngineEffects> {
for worker_idx in 0..self.workers.len() {
if !self.workers[worker_idx].is_ready() {
continue;
}
let executed = match self.pass_mode {
EnginePassMode::Visible => {
let Some(collector) = collector.as_deref_mut() else {
bail!("offline replay visible engine pass requires a collector");
};
self.workers[worker_idx].execute_pass(collector, now_ms)
}
EnginePassMode::Hidden => self.workers[worker_idx].execute_hidden_pass(now_ms),
};
let mut effects = EngineEffects::default();
let completion_kv_events =
if executed.router_event_visibility == RouterEventVisibility::PassStart {
effects.pass_start_kv_events = executed.kv_events;
Vec::new()
} else {
executed.kv_events
};
let payload = WorkerCompletionPayload {
stage: self.stage,
worker_idx,
completed_requests: executed.completed_requests,
output_signals: executed.output_signals,
kv_events: completion_kv_events,
};
if executed.end_ms == now_ms {
effects.immediate_completions.push(payload);
return Ok(effects);
}
self.workers[worker_idx].mark_busy();
effects
.scheduled_completions
.push(ScheduledWorkerCompletion {
at_ms: executed.end_ms,
payload,
});
return Ok(effects);
}
Ok(EngineEffects::default())
}
pub(in crate::replay::offline) fn on_scheduled_completion(
&mut self,
payload: WorkerCompletionPayload,
) -> anyhow::Result<WorkerCompletionPayload> {
if payload.stage != self.stage {
bail!(
"offline replay completion stage mismatch: expected {:?}, got {:?}",
self.stage,
payload.stage
);
}
self.validate_worker_idx(payload.worker_idx)?;
self.workers[payload.worker_idx].mark_idle();
self.workers[payload.worker_idx].mark_completed(payload.completed_requests);
Ok(payload)
}
pub(in crate::replay::offline) fn in_flight(&self) -> usize {
self.workers.iter().map(OfflineWorkerState::in_flight).sum()
}
pub(in crate::replay::offline) fn is_drained(&self) -> bool {
self.workers.iter().all(OfflineWorkerState::is_drained)
}
pub(in crate::replay::offline) fn worker_count(&self) -> usize {
self.workers.len()
}
fn validate_worker_idx(&self, worker_idx: usize) -> anyhow::Result<()> {
if worker_idx >= self.workers.len() {
bail!("offline replay selected unknown worker index {worker_idx}");
}
Ok(())
}
#[cfg(test)]
pub(crate) fn debug_snapshots(&self) -> Vec<OfflineWorkerSnapshot> {
self.workers
.iter()
.map(OfflineWorkerState::debug_snapshot)
.collect()
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
mod admission;
mod engine;
mod router;
mod types;
pub(in crate::replay::offline) use admission::AdmissionQueue;
pub(in crate::replay::offline) use engine::EngineComponent;
pub(crate) use router::OfflineReplayRouter;
#[cfg(test)]
pub(crate) use router::OfflineRouterSnapshot;
pub(in crate::replay::offline) use types::{
EngineEffects, EnginePassMode, ReadyArrival, ReplayMode, ScheduledWorkerCompletion,
};
pub(crate) use types::{RouterEffects, WorkerAdmission};
...@@ -10,8 +10,8 @@ use anyhow::{Context, Result, anyhow}; ...@@ -10,8 +10,8 @@ use anyhow::{Context, Result, anyhow};
use dynamo_kv_router::LocalBlockHash; use dynamo_kv_router::LocalBlockHash;
use dynamo_kv_router::config::KvRouterConfig; use dynamo_kv_router::config::KvRouterConfig;
use dynamo_kv_router::protocols::{ use dynamo_kv_router::protocols::{
BlockHashOptions, OverlapScores, RouterEvent, WorkerConfigLike, WorkerId, WorkerWithDpRank, BlockHashOptions, OverlapScores, PrefillLoadHint, RouterEvent, WorkerConfigLike, WorkerId,
compute_block_hash_for_seq, WorkerWithDpRank, compute_block_hash_for_seq,
}; };
use dynamo_kv_router::queue::DEFAULT_MAX_BATCHED_TOKENS; use dynamo_kv_router::queue::DEFAULT_MAX_BATCHED_TOKENS;
use dynamo_kv_router::{ use dynamo_kv_router::{
...@@ -19,15 +19,18 @@ use dynamo_kv_router::{ ...@@ -19,15 +19,18 @@ use dynamo_kv_router::{
SchedulingPolicy, SchedulingRequest, SequenceRequest, WorkerSelector, SchedulingPolicy, SchedulingRequest, SequenceRequest, WorkerSelector,
}; };
use dynamo_tokens::SequenceHash; use dynamo_tokens::SequenceHash;
use tokio::time::Instant;
use uuid::Uuid; use uuid::Uuid;
use super::shared::{ use super::{RouterEffects, WorkerAdmission};
ReplayNoopPublisher, ReplayWorkerConfig, replay_policy, replay_router_config, replay_selector,
replay_slots, replay_workers_with_configs,
};
use crate::common::protocols::DirectRequest; use crate::common::protocols::DirectRequest;
use crate::common::protocols::MockEngineArgs; use crate::common::protocols::MockEngineArgs;
use crate::loadgen::ReplayRequestHashes; use crate::loadgen::ReplayRequestHashes;
use crate::replay::ReplayPrefillLoadEstimator;
use crate::replay::router_shared::{
ReplayNoopPublisher, ReplayWorkerConfig, replay_policy, replay_router_config, replay_selector,
replay_slots, replay_workers_with_configs,
};
type ReplayQueueKey = <RouterSchedulingPolicy as SchedulingPolicy>::Key; type ReplayQueueKey = <RouterSchedulingPolicy as SchedulingPolicy>::Key;
...@@ -183,12 +186,15 @@ pub(crate) struct OfflineReplayRouter { ...@@ -183,12 +186,15 @@ pub(crate) struct OfflineReplayRouter {
pending: BinaryHeap<QueueEntry>, pending: BinaryHeap<QueueEntry>,
next_enqueue_seq: u64, next_enqueue_seq: u64,
indexer: SyncReplayIndexer, indexer: SyncReplayIndexer,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
decay_time_epoch: Instant,
} }
impl OfflineReplayRouter { impl OfflineReplayRouter {
pub(crate) fn new( pub(crate) fn new(
args: &MockEngineArgs, args: &MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
num_workers: usize, num_workers: usize,
) -> Result<Self> { ) -> Result<Self> {
let config = replay_router_config(args, router_config); let config = replay_router_config(args, router_config);
...@@ -213,19 +219,25 @@ impl OfflineReplayRouter { ...@@ -213,19 +219,25 @@ impl OfflineReplayRouter {
pending: BinaryHeap::new(), pending: BinaryHeap::new(),
next_enqueue_seq: 0, next_enqueue_seq: 0,
indexer: SyncReplayIndexer::new(args.block_size as u32), indexer: SyncReplayIndexer::new(args.block_size as u32),
prefill_load_estimator,
// This is only a base Instant for converting replay `now_ms` values into
// synthetic `Instant`s. All subsequent decay/accounting uses virtual replay
// time derived from this epoch, not wall-clock progression.
decay_time_epoch: Instant::now(),
}) })
} }
pub(crate) fn submit_request_with_hashes( pub(crate) fn on_request_arrival(
&mut self, &mut self,
request: &DirectRequest, request: &DirectRequest,
replay_hashes: Option<ReplayRequestHashes>, replay_hashes: Option<ReplayRequestHashes>,
now_ms: f64, now_ms: f64,
) -> Result<Option<usize>> { ) -> Result<RouterEffects> {
let pending = self.build_pending_request(request, replay_hashes)?; let pending = self.build_pending_request(request, replay_hashes)?;
let decay_now = self.decay_now(now_ms);
let should_queue = self let should_queue = self
.queue_threshold .queue_threshold
.is_some_and(|threshold| self.all_workers_busy(threshold)); .is_some_and(|threshold| self.all_workers_busy(threshold, decay_now));
if should_queue { if should_queue {
let key = self.enqueue_key(now_ms, &pending); let key = self.enqueue_key(now_ms, &pending);
...@@ -236,28 +248,60 @@ impl OfflineReplayRouter { ...@@ -236,28 +248,60 @@ impl OfflineReplayRouter {
request: pending, request: pending,
}); });
self.next_enqueue_seq += 1; self.next_enqueue_seq += 1;
return Ok(None); return Ok(RouterEffects::default());
} }
self.admit_request(pending).map(Some) Ok(RouterEffects {
admissions: vec![WorkerAdmission {
uuid: request
.uuid
.expect("offline replay requests must have UUIDs before router submission"),
worker_idx: self.admit_request(pending, decay_now)?,
}],
})
} }
pub(crate) fn apply_event(&mut self, event: RouterEvent) -> Result<()> { pub(crate) fn on_kv_events(&mut self, events: Vec<RouterEvent>) -> Result<RouterEffects> {
self.indexer.apply_event(event) for event in events {
self.indexer.apply_event(event)?;
}
Ok(RouterEffects::default())
} }
pub(crate) fn mark_prefill_completed(&mut self, uuid: Uuid) -> Result<Vec<(Uuid, usize)>> { pub(crate) fn on_prefill_completed(
&mut self,
uuid: Uuid,
now_ms: f64,
) -> Result<RouterEffects> {
let decay_now = self.decay_now(now_ms);
self.slots self.slots
.mark_prefill_completed(&uuid.to_string()) .mark_prefill_completed(&uuid.to_string(), decay_now)
.map_err(anyhow::Error::from)?; .map_err(anyhow::Error::from)?;
self.drain_pending() Ok(RouterEffects {
admissions: self
.drain_pending(decay_now)?
.into_iter()
.map(|(uuid, worker_idx)| WorkerAdmission { uuid, worker_idx })
.collect(),
})
} }
pub(crate) fn free(&mut self, uuid: Uuid) -> Result<Vec<(Uuid, usize)>> { pub(crate) fn on_request_completed(
&mut self,
uuid: Uuid,
now_ms: f64,
) -> Result<RouterEffects> {
let decay_now = self.decay_now(now_ms);
self.slots self.slots
.free(&uuid.to_string()) .free(&uuid.to_string(), decay_now)
.map_err(anyhow::Error::from)?; .map_err(anyhow::Error::from)?;
self.drain_pending() Ok(RouterEffects {
admissions: self
.drain_pending(decay_now)?
.into_iter()
.map(|(uuid, worker_idx)| WorkerAdmission { uuid, worker_idx })
.collect(),
})
} }
pub(crate) fn pending_count(&self) -> usize { pub(crate) fn pending_count(&self) -> usize {
...@@ -265,7 +309,8 @@ impl OfflineReplayRouter { ...@@ -265,7 +309,8 @@ impl OfflineReplayRouter {
} }
#[cfg(test)] #[cfg(test)]
pub(crate) fn debug_snapshot(&self) -> OfflineRouterSnapshot { pub(crate) fn debug_snapshot(&self, now_ms: f64) -> OfflineRouterSnapshot {
let decay_now = self.decay_now(now_ms);
let mut pending = self let mut pending = self
.pending .pending
.iter() .iter()
...@@ -302,7 +347,7 @@ impl OfflineReplayRouter { ...@@ -302,7 +347,7 @@ impl OfflineReplayRouter {
let mut active_tokens_by_worker = self let mut active_tokens_by_worker = self
.slots .slots
.active_tokens() .active_tokens(decay_now)
.into_iter() .into_iter()
.map(|(worker, tokens)| (worker.worker_id as usize, tokens)) .map(|(worker, tokens)| (worker.worker_id as usize, tokens))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
...@@ -324,6 +369,10 @@ impl OfflineReplayRouter { ...@@ -324,6 +369,10 @@ impl OfflineReplayRouter {
) )
} }
fn decay_now(&self, now_ms: f64) -> Instant {
self.decay_time_epoch + Duration::from_secs_f64(now_ms.max(0.0) / 1000.0)
}
fn build_pending_request( fn build_pending_request(
&self, &self,
request: &DirectRequest, request: &DirectRequest,
...@@ -378,7 +427,7 @@ impl OfflineReplayRouter { ...@@ -378,7 +427,7 @@ impl OfflineReplayRouter {
}) })
} }
fn admit_request(&mut self, request: PendingRequest) -> Result<usize> { fn admit_request(&mut self, request: PendingRequest, decay_now: Instant) -> Result<usize> {
let (decode_blocks, prefill_tokens) = self let (decode_blocks, prefill_tokens) = self
.slots .slots
.potential_blocks_and_tokens_with_prefill_tracking( .potential_blocks_and_tokens_with_prefill_tracking(
...@@ -386,6 +435,7 @@ impl OfflineReplayRouter { ...@@ -386,6 +435,7 @@ impl OfflineReplayRouter {
request.isl_tokens, request.isl_tokens,
request.overlaps.clone(), request.overlaps.clone(),
request.track_prefill_tokens, request.track_prefill_tokens,
decay_now,
); );
let scheduling_request = request.scheduling_request(decode_blocks, prefill_tokens); let scheduling_request = request.scheduling_request(decode_blocks, prefill_tokens);
let selection = self.selector.select_worker( let selection = self.selector.select_worker(
...@@ -396,56 +446,188 @@ impl OfflineReplayRouter { ...@@ -396,56 +446,188 @@ impl OfflineReplayRouter {
let worker_idx = usize::try_from(selection.worker.worker_id) let worker_idx = usize::try_from(selection.worker.worker_id)
.map_err(|_| anyhow!("selected worker id does not fit into usize"))?; .map_err(|_| anyhow!("selected worker id does not fit into usize"))?;
let request_id = request.request_id(); let request_id = request.request_id();
let prefill_load_hint = self.prefill_load_hint_for(
request.isl_tokens,
selection.overlap_blocks,
request.track_prefill_tokens,
);
self.slots self.slots
.add_request(SequenceRequest { .add_request(
request_id, SequenceRequest {
token_sequence: request.token_seq, request_id,
isl: request.isl_tokens, token_sequence: request.token_seq,
overlap: selection.overlap_blocks, isl: request.isl_tokens,
track_prefill_tokens: request.track_prefill_tokens, overlap: selection.overlap_blocks,
expected_output_tokens: request.expected_output_tokens, track_prefill_tokens: request.track_prefill_tokens,
worker: selection.worker, expected_output_tokens: request.expected_output_tokens,
lora_name: None, prefill_load_hint,
}) worker: selection.worker,
lora_name: None,
},
decay_now,
)
.map_err(anyhow::Error::from)?; .map_err(anyhow::Error::from)?;
Ok(worker_idx) Ok(worker_idx)
} }
fn drain_pending(&mut self) -> Result<Vec<(Uuid, usize)>> { fn drain_pending(&mut self, decay_now: Instant) -> Result<Vec<(Uuid, usize)>> {
let Some(threshold) = self.queue_threshold else { let Some(threshold) = self.queue_threshold else {
return Ok(Vec::new()); return Ok(Vec::new());
}; };
let mut admissions = Vec::new(); let mut admissions = Vec::new();
while !self.all_workers_busy(threshold) { while !self.all_workers_busy(threshold, decay_now) {
let Some(QueueEntry { request, .. }) = self.pending.pop() else { let Some(QueueEntry { request, .. }) = self.pending.pop() else {
break; break;
}; };
let uuid = request.uuid; let uuid = request.uuid;
let worker_idx = self.admit_request(request)?; let worker_idx = self.admit_request(request, decay_now)?;
admissions.push((uuid, worker_idx)); admissions.push((uuid, worker_idx));
} }
Ok(admissions) Ok(admissions)
} }
fn all_workers_busy(&self, threshold: f64) -> bool { fn all_workers_busy(&self, threshold: f64, decay_now: Instant) -> bool {
let mut checked_any = false; let mut checked_any = false;
let any_worker_not_busy = self let any_worker_not_busy =
.slots self.slots
.any_worker_matches_active_tokens(|worker, tokens| { .any_worker_matches_active_tokens(decay_now, |worker, tokens| {
let Some(config) = self.workers_with_configs.get(&worker.worker_id) else { let Some(config) = self.workers_with_configs.get(&worker.worker_id) else {
return false; return false;
}; };
checked_any = true; checked_any = true;
let max_batched = config let max_batched = config
.max_num_batched_tokens() .max_num_batched_tokens()
.unwrap_or(DEFAULT_MAX_BATCHED_TOKENS); .unwrap_or(DEFAULT_MAX_BATCHED_TOKENS);
(tokens as f64) <= threshold * (max_batched as f64) (tokens as f64) <= threshold * (max_batched as f64)
}); });
checked_any && !any_worker_not_busy checked_any && !any_worker_not_busy
} }
fn prefill_load_hint_for(
&self,
isl_tokens: usize,
overlap_blocks: u32,
track_prefill_tokens: bool,
) -> Option<PrefillLoadHint> {
if !track_prefill_tokens {
return None;
}
let prefix = (overlap_blocks as usize) * (self.block_size as usize);
let effective_isl = isl_tokens.saturating_sub(prefix);
if effective_isl == 0 {
return None;
}
let Some(estimator) = &self.prefill_load_estimator else {
return None;
};
match estimator.predict_prefill_duration(1, effective_isl, prefix) {
Ok(expected_prefill_duration) => Some(PrefillLoadHint {
initial_effective_prefill_tokens: effective_isl,
expected_prefill_duration: Some(expected_prefill_duration),
}),
Err(error) => {
tracing::warn!(
effective_isl,
prefix,
"failed to predict replay prefill duration for active load tracking: {error}"
);
None
}
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::time::Duration;
use dynamo_kv_router::PrefillLoadEstimator;
use dynamo_kv_router::config::{KvRouterConfig, RouterPrefillLoadModel};
use uuid::Uuid;
use super::OfflineReplayRouter;
use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::replay::ReplayPrefillLoadEstimator;
struct FixedPrefillLoadEstimator {
duration: Duration,
}
impl PrefillLoadEstimator for FixedPrefillLoadEstimator {
fn predict_prefill_duration(
&self,
_batch_size: usize,
_effective_isl: usize,
_prefix: usize,
) -> anyhow::Result<Duration> {
Ok(self.duration)
}
}
fn replay_args() -> MockEngineArgs {
MockEngineArgs::builder()
.block_size(64)
.max_num_batched_tokens(Some(256))
.build()
.unwrap()
}
fn router_config() -> KvRouterConfig {
KvRouterConfig {
router_track_prefill_tokens: true,
router_prefill_load_model: RouterPrefillLoadModel::Aic,
..KvRouterConfig::default()
}
}
fn estimator(duration: Duration) -> ReplayPrefillLoadEstimator {
Arc::new(FixedPrefillLoadEstimator { duration })
}
fn request(uuid: u128, token: u32) -> DirectRequest {
DirectRequest {
tokens: vec![token; 64],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(uuid)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
}
}
#[test]
fn test_prefill_load_estimator_decays_offline_router_active_tokens() {
let mut router = OfflineReplayRouter::new(
&replay_args(),
Some(router_config()),
Some(estimator(Duration::from_secs(10))),
1,
)
.unwrap();
let effects = router
.on_request_arrival(&request(1, 7), None, 0.0)
.unwrap();
assert_eq!(effects.admissions.len(), 1);
assert_eq!(
router.debug_snapshot(0.0).active_tokens_by_worker,
vec![(0, 64)]
);
assert_eq!(
router.debug_snapshot(5_000.0).active_tokens_by_worker,
vec![(0, 32)]
);
assert_eq!(
router.debug_snapshot(10_000.0).active_tokens_by_worker,
vec![(0, 0)]
);
}
} }
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dynamo_kv_router::protocols::RouterEvent;
use uuid::Uuid;
use super::super::runtime_utils::WorkerCompletionPayload;
use crate::common::protocols::DirectRequest;
use crate::loadgen::ReplayRequestHashes;
#[derive(Debug, Clone, Copy)]
pub(in crate::replay::offline) enum ReplayMode {
Trace,
Concurrency { max_in_flight: usize },
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(in crate::replay::offline) enum EnginePassMode {
Visible,
Hidden,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct WorkerAdmission {
pub(crate) uuid: Uuid,
pub(crate) worker_idx: usize,
}
#[derive(Debug)]
pub(in crate::replay::offline) struct ScheduledWorkerCompletion {
pub(in crate::replay::offline) at_ms: f64,
pub(in crate::replay::offline) payload: WorkerCompletionPayload,
}
#[derive(Debug, Default)]
pub(in crate::replay::offline) struct EngineEffects {
pub(in crate::replay::offline) pass_start_kv_events: Vec<RouterEvent>,
pub(in crate::replay::offline) immediate_completions: Vec<WorkerCompletionPayload>,
pub(in crate::replay::offline) scheduled_completions: Vec<ScheduledWorkerCompletion>,
}
impl EngineEffects {
pub(in crate::replay::offline) fn is_empty(&self) -> bool {
self.pass_start_kv_events.is_empty()
&& self.immediate_completions.is_empty()
&& self.scheduled_completions.is_empty()
}
}
#[derive(Debug, Default)]
pub(crate) struct RouterEffects {
pub(crate) admissions: Vec<WorkerAdmission>,
}
#[derive(Debug)]
pub(in crate::replay::offline) struct ReadyArrival {
pub(in crate::replay::offline) request: DirectRequest,
pub(in crate::replay::offline) arrival_time_ms: f64,
pub(in crate::replay::offline) replay_hashes: Option<ReplayRequestHashes>,
}
...@@ -8,32 +8,36 @@ use dynamo_kv_router::config::KvRouterConfig; ...@@ -8,32 +8,36 @@ use dynamo_kv_router::config::KvRouterConfig;
use dynamo_kv_router::protocols::RouterEvent; use dynamo_kv_router::protocols::RouterEvent;
use uuid::Uuid; use uuid::Uuid;
pub(super) use super::components::ReplayMode;
use super::components::{
AdmissionQueue, EngineComponent, EngineEffects, EnginePassMode, OfflineReplayRouter,
ScheduledWorkerCompletion, WorkerAdmission,
};
use super::events::{SimulationEvent, SimulationWorkerStage}; use super::events::{SimulationEvent, SimulationWorkerStage};
use super::progress::ReplayProgress;
use super::runtime_utils::{ use super::runtime_utils::{
WorkerCompletionPayload, next_timestamp as choose_next_timestamp, pop_next_concurrency_ready, next_timestamp as choose_next_timestamp, pop_ready_decode_handoff, pop_ready_worker_completion,
pop_next_trace_ready, pop_ready_decode_handoff, pop_ready_worker_completion,
push_decode_handoff, push_worker_completion, push_decode_handoff, push_worker_completion,
}; };
#[cfg(test)] #[cfg(test)]
use super::state::DisaggPhase;
#[cfg(test)]
use super::state::DisaggRequestSnapshot; use super::state::DisaggRequestSnapshot;
use super::state::{DisaggRequestState, OfflineWorkerState}; use super::state::{DisaggPhase, DisaggRequestState};
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal}; use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal};
use crate::loadgen::{ReplayRequestHashes, WorkloadDriver}; use crate::loadgen::{ReplayRequestHashes, WorkloadDriver};
use crate::replay::router::OfflineReplayRouter; use crate::replay::{
use crate::replay::{OfflineDisaggReplayConfig, ReplayRouterMode, TraceCollector}; OfflineDisaggReplayConfig, ReplayPrefillLoadEstimator, ReplayRouterMode, TraceCollector,
use crate::scheduler::RouterEventVisibility; };
#[derive(Debug, Clone, Copy)]
pub(super) enum ReplayMode {
Trace,
Concurrency { max_in_flight: usize },
}
enum AdmissionSource { #[cfg(test)]
Requests(VecDeque<DirectRequest>), #[derive(Debug, Clone, Copy, PartialEq, Eq)]
Workload(WorkloadDriver), pub(crate) enum DisaggTransition {
PrefillMarkCompleted { uuid: Uuid },
PrefillFree { uuid: Uuid },
DecodeHandoffQueued { uuid: Uuid },
DecodeEnqueued { uuid: Uuid },
DecodeFree { uuid: Uuid },
RequestMarkedDone { uuid: Uuid },
WorkloadCompleted { uuid: Uuid },
} }
#[cfg(test)] #[cfg(test)]
...@@ -48,6 +52,7 @@ pub(super) struct DisaggRuntimeStats { ...@@ -48,6 +52,7 @@ pub(super) struct DisaggRuntimeStats {
decode_router_freed_count: usize, decode_router_freed_count: usize,
max_prefill_router_pending_count: usize, max_prefill_router_pending_count: usize,
max_decode_router_pending_count: usize, max_decode_router_pending_count: usize,
transition_log: Vec<DisaggTransition>,
} }
#[cfg(not(test))] #[cfg(not(test))]
...@@ -59,15 +64,15 @@ pub(super) struct DisaggRuntime { ...@@ -59,15 +64,15 @@ pub(super) struct DisaggRuntime {
next_prefill_worker_idx: usize, next_prefill_worker_idx: usize,
next_decode_worker_idx: usize, next_decode_worker_idx: usize,
next_event_seq: u64, next_event_seq: u64,
admission: AdmissionSource, admission: AdmissionQueue,
prefill_workers: Vec<OfflineWorkerState>, prefill_engine: EngineComponent,
decode_workers: Vec<OfflineWorkerState>, decode_engine: EngineComponent,
prefill_router: Option<OfflineReplayRouter>, prefill_router: Option<OfflineReplayRouter>,
decode_router: Option<OfflineReplayRouter>, decode_router: Option<OfflineReplayRouter>,
requests: HashMap<Uuid, DisaggRequestState>, requests: HashMap<Uuid, DisaggRequestState>,
collector: TraceCollector, collector: TraceCollector,
events: BinaryHeap<SimulationEvent>, events: BinaryHeap<SimulationEvent>,
mode: ReplayMode, progress: ReplayProgress,
stats: DisaggRuntimeStats, stats: DisaggRuntimeStats,
} }
...@@ -76,6 +81,7 @@ impl DisaggRuntime { ...@@ -76,6 +81,7 @@ impl DisaggRuntime {
pub(super) fn new( pub(super) fn new(
config: &OfflineDisaggReplayConfig, config: &OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
pending: VecDeque<DirectRequest>, pending: VecDeque<DirectRequest>,
mode: ReplayMode, mode: ReplayMode,
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
...@@ -83,8 +89,8 @@ impl DisaggRuntime { ...@@ -83,8 +89,8 @@ impl DisaggRuntime {
Self::new_with_source( Self::new_with_source(
config, config,
router_config, router_config,
AdmissionSource::Requests(pending), prefill_load_estimator,
mode, AdmissionQueue::new_requests(pending, mode),
router_mode, router_mode,
) )
} }
...@@ -93,6 +99,7 @@ impl DisaggRuntime { ...@@ -93,6 +99,7 @@ impl DisaggRuntime {
pub(super) fn new_workload( pub(super) fn new_workload(
config: &OfflineDisaggReplayConfig, config: &OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
driver: WorkloadDriver, driver: WorkloadDriver,
mode: ReplayMode, mode: ReplayMode,
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
...@@ -100,8 +107,8 @@ impl DisaggRuntime { ...@@ -100,8 +107,8 @@ impl DisaggRuntime {
Self::new_with_source( Self::new_with_source(
config, config,
router_config, router_config,
AdmissionSource::Workload(driver), prefill_load_estimator,
mode, AdmissionQueue::new_workload(driver, mode),
router_mode, router_mode,
) )
} }
...@@ -110,10 +117,11 @@ impl DisaggRuntime { ...@@ -110,10 +117,11 @@ impl DisaggRuntime {
fn new_with_source( fn new_with_source(
config: &OfflineDisaggReplayConfig, config: &OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
admission: AdmissionSource, prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
mode: ReplayMode, admission: AdmissionQueue,
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
) -> Result<Self> { ) -> Result<Self> {
let progress = ReplayProgress::new(admission.total_requests(), "offline disagg replay");
let (prefill_router, decode_router) = match router_mode { let (prefill_router, decode_router) = match router_mode {
ReplayRouterMode::RoundRobin => (None, None), ReplayRouterMode::RoundRobin => (None, None),
ReplayRouterMode::KvRouter => { ReplayRouterMode::KvRouter => {
...@@ -125,43 +133,60 @@ impl DisaggRuntime { ...@@ -125,43 +133,60 @@ impl DisaggRuntime {
Some(OfflineReplayRouter::new( Some(OfflineReplayRouter::new(
&config.prefill_args, &config.prefill_args,
Some(prefill_router_config), Some(prefill_router_config),
prefill_load_estimator,
config.num_prefill_workers, config.num_prefill_workers,
)?), )?),
Some(OfflineReplayRouter::new( Some(OfflineReplayRouter::new(
&config.decode_args, &config.decode_args,
Some(decode_router_config), Some(decode_router_config),
None,
config.num_decode_workers, config.num_decode_workers,
)?), )?),
) )
} }
}; };
Ok(Self { let prefill_engine = EngineComponent::new(
now_ms: 0.0, SimulationWorkerStage::Prefill,
next_prefill_worker_idx: 0, EnginePassMode::Hidden,
next_decode_worker_idx: 0, (0..config.num_prefill_workers)
next_event_seq: 0,
admission,
prefill_workers: (0..config.num_prefill_workers)
.map(|worker_idx| { .map(|worker_idx| {
OfflineWorkerState::new( super::state::OfflineWorkerState::new(
worker_idx, worker_idx,
config.prefill_args.clone(), config.prefill_args.clone(),
prefill_router.is_some(), prefill_router.is_some(),
) )
}) })
.collect(), .collect(),
decode_workers: (0..config.num_decode_workers) );
let decode_engine = EngineComponent::new(
SimulationWorkerStage::Decode,
EnginePassMode::Visible,
(0..config.num_decode_workers)
.map(|worker_idx| { .map(|worker_idx| {
OfflineWorkerState::new(worker_idx, config.decode_args.clone(), false) super::state::OfflineWorkerState::new(
worker_idx,
config.decode_args.clone(),
false,
)
}) })
.collect(), .collect(),
);
Ok(Self {
now_ms: 0.0,
next_prefill_worker_idx: 0,
next_decode_worker_idx: 0,
next_event_seq: 0,
admission,
prefill_engine,
decode_engine,
prefill_router, prefill_router,
decode_router, decode_router,
requests: HashMap::new(), requests: HashMap::new(),
collector: TraceCollector::default(), collector: TraceCollector::default(),
events: BinaryHeap::new(), events: BinaryHeap::new(),
mode, progress,
#[cfg(test)] #[cfg(test)]
stats: DisaggRuntimeStats::default(), stats: DisaggRuntimeStats::default(),
#[cfg(not(test))] #[cfg(not(test))]
...@@ -171,15 +196,8 @@ impl DisaggRuntime { ...@@ -171,15 +196,8 @@ impl DisaggRuntime {
/// Count all requests consuming cluster capacity across prefill, decode, and router queues. /// Count all requests consuming cluster capacity across prefill, decode, and router queues.
fn cluster_in_flight(&self) -> usize { fn cluster_in_flight(&self) -> usize {
self.prefill_workers self.prefill_engine.in_flight()
.iter() + self.decode_engine.in_flight()
.map(OfflineWorkerState::in_flight)
.sum::<usize>()
+ self
.decode_workers
.iter()
.map(OfflineWorkerState::in_flight)
.sum::<usize>()
+ self + self
.prefill_router .prefill_router
.as_ref() .as_ref()
...@@ -194,14 +212,15 @@ impl DisaggRuntime { ...@@ -194,14 +212,15 @@ impl DisaggRuntime {
fn next_prefill_worker(&mut self) -> usize { fn next_prefill_worker(&mut self) -> usize {
let worker_idx = self.next_prefill_worker_idx; let worker_idx = self.next_prefill_worker_idx;
self.next_prefill_worker_idx = self.next_prefill_worker_idx =
(self.next_prefill_worker_idx + 1) % self.prefill_workers.len(); (self.next_prefill_worker_idx + 1) % self.prefill_engine.worker_count();
worker_idx worker_idx
} }
/// Pick the next decode worker in round-robin order. /// Pick the next decode worker in round-robin order.
fn next_decode_worker(&mut self) -> usize { fn next_decode_worker(&mut self) -> usize {
let worker_idx = self.next_decode_worker_idx; let worker_idx = self.next_decode_worker_idx;
self.next_decode_worker_idx = (self.next_decode_worker_idx + 1) % self.decode_workers.len(); self.next_decode_worker_idx =
(self.next_decode_worker_idx + 1) % self.decode_engine.worker_count();
worker_idx worker_idx
} }
...@@ -224,19 +243,6 @@ impl DisaggRuntime { ...@@ -224,19 +243,6 @@ impl DisaggRuntime {
} }
} }
/// Fail fast if a router admission targets a non-existent stage worker.
fn validate_worker_idx(&self, stage: SimulationWorkerStage, worker_idx: usize) -> Result<()> {
let worker_count = match stage {
SimulationWorkerStage::Prefill => self.prefill_workers.len(),
SimulationWorkerStage::Decode => self.decode_workers.len(),
SimulationWorkerStage::Aggregated => unreachable!("aggregated stage is not used"),
};
if worker_idx >= worker_count {
bail!("offline disagg replay selected unknown {stage:?} worker index {worker_idx}");
}
Ok(())
}
/// Borrow immutable request state with a structured missing-request error. /// Borrow immutable request state with a structured missing-request error.
fn state(&self, uuid: Uuid) -> Result<&DisaggRequestState> { fn state(&self, uuid: Uuid) -> Result<&DisaggRequestState> {
self.requests self.requests
...@@ -253,9 +259,8 @@ impl DisaggRuntime { ...@@ -253,9 +259,8 @@ impl DisaggRuntime {
/// Dispatch a request's prefill stage onto a specific prefill worker. /// Dispatch a request's prefill stage onto a specific prefill worker.
fn dispatch_prefill(&mut self, uuid: Uuid, worker_idx: usize) -> Result<()> { fn dispatch_prefill(&mut self, uuid: Uuid, worker_idx: usize) -> Result<()> {
self.validate_worker_idx(SimulationWorkerStage::Prefill, worker_idx)?;
let request = self.state(uuid)?.build_prefill_request()?; let request = self.state(uuid)?.build_prefill_request()?;
self.prefill_workers[worker_idx].receive_request(request); self.prefill_engine.dispatch(worker_idx, request)?;
self.state_mut(uuid)?.start_prefill(worker_idx); self.state_mut(uuid)?.start_prefill(worker_idx);
#[cfg(test)] #[cfg(test)]
{ {
...@@ -266,9 +271,8 @@ impl DisaggRuntime { ...@@ -266,9 +271,8 @@ impl DisaggRuntime {
/// Dispatch a request's decode stage onto a specific decode worker. /// Dispatch a request's decode stage onto a specific decode worker.
fn dispatch_decode(&mut self, uuid: Uuid, worker_idx: usize) -> Result<()> { fn dispatch_decode(&mut self, uuid: Uuid, worker_idx: usize) -> Result<()> {
self.validate_worker_idx(SimulationWorkerStage::Decode, worker_idx)?; let request = self.state(uuid)?.original_request()?.clone();
let request = self.state(uuid)?.build_decode_request()?; self.decode_engine.dispatch(worker_idx, request)?;
self.decode_workers[worker_idx].receive_request(request);
self.state_mut(uuid)?.start_decode(worker_idx); self.state_mut(uuid)?.start_decode(worker_idx);
#[cfg(test)] #[cfg(test)]
{ {
...@@ -278,9 +282,9 @@ impl DisaggRuntime { ...@@ -278,9 +282,9 @@ impl DisaggRuntime {
} }
/// Turn prefill router admissions into concrete worker dispatches. /// Turn prefill router admissions into concrete worker dispatches.
fn dispatch_prefill_admissions(&mut self, admissions: Vec<(Uuid, usize)>) -> Result<()> { fn dispatch_prefill_admissions(&mut self, admissions: Vec<WorkerAdmission>) -> Result<()> {
for (uuid, worker_idx) in admissions { for WorkerAdmission { uuid, worker_idx } in admissions {
if !self.state(uuid)?.is_queued_prefill() { if self.state(uuid)?.phase != DisaggPhase::QueuedPrefill {
bail!("offline disagg replay expected queued prefill request for {uuid}"); bail!("offline disagg replay expected queued prefill request for {uuid}");
} }
self.dispatch_prefill(uuid, worker_idx)?; self.dispatch_prefill(uuid, worker_idx)?;
...@@ -289,9 +293,9 @@ impl DisaggRuntime { ...@@ -289,9 +293,9 @@ impl DisaggRuntime {
} }
/// Turn decode router admissions into concrete worker dispatches. /// Turn decode router admissions into concrete worker dispatches.
fn dispatch_decode_admissions(&mut self, admissions: Vec<(Uuid, usize)>) -> Result<()> { fn dispatch_decode_admissions(&mut self, admissions: Vec<WorkerAdmission>) -> Result<()> {
for (uuid, worker_idx) in admissions { for WorkerAdmission { uuid, worker_idx } in admissions {
if !self.state(uuid)?.is_queued_decode() { if self.state(uuid)?.phase != DisaggPhase::QueuedDecode {
bail!("offline disagg replay expected queued decode request for {uuid}"); bail!("offline disagg replay expected queued decode request for {uuid}");
} }
self.dispatch_decode(uuid, worker_idx)?; self.dispatch_decode(uuid, worker_idx)?;
...@@ -299,52 +303,37 @@ impl DisaggRuntime { ...@@ -299,52 +303,37 @@ impl DisaggRuntime {
Ok(()) Ok(())
} }
/// Submit a request to the prefill router and return an immediate admission when available.
fn submit_to_prefill_router(
&mut self,
uuid: Uuid,
replay_hashes: Option<ReplayRequestHashes>,
) -> Result<Option<usize>> {
let request = self.state(uuid)?.original_request()?.clone();
let Some(prefill_router) = self.prefill_router.as_mut() else {
bail!("offline disagg replay prefill submission requires an active router");
};
let maybe_worker_idx =
prefill_router.submit_request_with_hashes(&request, replay_hashes, self.now_ms)?;
self.record_router_pending();
Ok(maybe_worker_idx)
}
/// Submit a request to the decode router and return an immediate admission when available.
fn submit_to_decode_router(&mut self, uuid: Uuid) -> Result<Option<usize>> {
let request = self.state(uuid)?.original_request()?.clone();
let Some(decode_router) = self.decode_router.as_mut() else {
bail!("offline disagg replay decode submission requires an active router");
};
let maybe_worker_idx =
decode_router.submit_request_with_hashes(&request, None, self.now_ms)?;
self.record_router_pending();
Ok(maybe_worker_idx)
}
/// Queue or dispatch a request into decode, depending on whether a decode router is active. /// Queue or dispatch a request into decode, depending on whether a decode router is active.
fn enqueue_decode(&mut self, uuid: Uuid) -> Result<()> { fn enqueue_decode(&mut self, uuid: Uuid) -> Result<()> {
if self.decode_router.is_none() { if self.decode_router.is_none() {
#[cfg(test)]
{
self.stats
.transition_log
.push(DisaggTransition::DecodeEnqueued { uuid });
self.stats.handoff_ms.insert(uuid, self.now_ms);
}
let worker_idx = self.next_decode_worker(); let worker_idx = self.next_decode_worker();
self.dispatch_decode(uuid, worker_idx)?; self.dispatch_decode(uuid, worker_idx)?;
return Ok(()); return Ok(());
} }
let maybe_worker_idx = self.submit_to_decode_router(uuid)?; let request = self.state(uuid)?.original_request()?.clone();
self.state_mut(uuid)?.queue_decode();
#[cfg(test)] #[cfg(test)]
{ {
self.stats
.transition_log
.push(DisaggTransition::DecodeEnqueued { uuid });
self.stats.handoff_ms.insert(uuid, self.now_ms); self.stats.handoff_ms.insert(uuid, self.now_ms);
} }
if let Some(worker_idx) = maybe_worker_idx { let admissions = self
self.dispatch_decode(uuid, worker_idx)?; .decode_router
return Ok(()); .as_mut()
} .expect("decode router presence checked above")
.on_request_arrival(&request, None, self.now_ms)?
self.state_mut(uuid)?.queue_decode(); .admissions;
self.record_router_pending();
self.dispatch_decode_admissions(admissions)?;
Ok(()) Ok(())
} }
...@@ -366,6 +355,7 @@ impl DisaggRuntime { ...@@ -366,6 +355,7 @@ impl DisaggRuntime {
request.max_output_tokens, request.max_output_tokens,
); );
let queued_request = request.clone();
self.requests self.requests
.insert(uuid, DisaggRequestState::new(request, arrival_time_ms)); .insert(uuid, DisaggRequestState::new(request, arrival_time_ms));
if self.prefill_router.is_none() { if self.prefill_router.is_none() {
...@@ -373,10 +363,14 @@ impl DisaggRuntime { ...@@ -373,10 +363,14 @@ impl DisaggRuntime {
self.dispatch_prefill(uuid, worker_idx)?; self.dispatch_prefill(uuid, worker_idx)?;
return Ok(uuid); return Ok(uuid);
} }
let maybe_worker_idx = self.submit_to_prefill_router(uuid, replay_hashes)?; let admissions = self
if let Some(worker_idx) = maybe_worker_idx { .prefill_router
self.dispatch_prefill(uuid, worker_idx)?; .as_mut()
} .expect("prefill router presence checked above")
.on_request_arrival(&queued_request, replay_hashes, self.now_ms)?
.admissions;
self.record_router_pending();
self.dispatch_prefill_admissions(admissions)?;
Ok(uuid) Ok(uuid)
} }
...@@ -384,40 +378,18 @@ impl DisaggRuntime { ...@@ -384,40 +378,18 @@ impl DisaggRuntime {
fn is_done(&self) -> bool { fn is_done(&self) -> bool {
self.events.is_empty() self.events.is_empty()
&& self.cluster_in_flight() == 0 && self.cluster_in_flight() == 0
&& match &self.admission { && self.admission.is_drained()
AdmissionSource::Requests(pending) => pending.is_empty(), && self.prefill_engine.is_drained()
AdmissionSource::Workload(driver) => driver.is_drained(), && self.decode_engine.is_drained()
}
&& self
.prefill_workers
.iter()
.all(OfflineWorkerState::is_drained)
&& self
.decode_workers
.iter()
.all(OfflineWorkerState::is_drained)
} }
/// Pick the next logical timestamp from arrivals, worker completions, or decode handoffs. /// Pick the next logical timestamp from arrivals, worker completions, or decode handoffs.
fn next_timestamp(&mut self) -> Option<f64> { fn next_timestamp(&mut self) -> Option<f64> {
let next_event_ms = self.events.peek().map(|event| event.at_ms); let next_event_ms = self.events.peek().map(|event| event.at_ms);
let cluster_in_flight = self.cluster_in_flight(); choose_next_timestamp(
let next_arrival_ms = match (&self.mode, &mut self.admission) { self.admission.next_ready_time_ms(self.cluster_in_flight()),
(ReplayMode::Trace, AdmissionSource::Requests(pending)) => pending next_event_ms,
.front() )
.and_then(|request| request.arrival_timestamp_ms),
(ReplayMode::Trace, AdmissionSource::Workload(driver)) => driver.next_ready_time_ms(),
(ReplayMode::Concurrency { max_in_flight }, AdmissionSource::Workload(driver)) => {
if cluster_in_flight < *max_in_flight {
driver.next_ready_time_ms()
} else {
None
}
}
(ReplayMode::Concurrency { .. }, AdmissionSource::Requests(_)) => None,
};
choose_next_timestamp(next_arrival_ms, next_event_ms)
} }
/// Apply prefill-side KV router events at the scheduler-selected visibility phase. /// Apply prefill-side KV router events at the scheduler-selected visibility phase.
...@@ -425,8 +397,9 @@ impl DisaggRuntime { ...@@ -425,8 +397,9 @@ impl DisaggRuntime {
let Some(prefill_router) = self.prefill_router.as_mut() else { let Some(prefill_router) = self.prefill_router.as_mut() else {
return Ok(()); return Ok(());
}; };
for event in events { let effects = prefill_router.on_kv_events(events)?;
prefill_router.apply_event(event)?; if !effects.admissions.is_empty() {
bail!("offline disagg replay prefill KV events must not admit requests");
} }
Ok(()) Ok(())
} }
...@@ -440,22 +413,32 @@ impl DisaggRuntime { ...@@ -440,22 +413,32 @@ impl DisaggRuntime {
if self.prefill_router.is_some() { if self.prefill_router.is_some() {
let prefill_complete_admissions = { let prefill_complete_admissions = {
let prefill_router = self.prefill_router.as_mut().expect("router checked above"); let prefill_router = self.prefill_router.as_mut().expect("router checked above");
prefill_router.mark_prefill_completed(signal.uuid)? prefill_router
.on_prefill_completed(signal.uuid, self.now_ms)?
.admissions
}; };
#[cfg(test)] #[cfg(test)]
{ {
self.stats.prefill_marked_count += 1; self.stats.prefill_marked_count += 1;
self.stats
.transition_log
.push(DisaggTransition::PrefillMarkCompleted { uuid: signal.uuid });
} }
self.record_router_pending(); self.record_router_pending();
self.dispatch_prefill_admissions(prefill_complete_admissions)?; self.dispatch_prefill_admissions(prefill_complete_admissions)?;
let admissions = { let admissions = {
let prefill_router = self.prefill_router.as_mut().expect("router checked above"); let prefill_router = self.prefill_router.as_mut().expect("router checked above");
prefill_router.free(signal.uuid)? prefill_router
.on_request_completed(signal.uuid, self.now_ms)?
.admissions
}; };
#[cfg(test)] #[cfg(test)]
{ {
self.stats.prefill_router_freed_count += 1; self.stats.prefill_router_freed_count += 1;
self.stats
.transition_log
.push(DisaggTransition::PrefillFree { uuid: signal.uuid });
} }
self.record_router_pending(); self.record_router_pending();
self.dispatch_prefill_admissions(admissions)?; self.dispatch_prefill_admissions(admissions)?;
...@@ -471,20 +454,37 @@ impl DisaggRuntime { ...@@ -471,20 +454,37 @@ impl DisaggRuntime {
} }
let admissions = if let Some(decode_router) = self.decode_router.as_mut() { let admissions = if let Some(decode_router) = self.decode_router.as_mut() {
let admissions = decode_router.free(signal.uuid)?; let admissions = decode_router
.on_request_completed(signal.uuid, self.now_ms)?
.admissions;
#[cfg(test)] #[cfg(test)]
{ {
self.stats.decode_router_freed_count += 1; self.stats.decode_router_freed_count += 1;
self.stats
.transition_log
.push(DisaggTransition::DecodeFree { uuid: signal.uuid });
} }
admissions admissions
} else { } else {
Vec::new() Vec::new()
}; };
self.record_router_pending(); self.record_router_pending();
if let AdmissionSource::Workload(driver) = &mut self.admission { self.admission
driver.on_complete(signal.uuid, self.now_ms)?; .on_request_completed(signal.uuid, self.now_ms)?;
self.progress.inc_completed();
#[cfg(test)]
if self.admission.is_workload() {
self.stats
.transition_log
.push(DisaggTransition::WorkloadCompleted { uuid: signal.uuid });
} }
self.state_mut(signal.uuid)?.mark_done(); self.state_mut(signal.uuid)?.mark_done();
#[cfg(test)]
{
self.stats
.transition_log
.push(DisaggTransition::RequestMarkedDone { uuid: signal.uuid });
}
self.dispatch_decode_admissions(admissions)?; self.dispatch_decode_admissions(admissions)?;
Ok(()) Ok(())
} }
...@@ -492,12 +492,11 @@ impl DisaggRuntime { ...@@ -492,12 +492,11 @@ impl DisaggRuntime {
/// Apply the side effects of a finished prefill pass. /// Apply the side effects of a finished prefill pass.
fn process_prefill_pass( fn process_prefill_pass(
&mut self, &mut self,
worker_idx: usize, _worker_idx: usize,
completed_requests: usize, _completed_requests: usize,
output_signals: Vec<OutputSignal>, output_signals: Vec<OutputSignal>,
kv_events: Vec<RouterEvent>, kv_events: Vec<RouterEvent>,
) -> Result<()> { ) -> Result<()> {
self.prefill_workers[worker_idx].mark_completed(completed_requests);
self.apply_prefill_router_events(kv_events)?; self.apply_prefill_router_events(kv_events)?;
for signal in output_signals { for signal in output_signals {
self.process_prefill_signal(signal)?; self.process_prefill_signal(signal)?;
...@@ -508,11 +507,10 @@ impl DisaggRuntime { ...@@ -508,11 +507,10 @@ impl DisaggRuntime {
/// Apply the side effects of a finished decode pass. /// Apply the side effects of a finished decode pass.
fn process_decode_pass( fn process_decode_pass(
&mut self, &mut self,
worker_idx: usize, _worker_idx: usize,
completed_requests: usize, _completed_requests: usize,
output_signals: Vec<OutputSignal>, output_signals: Vec<OutputSignal>,
) -> Result<()> { ) -> Result<()> {
self.decode_workers[worker_idx].mark_completed(completed_requests);
for signal in output_signals { for signal in output_signals {
self.process_decode_signal(signal)?; self.process_decode_signal(signal)?;
} }
...@@ -522,27 +520,24 @@ impl DisaggRuntime { ...@@ -522,27 +520,24 @@ impl DisaggRuntime {
/// Drain all worker-completion events scheduled for the current logical timestamp. /// Drain all worker-completion events scheduled for the current logical timestamp.
fn apply_worker_completions(&mut self) -> Result<bool> { fn apply_worker_completions(&mut self) -> Result<bool> {
let mut changed = false; let mut changed = false;
while let Some(WorkerCompletionPayload { while let Some(payload) = pop_ready_worker_completion(&mut self.events, self.now_ms) {
stage, match payload.stage {
worker_idx,
completed_requests,
output_signals,
kv_events,
}) = pop_ready_worker_completion(&mut self.events, self.now_ms)
{
match stage {
SimulationWorkerStage::Prefill => { SimulationWorkerStage::Prefill => {
self.prefill_workers[worker_idx].mark_idle(); let payload = self.prefill_engine.on_scheduled_completion(payload)?;
self.process_prefill_pass( self.process_prefill_pass(
worker_idx, payload.worker_idx,
completed_requests, payload.completed_requests,
output_signals, payload.output_signals,
kv_events, payload.kv_events,
)?; )?;
} }
SimulationWorkerStage::Decode => { SimulationWorkerStage::Decode => {
self.decode_workers[worker_idx].mark_idle(); let payload = self.decode_engine.on_scheduled_completion(payload)?;
self.process_decode_pass(worker_idx, completed_requests, output_signals)?; self.process_decode_pass(
payload.worker_idx,
payload.completed_requests,
payload.output_signals,
)?;
} }
SimulationWorkerStage::Aggregated => { SimulationWorkerStage::Aggregated => {
bail!("offline disagg replay received an aggregated completion event") bail!("offline disagg replay received an aggregated completion event")
...@@ -569,91 +564,33 @@ impl DisaggRuntime { ...@@ -569,91 +564,33 @@ impl DisaggRuntime {
uuid: Uuid, uuid: Uuid,
handoff_delay_ms: Option<f64>, handoff_delay_ms: Option<f64>,
) -> Result<()> { ) -> Result<()> {
if let Some(delay_ms) = handoff_delay_ms let Some(delay_ms) = handoff_delay_ms else {
&& delay_ms > 0.0 return self.enqueue_decode(uuid);
{ };
if delay_ms > 0.0 {
push_decode_handoff( push_decode_handoff(
&mut self.events, &mut self.events,
&mut self.next_event_seq, &mut self.next_event_seq,
self.now_ms + delay_ms, self.now_ms + delay_ms,
uuid, uuid,
); );
#[cfg(test)]
self.stats
.transition_log
.push(DisaggTransition::DecodeHandoffQueued { uuid });
return Ok(()); return Ok(());
} }
self.enqueue_decode(uuid) self.enqueue_decode(uuid)
} }
/// Release every trace arrival whose timestamp is now visible to the global clock. /// Release every admission made ready by the shared admission queue.
fn release_trace_arrivals(&mut self) -> Result<bool> { fn release_ready_arrivals(&mut self) -> Result<bool> {
let mut released_any = false;
if matches!(self.admission, AdmissionSource::Requests(_)) {
loop {
let next_ready = match &mut self.admission {
AdmissionSource::Requests(pending) => {
pop_next_trace_ready(pending, self.now_ms)
}
AdmissionSource::Workload(_) => unreachable!(),
};
let Some((request, arrival_ms)) = next_ready else {
break;
};
self.on_external_arrival(request, arrival_ms, None)?;
released_any = true;
}
return Ok(released_any);
}
let ready_requests = match &mut self.admission {
AdmissionSource::Requests(_) => unreachable!(),
AdmissionSource::Workload(driver) => driver.pop_ready(self.now_ms, usize::MAX),
};
for ready in ready_requests {
self.on_external_arrival(
ready.request,
ready.scheduled_ready_at_ms,
ready.replay_hashes,
)?;
released_any = true;
}
Ok(released_any)
}
/// Backfill closed-loop concurrency replay until the staged cluster reaches its cap.
fn top_off_concurrency(&mut self, max_in_flight: usize) -> Result<bool> {
let mut released_any = false; let mut released_any = false;
if matches!(self.admission, AdmissionSource::Requests(_)) { for ready in self
loop { .admission
let cluster_in_flight = self.cluster_in_flight(); .drain_ready(self.now_ms, self.cluster_in_flight())?
let next_ready = match &mut self.admission { {
AdmissionSource::Requests(pending) => pop_next_concurrency_ready( self.on_external_arrival(ready.request, ready.arrival_time_ms, ready.replay_hashes)?;
pending,
self.now_ms,
cluster_in_flight,
max_in_flight,
),
AdmissionSource::Workload(_) => unreachable!(),
};
let Some((request, arrival_ms)) = next_ready else {
break;
};
self.on_external_arrival(request, arrival_ms, None)?;
released_any = true;
}
return Ok(released_any);
}
let available = max_in_flight.saturating_sub(self.cluster_in_flight());
if available == 0 {
return Ok(false);
}
let ready_requests = match &mut self.admission {
AdmissionSource::Requests(_) => unreachable!(),
AdmissionSource::Workload(driver) => driver.pop_ready(self.now_ms, available),
};
for ready in ready_requests {
self.on_external_arrival(ready.request, self.now_ms, ready.replay_hashes)?;
released_any = true; released_any = true;
} }
Ok(released_any) Ok(released_any)
...@@ -662,93 +599,61 @@ impl DisaggRuntime { ...@@ -662,93 +599,61 @@ impl DisaggRuntime {
/// Start passes on every idle prefill worker that can make progress at the current timestamp. /// Start passes on every idle prefill worker that can make progress at the current timestamp.
fn drive_prefill_workers(&mut self) -> Result<bool> { fn drive_prefill_workers(&mut self) -> Result<bool> {
let mut changed = false; let mut changed = false;
for worker_idx in 0..self.prefill_workers.len() { loop {
loop { let effects = self.prefill_engine.drive_ready(self.now_ms, None)?;
if !self.prefill_workers[worker_idx].is_ready() { if effects.is_empty() {
break; return Ok(changed);
}
let executed = self.prefill_workers[worker_idx].execute_hidden_pass(self.now_ms);
changed = true;
let completion_kv_events =
if executed.router_event_visibility == RouterEventVisibility::PassStart {
self.apply_prefill_router_events(executed.kv_events)?;
Vec::new()
} else {
executed.kv_events
};
if executed.end_ms == self.now_ms {
self.process_prefill_pass(
worker_idx,
executed.completed_requests,
executed.output_signals,
completion_kv_events,
)?;
continue;
}
self.prefill_workers[worker_idx].mark_busy();
push_worker_completion(
&mut self.events,
&mut self.next_event_seq,
executed.end_ms,
WorkerCompletionPayload {
stage: SimulationWorkerStage::Prefill,
worker_idx,
completed_requests: executed.completed_requests,
output_signals: executed.output_signals,
kv_events: completion_kv_events,
},
);
break;
} }
changed = true;
self.handle_prefill_engine_effects(effects)?;
} }
Ok(changed)
} }
/// Start passes on every idle decode worker that can make progress at the current timestamp. /// Start passes on every idle decode worker that can make progress at the current timestamp.
fn drive_decode_workers(&mut self) -> Result<bool> { fn drive_decode_workers(&mut self) -> Result<bool> {
let mut changed = false; let mut changed = false;
for worker_idx in 0..self.decode_workers.len() { loop {
loop { let effects = self
if !self.decode_workers[worker_idx].is_ready() { .decode_engine
break; .drive_ready(self.now_ms, Some(&mut self.collector))?;
} if effects.is_empty() {
return Ok(changed);
let executed = { }
let (workers, collector) = (&mut self.decode_workers, &mut self.collector); changed = true;
workers[worker_idx].execute_pass(collector, self.now_ms) self.handle_decode_engine_effects(effects)?;
}; }
changed = true; }
if executed.end_ms == self.now_ms { fn handle_prefill_engine_effects(&mut self, effects: EngineEffects) -> Result<()> {
self.process_decode_pass( self.apply_prefill_router_events(effects.pass_start_kv_events)?;
worker_idx, for payload in effects.immediate_completions {
executed.completed_requests, let payload = self.prefill_engine.on_scheduled_completion(payload)?;
executed.output_signals, self.process_prefill_pass(
)?; payload.worker_idx,
continue; payload.completed_requests,
} payload.output_signals,
payload.kv_events,
)?;
}
for ScheduledWorkerCompletion { at_ms, payload } in effects.scheduled_completions {
push_worker_completion(&mut self.events, &mut self.next_event_seq, at_ms, payload);
}
Ok(())
}
self.decode_workers[worker_idx].mark_busy(); fn handle_decode_engine_effects(&mut self, effects: EngineEffects) -> Result<()> {
push_worker_completion( for payload in effects.immediate_completions {
&mut self.events, let payload = self.decode_engine.on_scheduled_completion(payload)?;
&mut self.next_event_seq, self.process_decode_pass(
executed.end_ms, payload.worker_idx,
WorkerCompletionPayload { payload.completed_requests,
stage: SimulationWorkerStage::Decode, payload.output_signals,
worker_idx, )?;
completed_requests: executed.completed_requests,
output_signals: executed.output_signals,
kv_events: Vec::new(),
},
);
break;
}
} }
Ok(changed) for ScheduledWorkerCompletion { at_ms, payload } in effects.scheduled_completions {
push_worker_completion(&mut self.events, &mut self.next_event_seq, at_ms, payload);
}
Ok(())
} }
/// Repeatedly process all work that becomes possible without advancing logical time. /// Repeatedly process all work that becomes possible without advancing logical time.
...@@ -756,14 +661,7 @@ impl DisaggRuntime { ...@@ -756,14 +661,7 @@ impl DisaggRuntime {
loop { loop {
let mut changed = self.apply_worker_completions()?; let mut changed = self.apply_worker_completions()?;
changed |= self.apply_decode_handoffs()?; changed |= self.apply_decode_handoffs()?;
changed |= self.release_ready_arrivals()?;
changed |= match self.mode {
ReplayMode::Trace => self.release_trace_arrivals()?,
ReplayMode::Concurrency { max_in_flight } => {
self.top_off_concurrency(max_in_flight)?
}
};
changed |= self.drive_prefill_workers()?; changed |= self.drive_prefill_workers()?;
changed |= self.drive_decode_workers()?; changed |= self.drive_decode_workers()?;
...@@ -801,6 +699,7 @@ impl DisaggRuntime { ...@@ -801,6 +699,7 @@ impl DisaggRuntime {
self.drain_current_timestamp()?; self.drain_current_timestamp()?;
} }
self.progress.finish();
self.finish_test_stats(); self.finish_test_stats();
Ok((self.collector, self.stats)) Ok((self.collector, self.stats))
} }
...@@ -834,14 +733,19 @@ fn derive_decode_router_config( ...@@ -834,14 +733,19 @@ fn derive_decode_router_config(
config.overlap_score_weight = 0.0; config.overlap_score_weight = 0.0;
config.router_assume_kv_reuse = false; config.router_assume_kv_reuse = false;
config.router_track_prefill_tokens = false; config.router_track_prefill_tokens = false;
config.router_prefill_load_model = dynamo_kv_router::config::RouterPrefillLoadModel::None;
config config
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::super::entrypoints::{run_concurrency_collect, run_trace_collect}; use super::super::entrypoints::{
run_concurrency_collect, run_concurrency_workload_collect, run_trace_collect,
run_trace_workload_collect,
};
use super::*; use super::*;
use crate::common::protocols::{MockEngineArgs, WorkerType}; use crate::common::protocols::{EngineType, MockEngineArgs, SglangArgs, WorkerType};
use crate::loadgen::{SessionTrace, Trace, TurnTrace};
fn staged_args(worker_type: WorkerType, speedup_ratio: f64) -> MockEngineArgs { fn staged_args(worker_type: WorkerType, speedup_ratio: f64) -> MockEngineArgs {
MockEngineArgs::builder() MockEngineArgs::builder()
...@@ -858,6 +762,26 @@ mod tests { ...@@ -858,6 +762,26 @@ mod tests {
.unwrap() .unwrap()
} }
fn sglang_staged_args(worker_type: WorkerType, speedup_ratio: f64) -> MockEngineArgs {
MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.block_size(64)
.num_gpu_blocks(512)
.max_num_batched_tokens(Some(8192))
.max_num_seqs(Some(8))
.enable_prefix_caching(true)
.enable_chunked_prefill(true)
.speedup_ratio(speedup_ratio)
.decode_speedup_ratio(speedup_ratio)
.worker_type(worker_type)
.sglang(Some(SglangArgs {
page_size: Some(64),
..Default::default()
}))
.build()
.unwrap()
}
fn disagg_config() -> OfflineDisaggReplayConfig { fn disagg_config() -> OfflineDisaggReplayConfig {
OfflineDisaggReplayConfig { OfflineDisaggReplayConfig {
prefill_args: staged_args(WorkerType::Prefill, 1000.0), prefill_args: staged_args(WorkerType::Prefill, 1000.0),
...@@ -867,6 +791,15 @@ mod tests { ...@@ -867,6 +791,15 @@ mod tests {
} }
} }
fn sglang_disagg_config() -> OfflineDisaggReplayConfig {
OfflineDisaggReplayConfig {
prefill_args: sglang_staged_args(WorkerType::Prefill, 1000.0),
decode_args: sglang_staged_args(WorkerType::Decode, 1000.0),
num_prefill_workers: 2,
num_decode_workers: 2,
}
}
fn disagg_config_with_handoff_delay() -> OfflineDisaggReplayConfig { fn disagg_config_with_handoff_delay() -> OfflineDisaggReplayConfig {
let mut config = disagg_config(); let mut config = disagg_config();
config.prefill_args.kv_transfer_bandwidth = Some(1.0); config.prefill_args.kv_transfer_bandwidth = Some(1.0);
...@@ -896,6 +829,49 @@ mod tests { ...@@ -896,6 +829,49 @@ mod tests {
} }
} }
fn multiturn_trace() -> Trace {
Trace {
block_size: 64,
sessions: vec![
SessionTrace {
session_id: "session-a".to_string(),
first_arrival_timestamp_ms: Some(0.0),
turns: vec![
TurnTrace {
input_length: 64,
max_output_tokens: 2,
hash_ids: vec![11],
delay_after_previous_ms: 0.0,
},
TurnTrace {
input_length: 192,
max_output_tokens: 2,
hash_ids: vec![21, 22, 23],
delay_after_previous_ms: 10.0,
},
],
},
SessionTrace {
session_id: "session-b".to_string(),
first_arrival_timestamp_ms: Some(5.0),
turns: vec![TurnTrace {
input_length: 128,
max_output_tokens: 2,
hash_ids: vec![31, 32],
delay_after_previous_ms: 0.0,
}],
},
],
}
}
fn transition_index(transitions: &[DisaggTransition], needle: DisaggTransition) -> usize {
transitions
.iter()
.position(|transition| *transition == needle)
.unwrap()
}
#[test] #[test]
fn test_derive_stage_router_configs_force_required_overrides() { fn test_derive_stage_router_configs_force_required_overrides() {
let config = KvRouterConfig { let config = KvRouterConfig {
...@@ -933,6 +909,7 @@ mod tests { ...@@ -933,6 +909,7 @@ mod tests {
assert!(snapshot.first_token_ms.is_some()); assert!(snapshot.first_token_ms.is_some());
assert_eq!(snapshot.output_length, 3); assert_eq!(snapshot.output_length, 3);
assert_eq!(report.request_counts.completed_requests, 1); assert_eq!(report.request_counts.completed_requests, 1);
assert_eq!(report.request_counts.total_output_tokens, 3);
assert_eq!( assert_eq!(
stats.request_snapshots[&Uuid::from_u128(1)].phase, stats.request_snapshots[&Uuid::from_u128(1)].phase,
DisaggPhase::Done DisaggPhase::Done
...@@ -952,26 +929,45 @@ mod tests { ...@@ -952,26 +929,45 @@ mod tests {
for uuid in [Uuid::from_u128(1), Uuid::from_u128(2)] { for uuid in [Uuid::from_u128(1), Uuid::from_u128(2)] {
assert!(stats.prefill_assignments.contains_key(&uuid)); assert!(stats.prefill_assignments.contains_key(&uuid));
assert!(stats.decode_assignments.contains_key(&uuid)); assert!(stats.decode_assignments.contains_key(&uuid));
assert_eq!(stats.request_snapshots[&uuid].phase, DisaggPhase::Done);
assert_eq!(
stats.request_snapshots[&uuid].prefill_worker_idx,
Some(stats.prefill_assignments[&uuid])
);
assert_eq!(
stats.request_snapshots[&uuid].decode_worker_idx,
Some(stats.decode_assignments[&uuid])
);
} }
} }
#[test] #[test]
fn test_prefill_overlap_prefers_same_worker_after_handoff_delay() { fn test_prefill_overlap_prefers_same_worker_after_handoff_delay() {
let config = disagg_config();
let requests = vec![request(1, 128, 2, 0.0), request(2, 128, 2, 100.0)]; let requests = vec![request(1, 128, 2, 0.0), request(2, 128, 2, 100.0)];
let (_, stats) = run_trace_collect( let cases = [(disagg_config(), true), (sglang_disagg_config(), false)];
&config, for (config, expect_same_worker) in cases {
requests, let (_, stats) = run_trace_collect(
Some(router_config()), &config,
1.0, requests.clone(),
ReplayRouterMode::KvRouter, Some(router_config()),
); 1.0,
ReplayRouterMode::KvRouter,
);
assert_eq!( if expect_same_worker {
stats.prefill_assignments[&Uuid::from_u128(1)], assert_eq!(
stats.prefill_assignments[&Uuid::from_u128(2)], stats.prefill_assignments[&Uuid::from_u128(1)],
); stats.prefill_assignments[&Uuid::from_u128(2)],
);
} else {
for uuid in [Uuid::from_u128(1), Uuid::from_u128(2)] {
assert!(stats.prefill_assignments.contains_key(&uuid));
assert!(stats.decode_assignments.contains_key(&uuid));
assert_eq!(stats.request_snapshots[&uuid].phase, DisaggPhase::Done);
}
}
}
} }
#[rstest::rstest] #[rstest::rstest]
...@@ -999,13 +995,21 @@ mod tests { ...@@ -999,13 +995,21 @@ mod tests {
]; ];
let router_config = (router_mode == ReplayRouterMode::KvRouter).then(router_config); let router_config = (router_mode == ReplayRouterMode::KvRouter).then(router_config);
let (collector, _) = let (collector, stats) =
run_concurrency_collect(&config, requests, router_config, 1, router_mode); run_concurrency_collect(&config, requests, router_config, 1, router_mode);
let first = collector.snapshot(Uuid::from_u128(1)).unwrap(); let first = collector.snapshot(Uuid::from_u128(1)).unwrap();
let second = collector.snapshot(Uuid::from_u128(2)).unwrap(); let second = collector.snapshot(Uuid::from_u128(2)).unwrap();
assert_eq!(first.arrival_time_ms, 0.0); assert_eq!(first.arrival_time_ms, 0.0);
assert_eq!(second.arrival_time_ms, first.last_token_ms.unwrap()); assert_eq!(second.arrival_time_ms, first.last_token_ms.unwrap());
assert_eq!(
stats.request_snapshots[&Uuid::from_u128(1)].phase,
DisaggPhase::Done
);
assert_eq!(
stats.request_snapshots[&Uuid::from_u128(2)].phase,
DisaggPhase::Done
);
} }
#[test] #[test]
...@@ -1024,6 +1028,14 @@ mod tests { ...@@ -1024,6 +1028,14 @@ mod tests {
assert_eq!(stats.prefill_marked_count, 1); assert_eq!(stats.prefill_marked_count, 1);
assert_eq!(stats.prefill_router_freed_count, 1); assert_eq!(stats.prefill_router_freed_count, 1);
assert_eq!(stats.decode_router_freed_count, 1); assert_eq!(stats.decode_router_freed_count, 1);
let transitions = &stats.transition_log;
let uuid = Uuid::from_u128(1);
let mark_idx =
transition_index(transitions, DisaggTransition::PrefillMarkCompleted { uuid });
let free_idx = transition_index(transitions, DisaggTransition::PrefillFree { uuid });
let enqueue_idx = transition_index(transitions, DisaggTransition::DecodeEnqueued { uuid });
assert!(mark_idx < free_idx);
assert!(free_idx < enqueue_idx);
} }
#[test] #[test]
...@@ -1037,7 +1049,7 @@ mod tests { ...@@ -1037,7 +1049,7 @@ mod tests {
1.0, 1.0,
ReplayRouterMode::RoundRobin, ReplayRouterMode::RoundRobin,
); );
let (delayed_collector, _) = run_trace_collect( let (delayed_collector, delayed_stats) = run_trace_collect(
&disagg_config_with_handoff_delay(), &disagg_config_with_handoff_delay(),
requests, requests,
None, None,
...@@ -1054,5 +1066,71 @@ mod tests { ...@@ -1054,5 +1066,71 @@ mod tests {
delayed_ttft >= baseline_ttft + 120.0, delayed_ttft >= baseline_ttft + 120.0,
"expected delayed TTFT to include roughly 128ms of handoff delay, baseline={baseline_ttft}, delayed={delayed_ttft}" "expected delayed TTFT to include roughly 128ms of handoff delay, baseline={baseline_ttft}, delayed={delayed_ttft}"
); );
let uuid = Uuid::from_u128(1);
let queued_idx = transition_index(
&delayed_stats.transition_log,
DisaggTransition::DecodeHandoffQueued { uuid },
);
let enqueued_idx = transition_index(
&delayed_stats.transition_log,
DisaggTransition::DecodeEnqueued { uuid },
);
assert!(queued_idx < enqueued_idx);
assert!(delayed_stats.handoff_ms[&uuid] >= 120.0);
}
#[test]
fn test_trace_workload_follow_up_turn_arrives_after_completion_plus_delay() {
let (collector, _) = run_trace_workload_collect(
&disagg_config(),
multiturn_trace(),
None,
ReplayRouterMode::RoundRobin,
);
let snapshots = collector.snapshots();
let first_turn = snapshots
.iter()
.find(|snapshot| snapshot.input_length == 64)
.unwrap();
let second_turn = snapshots
.iter()
.find(|snapshot| snapshot.input_length == 192)
.unwrap();
let session_b = snapshots
.iter()
.find(|snapshot| snapshot.input_length == 128)
.unwrap();
assert_eq!(first_turn.arrival_time_ms, 0.0);
assert_eq!(session_b.arrival_time_ms, 5.0);
assert!(
second_turn.arrival_time_ms >= first_turn.last_token_ms.unwrap() + 10.0,
"follow-up turn should unlock after completion plus delay"
);
}
#[test]
fn test_concurrency_workload_delayed_follow_up_does_not_bypass_other_ready_sessions() {
let (collector, _) = run_concurrency_workload_collect(
&disagg_config(),
multiturn_trace(),
None,
1,
ReplayRouterMode::RoundRobin,
);
let mut input_lengths = collector
.snapshots()
.into_iter()
.map(|snapshot| (snapshot.arrival_time_ms, snapshot.input_length))
.collect::<Vec<_>>();
input_lengths.sort_by(|left, right| left.0.total_cmp(&right.0));
assert_eq!(
input_lengths
.into_iter()
.map(|(_, input_length)| input_length)
.collect::<Vec<_>>(),
vec![64, 128, 192]
);
} }
} }
...@@ -20,8 +20,8 @@ use crate::common::protocols::{DirectRequest, EngineType, MockEngineArgs}; ...@@ -20,8 +20,8 @@ use crate::common::protocols::{DirectRequest, EngineType, MockEngineArgs};
use crate::loadgen::{Trace, WorkloadDriver}; use crate::loadgen::{Trace, WorkloadDriver};
use crate::replay::OfflineDisaggReplayConfig; use crate::replay::OfflineDisaggReplayConfig;
use crate::replay::{ use crate::replay::{
ReplayRouterMode, ReplayTimedKvEvent, ReplayTimedOutputSignal, ReplayTimedRequest, ReplayPrefillLoadEstimator, ReplayRouterMode, ReplayTimedKvEvent, ReplayTimedOutputSignal,
ReplayWorkerArtifacts, TraceCollector, TraceSimulationReport, ReplayTimedRequest, ReplayWorkerArtifacts, TraceCollector, TraceSimulationReport,
}; };
use crate::scheduler::RouterEventVisibility; use crate::scheduler::RouterEventVisibility;
...@@ -37,8 +37,10 @@ pub(crate) fn generate_trace_worker_artifacts( ...@@ -37,8 +37,10 @@ pub(crate) fn generate_trace_worker_artifacts(
args: MockEngineArgs, args: MockEngineArgs,
trace: Trace, trace: Trace,
) -> Result<ReplayWorkerArtifacts> { ) -> Result<ReplayWorkerArtifacts> {
let args = args.normalized()?;
let engine_block_size = args.block_size;
let mut worker = ReplayWorkerCore::new_with_kv_capture(args, WorkerId::default()); let mut worker = ReplayWorkerCore::new_with_kv_capture(args, WorkerId::default());
let mut driver = trace.into_trace_driver()?; let mut driver = trace.into_trace_driver_with_block_size(engine_block_size)?;
let mut collector = TraceCollector::default(); let mut collector = TraceCollector::default();
let mut artifacts = ReplayWorkerArtifacts::default(); let mut artifacts = ReplayWorkerArtifacts::default();
let mut current_time_ms = 0.0; let mut current_time_ms = 0.0;
...@@ -106,6 +108,7 @@ pub(crate) fn generate_trace_worker_artifacts( ...@@ -106,6 +108,7 @@ pub(crate) fn generate_trace_worker_artifacts(
pub(crate) fn simulate_trace( pub(crate) fn simulate_trace(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>, requests: Vec<DirectRequest>,
num_workers: usize, num_workers: usize,
arrival_speedup_ratio: f64, arrival_speedup_ratio: f64,
...@@ -117,6 +120,7 @@ pub(crate) fn simulate_trace( ...@@ -117,6 +120,7 @@ pub(crate) fn simulate_trace(
simulate_trace_multi( simulate_trace_multi(
args, args,
router_config, router_config,
prefill_load_estimator,
requests, requests,
num_workers, num_workers,
arrival_speedup_ratio, arrival_speedup_ratio,
...@@ -128,6 +132,7 @@ pub(crate) fn simulate_trace( ...@@ -128,6 +132,7 @@ pub(crate) fn simulate_trace(
pub(crate) fn simulate_concurrency( pub(crate) fn simulate_concurrency(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>, requests: Vec<DirectRequest>,
max_in_flight: usize, max_in_flight: usize,
num_workers: usize, num_workers: usize,
...@@ -139,6 +144,7 @@ pub(crate) fn simulate_concurrency( ...@@ -139,6 +144,7 @@ pub(crate) fn simulate_concurrency(
simulate_concurrency_multi( simulate_concurrency_multi(
args, args,
router_config, router_config,
prefill_load_estimator,
requests, requests,
max_in_flight, max_in_flight,
num_workers, num_workers,
...@@ -150,6 +156,7 @@ pub(crate) fn simulate_concurrency( ...@@ -150,6 +156,7 @@ pub(crate) fn simulate_concurrency(
pub(crate) fn simulate_trace_workload( pub(crate) fn simulate_trace_workload(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace, trace: Trace,
num_workers: usize, num_workers: usize,
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
...@@ -157,13 +164,21 @@ pub(crate) fn simulate_trace_workload( ...@@ -157,13 +164,21 @@ pub(crate) fn simulate_trace_workload(
if num_workers == 1 && args.engine_type == EngineType::Vllm { if num_workers == 1 && args.engine_type == EngineType::Vllm {
simulate_trace_workload_single(args, trace) simulate_trace_workload_single(args, trace)
} else { } else {
simulate_trace_workload_multi(args, router_config, trace, num_workers, router_mode) simulate_trace_workload_multi(
args,
router_config,
prefill_load_estimator,
trace,
num_workers,
router_mode,
)
} }
} }
pub(crate) fn simulate_concurrency_workload( pub(crate) fn simulate_concurrency_workload(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace, trace: Trace,
max_in_flight: usize, max_in_flight: usize,
num_workers: usize, num_workers: usize,
...@@ -175,6 +190,7 @@ pub(crate) fn simulate_concurrency_workload( ...@@ -175,6 +190,7 @@ pub(crate) fn simulate_concurrency_workload(
simulate_concurrency_workload_multi( simulate_concurrency_workload_multi(
args, args,
router_config, router_config,
prefill_load_estimator,
trace, trace,
max_in_flight, max_in_flight,
num_workers, num_workers,
...@@ -186,6 +202,7 @@ pub(crate) fn simulate_concurrency_workload( ...@@ -186,6 +202,7 @@ pub(crate) fn simulate_concurrency_workload(
pub(crate) fn simulate_trace_disagg( pub(crate) fn simulate_trace_disagg(
config: OfflineDisaggReplayConfig, config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>, requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64, arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
...@@ -194,6 +211,7 @@ pub(crate) fn simulate_trace_disagg( ...@@ -194,6 +211,7 @@ pub(crate) fn simulate_trace_disagg(
let (collector, _) = DisaggRuntime::new( let (collector, _) = DisaggRuntime::new(
&config, &config,
router_config, router_config,
prefill_load_estimator,
pending, pending,
DisaggReplayMode::Trace, DisaggReplayMode::Trace,
router_mode, router_mode,
...@@ -205,6 +223,7 @@ pub(crate) fn simulate_trace_disagg( ...@@ -205,6 +223,7 @@ pub(crate) fn simulate_trace_disagg(
pub(crate) fn simulate_concurrency_disagg( pub(crate) fn simulate_concurrency_disagg(
config: OfflineDisaggReplayConfig, config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>, requests: Vec<DirectRequest>,
max_in_flight: usize, max_in_flight: usize,
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
...@@ -213,6 +232,7 @@ pub(crate) fn simulate_concurrency_disagg( ...@@ -213,6 +232,7 @@ pub(crate) fn simulate_concurrency_disagg(
let (collector, _) = DisaggRuntime::new( let (collector, _) = DisaggRuntime::new(
&config, &config,
router_config, router_config,
prefill_load_estimator,
pending, pending,
DisaggReplayMode::Concurrency { max_in_flight }, DisaggReplayMode::Concurrency { max_in_flight },
router_mode, router_mode,
...@@ -224,13 +244,15 @@ pub(crate) fn simulate_concurrency_disagg( ...@@ -224,13 +244,15 @@ pub(crate) fn simulate_concurrency_disagg(
pub(crate) fn simulate_trace_workload_disagg( pub(crate) fn simulate_trace_workload_disagg(
config: OfflineDisaggReplayConfig, config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace, trace: Trace,
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> { ) -> Result<TraceSimulationReport> {
let driver = WorkloadDriver::new_trace(trace)?; let driver = WorkloadDriver::new_trace(trace, config.prefill_args.block_size)?;
let (collector, _) = DisaggRuntime::new_workload( let (collector, _) = DisaggRuntime::new_workload(
&config, &config,
router_config, router_config,
prefill_load_estimator,
driver, driver,
DisaggReplayMode::Trace, DisaggReplayMode::Trace,
router_mode, router_mode,
...@@ -242,14 +264,16 @@ pub(crate) fn simulate_trace_workload_disagg( ...@@ -242,14 +264,16 @@ pub(crate) fn simulate_trace_workload_disagg(
pub(crate) fn simulate_concurrency_workload_disagg( pub(crate) fn simulate_concurrency_workload_disagg(
config: OfflineDisaggReplayConfig, config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace, trace: Trace,
max_in_flight: usize, max_in_flight: usize,
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> { ) -> Result<TraceSimulationReport> {
let driver = WorkloadDriver::new_concurrency(trace)?; let driver = WorkloadDriver::new_concurrency(trace, config.prefill_args.block_size)?;
let (collector, _) = DisaggRuntime::new_workload( let (collector, _) = DisaggRuntime::new_workload(
&config, &config,
router_config, router_config,
prefill_load_estimator,
driver, driver,
DisaggReplayMode::Concurrency { max_in_flight }, DisaggReplayMode::Concurrency { max_in_flight },
router_mode, router_mode,
...@@ -290,9 +314,13 @@ pub(crate) fn simulate_trace_workload_single( ...@@ -290,9 +314,13 @@ pub(crate) fn simulate_trace_workload_single(
trace: Trace, trace: Trace,
) -> Result<TraceSimulationReport> { ) -> Result<TraceSimulationReport> {
let args = args.normalized()?; let args = args.normalized()?;
let collector = let engine_block_size = args.block_size;
SingleRuntime::new_workload(args, trace.into_trace_driver()?, SingleReplayMode::Trace) let collector = SingleRuntime::new_workload(
.run()?; args,
trace.into_trace_driver_with_block_size(engine_block_size)?,
SingleReplayMode::Trace,
)
.run()?;
Ok(collector.finish()) Ok(collector.finish())
} }
...@@ -302,9 +330,10 @@ pub(crate) fn simulate_concurrency_workload_single( ...@@ -302,9 +330,10 @@ pub(crate) fn simulate_concurrency_workload_single(
max_in_flight: usize, max_in_flight: usize,
) -> Result<TraceSimulationReport> { ) -> Result<TraceSimulationReport> {
let args = args.normalized()?; let args = args.normalized()?;
let engine_block_size = args.block_size;
let collector = SingleRuntime::new_workload( let collector = SingleRuntime::new_workload(
args, args,
trace.into_concurrency_driver()?, trace.into_concurrency_driver_with_block_size(engine_block_size)?,
SingleReplayMode::Concurrency { max_in_flight }, SingleReplayMode::Concurrency { max_in_flight },
) )
.run()?; .run()?;
...@@ -314,6 +343,7 @@ pub(crate) fn simulate_concurrency_workload_single( ...@@ -314,6 +343,7 @@ pub(crate) fn simulate_concurrency_workload_single(
pub(crate) fn simulate_trace_multi( pub(crate) fn simulate_trace_multi(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>, requests: Vec<DirectRequest>,
num_workers: usize, num_workers: usize,
arrival_speedup_ratio: f64, arrival_speedup_ratio: f64,
...@@ -324,6 +354,7 @@ pub(crate) fn simulate_trace_multi( ...@@ -324,6 +354,7 @@ pub(crate) fn simulate_trace_multi(
let (collector, _) = AggRuntime::new( let (collector, _) = AggRuntime::new(
&args, &args,
router_config, router_config,
prefill_load_estimator,
pending, pending,
num_workers, num_workers,
AggReplayMode::Trace, AggReplayMode::Trace,
...@@ -336,6 +367,7 @@ pub(crate) fn simulate_trace_multi( ...@@ -336,6 +367,7 @@ pub(crate) fn simulate_trace_multi(
pub(crate) fn simulate_concurrency_multi( pub(crate) fn simulate_concurrency_multi(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>, requests: Vec<DirectRequest>,
max_in_flight: usize, max_in_flight: usize,
num_workers: usize, num_workers: usize,
...@@ -346,6 +378,7 @@ pub(crate) fn simulate_concurrency_multi( ...@@ -346,6 +378,7 @@ pub(crate) fn simulate_concurrency_multi(
let (collector, _) = AggRuntime::new( let (collector, _) = AggRuntime::new(
&args, &args,
router_config, router_config,
prefill_load_estimator,
pending, pending,
num_workers, num_workers,
AggReplayMode::Concurrency { max_in_flight }, AggReplayMode::Concurrency { max_in_flight },
...@@ -358,6 +391,7 @@ pub(crate) fn simulate_concurrency_multi( ...@@ -358,6 +391,7 @@ pub(crate) fn simulate_concurrency_multi(
pub(crate) fn simulate_trace_workload_multi( pub(crate) fn simulate_trace_workload_multi(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace, trace: Trace,
num_workers: usize, num_workers: usize,
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
...@@ -366,7 +400,8 @@ pub(crate) fn simulate_trace_workload_multi( ...@@ -366,7 +400,8 @@ pub(crate) fn simulate_trace_workload_multi(
let (collector, _) = AggRuntime::new_workload( let (collector, _) = AggRuntime::new_workload(
&args, &args,
router_config, router_config,
trace.into_trace_driver()?, prefill_load_estimator,
trace.into_trace_driver_with_block_size(args.block_size)?,
num_workers, num_workers,
AggReplayMode::Trace, AggReplayMode::Trace,
router_mode, router_mode,
...@@ -378,6 +413,7 @@ pub(crate) fn simulate_trace_workload_multi( ...@@ -378,6 +413,7 @@ pub(crate) fn simulate_trace_workload_multi(
pub(crate) fn simulate_concurrency_workload_multi( pub(crate) fn simulate_concurrency_workload_multi(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace, trace: Trace,
max_in_flight: usize, max_in_flight: usize,
num_workers: usize, num_workers: usize,
...@@ -387,7 +423,8 @@ pub(crate) fn simulate_concurrency_workload_multi( ...@@ -387,7 +423,8 @@ pub(crate) fn simulate_concurrency_workload_multi(
let (collector, _) = AggRuntime::new_workload( let (collector, _) = AggRuntime::new_workload(
&args, &args,
router_config, router_config,
trace.into_concurrency_driver()?, prefill_load_estimator,
trace.into_concurrency_driver_with_block_size(args.block_size)?,
num_workers, num_workers,
AggReplayMode::Concurrency { max_in_flight }, AggReplayMode::Concurrency { max_in_flight },
router_mode, router_mode,
...@@ -428,9 +465,12 @@ pub(super) fn run_trace_workload_single_collect( ...@@ -428,9 +465,12 @@ pub(super) fn run_trace_workload_single_collect(
args: MockEngineArgs, args: MockEngineArgs,
trace: Trace, trace: Trace,
) -> TraceCollector { ) -> TraceCollector {
let engine_block_size = args.block_size;
SingleRuntime::new_workload( SingleRuntime::new_workload(
args, args,
trace.into_trace_driver().unwrap(), trace
.into_trace_driver_with_block_size(engine_block_size)
.unwrap(),
SingleReplayMode::Trace, SingleReplayMode::Trace,
) )
.run() .run()
...@@ -443,9 +483,12 @@ pub(super) fn run_concurrency_workload_single_collect( ...@@ -443,9 +483,12 @@ pub(super) fn run_concurrency_workload_single_collect(
trace: Trace, trace: Trace,
max_in_flight: usize, max_in_flight: usize,
) -> TraceCollector { ) -> TraceCollector {
let engine_block_size = args.block_size;
SingleRuntime::new_workload( SingleRuntime::new_workload(
args, args,
trace.into_concurrency_driver().unwrap(), trace
.into_concurrency_driver_with_block_size(engine_block_size)
.unwrap(),
SingleReplayMode::Concurrency { max_in_flight }, SingleReplayMode::Concurrency { max_in_flight },
) )
.run() .run()
...@@ -463,6 +506,7 @@ pub(super) fn run_trace_multi_collect_with_stats( ...@@ -463,6 +506,7 @@ pub(super) fn run_trace_multi_collect_with_stats(
AggRuntime::new( AggRuntime::new(
args, args,
None, None,
None,
pending, pending,
num_workers, num_workers,
AggReplayMode::Trace, AggReplayMode::Trace,
...@@ -484,6 +528,7 @@ pub(super) fn run_concurrency_multi_collect_with_stats( ...@@ -484,6 +528,7 @@ pub(super) fn run_concurrency_multi_collect_with_stats(
AggRuntime::new( AggRuntime::new(
args, args,
None, None,
None,
VecDeque::from(requests), VecDeque::from(requests),
num_workers, num_workers,
AggReplayMode::Concurrency { max_in_flight }, AggReplayMode::Concurrency { max_in_flight },
...@@ -504,7 +549,10 @@ pub(super) fn run_trace_workload_multi_collect_with_stats( ...@@ -504,7 +549,10 @@ pub(super) fn run_trace_workload_multi_collect_with_stats(
AggRuntime::new_workload( AggRuntime::new_workload(
args, args,
None, None,
trace.into_trace_driver().unwrap(), None,
trace
.into_trace_driver_with_block_size(args.block_size)
.unwrap(),
num_workers, num_workers,
AggReplayMode::Trace, AggReplayMode::Trace,
router_mode, router_mode,
...@@ -525,7 +573,10 @@ pub(super) fn run_concurrency_workload_multi_collect_with_stats( ...@@ -525,7 +573,10 @@ pub(super) fn run_concurrency_workload_multi_collect_with_stats(
AggRuntime::new_workload( AggRuntime::new_workload(
args, args,
None, None,
trace.into_concurrency_driver().unwrap(), None,
trace
.into_concurrency_driver_with_block_size(args.block_size)
.unwrap(),
num_workers, num_workers,
AggReplayMode::Concurrency { max_in_flight }, AggReplayMode::Concurrency { max_in_flight },
router_mode, router_mode,
...@@ -547,6 +598,7 @@ pub(super) fn run_trace_collect( ...@@ -547,6 +598,7 @@ pub(super) fn run_trace_collect(
DisaggRuntime::new( DisaggRuntime::new(
config, config,
router_config, router_config,
None,
pending, pending,
DisaggReplayMode::Trace, DisaggReplayMode::Trace,
router_mode, router_mode,
...@@ -567,6 +619,7 @@ pub(super) fn run_concurrency_collect( ...@@ -567,6 +619,7 @@ pub(super) fn run_concurrency_collect(
DisaggRuntime::new( DisaggRuntime::new(
config, config,
router_config, router_config,
None,
VecDeque::from(requests), VecDeque::from(requests),
DisaggReplayMode::Concurrency { max_in_flight }, DisaggReplayMode::Concurrency { max_in_flight },
router_mode, router_mode,
...@@ -576,6 +629,51 @@ pub(super) fn run_concurrency_collect( ...@@ -576,6 +629,51 @@ pub(super) fn run_concurrency_collect(
.unwrap() .unwrap()
} }
#[cfg(test)]
pub(super) fn run_trace_workload_collect(
config: &OfflineDisaggReplayConfig,
trace: Trace,
router_config: Option<KvRouterConfig>,
router_mode: ReplayRouterMode,
) -> (TraceCollector, DisaggRuntimeStats) {
DisaggRuntime::new_workload(
config,
router_config,
None,
trace
.into_trace_driver_with_block_size(config.prefill_args.block_size)
.unwrap(),
DisaggReplayMode::Trace,
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_concurrency_workload_collect(
config: &OfflineDisaggReplayConfig,
trace: Trace,
router_config: Option<KvRouterConfig>,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, DisaggRuntimeStats) {
DisaggRuntime::new_workload(
config,
router_config,
None,
trace
.into_concurrency_driver_with_block_size(config.prefill_args.block_size)
.unwrap(),
DisaggReplayMode::Concurrency { max_in_flight },
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::generate_trace_worker_artifacts; use super::generate_trace_worker_artifacts;
......
...@@ -4,10 +4,12 @@ ...@@ -4,10 +4,12 @@
pub(crate) use crate::replay::normalize_trace_requests; pub(crate) use crate::replay::normalize_trace_requests;
pub(crate) mod agg; pub(crate) mod agg;
pub(crate) mod components;
pub(crate) mod core; pub(crate) mod core;
pub(crate) mod disagg; pub(crate) mod disagg;
mod entrypoints; mod entrypoints;
pub(crate) mod events; pub(crate) mod events;
mod progress;
pub(crate) mod runtime_utils; pub(crate) mod runtime_utils;
pub(crate) mod single; pub(crate) mod single;
pub(crate) mod state; pub(crate) mod state;
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use indicatif::{ProgressBar, ProgressStyle};
pub(super) struct ReplayProgress {
bar: ProgressBar,
}
impl ReplayProgress {
pub(super) fn new(total_requests: usize, label: &'static str) -> Self {
let bar = ProgressBar::new(total_requests as u64);
bar.set_style(
ProgressStyle::with_template(
"[{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len} ({eta}) {msg}",
)
.expect("progress bar template must be valid")
.progress_chars("#>-"),
);
bar.set_message(label);
Self { bar }
}
pub(super) fn inc_completed(&self) {
self.bar.inc(1);
}
pub(super) fn finish(&self) {
self.bar.finish_and_clear();
}
}
impl Drop for ReplayProgress {
fn drop(&mut self) {
if !self.bar.is_finished() {
self.bar.finish_and_clear();
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::collections::{BinaryHeap, VecDeque}; use std::collections::BinaryHeap;
#[cfg(test)]
use std::collections::VecDeque;
use dynamo_kv_router::protocols::RouterEvent; use dynamo_kv_router::protocols::RouterEvent;
use super::events::{SimulationEvent, SimulationEventKind, SimulationWorkerStage}; use super::events::{SimulationEvent, SimulationEventKind, SimulationWorkerStage};
use crate::common::protocols::{DirectRequest, OutputSignal}; #[cfg(test)]
use crate::common::protocols::DirectRequest;
use crate::common::protocols::OutputSignal;
#[derive(Debug)] #[derive(Debug)]
pub(super) struct WorkerCompletionPayload { pub(super) struct WorkerCompletionPayload {
...@@ -29,6 +33,7 @@ pub(super) fn next_timestamp( ...@@ -29,6 +33,7 @@ pub(super) fn next_timestamp(
} }
} }
#[cfg(test)]
pub(super) fn pop_next_trace_ready( pub(super) fn pop_next_trace_ready(
pending: &mut VecDeque<DirectRequest>, pending: &mut VecDeque<DirectRequest>,
now_ms: f64, now_ms: f64,
...@@ -43,6 +48,7 @@ pub(super) fn pop_next_trace_ready( ...@@ -43,6 +48,7 @@ pub(super) fn pop_next_trace_ready(
Some((request, arrival_ms)) Some((request, arrival_ms))
} }
#[cfg(test)]
pub(super) fn pop_next_concurrency_ready( pub(super) fn pop_next_concurrency_ready(
pending: &mut VecDeque<DirectRequest>, pending: &mut VecDeque<DirectRequest>,
now_ms: f64, now_ms: f64,
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::core::ReplayWorkerCore; use super::core::ReplayWorkerCore;
use super::progress::ReplayProgress;
use crate::common::protocols::{DirectRequest, MockEngineArgs}; use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::loadgen::WorkloadDriver; use crate::loadgen::WorkloadDriver;
use crate::replay::TraceCollector; use crate::replay::TraceCollector;
...@@ -26,6 +27,7 @@ pub(super) struct SingleRuntime { ...@@ -26,6 +27,7 @@ pub(super) struct SingleRuntime {
worker: ReplayWorkerCore, worker: ReplayWorkerCore,
collector: TraceCollector, collector: TraceCollector,
mode: SingleReplayMode, mode: SingleReplayMode,
progress: ReplayProgress,
} }
impl SingleRuntime { impl SingleRuntime {
...@@ -50,12 +52,17 @@ impl SingleRuntime { ...@@ -50,12 +52,17 @@ impl SingleRuntime {
admission: AdmissionSource, admission: AdmissionSource,
mode: SingleReplayMode, mode: SingleReplayMode,
) -> Self { ) -> Self {
let total_requests = match &admission {
AdmissionSource::Requests(pending) => pending.len(),
AdmissionSource::Workload(driver) => driver.total_turns(),
};
Self { Self {
current_time_ms: 0.0, current_time_ms: 0.0,
admission, admission,
worker: ReplayWorkerCore::new(args), worker: ReplayWorkerCore::new(args),
collector: TraceCollector::default(), collector: TraceCollector::default(),
mode, mode,
progress: ReplayProgress::new(total_requests, "offline replay"),
} }
} }
...@@ -168,6 +175,14 @@ impl SingleRuntime { ...@@ -168,6 +175,14 @@ impl SingleRuntime {
.expect("completed workload request must belong to a session"); .expect("completed workload request must belong to a session");
} }
} }
let completed_requests = pass
.output_signals
.iter()
.filter(|signal| signal.completed)
.count();
for _ in 0..completed_requests {
self.progress.inc_completed();
}
if admit_arrivals_between_steps { if admit_arrivals_between_steps {
self.enqueue_trace_arrivals(); self.enqueue_trace_arrivals();
} }
...@@ -199,6 +214,7 @@ impl SingleRuntime { ...@@ -199,6 +214,7 @@ impl SingleRuntime {
} }
} }
self.progress.finish();
Ok(self.collector) Ok(self.collector)
} }
} }
......
...@@ -17,8 +17,8 @@ pub(crate) enum AggRequestPhase { ...@@ -17,8 +17,8 @@ pub(crate) enum AggRequestPhase {
pub(crate) struct AggRequestState { pub(crate) struct AggRequestState {
request: Option<DirectRequest>, request: Option<DirectRequest>,
phase: AggRequestPhase, pub(in crate::replay::offline) phase: AggRequestPhase,
prefill_completed: bool, pub(in crate::replay::offline) prefill_completed: bool,
} }
impl AggRequestState { impl AggRequestState {
...@@ -38,12 +38,8 @@ impl AggRequestState { ...@@ -38,12 +38,8 @@ impl AggRequestState {
} }
} }
pub(crate) fn is_queued_at_router(&self) -> bool {
self.phase == AggRequestPhase::QueuedAtRouter
}
pub(crate) fn take_queued_request(&mut self, uuid: Uuid) -> Result<DirectRequest> { pub(crate) fn take_queued_request(&mut self, uuid: Uuid) -> Result<DirectRequest> {
if !self.is_queued_at_router() { if self.phase != AggRequestPhase::QueuedAtRouter {
bail!("offline replay expected queued request state for {uuid}"); bail!("offline replay expected queued request state for {uuid}");
} }
let request = self let request = self
...@@ -53,14 +49,6 @@ impl AggRequestState { ...@@ -53,14 +49,6 @@ impl AggRequestState {
self.phase = AggRequestPhase::Running; self.phase = AggRequestPhase::Running;
Ok(request) Ok(request)
} }
pub(crate) fn prefill_completed(&self) -> bool {
self.prefill_completed
}
pub(crate) fn mark_prefill_completed(&mut self) {
self.prefill_completed = true;
}
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
...@@ -76,7 +64,7 @@ pub(crate) struct DisaggRequestState { ...@@ -76,7 +64,7 @@ pub(crate) struct DisaggRequestState {
original: Option<DirectRequest>, original: Option<DirectRequest>,
#[cfg(test)] #[cfg(test)]
arrival_ms: f64, arrival_ms: f64,
phase: DisaggPhase, pub(in crate::replay::offline) phase: DisaggPhase,
prefill_worker_idx: Option<usize>, prefill_worker_idx: Option<usize>,
decode_worker_idx: Option<usize>, decode_worker_idx: Option<usize>,
} }
...@@ -104,14 +92,6 @@ impl DisaggRequestState { ...@@ -104,14 +92,6 @@ impl DisaggRequestState {
} }
} }
pub(crate) fn is_queued_prefill(&self) -> bool {
self.phase == DisaggPhase::QueuedPrefill
}
pub(crate) fn is_queued_decode(&self) -> bool {
self.phase == DisaggPhase::QueuedDecode
}
pub(crate) fn original_request(&self) -> Result<&DirectRequest> { pub(crate) fn original_request(&self) -> Result<&DirectRequest> {
self.original self.original
.as_ref() .as_ref()
...@@ -124,10 +104,6 @@ impl DisaggRequestState { ...@@ -124,10 +104,6 @@ impl DisaggRequestState {
Ok(request) Ok(request)
} }
pub(crate) fn build_decode_request(&self) -> Result<DirectRequest> {
Ok(self.original_request()?.clone())
}
pub(crate) fn start_prefill(&mut self, worker_idx: usize) { pub(crate) fn start_prefill(&mut self, worker_idx: usize) {
self.phase = DisaggPhase::RunningPrefill; self.phase = DisaggPhase::RunningPrefill;
self.prefill_worker_idx = Some(worker_idx); self.prefill_worker_idx = Some(worker_idx);
......
...@@ -7,10 +7,10 @@ use tokio::sync::mpsc; ...@@ -7,10 +7,10 @@ use tokio::sync::mpsc;
use tokio::time::Instant; use tokio::time::Instant;
use crate::common::protocols::OutputSignal; use crate::common::protocols::OutputSignal;
use crate::replay::router::ReplayRouter;
use crate::replay::{TraceCollector, TraceSimulationReport}; use crate::replay::{TraceCollector, TraceSimulationReport};
use crate::scheduler::AdmissionEvent; use crate::scheduler::AdmissionEvent;
use super::ReplayRouter;
use super::state::{ArrivalEvent, RequestRegistry, SharedLiveRuntimeStats, now_ms}; use super::state::{ArrivalEvent, RequestRegistry, SharedLiveRuntimeStats, now_ms};
async fn process_output_signal( async fn process_output_signal(
......
...@@ -8,7 +8,9 @@ use dynamo_kv_router::config::KvRouterConfig; ...@@ -8,7 +8,9 @@ use dynamo_kv_router::config::KvRouterConfig;
use crate::common::protocols::{DirectRequest, MockEngineArgs}; use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::loadgen::{Trace, WorkloadDriver}; use crate::loadgen::{Trace, WorkloadDriver};
use crate::replay::{ReplayRouterMode, TraceSimulationReport, normalize_trace_requests}; use crate::replay::{
ReplayPrefillLoadEstimator, ReplayRouterMode, TraceSimulationReport, normalize_trace_requests,
};
use super::live_runtime::LiveRuntime; use super::live_runtime::LiveRuntime;
use super::state::{LiveReplayMode, LiveRuntimeStats}; use super::state::{LiveReplayMode, LiveRuntimeStats};
...@@ -24,6 +26,7 @@ fn total_turns(trace: &Trace) -> usize { ...@@ -24,6 +26,7 @@ fn total_turns(trace: &Trace) -> usize {
fn run_live_runtime( fn run_live_runtime(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
pending: VecDeque<DirectRequest>, pending: VecDeque<DirectRequest>,
num_workers: usize, num_workers: usize,
mode: LiveReplayMode, mode: LiveReplayMode,
...@@ -35,15 +38,25 @@ fn run_live_runtime( ...@@ -35,15 +38,25 @@ fn run_live_runtime(
.map_err(|e| anyhow!("failed to create online replay runtime: {e}"))?; .map_err(|e| anyhow!("failed to create online replay runtime: {e}"))?;
runtime.block_on(async move { runtime.block_on(async move {
LiveRuntime::new(args, router_config, pending, num_workers, mode, router_mode)? LiveRuntime::new(
.run() args,
.await router_config,
prefill_load_estimator,
pending,
num_workers,
mode,
router_mode,
)?
.run()
.await
}) })
} }
#[allow(clippy::too_many_arguments)]
fn run_live_workload_runtime( fn run_live_workload_runtime(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
driver: WorkloadDriver, driver: WorkloadDriver,
total_turns: usize, total_turns: usize,
num_workers: usize, num_workers: usize,
...@@ -59,6 +72,7 @@ fn run_live_workload_runtime( ...@@ -59,6 +72,7 @@ fn run_live_workload_runtime(
LiveRuntime::new( LiveRuntime::new(
args, args,
router_config, router_config,
prefill_load_estimator,
VecDeque::new(), VecDeque::new(),
num_workers, num_workers,
mode, mode,
...@@ -72,6 +86,7 @@ fn run_live_workload_runtime( ...@@ -72,6 +86,7 @@ fn run_live_workload_runtime(
pub(crate) fn simulate_trace_requests( pub(crate) fn simulate_trace_requests(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>, requests: Vec<DirectRequest>,
num_workers: usize, num_workers: usize,
arrival_speedup_ratio: f64, arrival_speedup_ratio: f64,
...@@ -82,6 +97,7 @@ pub(crate) fn simulate_trace_requests( ...@@ -82,6 +97,7 @@ pub(crate) fn simulate_trace_requests(
let (report, _) = run_live_runtime( let (report, _) = run_live_runtime(
args, args,
router_config, router_config,
prefill_load_estimator,
pending, pending,
num_workers, num_workers,
LiveReplayMode::Trace, LiveReplayMode::Trace,
...@@ -93,6 +109,7 @@ pub(crate) fn simulate_trace_requests( ...@@ -93,6 +109,7 @@ pub(crate) fn simulate_trace_requests(
pub(crate) fn simulate_concurrency_requests( pub(crate) fn simulate_concurrency_requests(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>, requests: Vec<DirectRequest>,
max_in_flight: usize, max_in_flight: usize,
num_workers: usize, num_workers: usize,
...@@ -107,6 +124,7 @@ pub(crate) fn simulate_concurrency_requests( ...@@ -107,6 +124,7 @@ pub(crate) fn simulate_concurrency_requests(
let (report, _) = run_live_runtime( let (report, _) = run_live_runtime(
args, args,
router_config, router_config,
prefill_load_estimator,
pending, pending,
num_workers, num_workers,
LiveReplayMode::Concurrency { max_in_flight }, LiveReplayMode::Concurrency { max_in_flight },
...@@ -118,16 +136,19 @@ pub(crate) fn simulate_concurrency_requests( ...@@ -118,16 +136,19 @@ pub(crate) fn simulate_concurrency_requests(
pub(crate) fn simulate_trace_workload( pub(crate) fn simulate_trace_workload(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace, trace: Trace,
num_workers: usize, num_workers: usize,
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> { ) -> Result<TraceSimulationReport> {
let args = args.normalized()?; let args = args.normalized()?;
let engine_block_size = args.block_size;
let total_turns = total_turns(&trace); let total_turns = total_turns(&trace);
let (report, _) = run_live_workload_runtime( let (report, _) = run_live_workload_runtime(
args, args,
router_config, router_config,
trace.into_trace_driver()?, prefill_load_estimator,
trace.into_trace_driver_with_block_size(engine_block_size)?,
total_turns, total_turns,
num_workers, num_workers,
LiveReplayMode::Trace, LiveReplayMode::Trace,
...@@ -139,17 +160,20 @@ pub(crate) fn simulate_trace_workload( ...@@ -139,17 +160,20 @@ pub(crate) fn simulate_trace_workload(
pub(crate) fn simulate_concurrency_workload( pub(crate) fn simulate_concurrency_workload(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace, trace: Trace,
max_in_flight: usize, max_in_flight: usize,
num_workers: usize, num_workers: usize,
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> { ) -> Result<TraceSimulationReport> {
let args = args.normalized()?; let args = args.normalized()?;
let engine_block_size = args.block_size;
let total_turns = total_turns(&trace); let total_turns = total_turns(&trace);
let (report, _) = run_live_workload_runtime( let (report, _) = run_live_workload_runtime(
args, args,
router_config, router_config,
trace.into_concurrency_driver()?, prefill_load_estimator,
trace.into_concurrency_driver_with_block_size(engine_block_size)?,
total_turns, total_turns,
num_workers, num_workers,
LiveReplayMode::Concurrency { max_in_flight }, LiveReplayMode::Concurrency { max_in_flight },
...@@ -171,6 +195,7 @@ pub(super) fn simulate_trace_requests_with_stats( ...@@ -171,6 +195,7 @@ pub(super) fn simulate_trace_requests_with_stats(
run_live_runtime( run_live_runtime(
args, args,
None, None,
None,
pending, pending,
num_workers, num_workers,
LiveReplayMode::Trace, LiveReplayMode::Trace,
...@@ -191,6 +216,7 @@ pub(super) fn simulate_concurrency_requests_with_stats( ...@@ -191,6 +216,7 @@ pub(super) fn simulate_concurrency_requests_with_stats(
run_live_runtime( run_live_runtime(
args, args,
None, None,
None,
pending, pending,
num_workers, num_workers,
LiveReplayMode::Concurrency { max_in_flight }, LiveReplayMode::Concurrency { max_in_flight },
...@@ -206,11 +232,13 @@ pub(super) fn simulate_trace_workload_with_stats( ...@@ -206,11 +232,13 @@ pub(super) fn simulate_trace_workload_with_stats(
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> { ) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?; let args = args.normalized()?;
let engine_block_size = args.block_size;
let total_turns = total_turns(&trace); let total_turns = total_turns(&trace);
run_live_workload_runtime( run_live_workload_runtime(
args, args,
None, None,
trace.into_trace_driver()?, None,
trace.into_trace_driver_with_block_size(engine_block_size)?,
total_turns, total_turns,
num_workers, num_workers,
LiveReplayMode::Trace, LiveReplayMode::Trace,
...@@ -227,11 +255,13 @@ pub(super) fn simulate_concurrency_workload_with_stats( ...@@ -227,11 +255,13 @@ pub(super) fn simulate_concurrency_workload_with_stats(
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> { ) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?; let args = args.normalized()?;
let engine_block_size = args.block_size;
let total_turns = total_turns(&trace); let total_turns = total_turns(&trace);
run_live_workload_runtime( run_live_workload_runtime(
args, args,
None, None,
trace.into_concurrency_driver()?, None,
trace.into_concurrency_driver_with_block_size(engine_block_size)?,
total_turns, total_turns,
num_workers, num_workers,
LiveReplayMode::Concurrency { max_in_flight }, LiveReplayMode::Concurrency { max_in_flight },
......
...@@ -13,10 +13,10 @@ use tokio_util::sync::CancellationToken; ...@@ -13,10 +13,10 @@ use tokio_util::sync::CancellationToken;
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal}; use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal};
use crate::loadgen::WorkloadDriver; use crate::loadgen::WorkloadDriver;
use crate::replay::router::ReplayRouter; use crate::replay::{ReplayPrefillLoadEstimator, ReplayRouterMode, TraceSimulationReport};
use crate::replay::{ReplayRouterMode, TraceSimulationReport};
use crate::scheduler::{AdmissionEvent, EngineScheduler, SchedulerHandle}; use crate::scheduler::{AdmissionEvent, EngineScheduler, SchedulerHandle};
use super::ReplayRouter;
use super::demux::run_demux; use super::demux::run_demux;
use super::state::{ use super::state::{
LiveReplayMode, LiveRuntimeStats, SharedLiveRuntimeStats, WorkloadDispatchState, now_ms, LiveReplayMode, LiveRuntimeStats, SharedLiveRuntimeStats, WorkloadDispatchState, now_ms,
...@@ -41,6 +41,7 @@ impl LiveRuntime { ...@@ -41,6 +41,7 @@ impl LiveRuntime {
pub(super) fn new( pub(super) fn new(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
pending: std::collections::VecDeque<DirectRequest>, pending: std::collections::VecDeque<DirectRequest>,
num_workers: usize, num_workers: usize,
mode: LiveReplayMode, mode: LiveReplayMode,
...@@ -53,6 +54,7 @@ impl LiveRuntime { ...@@ -53,6 +54,7 @@ impl LiveRuntime {
router_mode, router_mode,
&args, &args,
router_config, router_config,
prefill_load_estimator,
num_workers, num_workers,
)); ));
let mut schedulers = Vec::with_capacity(num_workers); let mut schedulers = Vec::with_capacity(num_workers);
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
mod demux; mod demux;
mod entrypoints; mod entrypoints;
mod live_runtime; mod live_runtime;
mod router;
mod state; mod state;
mod task; mod task;
...@@ -14,3 +15,4 @@ pub(crate) use entrypoints::{ ...@@ -14,3 +15,4 @@ pub(crate) use entrypoints::{
simulate_concurrency_requests, simulate_concurrency_workload, simulate_trace_requests, simulate_concurrency_requests, simulate_concurrency_workload, simulate_trace_requests,
simulate_trace_workload, simulate_trace_workload,
}; };
pub(crate) use router::ReplayRouter;
...@@ -16,14 +16,14 @@ use tokio::sync::mpsc; ...@@ -16,14 +16,14 @@ use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use uuid::Uuid; use uuid::Uuid;
use super::shared::{
ReplayScheduler, replay_policy, replay_router_config, replay_selector, replay_slots,
replay_workers_with_configs,
};
use crate::common::protocols::{ use crate::common::protocols::{
DirectRequest, KvCacheEventSink, KvEventPublishers, MockEngineArgs, DirectRequest, KvCacheEventSink, KvEventPublishers, MockEngineArgs,
}; };
use crate::replay::ReplayRouterMode; use crate::replay::router_shared::{
ReplayScheduler, replay_policy, replay_router_config, replay_selector, replay_slots,
replay_workers_with_configs,
};
use crate::replay::{ReplayPrefillLoadEstimator, ReplayRouterMode};
#[derive(Clone)] #[derive(Clone)]
enum ReplayIndexer { enum ReplayIndexer {
...@@ -120,6 +120,7 @@ impl KvReplayRouter { ...@@ -120,6 +120,7 @@ impl KvReplayRouter {
fn new( fn new(
args: &MockEngineArgs, args: &MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
num_workers: usize, num_workers: usize,
) -> Self { ) -> Self {
let config = replay_router_config(args, router_config); let config = replay_router_config(args, router_config);
...@@ -138,6 +139,8 @@ impl KvReplayRouter { ...@@ -138,6 +139,8 @@ impl KvReplayRouter {
args.block_size as u32, args.block_size as u32,
selector, selector,
policy, policy,
prefill_load_estimator,
config.router_queue_recheck_interval(),
config.router_track_prefill_tokens, config.router_track_prefill_tokens,
CancellationToken::new(), CancellationToken::new(),
"replay", "replay",
...@@ -237,6 +240,20 @@ impl KvReplayRouter { ...@@ -237,6 +240,20 @@ impl KvReplayRouter {
.map_err(|e| anyhow!("replay router event task failed: {e}"))?; .map_err(|e| anyhow!("replay router event task failed: {e}"))?;
Ok(()) Ok(())
} }
#[cfg(test)]
fn debug_potential_loads(
&self,
isl_tokens: usize,
track_prefill_tokens: bool,
) -> Vec<dynamo_kv_router::PotentialLoad> {
self.scheduler.get_potential_loads(
None,
isl_tokens,
OverlapScores::default(),
track_prefill_tokens,
)
}
} }
#[expect( #[expect(
...@@ -253,13 +270,17 @@ impl ReplayRouter { ...@@ -253,13 +270,17 @@ impl ReplayRouter {
mode: ReplayRouterMode, mode: ReplayRouterMode,
args: &MockEngineArgs, args: &MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
num_workers: usize, num_workers: usize,
) -> Self { ) -> Self {
match mode { match mode {
ReplayRouterMode::RoundRobin => Self::RoundRobin(RoundRobinRouter::default()), ReplayRouterMode::RoundRobin => Self::RoundRobin(RoundRobinRouter::default()),
ReplayRouterMode::KvRouter => { ReplayRouterMode::KvRouter => Self::Kv(KvReplayRouter::new(
Self::Kv(KvReplayRouter::new(args, router_config, num_workers)) args,
} router_config,
prefill_load_estimator,
num_workers,
)),
} }
} }
...@@ -307,4 +328,16 @@ impl ReplayRouter { ...@@ -307,4 +328,16 @@ impl ReplayRouter {
Self::Kv(router) => router.shutdown().await, Self::Kv(router) => router.shutdown().await,
} }
} }
#[cfg(test)]
pub(crate) fn debug_potential_loads(
&self,
isl_tokens: usize,
track_prefill_tokens: bool,
) -> Vec<dynamo_kv_router::PotentialLoad> {
match self {
Self::RoundRobin(_) => Vec::new(),
Self::Kv(router) => router.debug_potential_loads(isl_tokens, track_prefill_tokens),
}
}
} }
...@@ -10,8 +10,8 @@ use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc}; ...@@ -10,8 +10,8 @@ use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc};
use tokio::time::Instant; use tokio::time::Instant;
use crate::common::protocols::DirectRequest; use crate::common::protocols::DirectRequest;
use crate::replay::router::ReplayRouter;
use super::ReplayRouter;
use super::state::{ use super::state::{
LiveReplayMode, RequestRegistry, RequestState, SharedLiveRuntimeStats, WorkloadDispatchState, LiveReplayMode, RequestRegistry, RequestState, SharedLiveRuntimeStats, WorkloadDispatchState,
now_ms, request_uuid, now_ms, request_uuid,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment