Unverified Commit 95a750f4 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore(replay): refactor offline components into cleaner lanes (#7866)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 210bbf5d
......@@ -6,12 +6,14 @@ mod collector;
mod entrypoints;
pub(crate) mod offline;
mod online;
mod router;
mod router_shared;
mod validate;
use std::collections::VecDeque;
use std::sync::Arc;
use crate::common::protocols::{DirectRequest, MockEngineArgs};
use dynamo_kv_router::PrefillLoadEstimator;
pub use artifacts::{
ReplayTimedKvEvent, ReplayTimedOutputSignal, ReplayTimedRequest, ReplayWorkerArtifacts,
......@@ -35,6 +37,8 @@ pub enum ReplayArgsMode {
Disagg,
}
pub type ReplayPrefillLoadEstimator = Arc<dyn PrefillLoadEstimator>;
#[derive(Clone, Debug)]
pub struct OfflineDisaggReplayConfig {
pub prefill_args: MockEngineArgs,
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
#[cfg(test)]
use super::components::OfflineRouterSnapshot;
pub(super) use super::components::ReplayMode;
use super::events::{SimulationEvent, SimulationWorkerStage};
use super::progress::ReplayProgress;
use super::runtime_utils::{
WorkerCompletionPayload, next_timestamp as choose_next_timestamp, pop_next_concurrency_ready,
pop_next_trace_ready, pop_ready_worker_completion, push_worker_completion,
next_timestamp as choose_next_timestamp, pop_ready_worker_completion, push_worker_completion,
};
#[cfg(test)]
use super::state::AggRequestPhase;
#[cfg(test)]
use super::state::OfflineWorkerSnapshot;
use super::state::{AggRequestState, OfflineWorkerState};
use super::{
components::{
AdmissionQueue, EngineComponent, EngineEffects, EnginePassMode, OfflineReplayRouter,
ScheduledWorkerCompletion, WorkerAdmission,
},
state::AggRequestState,
};
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal};
use crate::loadgen::{ReplayRequestHashes, WorkloadDriver};
use crate::replay::router::OfflineReplayRouter;
#[cfg(test)]
use crate::replay::router::OfflineRouterSnapshot;
use crate::replay::{ReplayRouterMode, TraceCollector};
use crate::scheduler::RouterEventVisibility;
use crate::replay::{ReplayPrefillLoadEstimator, ReplayRouterMode, TraceCollector};
use anyhow::bail;
use dynamo_kv_router::config::KvRouterConfig;
use dynamo_kv_router::protocols::RouterEvent;
......@@ -25,17 +32,6 @@ use std::collections::HashMap;
use std::collections::{BinaryHeap, VecDeque};
use uuid::Uuid;
#[derive(Debug, Clone, Copy)]
pub(super) enum ReplayMode {
Trace,
Concurrency { max_in_flight: usize },
}
enum AdmissionSource {
Requests(VecDeque<DirectRequest>),
Workload(WorkloadDriver),
}
#[cfg(test)]
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub(super) struct AggRuntimeStats {
......@@ -67,13 +63,13 @@ pub(super) struct AggRuntime {
now_ms: f64,
next_worker_idx: usize,
next_event_seq: u64,
admission: AdmissionSource,
admission: AdmissionQueue,
requests: FxHashMap<Uuid, AggRequestState>,
workers: Vec<OfflineWorkerState>,
engine: EngineComponent,
collector: TraceCollector,
events: BinaryHeap<SimulationEvent>,
mode: ReplayMode,
router: Option<OfflineReplayRouter>,
progress: ReplayProgress,
stats: AggRuntimeStats,
#[cfg(test)]
worker_active_requests: Vec<Vec<Uuid>>,
......@@ -86,6 +82,7 @@ impl AggRuntime {
pub(super) fn new(
args: &MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
pending: VecDeque<DirectRequest>,
num_workers: usize,
mode: ReplayMode,
......@@ -94,9 +91,9 @@ impl AggRuntime {
Self::new_with_source(
args,
router_config,
AdmissionSource::Requests(pending),
prefill_load_estimator,
AdmissionQueue::new_requests(pending, mode),
num_workers,
mode,
router_mode,
)
}
......@@ -105,6 +102,7 @@ impl AggRuntime {
pub(super) fn new_workload(
args: &MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
driver: WorkloadDriver,
num_workers: usize,
mode: ReplayMode,
......@@ -113,9 +111,9 @@ impl AggRuntime {
Self::new_with_source(
args,
router_config,
AdmissionSource::Workload(driver),
prefill_load_estimator,
AdmissionQueue::new_workload(driver, mode),
num_workers,
mode,
router_mode,
)
}
......@@ -124,19 +122,36 @@ impl AggRuntime {
fn new_with_source(
args: &MockEngineArgs,
router_config: Option<KvRouterConfig>,
admission: AdmissionSource,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
admission: AdmissionQueue,
num_workers: usize,
mode: ReplayMode,
router_mode: ReplayRouterMode,
) -> anyhow::Result<Self> {
let args = args.clone().normalized()?;
let progress = ReplayProgress::new(admission.total_requests(), "offline replay");
let router = match router_mode {
ReplayRouterMode::RoundRobin => None,
ReplayRouterMode::KvRouter => {
Some(OfflineReplayRouter::new(&args, router_config, num_workers)?)
}
ReplayRouterMode::KvRouter => Some(OfflineReplayRouter::new(
&args,
router_config,
prefill_load_estimator,
num_workers,
)?),
};
let capture_kv_events = router.is_some();
let engine = EngineComponent::new(
SimulationWorkerStage::Aggregated,
EnginePassMode::Visible,
(0..num_workers)
.map(|worker_idx| {
super::state::OfflineWorkerState::new(
worker_idx,
args.clone(),
capture_kv_events,
)
})
.collect(),
);
Ok(Self {
now_ms: 0.0,
......@@ -144,15 +159,11 @@ impl AggRuntime {
next_event_seq: 0,
admission,
requests: FxHashMap::default(),
workers: (0..num_workers)
.map(|worker_idx| {
OfflineWorkerState::new(worker_idx, args.clone(), capture_kv_events)
})
.collect(),
engine,
collector: TraceCollector::default(),
events: BinaryHeap::new(),
mode,
router,
progress,
#[cfg(test)]
stats: AggRuntimeStats::default(),
#[cfg(not(test))]
......@@ -166,10 +177,7 @@ impl AggRuntime {
/// Count all requests currently consuming cluster capacity, including router-queued ones.
fn cluster_in_flight(&self) -> usize {
self.workers
.iter()
.map(OfflineWorkerState::in_flight)
.sum::<usize>()
self.engine.in_flight()
+ self
.router
.as_ref()
......@@ -203,7 +211,7 @@ impl AggRuntime {
/// Pick the next worker in round-robin order.
fn next_worker(&mut self) -> usize {
let worker_idx = self.next_worker_idx;
self.next_worker_idx = (self.next_worker_idx + 1) % self.workers.len();
self.next_worker_idx = (self.next_worker_idx + 1) % self.engine.worker_count();
worker_idx
}
......@@ -220,14 +228,6 @@ impl AggRuntime {
self.record_in_flight_peak();
}
/// Fail fast if a router admission points at a worker that does not exist.
fn validate_worker_idx(&self, worker_idx: usize) -> anyhow::Result<()> {
if worker_idx >= self.workers.len() {
bail!("offline replay selected unknown worker index {worker_idx}");
}
Ok(())
}
/// Deliver a request to a worker and update the runtime's bookkeeping for that assignment.
fn dispatch_to_worker(
&mut self,
......@@ -235,32 +235,19 @@ impl AggRuntime {
uuid: Uuid,
worker_idx: usize,
) -> anyhow::Result<()> {
self.validate_worker_idx(worker_idx)?;
self.workers[worker_idx].receive_request(request);
self.engine.dispatch(worker_idx, request)?;
self.record_dispatch(uuid, worker_idx);
#[cfg(test)]
self.worker_active_requests[worker_idx].push(uuid);
Ok(())
}
/// Submit a request to the router and return an immediate admission when one is available.
fn submit_to_router(
&mut self,
request: &DirectRequest,
replay_hashes: Option<ReplayRequestHashes>,
) -> anyhow::Result<Option<usize>> {
let Some(router) = self.router.as_mut() else {
bail!("offline replay router submission requires an active router");
};
let maybe_worker_idx =
router.submit_request_with_hashes(request, replay_hashes, self.now_ms)?;
self.record_router_pending();
Ok(maybe_worker_idx)
}
/// Materialize router admissions into concrete worker dispatches.
fn dispatch_router_admissions(&mut self, admissions: Vec<(Uuid, usize)>) -> anyhow::Result<()> {
for (uuid, worker_idx) in admissions {
fn dispatch_router_admissions(
&mut self,
admissions: Vec<WorkerAdmission>,
) -> anyhow::Result<()> {
for WorkerAdmission { uuid, worker_idx } in admissions {
let request = self
.requests
.get_mut(&uuid)
......@@ -282,7 +269,7 @@ impl AggRuntime {
) -> anyhow::Result<Uuid> {
let uuid = request.uuid.unwrap_or_else(Uuid::new_v4);
request.uuid = Some(uuid);
if matches!(self.mode, ReplayMode::Concurrency { .. }) {
if matches!(self.admission.mode(), ReplayMode::Concurrency { .. }) {
request.arrival_timestamp_ms = Some(arrival_time_ms);
}
......@@ -299,15 +286,17 @@ impl AggRuntime {
self.dispatch_to_worker(request, uuid, worker_idx)?;
return Ok(uuid);
}
let maybe_worker_idx = self.submit_to_router(&request, replay_hashes)?;
if let Some(worker_idx) = maybe_worker_idx {
self.requests.insert(uuid, AggRequestState::new_running());
self.dispatch_to_worker(request, uuid, worker_idx)?;
return Ok(uuid);
}
let queued_request = request.clone();
self.requests
.insert(uuid, AggRequestState::new_queued(request));
let admissions = {
let router = self.router.as_mut().expect("router presence checked above");
router
.on_request_arrival(&queued_request, replay_hashes, self.now_ms)?
.admissions
};
self.record_router_pending();
self.dispatch_router_admissions(admissions)?;
self.record_in_flight_peak();
Ok(uuid)
}
......@@ -316,38 +305,17 @@ impl AggRuntime {
fn is_done(&self) -> bool {
self.events.is_empty()
&& self.cluster_in_flight() == 0
&& match &self.admission {
AdmissionSource::Requests(pending) => pending.is_empty(),
AdmissionSource::Workload(driver) => driver.is_drained(),
}
&& self.workers.iter().all(OfflineWorkerState::is_drained)
&& self.admission.is_drained()
&& self.engine.is_drained()
}
/// Pick the next logical timestamp from either arrivals or scheduled worker completions.
fn next_timestamp(&mut self) -> Option<f64> {
let next_event_ms = self.events.peek().map(|event| event.at_ms);
let cluster_in_flight = self.cluster_in_flight();
let next_arrival_ms = match (&self.mode, &mut self.admission) {
(ReplayMode::Trace, AdmissionSource::Requests(pending)) => pending
.front()
.and_then(|request| request.arrival_timestamp_ms),
(ReplayMode::Trace, AdmissionSource::Workload(driver)) => driver.next_ready_time_ms(),
(ReplayMode::Concurrency { max_in_flight }, AdmissionSource::Workload(driver)) => {
if cluster_in_flight < *max_in_flight {
driver.next_ready_time_ms()
} else {
None
}
}
(ReplayMode::Concurrency { .. }, AdmissionSource::Requests(_)) => None,
};
choose_next_timestamp(next_arrival_ms, next_event_ms)
}
/// Release completed requests from worker-local accounting after a pass finishes.
fn apply_completed_requests(&mut self, worker_idx: usize, completed_requests: usize) {
self.workers[worker_idx].mark_completed(completed_requests);
choose_next_timestamp(
self.admission.next_ready_time_ms(self.cluster_in_flight()),
next_event_ms,
)
}
/// Apply router-visible KV events at the phase chosen by the scheduler core.
......@@ -355,8 +323,9 @@ impl AggRuntime {
let Some(router) = self.router.as_mut() else {
return Ok(());
};
for event in events {
router.apply_event(event)?;
let effects = router.on_kv_events(events)?;
if !effects.admissions.is_empty() {
bail!("offline replay router KV event application must not admit requests");
}
Ok(())
}
......@@ -368,7 +337,9 @@ impl AggRuntime {
#[cfg(test)]
self.remove_active_request(signal.uuid);
if let Some(router) = self.router.as_mut() {
admissions = router.free(signal.uuid)?;
admissions = router
.on_request_completed(signal.uuid, self.now_ms)?
.admissions;
#[cfg(test)]
{
self.stats.router_freed_count += 1;
......@@ -378,9 +349,9 @@ impl AggRuntime {
self.requests.remove(&signal.uuid).ok_or_else(|| {
anyhow::anyhow!("offline replay missing request state for {}", signal.uuid)
})?;
if let AdmissionSource::Workload(driver) = &mut self.admission {
driver.on_complete(signal.uuid, self.now_ms)?;
}
self.admission
.on_request_completed(signal.uuid, self.now_ms)?;
self.progress.inc_completed();
self.dispatch_router_admissions(admissions)?;
return Ok(());
}
......@@ -391,7 +362,7 @@ impl AggRuntime {
.ok_or_else(|| {
anyhow::anyhow!("offline replay missing request state for {}", signal.uuid)
})?
.prefill_completed();
.prefill_completed;
if already_marked {
return Ok(());
}
......@@ -401,9 +372,11 @@ impl AggRuntime {
.ok_or_else(|| {
anyhow::anyhow!("offline replay missing request state for {}", signal.uuid)
})?
.mark_prefill_completed();
.prefill_completed = true;
if let Some(router) = self.router.as_mut() {
admissions = router.mark_prefill_completed(signal.uuid)?;
admissions = router
.on_prefill_completed(signal.uuid, self.now_ms)?
.admissions;
#[cfg(test)]
{
self.stats.prefill_marked_count += 1;
......@@ -433,12 +406,11 @@ impl AggRuntime {
/// Apply one completed pass: free request slots, publish KV events, and handle outputs.
fn process_completed_pass(
&mut self,
worker_idx: usize,
completed_requests: usize,
_worker_idx: usize,
_completed_requests: usize,
output_signals: Vec<OutputSignal>,
kv_events: Vec<RouterEvent>,
) -> anyhow::Result<()> {
self.apply_completed_requests(worker_idx, completed_requests);
self.apply_router_events(kv_events)?;
for signal in output_signals {
self.process_output_signal(signal)?;
......@@ -449,165 +421,71 @@ impl AggRuntime {
/// Drain all worker-completion events scheduled for the current logical timestamp.
fn apply_worker_completions(&mut self) -> anyhow::Result<bool> {
let mut changed = false;
while let Some(WorkerCompletionPayload {
stage,
worker_idx,
completed_requests,
output_signals,
kv_events,
}) = pop_ready_worker_completion(&mut self.events, self.now_ms)
{
debug_assert_eq!(stage, SimulationWorkerStage::Aggregated);
self.workers[worker_idx].mark_idle();
self.process_completed_pass(worker_idx, completed_requests, output_signals, kv_events)?;
while let Some(payload) = pop_ready_worker_completion(&mut self.events, self.now_ms) {
debug_assert_eq!(payload.stage, SimulationWorkerStage::Aggregated);
let payload = self.engine.on_scheduled_completion(payload)?;
self.process_completed_pass(
payload.worker_idx,
payload.completed_requests,
payload.output_signals,
payload.kv_events,
)?;
changed = true;
}
Ok(changed)
}
/// Release every trace arrival whose timestamp is now visible to the global clock.
fn release_trace_arrivals(&mut self) -> anyhow::Result<bool> {
let mut released_any = false;
if matches!(self.admission, AdmissionSource::Requests(_)) {
loop {
let next_ready = match &mut self.admission {
AdmissionSource::Requests(pending) => {
pop_next_trace_ready(pending, self.now_ms)
}
AdmissionSource::Workload(_) => unreachable!(),
};
let Some((request, arrival_ms)) = next_ready else {
break;
};
self.assign_request(request, arrival_ms, None)?;
released_any = true;
}
return Ok(released_any);
}
let ready_requests = match &mut self.admission {
AdmissionSource::Requests(_) => unreachable!(),
AdmissionSource::Workload(driver) => driver.pop_ready(self.now_ms, usize::MAX),
};
for ready in ready_requests {
self.assign_request(
ready.request,
ready.scheduled_ready_at_ms,
ready.replay_hashes,
)?;
released_any = true;
}
Ok(released_any)
}
/// Backfill closed-loop concurrency replay until the configured in-flight limit is reached.
fn top_off_concurrency(&mut self, max_in_flight: usize) -> anyhow::Result<bool> {
/// Release every admission made ready by the shared admission queue.
fn release_ready_arrivals(&mut self) -> anyhow::Result<bool> {
let mut released_any = false;
if matches!(self.admission, AdmissionSource::Requests(_)) {
loop {
let cluster_in_flight = self.cluster_in_flight();
let next_ready = match &mut self.admission {
AdmissionSource::Requests(pending) => pop_next_concurrency_ready(
pending,
self.now_ms,
cluster_in_flight,
max_in_flight,
),
AdmissionSource::Workload(_) => unreachable!(),
};
let Some((request, arrival_ms)) = next_ready else {
break;
};
self.assign_request(request, arrival_ms, None)?;
released_any = true;
}
return Ok(released_any);
}
let available = max_in_flight.saturating_sub(self.cluster_in_flight());
if available == 0 {
return Ok(false);
}
let ready_requests = match &mut self.admission {
AdmissionSource::Requests(_) => unreachable!(),
AdmissionSource::Workload(driver) => driver.pop_ready(self.now_ms, available),
};
for ready in ready_requests {
self.assign_request(ready.request, self.now_ms, ready.replay_hashes)?;
for ready in self
.admission
.drain_ready(self.now_ms, self.cluster_in_flight())?
{
self.assign_request(ready.request, ready.arrival_time_ms, ready.replay_hashes)?;
released_any = true;
}
Ok(released_any)
}
/// Start passes on every idle worker that can make progress at the current timestamp.
fn drive_ready_workers(&mut self) -> anyhow::Result<bool> {
let mut changed = false;
for worker_idx in 0..self.workers.len() {
loop {
if !self.workers[worker_idx].is_ready() {
break;
}
let executed = {
let (workers, collector) = (&mut self.workers, &mut self.collector);
workers[worker_idx].execute_pass(collector, self.now_ms)
};
changed = true;
let completion_kv_events =
if executed.router_event_visibility == RouterEventVisibility::PassStart {
self.apply_router_events(executed.kv_events)?;
Vec::new()
} else {
executed.kv_events
};
if executed.end_ms == self.now_ms {
self.process_completed_pass(
worker_idx,
executed.completed_requests,
executed.output_signals,
completion_kv_events,
)?;
continue;
}
self.workers[worker_idx].mark_busy();
push_worker_completion(
&mut self.events,
&mut self.next_event_seq,
executed.end_ms,
WorkerCompletionPayload {
stage: SimulationWorkerStage::Aggregated,
worker_idx,
completed_requests: executed.completed_requests,
output_signals: executed.output_signals,
kv_events: completion_kv_events,
},
);
break;
loop {
let effects = self
.engine
.drive_ready(self.now_ms, Some(&mut self.collector))?;
if effects.is_empty() {
return Ok(changed);
}
changed = true;
self.handle_engine_effects(effects)?;
}
}
Ok(changed)
fn handle_engine_effects(&mut self, effects: EngineEffects) -> anyhow::Result<()> {
self.apply_router_events(effects.pass_start_kv_events)?;
for payload in effects.immediate_completions {
let payload = self.engine.on_scheduled_completion(payload)?;
self.process_completed_pass(
payload.worker_idx,
payload.completed_requests,
payload.output_signals,
payload.kv_events,
)?;
}
for ScheduledWorkerCompletion { at_ms, payload } in effects.scheduled_completions {
push_worker_completion(&mut self.events, &mut self.next_event_seq, at_ms, payload);
}
Ok(())
}
/// Repeatedly process all work that becomes possible without advancing logical time.
fn drain_current_timestamp(&mut self) -> anyhow::Result<()> {
loop {
let mut changed = self.apply_worker_completions()?;
changed |= match self.mode {
ReplayMode::Trace => self.release_trace_arrivals()?,
ReplayMode::Concurrency { max_in_flight } => {
self.top_off_concurrency(max_in_flight)?
}
};
changed |= self.release_ready_arrivals()?;
changed |= self.drive_ready_workers()?;
if !changed {
......@@ -634,6 +512,7 @@ impl AggRuntime {
self.drain_current_timestamp()?;
}
self.progress.finish();
Ok((self.collector, self.stats))
}
......@@ -668,14 +547,14 @@ impl AggRuntime {
let mut router_pending_request_ids = self
.requests
.iter()
.filter(|(_, state)| state.is_queued_at_router())
.filter(|(_, state)| state.phase == AggRequestPhase::QueuedAtRouter)
.map(|(uuid, _)| *uuid)
.collect::<Vec<_>>();
router_pending_request_ids.sort_unstable();
let mut prefill_completed = self
.requests
.iter()
.filter(|(_, state)| state.prefill_completed())
.filter(|(_, state)| state.prefill_completed)
.map(|(uuid, _)| *uuid)
.collect::<Vec<_>>();
prefill_completed.sort_unstable();
......@@ -683,17 +562,13 @@ impl AggRuntime {
AggRuntimeSnapshot {
now_ms: self.now_ms,
worker_active_requests: self.worker_active_requests.clone(),
workers: self
.workers
.iter()
.map(OfflineWorkerState::debug_snapshot)
.collect(),
workers: self.engine.debug_snapshots(),
router_pending_request_ids,
prefill_completed,
router: self
.router
.as_ref()
.map(OfflineReplayRouter::debug_snapshot),
.map(|router| router.debug_snapshot(self.now_ms)),
}
}
}
......@@ -960,6 +835,7 @@ mod tests {
let mut runtime = AggRuntime::new(
&args,
None,
None,
normalize_trace_requests(
vec![
DirectRequest {
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::VecDeque;
use anyhow::Result;
use uuid::Uuid;
use super::{ReadyArrival, ReplayMode};
use crate::common::protocols::DirectRequest;
use crate::loadgen::WorkloadDriver;
enum AdmissionSource {
Requests(VecDeque<DirectRequest>),
Workload(WorkloadDriver),
}
pub(in crate::replay::offline) struct AdmissionQueue {
source: AdmissionSource,
mode: ReplayMode,
}
impl AdmissionQueue {
pub(in crate::replay::offline) fn new_requests(
source: VecDeque<DirectRequest>,
mode: ReplayMode,
) -> Self {
Self {
source: AdmissionSource::Requests(source),
mode,
}
}
pub(in crate::replay::offline) fn new_workload(
driver: WorkloadDriver,
mode: ReplayMode,
) -> Self {
Self {
source: AdmissionSource::Workload(driver),
mode,
}
}
pub(in crate::replay::offline) fn mode(&self) -> ReplayMode {
self.mode
}
pub(in crate::replay::offline) fn next_ready_time_ms(
&mut self,
cluster_in_flight: usize,
) -> Option<f64> {
match (&self.mode, &mut self.source) {
(ReplayMode::Trace, AdmissionSource::Requests(pending)) => pending
.front()
.and_then(|request| request.arrival_timestamp_ms),
(ReplayMode::Trace, AdmissionSource::Workload(driver)) => driver.next_ready_time_ms(),
(ReplayMode::Concurrency { max_in_flight }, AdmissionSource::Workload(driver)) => {
if cluster_in_flight < *max_in_flight {
driver.next_ready_time_ms()
} else {
None
}
}
(ReplayMode::Concurrency { .. }, AdmissionSource::Requests(_)) => None,
}
}
pub(in crate::replay::offline) fn drain_ready(
&mut self,
now_ms: f64,
cluster_in_flight: usize,
) -> Result<Vec<ReadyArrival>> {
match (&self.mode, &mut self.source) {
(ReplayMode::Trace, AdmissionSource::Requests(pending)) => {
let mut ready = Vec::new();
loop {
let arrival_ms = pending
.front()
.and_then(|request| request.arrival_timestamp_ms)
.filter(|arrival_ms| *arrival_ms <= now_ms);
let Some(arrival_time_ms) = arrival_ms else {
break;
};
let request = pending
.pop_front()
.expect("front request must exist when arrival is ready");
ready.push(ReadyArrival {
request,
arrival_time_ms,
replay_hashes: None,
});
}
Ok(ready)
}
(ReplayMode::Trace, AdmissionSource::Workload(driver)) => Ok(driver
.pop_ready(now_ms, usize::MAX)
.into_iter()
.map(|ready| ReadyArrival {
request: ready.request,
arrival_time_ms: ready.scheduled_ready_at_ms,
replay_hashes: ready.replay_hashes,
})
.collect()),
(ReplayMode::Concurrency { max_in_flight }, AdmissionSource::Requests(pending)) => {
let mut ready = Vec::new();
let mut simulated_in_flight = cluster_in_flight;
while simulated_in_flight < *max_in_flight {
let Some(request) = pending.pop_front() else {
break;
};
ready.push(ReadyArrival {
request,
arrival_time_ms: now_ms,
replay_hashes: None,
});
simulated_in_flight += 1;
}
Ok(ready)
}
(ReplayMode::Concurrency { max_in_flight }, AdmissionSource::Workload(driver)) => {
let available = max_in_flight.saturating_sub(cluster_in_flight);
if available == 0 {
return Ok(Vec::new());
}
Ok(driver
.pop_ready(now_ms, available)
.into_iter()
.map(|ready| ReadyArrival {
request: ready.request,
arrival_time_ms: now_ms,
replay_hashes: ready.replay_hashes,
})
.collect())
}
}
}
pub(in crate::replay::offline) fn on_request_completed(
&mut self,
uuid: Uuid,
now_ms: f64,
) -> Result<()> {
let AdmissionSource::Workload(driver) = &mut self.source else {
return Ok(());
};
driver.on_complete(uuid, now_ms)
}
pub(in crate::replay::offline) fn is_drained(&self) -> bool {
match &self.source {
AdmissionSource::Requests(pending) => pending.is_empty(),
AdmissionSource::Workload(driver) => driver.is_drained(),
}
}
#[cfg(test)]
pub(crate) fn is_workload(&self) -> bool {
matches!(self.source, AdmissionSource::Workload(_))
}
pub(in crate::replay::offline) fn total_requests(&self) -> usize {
match &self.source {
AdmissionSource::Requests(pending) => pending.len(),
AdmissionSource::Workload(driver) => driver.total_turns(),
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use anyhow::bail;
use super::super::events::SimulationWorkerStage;
use super::super::runtime_utils::WorkerCompletionPayload;
#[cfg(test)]
use super::super::state::OfflineWorkerSnapshot;
use super::super::state::OfflineWorkerState;
use super::{EngineEffects, EnginePassMode, ScheduledWorkerCompletion};
use crate::common::protocols::DirectRequest;
use crate::replay::TraceCollector;
use crate::scheduler::RouterEventVisibility;
pub(in crate::replay::offline) struct EngineComponent {
stage: SimulationWorkerStage,
pass_mode: EnginePassMode,
workers: Vec<OfflineWorkerState>,
}
impl EngineComponent {
pub(in crate::replay::offline) fn new(
stage: SimulationWorkerStage,
pass_mode: EnginePassMode,
workers: Vec<OfflineWorkerState>,
) -> Self {
Self {
stage,
pass_mode,
workers,
}
}
pub(in crate::replay::offline) fn dispatch(
&mut self,
worker_idx: usize,
request: DirectRequest,
) -> anyhow::Result<()> {
self.validate_worker_idx(worker_idx)?;
self.workers[worker_idx].receive_request(request);
Ok(())
}
pub(in crate::replay::offline) fn drive_ready(
&mut self,
now_ms: f64,
mut collector: Option<&mut TraceCollector>,
) -> anyhow::Result<EngineEffects> {
for worker_idx in 0..self.workers.len() {
if !self.workers[worker_idx].is_ready() {
continue;
}
let executed = match self.pass_mode {
EnginePassMode::Visible => {
let Some(collector) = collector.as_deref_mut() else {
bail!("offline replay visible engine pass requires a collector");
};
self.workers[worker_idx].execute_pass(collector, now_ms)
}
EnginePassMode::Hidden => self.workers[worker_idx].execute_hidden_pass(now_ms),
};
let mut effects = EngineEffects::default();
let completion_kv_events =
if executed.router_event_visibility == RouterEventVisibility::PassStart {
effects.pass_start_kv_events = executed.kv_events;
Vec::new()
} else {
executed.kv_events
};
let payload = WorkerCompletionPayload {
stage: self.stage,
worker_idx,
completed_requests: executed.completed_requests,
output_signals: executed.output_signals,
kv_events: completion_kv_events,
};
if executed.end_ms == now_ms {
effects.immediate_completions.push(payload);
return Ok(effects);
}
self.workers[worker_idx].mark_busy();
effects
.scheduled_completions
.push(ScheduledWorkerCompletion {
at_ms: executed.end_ms,
payload,
});
return Ok(effects);
}
Ok(EngineEffects::default())
}
pub(in crate::replay::offline) fn on_scheduled_completion(
&mut self,
payload: WorkerCompletionPayload,
) -> anyhow::Result<WorkerCompletionPayload> {
if payload.stage != self.stage {
bail!(
"offline replay completion stage mismatch: expected {:?}, got {:?}",
self.stage,
payload.stage
);
}
self.validate_worker_idx(payload.worker_idx)?;
self.workers[payload.worker_idx].mark_idle();
self.workers[payload.worker_idx].mark_completed(payload.completed_requests);
Ok(payload)
}
pub(in crate::replay::offline) fn in_flight(&self) -> usize {
self.workers.iter().map(OfflineWorkerState::in_flight).sum()
}
pub(in crate::replay::offline) fn is_drained(&self) -> bool {
self.workers.iter().all(OfflineWorkerState::is_drained)
}
pub(in crate::replay::offline) fn worker_count(&self) -> usize {
self.workers.len()
}
fn validate_worker_idx(&self, worker_idx: usize) -> anyhow::Result<()> {
if worker_idx >= self.workers.len() {
bail!("offline replay selected unknown worker index {worker_idx}");
}
Ok(())
}
#[cfg(test)]
pub(crate) fn debug_snapshots(&self) -> Vec<OfflineWorkerSnapshot> {
self.workers
.iter()
.map(OfflineWorkerState::debug_snapshot)
.collect()
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
mod admission;
mod engine;
mod router;
mod types;
pub(in crate::replay::offline) use admission::AdmissionQueue;
pub(in crate::replay::offline) use engine::EngineComponent;
pub(crate) use router::OfflineReplayRouter;
#[cfg(test)]
pub(crate) use router::OfflineRouterSnapshot;
pub(in crate::replay::offline) use types::{
EngineEffects, EnginePassMode, ReadyArrival, ReplayMode, ScheduledWorkerCompletion,
};
pub(crate) use types::{RouterEffects, WorkerAdmission};
......@@ -10,8 +10,8 @@ use anyhow::{Context, Result, anyhow};
use dynamo_kv_router::LocalBlockHash;
use dynamo_kv_router::config::KvRouterConfig;
use dynamo_kv_router::protocols::{
BlockHashOptions, OverlapScores, RouterEvent, WorkerConfigLike, WorkerId, WorkerWithDpRank,
compute_block_hash_for_seq,
BlockHashOptions, OverlapScores, PrefillLoadHint, RouterEvent, WorkerConfigLike, WorkerId,
WorkerWithDpRank, compute_block_hash_for_seq,
};
use dynamo_kv_router::queue::DEFAULT_MAX_BATCHED_TOKENS;
use dynamo_kv_router::{
......@@ -19,15 +19,18 @@ use dynamo_kv_router::{
SchedulingPolicy, SchedulingRequest, SequenceRequest, WorkerSelector,
};
use dynamo_tokens::SequenceHash;
use tokio::time::Instant;
use uuid::Uuid;
use super::shared::{
ReplayNoopPublisher, ReplayWorkerConfig, replay_policy, replay_router_config, replay_selector,
replay_slots, replay_workers_with_configs,
};
use super::{RouterEffects, WorkerAdmission};
use crate::common::protocols::DirectRequest;
use crate::common::protocols::MockEngineArgs;
use crate::loadgen::ReplayRequestHashes;
use crate::replay::ReplayPrefillLoadEstimator;
use crate::replay::router_shared::{
ReplayNoopPublisher, ReplayWorkerConfig, replay_policy, replay_router_config, replay_selector,
replay_slots, replay_workers_with_configs,
};
type ReplayQueueKey = <RouterSchedulingPolicy as SchedulingPolicy>::Key;
......@@ -183,12 +186,15 @@ pub(crate) struct OfflineReplayRouter {
pending: BinaryHeap<QueueEntry>,
next_enqueue_seq: u64,
indexer: SyncReplayIndexer,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
decay_time_epoch: Instant,
}
impl OfflineReplayRouter {
pub(crate) fn new(
args: &MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
num_workers: usize,
) -> Result<Self> {
let config = replay_router_config(args, router_config);
......@@ -213,19 +219,25 @@ impl OfflineReplayRouter {
pending: BinaryHeap::new(),
next_enqueue_seq: 0,
indexer: SyncReplayIndexer::new(args.block_size as u32),
prefill_load_estimator,
// This is only a base Instant for converting replay `now_ms` values into
// synthetic `Instant`s. All subsequent decay/accounting uses virtual replay
// time derived from this epoch, not wall-clock progression.
decay_time_epoch: Instant::now(),
})
}
pub(crate) fn submit_request_with_hashes(
pub(crate) fn on_request_arrival(
&mut self,
request: &DirectRequest,
replay_hashes: Option<ReplayRequestHashes>,
now_ms: f64,
) -> Result<Option<usize>> {
) -> Result<RouterEffects> {
let pending = self.build_pending_request(request, replay_hashes)?;
let decay_now = self.decay_now(now_ms);
let should_queue = self
.queue_threshold
.is_some_and(|threshold| self.all_workers_busy(threshold));
.is_some_and(|threshold| self.all_workers_busy(threshold, decay_now));
if should_queue {
let key = self.enqueue_key(now_ms, &pending);
......@@ -236,28 +248,60 @@ impl OfflineReplayRouter {
request: pending,
});
self.next_enqueue_seq += 1;
return Ok(None);
return Ok(RouterEffects::default());
}
self.admit_request(pending).map(Some)
Ok(RouterEffects {
admissions: vec![WorkerAdmission {
uuid: request
.uuid
.expect("offline replay requests must have UUIDs before router submission"),
worker_idx: self.admit_request(pending, decay_now)?,
}],
})
}
pub(crate) fn apply_event(&mut self, event: RouterEvent) -> Result<()> {
self.indexer.apply_event(event)
pub(crate) fn on_kv_events(&mut self, events: Vec<RouterEvent>) -> Result<RouterEffects> {
for event in events {
self.indexer.apply_event(event)?;
}
Ok(RouterEffects::default())
}
pub(crate) fn mark_prefill_completed(&mut self, uuid: Uuid) -> Result<Vec<(Uuid, usize)>> {
pub(crate) fn on_prefill_completed(
&mut self,
uuid: Uuid,
now_ms: f64,
) -> Result<RouterEffects> {
let decay_now = self.decay_now(now_ms);
self.slots
.mark_prefill_completed(&uuid.to_string())
.mark_prefill_completed(&uuid.to_string(), decay_now)
.map_err(anyhow::Error::from)?;
self.drain_pending()
Ok(RouterEffects {
admissions: self
.drain_pending(decay_now)?
.into_iter()
.map(|(uuid, worker_idx)| WorkerAdmission { uuid, worker_idx })
.collect(),
})
}
pub(crate) fn free(&mut self, uuid: Uuid) -> Result<Vec<(Uuid, usize)>> {
pub(crate) fn on_request_completed(
&mut self,
uuid: Uuid,
now_ms: f64,
) -> Result<RouterEffects> {
let decay_now = self.decay_now(now_ms);
self.slots
.free(&uuid.to_string())
.free(&uuid.to_string(), decay_now)
.map_err(anyhow::Error::from)?;
self.drain_pending()
Ok(RouterEffects {
admissions: self
.drain_pending(decay_now)?
.into_iter()
.map(|(uuid, worker_idx)| WorkerAdmission { uuid, worker_idx })
.collect(),
})
}
pub(crate) fn pending_count(&self) -> usize {
......@@ -265,7 +309,8 @@ impl OfflineReplayRouter {
}
#[cfg(test)]
pub(crate) fn debug_snapshot(&self) -> OfflineRouterSnapshot {
pub(crate) fn debug_snapshot(&self, now_ms: f64) -> OfflineRouterSnapshot {
let decay_now = self.decay_now(now_ms);
let mut pending = self
.pending
.iter()
......@@ -302,7 +347,7 @@ impl OfflineReplayRouter {
let mut active_tokens_by_worker = self
.slots
.active_tokens()
.active_tokens(decay_now)
.into_iter()
.map(|(worker, tokens)| (worker.worker_id as usize, tokens))
.collect::<Vec<_>>();
......@@ -324,6 +369,10 @@ impl OfflineReplayRouter {
)
}
fn decay_now(&self, now_ms: f64) -> Instant {
self.decay_time_epoch + Duration::from_secs_f64(now_ms.max(0.0) / 1000.0)
}
fn build_pending_request(
&self,
request: &DirectRequest,
......@@ -378,7 +427,7 @@ impl OfflineReplayRouter {
})
}
fn admit_request(&mut self, request: PendingRequest) -> Result<usize> {
fn admit_request(&mut self, request: PendingRequest, decay_now: Instant) -> Result<usize> {
let (decode_blocks, prefill_tokens) = self
.slots
.potential_blocks_and_tokens_with_prefill_tracking(
......@@ -386,6 +435,7 @@ impl OfflineReplayRouter {
request.isl_tokens,
request.overlaps.clone(),
request.track_prefill_tokens,
decay_now,
);
let scheduling_request = request.scheduling_request(decode_blocks, prefill_tokens);
let selection = self.selector.select_worker(
......@@ -396,56 +446,188 @@ impl OfflineReplayRouter {
let worker_idx = usize::try_from(selection.worker.worker_id)
.map_err(|_| anyhow!("selected worker id does not fit into usize"))?;
let request_id = request.request_id();
let prefill_load_hint = self.prefill_load_hint_for(
request.isl_tokens,
selection.overlap_blocks,
request.track_prefill_tokens,
);
self.slots
.add_request(SequenceRequest {
request_id,
token_sequence: request.token_seq,
isl: request.isl_tokens,
overlap: selection.overlap_blocks,
track_prefill_tokens: request.track_prefill_tokens,
expected_output_tokens: request.expected_output_tokens,
worker: selection.worker,
lora_name: None,
})
.add_request(
SequenceRequest {
request_id,
token_sequence: request.token_seq,
isl: request.isl_tokens,
overlap: selection.overlap_blocks,
track_prefill_tokens: request.track_prefill_tokens,
expected_output_tokens: request.expected_output_tokens,
prefill_load_hint,
worker: selection.worker,
lora_name: None,
},
decay_now,
)
.map_err(anyhow::Error::from)?;
Ok(worker_idx)
}
fn drain_pending(&mut self) -> Result<Vec<(Uuid, usize)>> {
fn drain_pending(&mut self, decay_now: Instant) -> Result<Vec<(Uuid, usize)>> {
let Some(threshold) = self.queue_threshold else {
return Ok(Vec::new());
};
let mut admissions = Vec::new();
while !self.all_workers_busy(threshold) {
while !self.all_workers_busy(threshold, decay_now) {
let Some(QueueEntry { request, .. }) = self.pending.pop() else {
break;
};
let uuid = request.uuid;
let worker_idx = self.admit_request(request)?;
let worker_idx = self.admit_request(request, decay_now)?;
admissions.push((uuid, worker_idx));
}
Ok(admissions)
}
fn all_workers_busy(&self, threshold: f64) -> bool {
fn all_workers_busy(&self, threshold: f64, decay_now: Instant) -> bool {
let mut checked_any = false;
let any_worker_not_busy = self
.slots
.any_worker_matches_active_tokens(|worker, tokens| {
let Some(config) = self.workers_with_configs.get(&worker.worker_id) else {
return false;
};
checked_any = true;
let max_batched = config
.max_num_batched_tokens()
.unwrap_or(DEFAULT_MAX_BATCHED_TOKENS);
(tokens as f64) <= threshold * (max_batched as f64)
});
let any_worker_not_busy =
self.slots
.any_worker_matches_active_tokens(decay_now, |worker, tokens| {
let Some(config) = self.workers_with_configs.get(&worker.worker_id) else {
return false;
};
checked_any = true;
let max_batched = config
.max_num_batched_tokens()
.unwrap_or(DEFAULT_MAX_BATCHED_TOKENS);
(tokens as f64) <= threshold * (max_batched as f64)
});
checked_any && !any_worker_not_busy
}
fn prefill_load_hint_for(
&self,
isl_tokens: usize,
overlap_blocks: u32,
track_prefill_tokens: bool,
) -> Option<PrefillLoadHint> {
if !track_prefill_tokens {
return None;
}
let prefix = (overlap_blocks as usize) * (self.block_size as usize);
let effective_isl = isl_tokens.saturating_sub(prefix);
if effective_isl == 0 {
return None;
}
let Some(estimator) = &self.prefill_load_estimator else {
return None;
};
match estimator.predict_prefill_duration(1, effective_isl, prefix) {
Ok(expected_prefill_duration) => Some(PrefillLoadHint {
initial_effective_prefill_tokens: effective_isl,
expected_prefill_duration: Some(expected_prefill_duration),
}),
Err(error) => {
tracing::warn!(
effective_isl,
prefix,
"failed to predict replay prefill duration for active load tracking: {error}"
);
None
}
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::time::Duration;
use dynamo_kv_router::PrefillLoadEstimator;
use dynamo_kv_router::config::{KvRouterConfig, RouterPrefillLoadModel};
use uuid::Uuid;
use super::OfflineReplayRouter;
use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::replay::ReplayPrefillLoadEstimator;
struct FixedPrefillLoadEstimator {
duration: Duration,
}
impl PrefillLoadEstimator for FixedPrefillLoadEstimator {
fn predict_prefill_duration(
&self,
_batch_size: usize,
_effective_isl: usize,
_prefix: usize,
) -> anyhow::Result<Duration> {
Ok(self.duration)
}
}
fn replay_args() -> MockEngineArgs {
MockEngineArgs::builder()
.block_size(64)
.max_num_batched_tokens(Some(256))
.build()
.unwrap()
}
fn router_config() -> KvRouterConfig {
KvRouterConfig {
router_track_prefill_tokens: true,
router_prefill_load_model: RouterPrefillLoadModel::Aic,
..KvRouterConfig::default()
}
}
fn estimator(duration: Duration) -> ReplayPrefillLoadEstimator {
Arc::new(FixedPrefillLoadEstimator { duration })
}
fn request(uuid: u128, token: u32) -> DirectRequest {
DirectRequest {
tokens: vec![token; 64],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(uuid)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
}
}
#[test]
fn test_prefill_load_estimator_decays_offline_router_active_tokens() {
let mut router = OfflineReplayRouter::new(
&replay_args(),
Some(router_config()),
Some(estimator(Duration::from_secs(10))),
1,
)
.unwrap();
let effects = router
.on_request_arrival(&request(1, 7), None, 0.0)
.unwrap();
assert_eq!(effects.admissions.len(), 1);
assert_eq!(
router.debug_snapshot(0.0).active_tokens_by_worker,
vec![(0, 64)]
);
assert_eq!(
router.debug_snapshot(5_000.0).active_tokens_by_worker,
vec![(0, 32)]
);
assert_eq!(
router.debug_snapshot(10_000.0).active_tokens_by_worker,
vec![(0, 0)]
);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dynamo_kv_router::protocols::RouterEvent;
use uuid::Uuid;
use super::super::runtime_utils::WorkerCompletionPayload;
use crate::common::protocols::DirectRequest;
use crate::loadgen::ReplayRequestHashes;
#[derive(Debug, Clone, Copy)]
pub(in crate::replay::offline) enum ReplayMode {
Trace,
Concurrency { max_in_flight: usize },
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(in crate::replay::offline) enum EnginePassMode {
Visible,
Hidden,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct WorkerAdmission {
pub(crate) uuid: Uuid,
pub(crate) worker_idx: usize,
}
#[derive(Debug)]
pub(in crate::replay::offline) struct ScheduledWorkerCompletion {
pub(in crate::replay::offline) at_ms: f64,
pub(in crate::replay::offline) payload: WorkerCompletionPayload,
}
#[derive(Debug, Default)]
pub(in crate::replay::offline) struct EngineEffects {
pub(in crate::replay::offline) pass_start_kv_events: Vec<RouterEvent>,
pub(in crate::replay::offline) immediate_completions: Vec<WorkerCompletionPayload>,
pub(in crate::replay::offline) scheduled_completions: Vec<ScheduledWorkerCompletion>,
}
impl EngineEffects {
pub(in crate::replay::offline) fn is_empty(&self) -> bool {
self.pass_start_kv_events.is_empty()
&& self.immediate_completions.is_empty()
&& self.scheduled_completions.is_empty()
}
}
#[derive(Debug, Default)]
pub(crate) struct RouterEffects {
pub(crate) admissions: Vec<WorkerAdmission>,
}
#[derive(Debug)]
pub(in crate::replay::offline) struct ReadyArrival {
pub(in crate::replay::offline) request: DirectRequest,
pub(in crate::replay::offline) arrival_time_ms: f64,
pub(in crate::replay::offline) replay_hashes: Option<ReplayRequestHashes>,
}
......@@ -8,32 +8,36 @@ use dynamo_kv_router::config::KvRouterConfig;
use dynamo_kv_router::protocols::RouterEvent;
use uuid::Uuid;
pub(super) use super::components::ReplayMode;
use super::components::{
AdmissionQueue, EngineComponent, EngineEffects, EnginePassMode, OfflineReplayRouter,
ScheduledWorkerCompletion, WorkerAdmission,
};
use super::events::{SimulationEvent, SimulationWorkerStage};
use super::progress::ReplayProgress;
use super::runtime_utils::{
WorkerCompletionPayload, next_timestamp as choose_next_timestamp, pop_next_concurrency_ready,
pop_next_trace_ready, pop_ready_decode_handoff, pop_ready_worker_completion,
next_timestamp as choose_next_timestamp, pop_ready_decode_handoff, pop_ready_worker_completion,
push_decode_handoff, push_worker_completion,
};
#[cfg(test)]
use super::state::DisaggPhase;
#[cfg(test)]
use super::state::DisaggRequestSnapshot;
use super::state::{DisaggRequestState, OfflineWorkerState};
use super::state::{DisaggPhase, DisaggRequestState};
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal};
use crate::loadgen::{ReplayRequestHashes, WorkloadDriver};
use crate::replay::router::OfflineReplayRouter;
use crate::replay::{OfflineDisaggReplayConfig, ReplayRouterMode, TraceCollector};
use crate::scheduler::RouterEventVisibility;
#[derive(Debug, Clone, Copy)]
pub(super) enum ReplayMode {
Trace,
Concurrency { max_in_flight: usize },
}
use crate::replay::{
OfflineDisaggReplayConfig, ReplayPrefillLoadEstimator, ReplayRouterMode, TraceCollector,
};
enum AdmissionSource {
Requests(VecDeque<DirectRequest>),
Workload(WorkloadDriver),
#[cfg(test)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum DisaggTransition {
PrefillMarkCompleted { uuid: Uuid },
PrefillFree { uuid: Uuid },
DecodeHandoffQueued { uuid: Uuid },
DecodeEnqueued { uuid: Uuid },
DecodeFree { uuid: Uuid },
RequestMarkedDone { uuid: Uuid },
WorkloadCompleted { uuid: Uuid },
}
#[cfg(test)]
......@@ -48,6 +52,7 @@ pub(super) struct DisaggRuntimeStats {
decode_router_freed_count: usize,
max_prefill_router_pending_count: usize,
max_decode_router_pending_count: usize,
transition_log: Vec<DisaggTransition>,
}
#[cfg(not(test))]
......@@ -59,15 +64,15 @@ pub(super) struct DisaggRuntime {
next_prefill_worker_idx: usize,
next_decode_worker_idx: usize,
next_event_seq: u64,
admission: AdmissionSource,
prefill_workers: Vec<OfflineWorkerState>,
decode_workers: Vec<OfflineWorkerState>,
admission: AdmissionQueue,
prefill_engine: EngineComponent,
decode_engine: EngineComponent,
prefill_router: Option<OfflineReplayRouter>,
decode_router: Option<OfflineReplayRouter>,
requests: HashMap<Uuid, DisaggRequestState>,
collector: TraceCollector,
events: BinaryHeap<SimulationEvent>,
mode: ReplayMode,
progress: ReplayProgress,
stats: DisaggRuntimeStats,
}
......@@ -76,6 +81,7 @@ impl DisaggRuntime {
pub(super) fn new(
config: &OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
pending: VecDeque<DirectRequest>,
mode: ReplayMode,
router_mode: ReplayRouterMode,
......@@ -83,8 +89,8 @@ impl DisaggRuntime {
Self::new_with_source(
config,
router_config,
AdmissionSource::Requests(pending),
mode,
prefill_load_estimator,
AdmissionQueue::new_requests(pending, mode),
router_mode,
)
}
......@@ -93,6 +99,7 @@ impl DisaggRuntime {
pub(super) fn new_workload(
config: &OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
driver: WorkloadDriver,
mode: ReplayMode,
router_mode: ReplayRouterMode,
......@@ -100,8 +107,8 @@ impl DisaggRuntime {
Self::new_with_source(
config,
router_config,
AdmissionSource::Workload(driver),
mode,
prefill_load_estimator,
AdmissionQueue::new_workload(driver, mode),
router_mode,
)
}
......@@ -110,10 +117,11 @@ impl DisaggRuntime {
fn new_with_source(
config: &OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
admission: AdmissionSource,
mode: ReplayMode,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
admission: AdmissionQueue,
router_mode: ReplayRouterMode,
) -> Result<Self> {
let progress = ReplayProgress::new(admission.total_requests(), "offline disagg replay");
let (prefill_router, decode_router) = match router_mode {
ReplayRouterMode::RoundRobin => (None, None),
ReplayRouterMode::KvRouter => {
......@@ -125,43 +133,60 @@ impl DisaggRuntime {
Some(OfflineReplayRouter::new(
&config.prefill_args,
Some(prefill_router_config),
prefill_load_estimator,
config.num_prefill_workers,
)?),
Some(OfflineReplayRouter::new(
&config.decode_args,
Some(decode_router_config),
None,
config.num_decode_workers,
)?),
)
}
};
Ok(Self {
now_ms: 0.0,
next_prefill_worker_idx: 0,
next_decode_worker_idx: 0,
next_event_seq: 0,
admission,
prefill_workers: (0..config.num_prefill_workers)
let prefill_engine = EngineComponent::new(
SimulationWorkerStage::Prefill,
EnginePassMode::Hidden,
(0..config.num_prefill_workers)
.map(|worker_idx| {
OfflineWorkerState::new(
super::state::OfflineWorkerState::new(
worker_idx,
config.prefill_args.clone(),
prefill_router.is_some(),
)
})
.collect(),
decode_workers: (0..config.num_decode_workers)
);
let decode_engine = EngineComponent::new(
SimulationWorkerStage::Decode,
EnginePassMode::Visible,
(0..config.num_decode_workers)
.map(|worker_idx| {
OfflineWorkerState::new(worker_idx, config.decode_args.clone(), false)
super::state::OfflineWorkerState::new(
worker_idx,
config.decode_args.clone(),
false,
)
})
.collect(),
);
Ok(Self {
now_ms: 0.0,
next_prefill_worker_idx: 0,
next_decode_worker_idx: 0,
next_event_seq: 0,
admission,
prefill_engine,
decode_engine,
prefill_router,
decode_router,
requests: HashMap::new(),
collector: TraceCollector::default(),
events: BinaryHeap::new(),
mode,
progress,
#[cfg(test)]
stats: DisaggRuntimeStats::default(),
#[cfg(not(test))]
......@@ -171,15 +196,8 @@ impl DisaggRuntime {
/// Count all requests consuming cluster capacity across prefill, decode, and router queues.
fn cluster_in_flight(&self) -> usize {
self.prefill_workers
.iter()
.map(OfflineWorkerState::in_flight)
.sum::<usize>()
+ self
.decode_workers
.iter()
.map(OfflineWorkerState::in_flight)
.sum::<usize>()
self.prefill_engine.in_flight()
+ self.decode_engine.in_flight()
+ self
.prefill_router
.as_ref()
......@@ -194,14 +212,15 @@ impl DisaggRuntime {
fn next_prefill_worker(&mut self) -> usize {
let worker_idx = self.next_prefill_worker_idx;
self.next_prefill_worker_idx =
(self.next_prefill_worker_idx + 1) % self.prefill_workers.len();
(self.next_prefill_worker_idx + 1) % self.prefill_engine.worker_count();
worker_idx
}
/// Pick the next decode worker in round-robin order.
fn next_decode_worker(&mut self) -> usize {
let worker_idx = self.next_decode_worker_idx;
self.next_decode_worker_idx = (self.next_decode_worker_idx + 1) % self.decode_workers.len();
self.next_decode_worker_idx =
(self.next_decode_worker_idx + 1) % self.decode_engine.worker_count();
worker_idx
}
......@@ -224,19 +243,6 @@ impl DisaggRuntime {
}
}
/// Fail fast if a router admission targets a non-existent stage worker.
fn validate_worker_idx(&self, stage: SimulationWorkerStage, worker_idx: usize) -> Result<()> {
let worker_count = match stage {
SimulationWorkerStage::Prefill => self.prefill_workers.len(),
SimulationWorkerStage::Decode => self.decode_workers.len(),
SimulationWorkerStage::Aggregated => unreachable!("aggregated stage is not used"),
};
if worker_idx >= worker_count {
bail!("offline disagg replay selected unknown {stage:?} worker index {worker_idx}");
}
Ok(())
}
/// Borrow immutable request state with a structured missing-request error.
fn state(&self, uuid: Uuid) -> Result<&DisaggRequestState> {
self.requests
......@@ -253,9 +259,8 @@ impl DisaggRuntime {
/// Dispatch a request's prefill stage onto a specific prefill worker.
fn dispatch_prefill(&mut self, uuid: Uuid, worker_idx: usize) -> Result<()> {
self.validate_worker_idx(SimulationWorkerStage::Prefill, worker_idx)?;
let request = self.state(uuid)?.build_prefill_request()?;
self.prefill_workers[worker_idx].receive_request(request);
self.prefill_engine.dispatch(worker_idx, request)?;
self.state_mut(uuid)?.start_prefill(worker_idx);
#[cfg(test)]
{
......@@ -266,9 +271,8 @@ impl DisaggRuntime {
/// Dispatch a request's decode stage onto a specific decode worker.
fn dispatch_decode(&mut self, uuid: Uuid, worker_idx: usize) -> Result<()> {
self.validate_worker_idx(SimulationWorkerStage::Decode, worker_idx)?;
let request = self.state(uuid)?.build_decode_request()?;
self.decode_workers[worker_idx].receive_request(request);
let request = self.state(uuid)?.original_request()?.clone();
self.decode_engine.dispatch(worker_idx, request)?;
self.state_mut(uuid)?.start_decode(worker_idx);
#[cfg(test)]
{
......@@ -278,9 +282,9 @@ impl DisaggRuntime {
}
/// Turn prefill router admissions into concrete worker dispatches.
fn dispatch_prefill_admissions(&mut self, admissions: Vec<(Uuid, usize)>) -> Result<()> {
for (uuid, worker_idx) in admissions {
if !self.state(uuid)?.is_queued_prefill() {
fn dispatch_prefill_admissions(&mut self, admissions: Vec<WorkerAdmission>) -> Result<()> {
for WorkerAdmission { uuid, worker_idx } in admissions {
if self.state(uuid)?.phase != DisaggPhase::QueuedPrefill {
bail!("offline disagg replay expected queued prefill request for {uuid}");
}
self.dispatch_prefill(uuid, worker_idx)?;
......@@ -289,9 +293,9 @@ impl DisaggRuntime {
}
/// Turn decode router admissions into concrete worker dispatches.
fn dispatch_decode_admissions(&mut self, admissions: Vec<(Uuid, usize)>) -> Result<()> {
for (uuid, worker_idx) in admissions {
if !self.state(uuid)?.is_queued_decode() {
fn dispatch_decode_admissions(&mut self, admissions: Vec<WorkerAdmission>) -> Result<()> {
for WorkerAdmission { uuid, worker_idx } in admissions {
if self.state(uuid)?.phase != DisaggPhase::QueuedDecode {
bail!("offline disagg replay expected queued decode request for {uuid}");
}
self.dispatch_decode(uuid, worker_idx)?;
......@@ -299,52 +303,37 @@ impl DisaggRuntime {
Ok(())
}
/// Submit a request to the prefill router and return an immediate admission when available.
fn submit_to_prefill_router(
&mut self,
uuid: Uuid,
replay_hashes: Option<ReplayRequestHashes>,
) -> Result<Option<usize>> {
let request = self.state(uuid)?.original_request()?.clone();
let Some(prefill_router) = self.prefill_router.as_mut() else {
bail!("offline disagg replay prefill submission requires an active router");
};
let maybe_worker_idx =
prefill_router.submit_request_with_hashes(&request, replay_hashes, self.now_ms)?;
self.record_router_pending();
Ok(maybe_worker_idx)
}
/// Submit a request to the decode router and return an immediate admission when available.
fn submit_to_decode_router(&mut self, uuid: Uuid) -> Result<Option<usize>> {
let request = self.state(uuid)?.original_request()?.clone();
let Some(decode_router) = self.decode_router.as_mut() else {
bail!("offline disagg replay decode submission requires an active router");
};
let maybe_worker_idx =
decode_router.submit_request_with_hashes(&request, None, self.now_ms)?;
self.record_router_pending();
Ok(maybe_worker_idx)
}
/// Queue or dispatch a request into decode, depending on whether a decode router is active.
fn enqueue_decode(&mut self, uuid: Uuid) -> Result<()> {
if self.decode_router.is_none() {
#[cfg(test)]
{
self.stats
.transition_log
.push(DisaggTransition::DecodeEnqueued { uuid });
self.stats.handoff_ms.insert(uuid, self.now_ms);
}
let worker_idx = self.next_decode_worker();
self.dispatch_decode(uuid, worker_idx)?;
return Ok(());
}
let maybe_worker_idx = self.submit_to_decode_router(uuid)?;
let request = self.state(uuid)?.original_request()?.clone();
self.state_mut(uuid)?.queue_decode();
#[cfg(test)]
{
self.stats
.transition_log
.push(DisaggTransition::DecodeEnqueued { uuid });
self.stats.handoff_ms.insert(uuid, self.now_ms);
}
if let Some(worker_idx) = maybe_worker_idx {
self.dispatch_decode(uuid, worker_idx)?;
return Ok(());
}
self.state_mut(uuid)?.queue_decode();
let admissions = self
.decode_router
.as_mut()
.expect("decode router presence checked above")
.on_request_arrival(&request, None, self.now_ms)?
.admissions;
self.record_router_pending();
self.dispatch_decode_admissions(admissions)?;
Ok(())
}
......@@ -366,6 +355,7 @@ impl DisaggRuntime {
request.max_output_tokens,
);
let queued_request = request.clone();
self.requests
.insert(uuid, DisaggRequestState::new(request, arrival_time_ms));
if self.prefill_router.is_none() {
......@@ -373,10 +363,14 @@ impl DisaggRuntime {
self.dispatch_prefill(uuid, worker_idx)?;
return Ok(uuid);
}
let maybe_worker_idx = self.submit_to_prefill_router(uuid, replay_hashes)?;
if let Some(worker_idx) = maybe_worker_idx {
self.dispatch_prefill(uuid, worker_idx)?;
}
let admissions = self
.prefill_router
.as_mut()
.expect("prefill router presence checked above")
.on_request_arrival(&queued_request, replay_hashes, self.now_ms)?
.admissions;
self.record_router_pending();
self.dispatch_prefill_admissions(admissions)?;
Ok(uuid)
}
......@@ -384,40 +378,18 @@ impl DisaggRuntime {
fn is_done(&self) -> bool {
self.events.is_empty()
&& self.cluster_in_flight() == 0
&& match &self.admission {
AdmissionSource::Requests(pending) => pending.is_empty(),
AdmissionSource::Workload(driver) => driver.is_drained(),
}
&& self
.prefill_workers
.iter()
.all(OfflineWorkerState::is_drained)
&& self
.decode_workers
.iter()
.all(OfflineWorkerState::is_drained)
&& self.admission.is_drained()
&& self.prefill_engine.is_drained()
&& self.decode_engine.is_drained()
}
/// Pick the next logical timestamp from arrivals, worker completions, or decode handoffs.
fn next_timestamp(&mut self) -> Option<f64> {
let next_event_ms = self.events.peek().map(|event| event.at_ms);
let cluster_in_flight = self.cluster_in_flight();
let next_arrival_ms = match (&self.mode, &mut self.admission) {
(ReplayMode::Trace, AdmissionSource::Requests(pending)) => pending
.front()
.and_then(|request| request.arrival_timestamp_ms),
(ReplayMode::Trace, AdmissionSource::Workload(driver)) => driver.next_ready_time_ms(),
(ReplayMode::Concurrency { max_in_flight }, AdmissionSource::Workload(driver)) => {
if cluster_in_flight < *max_in_flight {
driver.next_ready_time_ms()
} else {
None
}
}
(ReplayMode::Concurrency { .. }, AdmissionSource::Requests(_)) => None,
};
choose_next_timestamp(next_arrival_ms, next_event_ms)
choose_next_timestamp(
self.admission.next_ready_time_ms(self.cluster_in_flight()),
next_event_ms,
)
}
/// Apply prefill-side KV router events at the scheduler-selected visibility phase.
......@@ -425,8 +397,9 @@ impl DisaggRuntime {
let Some(prefill_router) = self.prefill_router.as_mut() else {
return Ok(());
};
for event in events {
prefill_router.apply_event(event)?;
let effects = prefill_router.on_kv_events(events)?;
if !effects.admissions.is_empty() {
bail!("offline disagg replay prefill KV events must not admit requests");
}
Ok(())
}
......@@ -440,22 +413,32 @@ impl DisaggRuntime {
if self.prefill_router.is_some() {
let prefill_complete_admissions = {
let prefill_router = self.prefill_router.as_mut().expect("router checked above");
prefill_router.mark_prefill_completed(signal.uuid)?
prefill_router
.on_prefill_completed(signal.uuid, self.now_ms)?
.admissions
};
#[cfg(test)]
{
self.stats.prefill_marked_count += 1;
self.stats
.transition_log
.push(DisaggTransition::PrefillMarkCompleted { uuid: signal.uuid });
}
self.record_router_pending();
self.dispatch_prefill_admissions(prefill_complete_admissions)?;
let admissions = {
let prefill_router = self.prefill_router.as_mut().expect("router checked above");
prefill_router.free(signal.uuid)?
prefill_router
.on_request_completed(signal.uuid, self.now_ms)?
.admissions
};
#[cfg(test)]
{
self.stats.prefill_router_freed_count += 1;
self.stats
.transition_log
.push(DisaggTransition::PrefillFree { uuid: signal.uuid });
}
self.record_router_pending();
self.dispatch_prefill_admissions(admissions)?;
......@@ -471,20 +454,37 @@ impl DisaggRuntime {
}
let admissions = if let Some(decode_router) = self.decode_router.as_mut() {
let admissions = decode_router.free(signal.uuid)?;
let admissions = decode_router
.on_request_completed(signal.uuid, self.now_ms)?
.admissions;
#[cfg(test)]
{
self.stats.decode_router_freed_count += 1;
self.stats
.transition_log
.push(DisaggTransition::DecodeFree { uuid: signal.uuid });
}
admissions
} else {
Vec::new()
};
self.record_router_pending();
if let AdmissionSource::Workload(driver) = &mut self.admission {
driver.on_complete(signal.uuid, self.now_ms)?;
self.admission
.on_request_completed(signal.uuid, self.now_ms)?;
self.progress.inc_completed();
#[cfg(test)]
if self.admission.is_workload() {
self.stats
.transition_log
.push(DisaggTransition::WorkloadCompleted { uuid: signal.uuid });
}
self.state_mut(signal.uuid)?.mark_done();
#[cfg(test)]
{
self.stats
.transition_log
.push(DisaggTransition::RequestMarkedDone { uuid: signal.uuid });
}
self.dispatch_decode_admissions(admissions)?;
Ok(())
}
......@@ -492,12 +492,11 @@ impl DisaggRuntime {
/// Apply the side effects of a finished prefill pass.
fn process_prefill_pass(
&mut self,
worker_idx: usize,
completed_requests: usize,
_worker_idx: usize,
_completed_requests: usize,
output_signals: Vec<OutputSignal>,
kv_events: Vec<RouterEvent>,
) -> Result<()> {
self.prefill_workers[worker_idx].mark_completed(completed_requests);
self.apply_prefill_router_events(kv_events)?;
for signal in output_signals {
self.process_prefill_signal(signal)?;
......@@ -508,11 +507,10 @@ impl DisaggRuntime {
/// Apply the side effects of a finished decode pass.
fn process_decode_pass(
&mut self,
worker_idx: usize,
completed_requests: usize,
_worker_idx: usize,
_completed_requests: usize,
output_signals: Vec<OutputSignal>,
) -> Result<()> {
self.decode_workers[worker_idx].mark_completed(completed_requests);
for signal in output_signals {
self.process_decode_signal(signal)?;
}
......@@ -522,27 +520,24 @@ impl DisaggRuntime {
/// Drain all worker-completion events scheduled for the current logical timestamp.
fn apply_worker_completions(&mut self) -> Result<bool> {
let mut changed = false;
while let Some(WorkerCompletionPayload {
stage,
worker_idx,
completed_requests,
output_signals,
kv_events,
}) = pop_ready_worker_completion(&mut self.events, self.now_ms)
{
match stage {
while let Some(payload) = pop_ready_worker_completion(&mut self.events, self.now_ms) {
match payload.stage {
SimulationWorkerStage::Prefill => {
self.prefill_workers[worker_idx].mark_idle();
let payload = self.prefill_engine.on_scheduled_completion(payload)?;
self.process_prefill_pass(
worker_idx,
completed_requests,
output_signals,
kv_events,
payload.worker_idx,
payload.completed_requests,
payload.output_signals,
payload.kv_events,
)?;
}
SimulationWorkerStage::Decode => {
self.decode_workers[worker_idx].mark_idle();
self.process_decode_pass(worker_idx, completed_requests, output_signals)?;
let payload = self.decode_engine.on_scheduled_completion(payload)?;
self.process_decode_pass(
payload.worker_idx,
payload.completed_requests,
payload.output_signals,
)?;
}
SimulationWorkerStage::Aggregated => {
bail!("offline disagg replay received an aggregated completion event")
......@@ -569,91 +564,33 @@ impl DisaggRuntime {
uuid: Uuid,
handoff_delay_ms: Option<f64>,
) -> Result<()> {
if let Some(delay_ms) = handoff_delay_ms
&& delay_ms > 0.0
{
let Some(delay_ms) = handoff_delay_ms else {
return self.enqueue_decode(uuid);
};
if delay_ms > 0.0 {
push_decode_handoff(
&mut self.events,
&mut self.next_event_seq,
self.now_ms + delay_ms,
uuid,
);
#[cfg(test)]
self.stats
.transition_log
.push(DisaggTransition::DecodeHandoffQueued { uuid });
return Ok(());
}
self.enqueue_decode(uuid)
}
/// Release every trace arrival whose timestamp is now visible to the global clock.
fn release_trace_arrivals(&mut self) -> Result<bool> {
let mut released_any = false;
if matches!(self.admission, AdmissionSource::Requests(_)) {
loop {
let next_ready = match &mut self.admission {
AdmissionSource::Requests(pending) => {
pop_next_trace_ready(pending, self.now_ms)
}
AdmissionSource::Workload(_) => unreachable!(),
};
let Some((request, arrival_ms)) = next_ready else {
break;
};
self.on_external_arrival(request, arrival_ms, None)?;
released_any = true;
}
return Ok(released_any);
}
let ready_requests = match &mut self.admission {
AdmissionSource::Requests(_) => unreachable!(),
AdmissionSource::Workload(driver) => driver.pop_ready(self.now_ms, usize::MAX),
};
for ready in ready_requests {
self.on_external_arrival(
ready.request,
ready.scheduled_ready_at_ms,
ready.replay_hashes,
)?;
released_any = true;
}
Ok(released_any)
}
/// Backfill closed-loop concurrency replay until the staged cluster reaches its cap.
fn top_off_concurrency(&mut self, max_in_flight: usize) -> Result<bool> {
/// Release every admission made ready by the shared admission queue.
fn release_ready_arrivals(&mut self) -> Result<bool> {
let mut released_any = false;
if matches!(self.admission, AdmissionSource::Requests(_)) {
loop {
let cluster_in_flight = self.cluster_in_flight();
let next_ready = match &mut self.admission {
AdmissionSource::Requests(pending) => pop_next_concurrency_ready(
pending,
self.now_ms,
cluster_in_flight,
max_in_flight,
),
AdmissionSource::Workload(_) => unreachable!(),
};
let Some((request, arrival_ms)) = next_ready else {
break;
};
self.on_external_arrival(request, arrival_ms, None)?;
released_any = true;
}
return Ok(released_any);
}
let available = max_in_flight.saturating_sub(self.cluster_in_flight());
if available == 0 {
return Ok(false);
}
let ready_requests = match &mut self.admission {
AdmissionSource::Requests(_) => unreachable!(),
AdmissionSource::Workload(driver) => driver.pop_ready(self.now_ms, available),
};
for ready in ready_requests {
self.on_external_arrival(ready.request, self.now_ms, ready.replay_hashes)?;
for ready in self
.admission
.drain_ready(self.now_ms, self.cluster_in_flight())?
{
self.on_external_arrival(ready.request, ready.arrival_time_ms, ready.replay_hashes)?;
released_any = true;
}
Ok(released_any)
......@@ -662,93 +599,61 @@ impl DisaggRuntime {
/// Start passes on every idle prefill worker that can make progress at the current timestamp.
fn drive_prefill_workers(&mut self) -> Result<bool> {
let mut changed = false;
for worker_idx in 0..self.prefill_workers.len() {
loop {
if !self.prefill_workers[worker_idx].is_ready() {
break;
}
let executed = self.prefill_workers[worker_idx].execute_hidden_pass(self.now_ms);
changed = true;
let completion_kv_events =
if executed.router_event_visibility == RouterEventVisibility::PassStart {
self.apply_prefill_router_events(executed.kv_events)?;
Vec::new()
} else {
executed.kv_events
};
if executed.end_ms == self.now_ms {
self.process_prefill_pass(
worker_idx,
executed.completed_requests,
executed.output_signals,
completion_kv_events,
)?;
continue;
}
self.prefill_workers[worker_idx].mark_busy();
push_worker_completion(
&mut self.events,
&mut self.next_event_seq,
executed.end_ms,
WorkerCompletionPayload {
stage: SimulationWorkerStage::Prefill,
worker_idx,
completed_requests: executed.completed_requests,
output_signals: executed.output_signals,
kv_events: completion_kv_events,
},
);
break;
loop {
let effects = self.prefill_engine.drive_ready(self.now_ms, None)?;
if effects.is_empty() {
return Ok(changed);
}
changed = true;
self.handle_prefill_engine_effects(effects)?;
}
Ok(changed)
}
/// Start passes on every idle decode worker that can make progress at the current timestamp.
fn drive_decode_workers(&mut self) -> Result<bool> {
let mut changed = false;
for worker_idx in 0..self.decode_workers.len() {
loop {
if !self.decode_workers[worker_idx].is_ready() {
break;
}
let executed = {
let (workers, collector) = (&mut self.decode_workers, &mut self.collector);
workers[worker_idx].execute_pass(collector, self.now_ms)
};
changed = true;
loop {
let effects = self
.decode_engine
.drive_ready(self.now_ms, Some(&mut self.collector))?;
if effects.is_empty() {
return Ok(changed);
}
changed = true;
self.handle_decode_engine_effects(effects)?;
}
}
if executed.end_ms == self.now_ms {
self.process_decode_pass(
worker_idx,
executed.completed_requests,
executed.output_signals,
)?;
continue;
}
fn handle_prefill_engine_effects(&mut self, effects: EngineEffects) -> Result<()> {
self.apply_prefill_router_events(effects.pass_start_kv_events)?;
for payload in effects.immediate_completions {
let payload = self.prefill_engine.on_scheduled_completion(payload)?;
self.process_prefill_pass(
payload.worker_idx,
payload.completed_requests,
payload.output_signals,
payload.kv_events,
)?;
}
for ScheduledWorkerCompletion { at_ms, payload } in effects.scheduled_completions {
push_worker_completion(&mut self.events, &mut self.next_event_seq, at_ms, payload);
}
Ok(())
}
self.decode_workers[worker_idx].mark_busy();
push_worker_completion(
&mut self.events,
&mut self.next_event_seq,
executed.end_ms,
WorkerCompletionPayload {
stage: SimulationWorkerStage::Decode,
worker_idx,
completed_requests: executed.completed_requests,
output_signals: executed.output_signals,
kv_events: Vec::new(),
},
);
break;
}
fn handle_decode_engine_effects(&mut self, effects: EngineEffects) -> Result<()> {
for payload in effects.immediate_completions {
let payload = self.decode_engine.on_scheduled_completion(payload)?;
self.process_decode_pass(
payload.worker_idx,
payload.completed_requests,
payload.output_signals,
)?;
}
Ok(changed)
for ScheduledWorkerCompletion { at_ms, payload } in effects.scheduled_completions {
push_worker_completion(&mut self.events, &mut self.next_event_seq, at_ms, payload);
}
Ok(())
}
/// Repeatedly process all work that becomes possible without advancing logical time.
......@@ -756,14 +661,7 @@ impl DisaggRuntime {
loop {
let mut changed = self.apply_worker_completions()?;
changed |= self.apply_decode_handoffs()?;
changed |= match self.mode {
ReplayMode::Trace => self.release_trace_arrivals()?,
ReplayMode::Concurrency { max_in_flight } => {
self.top_off_concurrency(max_in_flight)?
}
};
changed |= self.release_ready_arrivals()?;
changed |= self.drive_prefill_workers()?;
changed |= self.drive_decode_workers()?;
......@@ -801,6 +699,7 @@ impl DisaggRuntime {
self.drain_current_timestamp()?;
}
self.progress.finish();
self.finish_test_stats();
Ok((self.collector, self.stats))
}
......@@ -834,14 +733,19 @@ fn derive_decode_router_config(
config.overlap_score_weight = 0.0;
config.router_assume_kv_reuse = false;
config.router_track_prefill_tokens = false;
config.router_prefill_load_model = dynamo_kv_router::config::RouterPrefillLoadModel::None;
config
}
#[cfg(test)]
mod tests {
use super::super::entrypoints::{run_concurrency_collect, run_trace_collect};
use super::super::entrypoints::{
run_concurrency_collect, run_concurrency_workload_collect, run_trace_collect,
run_trace_workload_collect,
};
use super::*;
use crate::common::protocols::{MockEngineArgs, WorkerType};
use crate::common::protocols::{EngineType, MockEngineArgs, SglangArgs, WorkerType};
use crate::loadgen::{SessionTrace, Trace, TurnTrace};
fn staged_args(worker_type: WorkerType, speedup_ratio: f64) -> MockEngineArgs {
MockEngineArgs::builder()
......@@ -858,6 +762,26 @@ mod tests {
.unwrap()
}
fn sglang_staged_args(worker_type: WorkerType, speedup_ratio: f64) -> MockEngineArgs {
MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.block_size(64)
.num_gpu_blocks(512)
.max_num_batched_tokens(Some(8192))
.max_num_seqs(Some(8))
.enable_prefix_caching(true)
.enable_chunked_prefill(true)
.speedup_ratio(speedup_ratio)
.decode_speedup_ratio(speedup_ratio)
.worker_type(worker_type)
.sglang(Some(SglangArgs {
page_size: Some(64),
..Default::default()
}))
.build()
.unwrap()
}
fn disagg_config() -> OfflineDisaggReplayConfig {
OfflineDisaggReplayConfig {
prefill_args: staged_args(WorkerType::Prefill, 1000.0),
......@@ -867,6 +791,15 @@ mod tests {
}
}
fn sglang_disagg_config() -> OfflineDisaggReplayConfig {
OfflineDisaggReplayConfig {
prefill_args: sglang_staged_args(WorkerType::Prefill, 1000.0),
decode_args: sglang_staged_args(WorkerType::Decode, 1000.0),
num_prefill_workers: 2,
num_decode_workers: 2,
}
}
fn disagg_config_with_handoff_delay() -> OfflineDisaggReplayConfig {
let mut config = disagg_config();
config.prefill_args.kv_transfer_bandwidth = Some(1.0);
......@@ -896,6 +829,49 @@ mod tests {
}
}
fn multiturn_trace() -> Trace {
Trace {
block_size: 64,
sessions: vec![
SessionTrace {
session_id: "session-a".to_string(),
first_arrival_timestamp_ms: Some(0.0),
turns: vec![
TurnTrace {
input_length: 64,
max_output_tokens: 2,
hash_ids: vec![11],
delay_after_previous_ms: 0.0,
},
TurnTrace {
input_length: 192,
max_output_tokens: 2,
hash_ids: vec![21, 22, 23],
delay_after_previous_ms: 10.0,
},
],
},
SessionTrace {
session_id: "session-b".to_string(),
first_arrival_timestamp_ms: Some(5.0),
turns: vec![TurnTrace {
input_length: 128,
max_output_tokens: 2,
hash_ids: vec![31, 32],
delay_after_previous_ms: 0.0,
}],
},
],
}
}
fn transition_index(transitions: &[DisaggTransition], needle: DisaggTransition) -> usize {
transitions
.iter()
.position(|transition| *transition == needle)
.unwrap()
}
#[test]
fn test_derive_stage_router_configs_force_required_overrides() {
let config = KvRouterConfig {
......@@ -933,6 +909,7 @@ mod tests {
assert!(snapshot.first_token_ms.is_some());
assert_eq!(snapshot.output_length, 3);
assert_eq!(report.request_counts.completed_requests, 1);
assert_eq!(report.request_counts.total_output_tokens, 3);
assert_eq!(
stats.request_snapshots[&Uuid::from_u128(1)].phase,
DisaggPhase::Done
......@@ -952,26 +929,45 @@ mod tests {
for uuid in [Uuid::from_u128(1), Uuid::from_u128(2)] {
assert!(stats.prefill_assignments.contains_key(&uuid));
assert!(stats.decode_assignments.contains_key(&uuid));
assert_eq!(stats.request_snapshots[&uuid].phase, DisaggPhase::Done);
assert_eq!(
stats.request_snapshots[&uuid].prefill_worker_idx,
Some(stats.prefill_assignments[&uuid])
);
assert_eq!(
stats.request_snapshots[&uuid].decode_worker_idx,
Some(stats.decode_assignments[&uuid])
);
}
}
#[test]
fn test_prefill_overlap_prefers_same_worker_after_handoff_delay() {
let config = disagg_config();
let requests = vec![request(1, 128, 2, 0.0), request(2, 128, 2, 100.0)];
let (_, stats) = run_trace_collect(
&config,
requests,
Some(router_config()),
1.0,
ReplayRouterMode::KvRouter,
);
let cases = [(disagg_config(), true), (sglang_disagg_config(), false)];
for (config, expect_same_worker) in cases {
let (_, stats) = run_trace_collect(
&config,
requests.clone(),
Some(router_config()),
1.0,
ReplayRouterMode::KvRouter,
);
assert_eq!(
stats.prefill_assignments[&Uuid::from_u128(1)],
stats.prefill_assignments[&Uuid::from_u128(2)],
);
if expect_same_worker {
assert_eq!(
stats.prefill_assignments[&Uuid::from_u128(1)],
stats.prefill_assignments[&Uuid::from_u128(2)],
);
} else {
for uuid in [Uuid::from_u128(1), Uuid::from_u128(2)] {
assert!(stats.prefill_assignments.contains_key(&uuid));
assert!(stats.decode_assignments.contains_key(&uuid));
assert_eq!(stats.request_snapshots[&uuid].phase, DisaggPhase::Done);
}
}
}
}
#[rstest::rstest]
......@@ -999,13 +995,21 @@ mod tests {
];
let router_config = (router_mode == ReplayRouterMode::KvRouter).then(router_config);
let (collector, _) =
let (collector, stats) =
run_concurrency_collect(&config, requests, router_config, 1, router_mode);
let first = collector.snapshot(Uuid::from_u128(1)).unwrap();
let second = collector.snapshot(Uuid::from_u128(2)).unwrap();
assert_eq!(first.arrival_time_ms, 0.0);
assert_eq!(second.arrival_time_ms, first.last_token_ms.unwrap());
assert_eq!(
stats.request_snapshots[&Uuid::from_u128(1)].phase,
DisaggPhase::Done
);
assert_eq!(
stats.request_snapshots[&Uuid::from_u128(2)].phase,
DisaggPhase::Done
);
}
#[test]
......@@ -1024,6 +1028,14 @@ mod tests {
assert_eq!(stats.prefill_marked_count, 1);
assert_eq!(stats.prefill_router_freed_count, 1);
assert_eq!(stats.decode_router_freed_count, 1);
let transitions = &stats.transition_log;
let uuid = Uuid::from_u128(1);
let mark_idx =
transition_index(transitions, DisaggTransition::PrefillMarkCompleted { uuid });
let free_idx = transition_index(transitions, DisaggTransition::PrefillFree { uuid });
let enqueue_idx = transition_index(transitions, DisaggTransition::DecodeEnqueued { uuid });
assert!(mark_idx < free_idx);
assert!(free_idx < enqueue_idx);
}
#[test]
......@@ -1037,7 +1049,7 @@ mod tests {
1.0,
ReplayRouterMode::RoundRobin,
);
let (delayed_collector, _) = run_trace_collect(
let (delayed_collector, delayed_stats) = run_trace_collect(
&disagg_config_with_handoff_delay(),
requests,
None,
......@@ -1054,5 +1066,71 @@ mod tests {
delayed_ttft >= baseline_ttft + 120.0,
"expected delayed TTFT to include roughly 128ms of handoff delay, baseline={baseline_ttft}, delayed={delayed_ttft}"
);
let uuid = Uuid::from_u128(1);
let queued_idx = transition_index(
&delayed_stats.transition_log,
DisaggTransition::DecodeHandoffQueued { uuid },
);
let enqueued_idx = transition_index(
&delayed_stats.transition_log,
DisaggTransition::DecodeEnqueued { uuid },
);
assert!(queued_idx < enqueued_idx);
assert!(delayed_stats.handoff_ms[&uuid] >= 120.0);
}
#[test]
fn test_trace_workload_follow_up_turn_arrives_after_completion_plus_delay() {
let (collector, _) = run_trace_workload_collect(
&disagg_config(),
multiturn_trace(),
None,
ReplayRouterMode::RoundRobin,
);
let snapshots = collector.snapshots();
let first_turn = snapshots
.iter()
.find(|snapshot| snapshot.input_length == 64)
.unwrap();
let second_turn = snapshots
.iter()
.find(|snapshot| snapshot.input_length == 192)
.unwrap();
let session_b = snapshots
.iter()
.find(|snapshot| snapshot.input_length == 128)
.unwrap();
assert_eq!(first_turn.arrival_time_ms, 0.0);
assert_eq!(session_b.arrival_time_ms, 5.0);
assert!(
second_turn.arrival_time_ms >= first_turn.last_token_ms.unwrap() + 10.0,
"follow-up turn should unlock after completion plus delay"
);
}
#[test]
fn test_concurrency_workload_delayed_follow_up_does_not_bypass_other_ready_sessions() {
let (collector, _) = run_concurrency_workload_collect(
&disagg_config(),
multiturn_trace(),
None,
1,
ReplayRouterMode::RoundRobin,
);
let mut input_lengths = collector
.snapshots()
.into_iter()
.map(|snapshot| (snapshot.arrival_time_ms, snapshot.input_length))
.collect::<Vec<_>>();
input_lengths.sort_by(|left, right| left.0.total_cmp(&right.0));
assert_eq!(
input_lengths
.into_iter()
.map(|(_, input_length)| input_length)
.collect::<Vec<_>>(),
vec![64, 128, 192]
);
}
}
......@@ -20,8 +20,8 @@ use crate::common::protocols::{DirectRequest, EngineType, MockEngineArgs};
use crate::loadgen::{Trace, WorkloadDriver};
use crate::replay::OfflineDisaggReplayConfig;
use crate::replay::{
ReplayRouterMode, ReplayTimedKvEvent, ReplayTimedOutputSignal, ReplayTimedRequest,
ReplayWorkerArtifacts, TraceCollector, TraceSimulationReport,
ReplayPrefillLoadEstimator, ReplayRouterMode, ReplayTimedKvEvent, ReplayTimedOutputSignal,
ReplayTimedRequest, ReplayWorkerArtifacts, TraceCollector, TraceSimulationReport,
};
use crate::scheduler::RouterEventVisibility;
......@@ -37,8 +37,10 @@ pub(crate) fn generate_trace_worker_artifacts(
args: MockEngineArgs,
trace: Trace,
) -> Result<ReplayWorkerArtifacts> {
let args = args.normalized()?;
let engine_block_size = args.block_size;
let mut worker = ReplayWorkerCore::new_with_kv_capture(args, WorkerId::default());
let mut driver = trace.into_trace_driver()?;
let mut driver = trace.into_trace_driver_with_block_size(engine_block_size)?;
let mut collector = TraceCollector::default();
let mut artifacts = ReplayWorkerArtifacts::default();
let mut current_time_ms = 0.0;
......@@ -106,6 +108,7 @@ pub(crate) fn generate_trace_worker_artifacts(
pub(crate) fn simulate_trace(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
......@@ -117,6 +120,7 @@ pub(crate) fn simulate_trace(
simulate_trace_multi(
args,
router_config,
prefill_load_estimator,
requests,
num_workers,
arrival_speedup_ratio,
......@@ -128,6 +132,7 @@ pub(crate) fn simulate_trace(
pub(crate) fn simulate_concurrency(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
......@@ -139,6 +144,7 @@ pub(crate) fn simulate_concurrency(
simulate_concurrency_multi(
args,
router_config,
prefill_load_estimator,
requests,
max_in_flight,
num_workers,
......@@ -150,6 +156,7 @@ pub(crate) fn simulate_concurrency(
pub(crate) fn simulate_trace_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
......@@ -157,13 +164,21 @@ pub(crate) fn simulate_trace_workload(
if num_workers == 1 && args.engine_type == EngineType::Vllm {
simulate_trace_workload_single(args, trace)
} else {
simulate_trace_workload_multi(args, router_config, trace, num_workers, router_mode)
simulate_trace_workload_multi(
args,
router_config,
prefill_load_estimator,
trace,
num_workers,
router_mode,
)
}
}
pub(crate) fn simulate_concurrency_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
......@@ -175,6 +190,7 @@ pub(crate) fn simulate_concurrency_workload(
simulate_concurrency_workload_multi(
args,
router_config,
prefill_load_estimator,
trace,
max_in_flight,
num_workers,
......@@ -186,6 +202,7 @@ pub(crate) fn simulate_concurrency_workload(
pub(crate) fn simulate_trace_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
......@@ -194,6 +211,7 @@ pub(crate) fn simulate_trace_disagg(
let (collector, _) = DisaggRuntime::new(
&config,
router_config,
prefill_load_estimator,
pending,
DisaggReplayMode::Trace,
router_mode,
......@@ -205,6 +223,7 @@ pub(crate) fn simulate_trace_disagg(
pub(crate) fn simulate_concurrency_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
router_mode: ReplayRouterMode,
......@@ -213,6 +232,7 @@ pub(crate) fn simulate_concurrency_disagg(
let (collector, _) = DisaggRuntime::new(
&config,
router_config,
prefill_load_estimator,
pending,
DisaggReplayMode::Concurrency { max_in_flight },
router_mode,
......@@ -224,13 +244,15 @@ pub(crate) fn simulate_concurrency_disagg(
pub(crate) fn simulate_trace_workload_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let driver = WorkloadDriver::new_trace(trace)?;
let driver = WorkloadDriver::new_trace(trace, config.prefill_args.block_size)?;
let (collector, _) = DisaggRuntime::new_workload(
&config,
router_config,
prefill_load_estimator,
driver,
DisaggReplayMode::Trace,
router_mode,
......@@ -242,14 +264,16 @@ pub(crate) fn simulate_trace_workload_disagg(
pub(crate) fn simulate_concurrency_workload_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let driver = WorkloadDriver::new_concurrency(trace)?;
let driver = WorkloadDriver::new_concurrency(trace, config.prefill_args.block_size)?;
let (collector, _) = DisaggRuntime::new_workload(
&config,
router_config,
prefill_load_estimator,
driver,
DisaggReplayMode::Concurrency { max_in_flight },
router_mode,
......@@ -290,9 +314,13 @@ pub(crate) fn simulate_trace_workload_single(
trace: Trace,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let collector =
SingleRuntime::new_workload(args, trace.into_trace_driver()?, SingleReplayMode::Trace)
.run()?;
let engine_block_size = args.block_size;
let collector = SingleRuntime::new_workload(
args,
trace.into_trace_driver_with_block_size(engine_block_size)?,
SingleReplayMode::Trace,
)
.run()?;
Ok(collector.finish())
}
......@@ -302,9 +330,10 @@ pub(crate) fn simulate_concurrency_workload_single(
max_in_flight: usize,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let engine_block_size = args.block_size;
let collector = SingleRuntime::new_workload(
args,
trace.into_concurrency_driver()?,
trace.into_concurrency_driver_with_block_size(engine_block_size)?,
SingleReplayMode::Concurrency { max_in_flight },
)
.run()?;
......@@ -314,6 +343,7 @@ pub(crate) fn simulate_concurrency_workload_single(
pub(crate) fn simulate_trace_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
......@@ -324,6 +354,7 @@ pub(crate) fn simulate_trace_multi(
let (collector, _) = AggRuntime::new(
&args,
router_config,
prefill_load_estimator,
pending,
num_workers,
AggReplayMode::Trace,
......@@ -336,6 +367,7 @@ pub(crate) fn simulate_trace_multi(
pub(crate) fn simulate_concurrency_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
......@@ -346,6 +378,7 @@ pub(crate) fn simulate_concurrency_multi(
let (collector, _) = AggRuntime::new(
&args,
router_config,
prefill_load_estimator,
pending,
num_workers,
AggReplayMode::Concurrency { max_in_flight },
......@@ -358,6 +391,7 @@ pub(crate) fn simulate_concurrency_multi(
pub(crate) fn simulate_trace_workload_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
......@@ -366,7 +400,8 @@ pub(crate) fn simulate_trace_workload_multi(
let (collector, _) = AggRuntime::new_workload(
&args,
router_config,
trace.into_trace_driver()?,
prefill_load_estimator,
trace.into_trace_driver_with_block_size(args.block_size)?,
num_workers,
AggReplayMode::Trace,
router_mode,
......@@ -378,6 +413,7 @@ pub(crate) fn simulate_trace_workload_multi(
pub(crate) fn simulate_concurrency_workload_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
......@@ -387,7 +423,8 @@ pub(crate) fn simulate_concurrency_workload_multi(
let (collector, _) = AggRuntime::new_workload(
&args,
router_config,
trace.into_concurrency_driver()?,
prefill_load_estimator,
trace.into_concurrency_driver_with_block_size(args.block_size)?,
num_workers,
AggReplayMode::Concurrency { max_in_flight },
router_mode,
......@@ -428,9 +465,12 @@ pub(super) fn run_trace_workload_single_collect(
args: MockEngineArgs,
trace: Trace,
) -> TraceCollector {
let engine_block_size = args.block_size;
SingleRuntime::new_workload(
args,
trace.into_trace_driver().unwrap(),
trace
.into_trace_driver_with_block_size(engine_block_size)
.unwrap(),
SingleReplayMode::Trace,
)
.run()
......@@ -443,9 +483,12 @@ pub(super) fn run_concurrency_workload_single_collect(
trace: Trace,
max_in_flight: usize,
) -> TraceCollector {
let engine_block_size = args.block_size;
SingleRuntime::new_workload(
args,
trace.into_concurrency_driver().unwrap(),
trace
.into_concurrency_driver_with_block_size(engine_block_size)
.unwrap(),
SingleReplayMode::Concurrency { max_in_flight },
)
.run()
......@@ -463,6 +506,7 @@ pub(super) fn run_trace_multi_collect_with_stats(
AggRuntime::new(
args,
None,
None,
pending,
num_workers,
AggReplayMode::Trace,
......@@ -484,6 +528,7 @@ pub(super) fn run_concurrency_multi_collect_with_stats(
AggRuntime::new(
args,
None,
None,
VecDeque::from(requests),
num_workers,
AggReplayMode::Concurrency { max_in_flight },
......@@ -504,7 +549,10 @@ pub(super) fn run_trace_workload_multi_collect_with_stats(
AggRuntime::new_workload(
args,
None,
trace.into_trace_driver().unwrap(),
None,
trace
.into_trace_driver_with_block_size(args.block_size)
.unwrap(),
num_workers,
AggReplayMode::Trace,
router_mode,
......@@ -525,7 +573,10 @@ pub(super) fn run_concurrency_workload_multi_collect_with_stats(
AggRuntime::new_workload(
args,
None,
trace.into_concurrency_driver().unwrap(),
None,
trace
.into_concurrency_driver_with_block_size(args.block_size)
.unwrap(),
num_workers,
AggReplayMode::Concurrency { max_in_flight },
router_mode,
......@@ -547,6 +598,7 @@ pub(super) fn run_trace_collect(
DisaggRuntime::new(
config,
router_config,
None,
pending,
DisaggReplayMode::Trace,
router_mode,
......@@ -567,6 +619,7 @@ pub(super) fn run_concurrency_collect(
DisaggRuntime::new(
config,
router_config,
None,
VecDeque::from(requests),
DisaggReplayMode::Concurrency { max_in_flight },
router_mode,
......@@ -576,6 +629,51 @@ pub(super) fn run_concurrency_collect(
.unwrap()
}
#[cfg(test)]
pub(super) fn run_trace_workload_collect(
config: &OfflineDisaggReplayConfig,
trace: Trace,
router_config: Option<KvRouterConfig>,
router_mode: ReplayRouterMode,
) -> (TraceCollector, DisaggRuntimeStats) {
DisaggRuntime::new_workload(
config,
router_config,
None,
trace
.into_trace_driver_with_block_size(config.prefill_args.block_size)
.unwrap(),
DisaggReplayMode::Trace,
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_concurrency_workload_collect(
config: &OfflineDisaggReplayConfig,
trace: Trace,
router_config: Option<KvRouterConfig>,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, DisaggRuntimeStats) {
DisaggRuntime::new_workload(
config,
router_config,
None,
trace
.into_concurrency_driver_with_block_size(config.prefill_args.block_size)
.unwrap(),
DisaggReplayMode::Concurrency { max_in_flight },
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
mod tests {
use super::generate_trace_worker_artifacts;
......
......@@ -4,10 +4,12 @@
pub(crate) use crate::replay::normalize_trace_requests;
pub(crate) mod agg;
pub(crate) mod components;
pub(crate) mod core;
pub(crate) mod disagg;
mod entrypoints;
pub(crate) mod events;
mod progress;
pub(crate) mod runtime_utils;
pub(crate) mod single;
pub(crate) mod state;
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use indicatif::{ProgressBar, ProgressStyle};
pub(super) struct ReplayProgress {
bar: ProgressBar,
}
impl ReplayProgress {
pub(super) fn new(total_requests: usize, label: &'static str) -> Self {
let bar = ProgressBar::new(total_requests as u64);
bar.set_style(
ProgressStyle::with_template(
"[{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len} ({eta}) {msg}",
)
.expect("progress bar template must be valid")
.progress_chars("#>-"),
);
bar.set_message(label);
Self { bar }
}
pub(super) fn inc_completed(&self) {
self.bar.inc(1);
}
pub(super) fn finish(&self) {
self.bar.finish_and_clear();
}
}
impl Drop for ReplayProgress {
fn drop(&mut self) {
if !self.bar.is_finished() {
self.bar.finish_and_clear();
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::{BinaryHeap, VecDeque};
use std::collections::BinaryHeap;
#[cfg(test)]
use std::collections::VecDeque;
use dynamo_kv_router::protocols::RouterEvent;
use super::events::{SimulationEvent, SimulationEventKind, SimulationWorkerStage};
use crate::common::protocols::{DirectRequest, OutputSignal};
#[cfg(test)]
use crate::common::protocols::DirectRequest;
use crate::common::protocols::OutputSignal;
#[derive(Debug)]
pub(super) struct WorkerCompletionPayload {
......@@ -29,6 +33,7 @@ pub(super) fn next_timestamp(
}
}
#[cfg(test)]
pub(super) fn pop_next_trace_ready(
pending: &mut VecDeque<DirectRequest>,
now_ms: f64,
......@@ -43,6 +48,7 @@ pub(super) fn pop_next_trace_ready(
Some((request, arrival_ms))
}
#[cfg(test)]
pub(super) fn pop_next_concurrency_ready(
pending: &mut VecDeque<DirectRequest>,
now_ms: f64,
......
......@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use super::core::ReplayWorkerCore;
use super::progress::ReplayProgress;
use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::loadgen::WorkloadDriver;
use crate::replay::TraceCollector;
......@@ -26,6 +27,7 @@ pub(super) struct SingleRuntime {
worker: ReplayWorkerCore,
collector: TraceCollector,
mode: SingleReplayMode,
progress: ReplayProgress,
}
impl SingleRuntime {
......@@ -50,12 +52,17 @@ impl SingleRuntime {
admission: AdmissionSource,
mode: SingleReplayMode,
) -> Self {
let total_requests = match &admission {
AdmissionSource::Requests(pending) => pending.len(),
AdmissionSource::Workload(driver) => driver.total_turns(),
};
Self {
current_time_ms: 0.0,
admission,
worker: ReplayWorkerCore::new(args),
collector: TraceCollector::default(),
mode,
progress: ReplayProgress::new(total_requests, "offline replay"),
}
}
......@@ -168,6 +175,14 @@ impl SingleRuntime {
.expect("completed workload request must belong to a session");
}
}
let completed_requests = pass
.output_signals
.iter()
.filter(|signal| signal.completed)
.count();
for _ in 0..completed_requests {
self.progress.inc_completed();
}
if admit_arrivals_between_steps {
self.enqueue_trace_arrivals();
}
......@@ -199,6 +214,7 @@ impl SingleRuntime {
}
}
self.progress.finish();
Ok(self.collector)
}
}
......
......@@ -17,8 +17,8 @@ pub(crate) enum AggRequestPhase {
pub(crate) struct AggRequestState {
request: Option<DirectRequest>,
phase: AggRequestPhase,
prefill_completed: bool,
pub(in crate::replay::offline) phase: AggRequestPhase,
pub(in crate::replay::offline) prefill_completed: bool,
}
impl AggRequestState {
......@@ -38,12 +38,8 @@ impl AggRequestState {
}
}
pub(crate) fn is_queued_at_router(&self) -> bool {
self.phase == AggRequestPhase::QueuedAtRouter
}
pub(crate) fn take_queued_request(&mut self, uuid: Uuid) -> Result<DirectRequest> {
if !self.is_queued_at_router() {
if self.phase != AggRequestPhase::QueuedAtRouter {
bail!("offline replay expected queued request state for {uuid}");
}
let request = self
......@@ -53,14 +49,6 @@ impl AggRequestState {
self.phase = AggRequestPhase::Running;
Ok(request)
}
pub(crate) fn prefill_completed(&self) -> bool {
self.prefill_completed
}
pub(crate) fn mark_prefill_completed(&mut self) {
self.prefill_completed = true;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
......@@ -76,7 +64,7 @@ pub(crate) struct DisaggRequestState {
original: Option<DirectRequest>,
#[cfg(test)]
arrival_ms: f64,
phase: DisaggPhase,
pub(in crate::replay::offline) phase: DisaggPhase,
prefill_worker_idx: Option<usize>,
decode_worker_idx: Option<usize>,
}
......@@ -104,14 +92,6 @@ impl DisaggRequestState {
}
}
pub(crate) fn is_queued_prefill(&self) -> bool {
self.phase == DisaggPhase::QueuedPrefill
}
pub(crate) fn is_queued_decode(&self) -> bool {
self.phase == DisaggPhase::QueuedDecode
}
pub(crate) fn original_request(&self) -> Result<&DirectRequest> {
self.original
.as_ref()
......@@ -124,10 +104,6 @@ impl DisaggRequestState {
Ok(request)
}
pub(crate) fn build_decode_request(&self) -> Result<DirectRequest> {
Ok(self.original_request()?.clone())
}
pub(crate) fn start_prefill(&mut self, worker_idx: usize) {
self.phase = DisaggPhase::RunningPrefill;
self.prefill_worker_idx = Some(worker_idx);
......
......@@ -7,10 +7,10 @@ use tokio::sync::mpsc;
use tokio::time::Instant;
use crate::common::protocols::OutputSignal;
use crate::replay::router::ReplayRouter;
use crate::replay::{TraceCollector, TraceSimulationReport};
use crate::scheduler::AdmissionEvent;
use super::ReplayRouter;
use super::state::{ArrivalEvent, RequestRegistry, SharedLiveRuntimeStats, now_ms};
async fn process_output_signal(
......
......@@ -8,7 +8,9 @@ use dynamo_kv_router::config::KvRouterConfig;
use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::loadgen::{Trace, WorkloadDriver};
use crate::replay::{ReplayRouterMode, TraceSimulationReport, normalize_trace_requests};
use crate::replay::{
ReplayPrefillLoadEstimator, ReplayRouterMode, TraceSimulationReport, normalize_trace_requests,
};
use super::live_runtime::LiveRuntime;
use super::state::{LiveReplayMode, LiveRuntimeStats};
......@@ -24,6 +26,7 @@ fn total_turns(trace: &Trace) -> usize {
fn run_live_runtime(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
pending: VecDeque<DirectRequest>,
num_workers: usize,
mode: LiveReplayMode,
......@@ -35,15 +38,25 @@ fn run_live_runtime(
.map_err(|e| anyhow!("failed to create online replay runtime: {e}"))?;
runtime.block_on(async move {
LiveRuntime::new(args, router_config, pending, num_workers, mode, router_mode)?
.run()
.await
LiveRuntime::new(
args,
router_config,
prefill_load_estimator,
pending,
num_workers,
mode,
router_mode,
)?
.run()
.await
})
}
#[allow(clippy::too_many_arguments)]
fn run_live_workload_runtime(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
driver: WorkloadDriver,
total_turns: usize,
num_workers: usize,
......@@ -59,6 +72,7 @@ fn run_live_workload_runtime(
LiveRuntime::new(
args,
router_config,
prefill_load_estimator,
VecDeque::new(),
num_workers,
mode,
......@@ -72,6 +86,7 @@ fn run_live_workload_runtime(
pub(crate) fn simulate_trace_requests(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
......@@ -82,6 +97,7 @@ pub(crate) fn simulate_trace_requests(
let (report, _) = run_live_runtime(
args,
router_config,
prefill_load_estimator,
pending,
num_workers,
LiveReplayMode::Trace,
......@@ -93,6 +109,7 @@ pub(crate) fn simulate_trace_requests(
pub(crate) fn simulate_concurrency_requests(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
......@@ -107,6 +124,7 @@ pub(crate) fn simulate_concurrency_requests(
let (report, _) = run_live_runtime(
args,
router_config,
prefill_load_estimator,
pending,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
......@@ -118,16 +136,19 @@ pub(crate) fn simulate_concurrency_requests(
pub(crate) fn simulate_trace_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let engine_block_size = args.block_size;
let total_turns = total_turns(&trace);
let (report, _) = run_live_workload_runtime(
args,
router_config,
trace.into_trace_driver()?,
prefill_load_estimator,
trace.into_trace_driver_with_block_size(engine_block_size)?,
total_turns,
num_workers,
LiveReplayMode::Trace,
......@@ -139,17 +160,20 @@ pub(crate) fn simulate_trace_workload(
pub(crate) fn simulate_concurrency_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let engine_block_size = args.block_size;
let total_turns = total_turns(&trace);
let (report, _) = run_live_workload_runtime(
args,
router_config,
trace.into_concurrency_driver()?,
prefill_load_estimator,
trace.into_concurrency_driver_with_block_size(engine_block_size)?,
total_turns,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
......@@ -171,6 +195,7 @@ pub(super) fn simulate_trace_requests_with_stats(
run_live_runtime(
args,
None,
None,
pending,
num_workers,
LiveReplayMode::Trace,
......@@ -191,6 +216,7 @@ pub(super) fn simulate_concurrency_requests_with_stats(
run_live_runtime(
args,
None,
None,
pending,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
......@@ -206,11 +232,13 @@ pub(super) fn simulate_trace_workload_with_stats(
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let engine_block_size = args.block_size;
let total_turns = total_turns(&trace);
run_live_workload_runtime(
args,
None,
trace.into_trace_driver()?,
None,
trace.into_trace_driver_with_block_size(engine_block_size)?,
total_turns,
num_workers,
LiveReplayMode::Trace,
......@@ -227,11 +255,13 @@ pub(super) fn simulate_concurrency_workload_with_stats(
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let engine_block_size = args.block_size;
let total_turns = total_turns(&trace);
run_live_workload_runtime(
args,
None,
trace.into_concurrency_driver()?,
None,
trace.into_concurrency_driver_with_block_size(engine_block_size)?,
total_turns,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
......
......@@ -13,10 +13,10 @@ use tokio_util::sync::CancellationToken;
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal};
use crate::loadgen::WorkloadDriver;
use crate::replay::router::ReplayRouter;
use crate::replay::{ReplayRouterMode, TraceSimulationReport};
use crate::replay::{ReplayPrefillLoadEstimator, ReplayRouterMode, TraceSimulationReport};
use crate::scheduler::{AdmissionEvent, EngineScheduler, SchedulerHandle};
use super::ReplayRouter;
use super::demux::run_demux;
use super::state::{
LiveReplayMode, LiveRuntimeStats, SharedLiveRuntimeStats, WorkloadDispatchState, now_ms,
......@@ -41,6 +41,7 @@ impl LiveRuntime {
pub(super) fn new(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
pending: std::collections::VecDeque<DirectRequest>,
num_workers: usize,
mode: LiveReplayMode,
......@@ -53,6 +54,7 @@ impl LiveRuntime {
router_mode,
&args,
router_config,
prefill_load_estimator,
num_workers,
));
let mut schedulers = Vec::with_capacity(num_workers);
......
......@@ -4,6 +4,7 @@
mod demux;
mod entrypoints;
mod live_runtime;
mod router;
mod state;
mod task;
......@@ -14,3 +15,4 @@ pub(crate) use entrypoints::{
simulate_concurrency_requests, simulate_concurrency_workload, simulate_trace_requests,
simulate_trace_workload,
};
pub(crate) use router::ReplayRouter;
......@@ -16,14 +16,14 @@ use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use super::shared::{
ReplayScheduler, replay_policy, replay_router_config, replay_selector, replay_slots,
replay_workers_with_configs,
};
use crate::common::protocols::{
DirectRequest, KvCacheEventSink, KvEventPublishers, MockEngineArgs,
};
use crate::replay::ReplayRouterMode;
use crate::replay::router_shared::{
ReplayScheduler, replay_policy, replay_router_config, replay_selector, replay_slots,
replay_workers_with_configs,
};
use crate::replay::{ReplayPrefillLoadEstimator, ReplayRouterMode};
#[derive(Clone)]
enum ReplayIndexer {
......@@ -120,6 +120,7 @@ impl KvReplayRouter {
fn new(
args: &MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
num_workers: usize,
) -> Self {
let config = replay_router_config(args, router_config);
......@@ -138,6 +139,8 @@ impl KvReplayRouter {
args.block_size as u32,
selector,
policy,
prefill_load_estimator,
config.router_queue_recheck_interval(),
config.router_track_prefill_tokens,
CancellationToken::new(),
"replay",
......@@ -237,6 +240,20 @@ impl KvReplayRouter {
.map_err(|e| anyhow!("replay router event task failed: {e}"))?;
Ok(())
}
#[cfg(test)]
fn debug_potential_loads(
&self,
isl_tokens: usize,
track_prefill_tokens: bool,
) -> Vec<dynamo_kv_router::PotentialLoad> {
self.scheduler.get_potential_loads(
None,
isl_tokens,
OverlapScores::default(),
track_prefill_tokens,
)
}
}
#[expect(
......@@ -253,13 +270,17 @@ impl ReplayRouter {
mode: ReplayRouterMode,
args: &MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
num_workers: usize,
) -> Self {
match mode {
ReplayRouterMode::RoundRobin => Self::RoundRobin(RoundRobinRouter::default()),
ReplayRouterMode::KvRouter => {
Self::Kv(KvReplayRouter::new(args, router_config, num_workers))
}
ReplayRouterMode::KvRouter => Self::Kv(KvReplayRouter::new(
args,
router_config,
prefill_load_estimator,
num_workers,
)),
}
}
......@@ -307,4 +328,16 @@ impl ReplayRouter {
Self::Kv(router) => router.shutdown().await,
}
}
#[cfg(test)]
pub(crate) fn debug_potential_loads(
&self,
isl_tokens: usize,
track_prefill_tokens: bool,
) -> Vec<dynamo_kv_router::PotentialLoad> {
match self {
Self::RoundRobin(_) => Vec::new(),
Self::Kv(router) => router.debug_potential_loads(isl_tokens, track_prefill_tokens),
}
}
}
......@@ -10,8 +10,8 @@ use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc};
use tokio::time::Instant;
use crate::common::protocols::DirectRequest;
use crate::replay::router::ReplayRouter;
use super::ReplayRouter;
use super::state::{
LiveReplayMode, RequestRegistry, RequestState, SharedLiveRuntimeStats, WorkloadDispatchState,
now_ms, request_uuid,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment