Unverified Commit 02b1c58a authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(mocker): add offline disagg replay (#7617)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 4b8826b3
......@@ -23,7 +23,7 @@ use dynamo_mocker::common::protocols::{
DirectRequest, KvCacheEventSink, KvEventPublishers, MockEngineArgs, OutputSignal, RawKvEvent,
RawKvEventSink,
};
use dynamo_mocker::common::utils::{compute_kv_transfer_delay, sleep_precise};
use dynamo_mocker::common::utils::sleep_precise;
use dynamo_mocker::engine::create_engine;
use dynamo_mocker::scheduler::SchedulerHandle;
use dynamo_runtime::DistributedRuntime;
......@@ -645,14 +645,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
let bootstrap_server = self.bootstrap_server.clone();
let reasoning = self.engine_args.reasoning.clone();
// Compute KV transfer delay for prefill workers.
// Simulates the time to transfer KV cache from prefill to decode worker.
let kv_transfer_delay = if is_prefill {
compute_kv_transfer_delay(&self.engine_args, request.token_ids.len())
} else {
None
};
// Spawn a task to handle the complex async logic
tokio::spawn(async move {
let mut token_count = 0;
......@@ -693,17 +685,15 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
if signal.completed {
let _ = stream_tx.send(output);
// Simulate KV transfer delay before prefill's first (and only) token.
// This models the time to transfer KV cache to the decode worker.
if token_count == 1
&& let Some(delay) = kv_transfer_delay
// Prefill-to-decode handoff delay is emitted by the shared mocker core.
if is_prefill
&& let Some(delay_ms) = signal.handoff_delay_ms
{
sleep_precise(delay).await;
sleep_precise(Duration::from_secs_f64(delay_ms / 1000.0)).await;
}
// Prefill: after first token, mark room complete (unblocks decode)
if is_prefill
&& token_count == 1
&& let (Some(server), Some(room_id)) = (bootstrap_server.get(), bootstrap_room)
{
server.complete_room(room_id);
......
......@@ -139,6 +139,8 @@ impl PrefillCost {
pub struct OutputSignal {
pub uuid: Uuid,
pub completed: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub handoff_delay_ms: Option<f64>,
}
/// Preemption policy for evicting decode requests under memory pressure
......@@ -286,6 +288,10 @@ pub struct MockEngineArgs {
#[builder(default = "WorkerType::Aggregated")]
pub worker_type: WorkerType,
/// Original planner profile NPZ path used to materialize `perf_model`.
#[builder(default = "None")]
pub planner_profile_data: Option<PathBuf>,
/// Performance model for timing predictions (not serialized, loaded from planner_profile_data)
#[serde(skip)]
#[builder(default = "Arc::new(PerfModel::default())")]
......@@ -691,6 +697,7 @@ impl MockEngineArgs {
&& let Some(path_str) = path_str.as_str()
{
let npz_path = PathBuf::from(path_str);
builder = builder.planner_profile_data(Some(npz_path.clone()));
match PerfModel::from_npz(&npz_path) {
Ok(model) => {
tracing::info!("Successfully loaded performance model from: {:?}", npz_path);
......
......@@ -3,32 +3,59 @@
use std::time::{Duration, Instant};
use crate::common::protocols::MockEngineArgs;
use crate::common::protocols::{MockEngineArgs, WorkerType};
/// Compute the KV transfer delay duration for a given number of input tokens.
/// Compute the modeled handoff delay after a prefill worker emits its terminal token.
///
/// Returns `None` if KV transfer simulation is disabled (bandwidth is 0 or not configured).
pub fn compute_kv_transfer_delay(
args: &MockEngineArgs,
/// NOTE: this intentionally does not model the internal prefill TTFT itself accurately, and the
/// exact prefill/decode boundary is backend dependent. For now we only care about decode-visible
/// TTFT, which is what the client observes, so modeling the delay as prefill-to-decode handoff is
/// good enough.
pub fn compute_prefill_handoff_delay_ms(
worker_type: WorkerType,
completed: bool,
num_input_tokens: usize,
) -> Option<Duration> {
match (args.kv_transfer_bandwidth, args.kv_bytes_per_token) {
kv_transfer_bandwidth: Option<f64>,
kv_bytes_per_token: Option<usize>,
) -> Option<f64> {
if worker_type != WorkerType::Prefill || !completed {
return None;
}
match (kv_transfer_bandwidth, kv_bytes_per_token) {
(Some(bw), Some(bpt)) if bw > 0.0 => {
let kv_bytes = num_input_tokens as f64 * bpt as f64;
let delay = Duration::from_secs_f64(kv_bytes / (bw * 1e9));
let delay_ms = kv_bytes / (bw * 1e9) * 1000.0;
tracing::debug!(
num_input_tokens,
kv_bytes,
bandwidth_gb_s = bw,
delay_ms = format!("{:.2}", delay.as_secs_f64() * 1000.0),
"KV transfer delay for prefill"
delay_ms = format!("{delay_ms:.2}"),
"KV handoff delay for prefill completion"
);
Some(delay)
Some(delay_ms)
}
_ => None,
}
}
/// Compute the KV transfer delay duration for a given number of input tokens.
///
/// Returns `None` if KV transfer simulation is disabled (bandwidth is 0 or not configured).
pub fn compute_kv_transfer_delay(
args: &MockEngineArgs,
num_input_tokens: usize,
) -> Option<Duration> {
compute_prefill_handoff_delay_ms(
args.worker_type,
true,
num_input_tokens,
args.kv_transfer_bandwidth,
args.kv_bytes_per_token,
)
.map(|delay_ms| Duration::from_secs_f64(delay_ms / 1000.0))
}
/// Sleep for the specified duration using timerfd on Linux for precision.
pub async fn sleep_precise(duration: Duration) {
sleep_until_precise(Instant::now() + duration).await;
......@@ -53,3 +80,42 @@ pub async fn sleep_until_precise(deadline: Instant) {
tokio::time::sleep_until(tokio::time::Instant::from_std(deadline)).await;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prefill_handoff_delay_only_applies_to_completed_prefill() {
let delay_ms = compute_prefill_handoff_delay_ms(
WorkerType::Prefill,
true,
128,
Some(1.0),
Some(1_000_000),
)
.expect("prefill completion should produce a handoff delay");
assert!((delay_ms - 128.0).abs() < 1e-9);
assert!(
compute_prefill_handoff_delay_ms(
WorkerType::Prefill,
false,
128,
Some(1.0),
Some(1_000_000),
)
.is_none()
);
assert!(
compute_prefill_handoff_delay_ms(
WorkerType::Decode,
true,
128,
Some(1.0),
Some(1_000_000),
)
.is_none()
);
}
}
......@@ -7,7 +7,8 @@ use std::collections::{BinaryHeap, HashMap};
use anyhow::{Result, anyhow, bail};
use uuid::Uuid;
use super::types::{ReadyTurn, Trace, TurnTrace};
use super::types::{ReadyTurn, ReplayRequestHashes, Trace};
use crate::common::protocols::DirectRequest;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum DriverMode {
......@@ -18,12 +19,20 @@ enum DriverMode {
#[derive(Debug)]
struct SessionRuntime {
session_id: String,
turns: Vec<TurnTrace>,
turns: Vec<TurnRuntime>,
next_turn_index: usize,
next_ready_at_ms: Option<f64>,
in_flight: Option<Uuid>,
}
#[derive(Debug)]
struct TurnRuntime {
tokens: Vec<u32>,
max_output_tokens: usize,
delay_after_previous_ms: f64,
replay_hashes: ReplayRequestHashes,
}
#[derive(Debug)]
struct InFlightTurn {
session_index: usize,
......@@ -66,7 +75,6 @@ impl PartialOrd for ReadySession {
#[derive(Debug)]
pub struct WorkloadDriver {
mode: DriverMode,
block_size: usize,
sessions: Vec<SessionRuntime>,
in_flight: HashMap<Uuid, InFlightTurn>,
ready_sessions: BinaryHeap<ReadySession>,
......@@ -82,20 +90,36 @@ impl WorkloadDriver {
}
fn new(trace: Trace, mode: DriverMode) -> Result<Self> {
let block_size = trace.block_size;
let sessions: Vec<SessionRuntime> = trace
.sessions
.into_iter()
.map(|session| SessionRuntime {
session_id: session.session_id,
turns: session.turns,
next_turn_index: 0,
next_ready_at_ms: Some(match mode {
.map(|session| -> Result<SessionRuntime> {
let next_ready_at_ms = Some(match mode {
DriverMode::Trace => session.first_arrival_timestamp_ms.unwrap_or(0.0),
DriverMode::Concurrency => 0.0,
}),
in_flight: None,
});
let turns = session
.turns
.into_iter()
.map(|turn| -> Result<TurnRuntime> {
Ok(TurnRuntime {
tokens: turn.synthesize_tokens(block_size)?,
max_output_tokens: turn.max_output_tokens,
delay_after_previous_ms: turn.delay_after_previous_ms,
replay_hashes: turn.to_replay_hashes(block_size)?,
})
})
.collect::<Result<Vec<_>>>()?;
Ok(SessionRuntime {
session_id: session.session_id,
turns,
next_turn_index: 0,
next_ready_at_ms,
in_flight: None,
})
})
.collect();
.collect::<Result<Vec<_>>>()?;
let ready_sessions = sessions
.iter()
......@@ -111,7 +135,6 @@ impl WorkloadDriver {
Ok(Self {
mode,
block_size: trace.block_size,
sessions,
in_flight: HashMap::new(),
ready_sessions,
......@@ -146,16 +169,18 @@ impl WorkloadDriver {
.next_ready_at_ms
.expect("ready session must have a timestamp");
let request_uuid = Uuid::new_v4();
let replay_hashes = session.turns[turn_index]
.to_replay_hashes(self.block_size)
.expect("validated trace should always synthesize replay hashes");
let turn = &session.turns[turn_index];
let arrival_timestamp_ms = match self.mode {
DriverMode::Trace => Some(scheduled_ready_at_ms),
DriverMode::Concurrency => None,
};
let request = session.turns[turn_index]
.to_direct_request(self.block_size, request_uuid, arrival_timestamp_ms)
.expect("validated trace should always synthesize into a direct request");
let request = DirectRequest {
tokens: turn.tokens.clone(),
max_output_tokens: turn.max_output_tokens,
uuid: Some(request_uuid),
dp_rank: 0,
arrival_timestamp_ms,
};
session.in_flight = Some(request_uuid);
session.next_ready_at_ms = None;
self.in_flight.insert(
......@@ -170,7 +195,7 @@ impl WorkloadDriver {
session_id: session.session_id.clone(),
turn_index,
scheduled_ready_at_ms,
replay_hashes: Some(replay_hashes),
replay_hashes: Some(turn.replay_hashes.clone()),
request,
});
}
......
......@@ -59,12 +59,7 @@ impl TurnTrace {
Ok(())
}
pub fn to_direct_request(
&self,
block_size: usize,
request_uuid: Uuid,
arrival_timestamp_ms: Option<f64>,
) -> Result<DirectRequest> {
pub(crate) fn synthesize_tokens(&self, block_size: usize) -> Result<Vec<u32>> {
self.validate_block_size_and_capacity(block_size)?;
let mut tokens = Vec::with_capacity(self.input_length);
......@@ -85,6 +80,16 @@ impl TurnTrace {
);
}
Ok(tokens)
}
pub fn to_direct_request(
&self,
block_size: usize,
request_uuid: Uuid,
arrival_timestamp_ms: Option<f64>,
) -> Result<DirectRequest> {
let tokens = self.synthesize_tokens(block_size)?;
Ok(DirectRequest {
tokens,
max_output_tokens: self.max_output_tokens,
......
......@@ -309,6 +309,7 @@ impl TraceCollector {
}
let duration_s = (duration_ms / 1000.0).max(1e-9);
let itl_distribution = build_distribution_stats(&itls);
TraceSimulationReport {
request_counts: TraceRequestCounts {
num_requests: requests.len(),
......@@ -335,8 +336,8 @@ impl TraceCollector {
ttst: build_distribution_stats(&ttsts),
tpot: build_distribution_stats(&tpots),
itl: TraceInterTokenLatencyStats {
distribution: build_distribution_stats(&itls),
max_ms: max_value(&itls),
max_ms: itl_distribution.max_ms,
distribution: itl_distribution,
},
e2e: build_distribution_stats(&e2e_latencies),
output_token_throughput_per_user: build_distribution_stats(
......@@ -386,39 +387,41 @@ fn mean(values: &[f64]) -> f64 {
}
}
fn max_value(values: &[f64]) -> f64 {
values.iter().copied().reduce(f64::max).unwrap_or(0.0)
}
fn build_distribution_stats(values: &[f64]) -> TraceDistributionStats {
if values.is_empty() {
return TraceDistributionStats {
mean_ms: 0.0,
min_ms: 0.0,
max_ms: 0.0,
median_ms: 0.0,
p75_ms: 0.0,
p90_ms: 0.0,
p95_ms: 0.0,
p99_ms: 0.0,
std_ms: 0.0,
};
}
let mut sorted = values.to_vec();
sorted.sort_by(|left, right| left.total_cmp(right));
TraceDistributionStats {
mean_ms: mean(values),
min_ms: min_value(values),
max_ms: max_value(values),
median_ms: percentile(values, 50.0),
p75_ms: percentile(values, 75.0),
p90_ms: percentile(values, 90.0),
p95_ms: percentile(values, 95.0),
p99_ms: percentile(values, 99.0),
min_ms: sorted[0],
max_ms: *sorted.last().expect("sorted values must be non-empty"),
median_ms: percentile_sorted(&sorted, 50.0),
p75_ms: percentile_sorted(&sorted, 75.0),
p90_ms: percentile_sorted(&sorted, 90.0),
p95_ms: percentile_sorted(&sorted, 95.0),
p99_ms: percentile_sorted(&sorted, 99.0),
std_ms: std_dev(values),
}
}
fn percentile(values: &[f64], percentile: f64) -> f64 {
if values.is_empty() {
return 0.0;
}
let mut sorted = values.to_vec();
sorted.sort_by(|left, right| left.total_cmp(right));
fn percentile_sorted(sorted: &[f64], percentile: f64) -> f64 {
let rank = ((sorted.len() - 1) as f64 * percentile / 100.0).round() as usize;
sorted[rank.min(sorted.len() - 1)]
}
fn min_value(values: &[f64]) -> f64 {
values.iter().copied().reduce(f64::min).unwrap_or(0.0)
}
fn std_dev(values: &[f64]) -> f64 {
if values.is_empty() {
return 0.0;
......
......@@ -9,10 +9,11 @@ use dynamo_kv_router::config::KvRouterConfig;
use super::online;
use super::validate::{
validate_offline_concurrency_args, validate_offline_replay_args,
validate_offline_concurrency_args, validate_offline_disagg_concurrency_args,
validate_offline_disagg_replay_args, validate_offline_replay_args,
validate_online_concurrency_args, validate_online_replay_args,
};
use super::{ReplayRouterMode, TraceSimulationReport};
use super::{OfflineDisaggReplayConfig, ReplayRouterMode, TraceSimulationReport};
use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::loadgen::Trace;
......@@ -56,6 +57,28 @@ pub fn simulate_trace_file_with_router_mode(
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
}
pub fn simulate_trace_file_disagg_with_router_mode(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace_path: &Path,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let config = config.normalized()?;
validate_offline_disagg_replay_args(&config, router_mode)?;
let trace = Trace::from_mooncake(trace_path, config.prefill_args.block_size)?
.normalize_session_starts()?
.speed_up_timing(arrival_speedup_ratio)?;
let started_at = Instant::now();
let report = crate::replay::offline::simulate_trace_workload_disagg(
config,
router_config,
trace,
router_mode,
)?;
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
}
pub fn simulate_trace_live_file(
args: MockEngineArgs,
trace_path: &Path,
......@@ -130,6 +153,30 @@ pub fn simulate_trace_requests_with_router_mode(
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
}
pub fn simulate_trace_requests_disagg_with_router_mode(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let config = config.normalized()?;
validate_offline_disagg_replay_args(&config, router_mode)?;
if requests.is_empty() {
bail!("trace replay requires at least one request");
}
let started_at = Instant::now();
let report = crate::replay::offline::simulate_trace_disagg(
config,
router_config,
requests,
arrival_speedup_ratio,
router_mode,
)?;
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
}
pub fn simulate_trace_live_requests(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
......@@ -209,6 +256,27 @@ pub fn simulate_concurrency_file_with_router_mode(
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
}
pub fn simulate_concurrency_file_disagg_with_router_mode(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace_path: &Path,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let config = config.normalized()?;
validate_offline_disagg_concurrency_args(&config, max_in_flight, router_mode)?;
let trace = Trace::from_mooncake(trace_path, config.prefill_args.block_size)?;
let started_at = Instant::now();
let report = simulate_concurrency_workload_disagg_with_router_mode(
config,
router_config,
trace,
max_in_flight,
router_mode,
)?;
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
}
pub fn simulate_concurrency_live_file(
args: MockEngineArgs,
trace_path: &Path,
......@@ -326,6 +394,28 @@ pub fn simulate_concurrency_requests_with_router_mode(
)
}
pub fn simulate_concurrency_requests_disagg_with_router_mode(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let config = config.normalized()?;
validate_offline_disagg_concurrency_args(&config, max_in_flight, router_mode)?;
if requests.is_empty() {
bail!("concurrency replay requires at least one request");
}
crate::replay::offline::simulate_concurrency_disagg(
config,
router_config,
requests,
max_in_flight,
router_mode,
)
}
pub fn simulate_trace_workload(
args: MockEngineArgs,
trace: Trace,
......@@ -360,6 +450,24 @@ pub fn simulate_trace_workload_with_router_mode(
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
}
pub fn simulate_trace_workload_disagg_with_router_mode(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let config = config.normalized()?;
validate_offline_disagg_replay_args(&config, router_mode)?;
let started_at = Instant::now();
let report = crate::replay::offline::simulate_trace_workload_disagg(
config,
router_config,
trace,
router_mode,
)?;
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
}
pub fn simulate_trace_live_workload(
args: MockEngineArgs,
trace: Trace,
......@@ -422,6 +530,24 @@ pub fn simulate_concurrency_workload_with_router_mode(
)
}
pub fn simulate_concurrency_workload_disagg_with_router_mode(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let config = config.normalized()?;
validate_offline_disagg_concurrency_args(&config, max_in_flight, router_mode)?;
crate::replay::offline::simulate_concurrency_workload_disagg(
config,
router_config,
trace,
max_in_flight,
router_mode,
)
}
pub fn simulate_concurrency_live_workload(
args: MockEngineArgs,
trace: Trace,
......
......@@ -10,7 +10,7 @@ mod validate;
use std::collections::VecDeque;
use crate::common::protocols::DirectRequest;
use crate::common::protocols::{DirectRequest, MockEngineArgs};
pub(crate) use collector::TraceCollector;
#[cfg(test)]
......@@ -25,20 +25,50 @@ pub enum ReplayRouterMode {
KvRouter,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ReplayArgsMode {
Aggregated,
Disagg,
}
#[derive(Clone, Debug)]
pub struct OfflineDisaggReplayConfig {
pub prefill_args: MockEngineArgs,
pub decode_args: MockEngineArgs,
pub num_prefill_workers: usize,
pub num_decode_workers: usize,
}
impl OfflineDisaggReplayConfig {
pub fn normalized(self) -> anyhow::Result<Self> {
Ok(Self {
prefill_args: self.prefill_args.normalized()?,
decode_args: self.decode_args.normalized()?,
num_prefill_workers: self.num_prefill_workers,
num_decode_workers: self.num_decode_workers,
})
}
}
pub use entrypoints::{
simulate_concurrency_file, simulate_concurrency_file_with_router_mode,
simulate_concurrency_live_file, simulate_concurrency_live_file_with_router_mode,
simulate_concurrency_live_requests, simulate_concurrency_live_requests_with_router_mode,
simulate_concurrency_live_workload, simulate_concurrency_live_workload_with_router_mode,
simulate_concurrency_requests, simulate_concurrency_requests_with_router_mode,
simulate_concurrency_workload, simulate_concurrency_workload_with_router_mode,
simulate_trace_file, simulate_trace_file_with_router_mode, simulate_trace_live_file,
simulate_trace_live_file_with_router_mode, simulate_trace_live_requests,
simulate_trace_live_requests_with_router_mode, simulate_trace_live_workload,
simulate_trace_live_workload_with_router_mode, simulate_trace_requests,
simulate_concurrency_file, simulate_concurrency_file_disagg_with_router_mode,
simulate_concurrency_file_with_router_mode, simulate_concurrency_live_file,
simulate_concurrency_live_file_with_router_mode, simulate_concurrency_live_requests,
simulate_concurrency_live_requests_with_router_mode, simulate_concurrency_live_workload,
simulate_concurrency_live_workload_with_router_mode, simulate_concurrency_requests,
simulate_concurrency_requests_disagg_with_router_mode,
simulate_concurrency_requests_with_router_mode, simulate_concurrency_workload,
simulate_concurrency_workload_disagg_with_router_mode,
simulate_concurrency_workload_with_router_mode, simulate_trace_file,
simulate_trace_file_disagg_with_router_mode, simulate_trace_file_with_router_mode,
simulate_trace_live_file, simulate_trace_live_file_with_router_mode,
simulate_trace_live_requests, simulate_trace_live_requests_with_router_mode,
simulate_trace_live_workload, simulate_trace_live_workload_with_router_mode,
simulate_trace_requests, simulate_trace_requests_disagg_with_router_mode,
simulate_trace_requests_with_router_mode, simulate_trace_workload,
simulate_trace_workload_with_router_mode,
simulate_trace_workload_disagg_with_router_mode, simulate_trace_workload_with_router_mode,
};
pub use validate::validate_replay_args_mode;
pub(crate) fn normalize_trace_requests(
mut requests: Vec<DirectRequest>,
......
......@@ -15,10 +15,11 @@ The public replay entrypoints live one level up in `lib/mocker/src/replay/entryp
Offline replay starts in `lib/mocker/src/replay/offline/mod.rs`.
`offline/mod.rs` chooses between two implementations:
`offline/mod.rs` chooses between three implementations:
- `lib/mocker/src/replay/offline/single.rs` for the special case `num_workers == 1` with the vLLM engine
- `lib/mocker/src/replay/offline/multi.rs` for everything else, including multi-worker replay and `kv_router` replay
- `lib/mocker/src/replay/offline/disagg.rs` for offline disaggregated prefill/decode replay
## File Map
......@@ -28,6 +29,8 @@ Offline replay starts in `lib/mocker/src/replay/offline/mod.rs`.
Minimal replay loop for one vLLM worker.
- `lib/mocker/src/replay/offline/multi.rs`
General offline cluster simulator for multi-worker replay and KV-router replay.
- `lib/mocker/src/replay/offline/disagg.rs`
Offline two-stage replay harness with separate prefill and decode pools.
- `lib/mocker/src/replay/offline/state.rs`
Per-worker wrapper around `EngineCore`, including optional KV event capture.
- `lib/mocker/src/replay/offline/events.rs`
......@@ -114,7 +117,7 @@ So offline replay is not a toy simulator. It reuses the real per-pass mocker sch
## Completion Event Queue
The multi-worker harness uses `SimulationEvent` from `lib/mocker/src/replay/offline/events.rs` as a min-time priority queue implemented with `BinaryHeap`.
The multi-worker and disagg harnesses use `SimulationEvent` from `lib/mocker/src/replay/offline/events.rs` as a min-time priority queue implemented with `BinaryHeap`.
Right now the only scheduled event type is:
......@@ -122,6 +125,7 @@ Right now the only scheduled event type is:
That event carries:
- worker `stage` (`aggregated`, `prefill`, or `decode`)
- `worker_idx`
- `completed_requests`
- `output_signals`
......@@ -167,7 +171,7 @@ flowchart LR
M --> G
```
### Why KV events are captured only here
### Why KV events are captured only where needed
When offline replay uses `kv_router`, workers are created with KV event capture enabled via:
......@@ -177,6 +181,37 @@ When offline replay uses `kv_router`, workers are created with KV event capture
That causes each pass to return router-visible `kv_events`, which the harness applies synchronously to the offline router indexer after the pass completes.
In round-robin mode, this capture is skipped because nothing consumes those events.
In offline disagg replay, only the prefill workers capture and publish KV events; the decode workers
run with capture disabled because the decode router is overlap-blind and does not consume router
events.
## Disaggregated Harness
The disaggregated runtime in `lib/mocker/src/replay/offline/disagg.rs` models two distinct stages:
- a prefill router and prefill worker pool
- a decode router and decode worker pool
It keeps one logical clock and one completion-event heap, but request ownership moves through a
two-stage state machine instead of the aggregated single-pool lifecycle.
The prefill router is derived from the main router config with `router_track_active_blocks = false`.
The decode router is derived with:
- overlap disabled
- `assume_kv_reuse = false`
- `track_prefill_tokens = false`
The prefill stage runs a hidden synthetic one-token bootstrap request. When prefill completes, the
harness:
1. applies any prefill KV events
2. marks prefill complete in the prefill router
3. frees prefill router state
4. enqueues the original request into decode at the same logical timestamp
Decode then runs with normal collector visibility. The public replay report remains decode-only, so
TTFT includes prefill queueing and prefill compute.
## Trace vs Concurrency Modes
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::{BinaryHeap, HashMap, VecDeque};
use anyhow::{Result, anyhow, bail};
use dynamo_kv_router::config::KvRouterConfig;
use dynamo_kv_router::protocols::RouterEvent;
use uuid::Uuid;
use super::events::{SimulationEvent, SimulationWorkerStage};
use super::normalize_trace_requests;
use super::runtime_utils::{
WorkerCompletionPayload, next_timestamp as choose_next_timestamp, pop_next_concurrency_ready,
pop_next_trace_ready, pop_ready_decode_handoff, pop_ready_worker_completion,
push_decode_handoff, push_worker_completion,
};
#[cfg(test)]
use super::state::DisaggPhase;
#[cfg(test)]
use super::state::DisaggRequestSnapshot;
use super::state::{DisaggRequestState, OfflineWorkerState};
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal};
use crate::loadgen::{ReplayRequestHashes, Trace, WorkloadDriver};
use crate::replay::router::OfflineReplayRouter;
use crate::replay::{
OfflineDisaggReplayConfig, ReplayRouterMode, TraceCollector, TraceSimulationReport,
};
use crate::scheduler::RouterEventVisibility;
#[derive(Debug, Clone, Copy)]
enum ReplayMode {
Trace,
Concurrency { max_in_flight: usize },
}
enum AdmissionSource {
Requests(VecDeque<DirectRequest>),
Workload(WorkloadDriver),
}
#[cfg(test)]
#[derive(Debug, Default, Clone, PartialEq)]
struct DisaggRuntimeStats {
request_snapshots: HashMap<Uuid, DisaggRequestSnapshot>,
prefill_assignments: HashMap<Uuid, usize>,
decode_assignments: HashMap<Uuid, usize>,
handoff_ms: HashMap<Uuid, f64>,
prefill_marked_count: usize,
prefill_freed_count: usize,
decode_freed_count: usize,
max_prefill_router_pending: usize,
max_decode_router_pending: usize,
}
#[cfg(not(test))]
#[derive(Debug, Default, Clone, PartialEq, Eq)]
struct DisaggRuntimeStats;
struct DisaggRuntime {
now_ms: f64,
next_prefill_worker_idx: usize,
next_decode_worker_idx: usize,
next_event_seq: u64,
admission: AdmissionSource,
prefill_workers: Vec<OfflineWorkerState>,
decode_workers: Vec<OfflineWorkerState>,
prefill_router: Option<OfflineReplayRouter>,
decode_router: Option<OfflineReplayRouter>,
requests: HashMap<Uuid, DisaggRequestState>,
collector: TraceCollector,
events: BinaryHeap<SimulationEvent>,
mode: ReplayMode,
stats: DisaggRuntimeStats,
}
impl DisaggRuntime {
fn new(
config: &OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
pending: VecDeque<DirectRequest>,
mode: ReplayMode,
router_mode: ReplayRouterMode,
) -> Result<Self> {
Self::new_with_source(
config,
router_config,
AdmissionSource::Requests(pending),
mode,
router_mode,
)
}
fn new_workload(
config: &OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
driver: WorkloadDriver,
mode: ReplayMode,
router_mode: ReplayRouterMode,
) -> Result<Self> {
Self::new_with_source(
config,
router_config,
AdmissionSource::Workload(driver),
mode,
router_mode,
)
}
fn new_with_source(
config: &OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
admission: AdmissionSource,
mode: ReplayMode,
router_mode: ReplayRouterMode,
) -> Result<Self> {
let (prefill_router, decode_router) = match router_mode {
ReplayRouterMode::RoundRobin => (None, None),
ReplayRouterMode::KvRouter => {
let prefill_router_config =
derive_prefill_router_config(&config.prefill_args, router_config.clone());
let decode_router_config =
derive_decode_router_config(&config.decode_args, router_config);
(
Some(OfflineReplayRouter::new(
&config.prefill_args,
Some(prefill_router_config),
config.num_prefill_workers,
)?),
Some(OfflineReplayRouter::new(
&config.decode_args,
Some(decode_router_config),
config.num_decode_workers,
)?),
)
}
};
Ok(Self {
now_ms: 0.0,
next_prefill_worker_idx: 0,
next_decode_worker_idx: 0,
next_event_seq: 0,
admission,
prefill_workers: (0..config.num_prefill_workers)
.map(|worker_idx| {
OfflineWorkerState::new(
worker_idx,
config.prefill_args.clone(),
prefill_router.is_some(),
)
})
.collect(),
decode_workers: (0..config.num_decode_workers)
.map(|worker_idx| {
OfflineWorkerState::new(worker_idx, config.decode_args.clone(), false)
})
.collect(),
prefill_router,
decode_router,
requests: HashMap::new(),
collector: TraceCollector::default(),
events: BinaryHeap::new(),
mode,
#[cfg(test)]
stats: DisaggRuntimeStats::default(),
#[cfg(not(test))]
stats: DisaggRuntimeStats,
})
}
fn cluster_in_flight(&self) -> usize {
self.prefill_workers
.iter()
.map(OfflineWorkerState::in_flight)
.sum::<usize>()
+ self
.decode_workers
.iter()
.map(OfflineWorkerState::in_flight)
.sum::<usize>()
+ self
.prefill_router
.as_ref()
.map_or(0, OfflineReplayRouter::pending_count)
+ self
.decode_router
.as_ref()
.map_or(0, OfflineReplayRouter::pending_count)
}
fn next_prefill_worker(&mut self) -> usize {
let worker_idx = self.next_prefill_worker_idx;
self.next_prefill_worker_idx =
(self.next_prefill_worker_idx + 1) % self.prefill_workers.len();
worker_idx
}
fn next_decode_worker(&mut self) -> usize {
let worker_idx = self.next_decode_worker_idx;
self.next_decode_worker_idx = (self.next_decode_worker_idx + 1) % self.decode_workers.len();
worker_idx
}
fn record_router_pending(&mut self) {
#[cfg(test)]
{
self.stats.max_prefill_router_pending = self.stats.max_prefill_router_pending.max(
self.prefill_router
.as_ref()
.map_or(0, OfflineReplayRouter::pending_count),
);
self.stats.max_decode_router_pending = self.stats.max_decode_router_pending.max(
self.decode_router
.as_ref()
.map_or(0, OfflineReplayRouter::pending_count),
);
}
}
fn validate_worker_idx(&self, stage: SimulationWorkerStage, worker_idx: usize) -> Result<()> {
let worker_count = match stage {
SimulationWorkerStage::Prefill => self.prefill_workers.len(),
SimulationWorkerStage::Decode => self.decode_workers.len(),
SimulationWorkerStage::Aggregated => unreachable!("aggregated stage is not used"),
};
if worker_idx >= worker_count {
bail!("offline disagg replay selected unknown {stage:?} worker index {worker_idx}");
}
Ok(())
}
fn state(&self, uuid: Uuid) -> Result<&DisaggRequestState> {
self.requests
.get(&uuid)
.ok_or_else(|| anyhow!("offline disagg replay missing request state for {uuid}"))
}
fn state_mut(&mut self, uuid: Uuid) -> Result<&mut DisaggRequestState> {
self.requests
.get_mut(&uuid)
.ok_or_else(|| anyhow!("offline disagg replay missing request state for {uuid}"))
}
fn dispatch_prefill(&mut self, uuid: Uuid, worker_idx: usize) -> Result<()> {
self.validate_worker_idx(SimulationWorkerStage::Prefill, worker_idx)?;
let request = self.state(uuid)?.build_prefill_request()?;
self.prefill_workers[worker_idx].receive_request(request);
self.state_mut(uuid)?.start_prefill(worker_idx);
#[cfg(test)]
{
self.stats.prefill_assignments.insert(uuid, worker_idx);
}
Ok(())
}
fn dispatch_decode(&mut self, uuid: Uuid, worker_idx: usize) -> Result<()> {
self.validate_worker_idx(SimulationWorkerStage::Decode, worker_idx)?;
let request = self.state(uuid)?.build_decode_request()?;
self.decode_workers[worker_idx].receive_request(request);
self.state_mut(uuid)?.start_decode(worker_idx);
#[cfg(test)]
{
self.stats.decode_assignments.insert(uuid, worker_idx);
}
Ok(())
}
fn dispatch_prefill_admissions(&mut self, admissions: Vec<(Uuid, usize)>) -> Result<()> {
for (uuid, worker_idx) in admissions {
if !self.state(uuid)?.is_queued_prefill() {
bail!("offline disagg replay expected queued prefill request for {uuid}");
}
self.dispatch_prefill(uuid, worker_idx)?;
}
Ok(())
}
fn dispatch_decode_admissions(&mut self, admissions: Vec<(Uuid, usize)>) -> Result<()> {
for (uuid, worker_idx) in admissions {
if !self.state(uuid)?.is_queued_decode() {
bail!("offline disagg replay expected queued decode request for {uuid}");
}
self.dispatch_decode(uuid, worker_idx)?;
}
Ok(())
}
fn enqueue_decode(&mut self, uuid: Uuid) -> Result<()> {
let Some(decode_router) = self.decode_router.as_mut() else {
let worker_idx = self.next_decode_worker();
self.dispatch_decode(uuid, worker_idx)?;
return Ok(());
};
let maybe_worker_idx = {
let requests = &self.requests;
let request = requests
.get(&uuid)
.ok_or_else(|| anyhow!("offline disagg replay missing request state for {uuid}"))?
.original_request()?;
decode_router.submit_request_with_hashes(request, None, self.now_ms)?
};
self.record_router_pending();
#[cfg(test)]
{
self.stats.handoff_ms.insert(uuid, self.now_ms);
}
if let Some(worker_idx) = maybe_worker_idx {
self.dispatch_decode(uuid, worker_idx)?;
return Ok(());
}
self.state_mut(uuid)?.queue_decode();
Ok(())
}
fn on_external_arrival(
&mut self,
mut request: DirectRequest,
arrival_time_ms: f64,
replay_hashes: Option<ReplayRequestHashes>,
) -> Result<Uuid> {
let uuid = request.uuid.unwrap_or_else(Uuid::new_v4);
request.uuid = Some(uuid);
request.arrival_timestamp_ms = Some(arrival_time_ms);
self.collector.on_arrival(
uuid,
arrival_time_ms,
request.tokens.len(),
request.max_output_tokens,
);
self.requests
.insert(uuid, DisaggRequestState::new(request, arrival_time_ms));
let Some(prefill_router) = self.prefill_router.as_mut() else {
let worker_idx = self.next_prefill_worker();
self.dispatch_prefill(uuid, worker_idx)?;
return Ok(uuid);
};
let maybe_worker_idx = {
let requests = &self.requests;
let request = requests
.get(&uuid)
.ok_or_else(|| anyhow!("offline disagg replay missing request state for {uuid}"))?
.original_request()?;
prefill_router.submit_request_with_hashes(request, replay_hashes, self.now_ms)?
};
self.record_router_pending();
if let Some(worker_idx) = maybe_worker_idx {
self.dispatch_prefill(uuid, worker_idx)?;
}
Ok(uuid)
}
fn is_done(&self) -> bool {
self.events.is_empty()
&& self.cluster_in_flight() == 0
&& match &self.admission {
AdmissionSource::Requests(pending) => pending.is_empty(),
AdmissionSource::Workload(driver) => driver.is_drained(),
}
&& self
.prefill_workers
.iter()
.all(OfflineWorkerState::is_drained)
&& self
.decode_workers
.iter()
.all(OfflineWorkerState::is_drained)
}
fn next_timestamp(&mut self) -> Option<f64> {
let next_event_ms = self.events.peek().map(|event| event.at_ms);
let cluster_in_flight = self.cluster_in_flight();
let next_arrival_ms = match (&self.mode, &mut self.admission) {
(ReplayMode::Trace, AdmissionSource::Requests(pending)) => pending
.front()
.and_then(|request| request.arrival_timestamp_ms),
(ReplayMode::Trace, AdmissionSource::Workload(driver)) => driver.next_ready_time_ms(),
(ReplayMode::Concurrency { max_in_flight }, AdmissionSource::Workload(driver)) => {
if cluster_in_flight < *max_in_flight {
driver.next_ready_time_ms()
} else {
None
}
}
(ReplayMode::Concurrency { .. }, AdmissionSource::Requests(_)) => None,
};
choose_next_timestamp(next_arrival_ms, next_event_ms)
}
fn apply_prefill_router_events(&mut self, events: Vec<RouterEvent>) -> Result<()> {
let Some(prefill_router) = self.prefill_router.as_mut() else {
return Ok(());
};
for event in events {
prefill_router.apply_event(event)?;
}
Ok(())
}
fn process_prefill_signal(&mut self, signal: OutputSignal) -> Result<()> {
if !signal.completed {
return Ok(());
}
if self.prefill_router.is_some() {
let prefill_complete_admissions = {
let prefill_router = self.prefill_router.as_mut().expect("router checked above");
prefill_router.mark_prefill_completed(signal.uuid)?
};
#[cfg(test)]
{
self.stats.prefill_marked_count += 1;
}
self.record_router_pending();
self.dispatch_prefill_admissions(prefill_complete_admissions)?;
let admissions = {
let prefill_router = self.prefill_router.as_mut().expect("router checked above");
prefill_router.free(signal.uuid)?
};
#[cfg(test)]
{
self.stats.prefill_freed_count += 1;
}
self.record_router_pending();
self.dispatch_prefill_admissions(admissions)?;
}
self.enqueue_decode_after_handoff(signal.uuid, signal.handoff_delay_ms)
}
fn process_decode_signal(&mut self, signal: OutputSignal) -> Result<()> {
if !signal.completed {
return Ok(());
}
let admissions = if let Some(decode_router) = self.decode_router.as_mut() {
let admissions = decode_router.free(signal.uuid)?;
#[cfg(test)]
{
self.stats.decode_freed_count += 1;
}
admissions
} else {
Vec::new()
};
self.record_router_pending();
if let AdmissionSource::Workload(driver) = &mut self.admission {
driver.on_complete(signal.uuid, self.now_ms)?;
}
self.state_mut(signal.uuid)?.mark_done();
self.dispatch_decode_admissions(admissions)?;
Ok(())
}
fn process_prefill_pass(
&mut self,
worker_idx: usize,
completed_requests: usize,
output_signals: Vec<OutputSignal>,
kv_events: Vec<RouterEvent>,
) -> Result<()> {
self.prefill_workers[worker_idx].mark_completed(completed_requests);
self.apply_prefill_router_events(kv_events)?;
for signal in output_signals {
self.process_prefill_signal(signal)?;
}
Ok(())
}
fn process_decode_pass(
&mut self,
worker_idx: usize,
completed_requests: usize,
output_signals: Vec<OutputSignal>,
) -> Result<()> {
self.decode_workers[worker_idx].mark_completed(completed_requests);
for signal in output_signals {
self.process_decode_signal(signal)?;
}
Ok(())
}
fn apply_worker_completions(&mut self) -> Result<bool> {
let mut changed = false;
while let Some(WorkerCompletionPayload {
stage,
worker_idx,
completed_requests,
output_signals,
kv_events,
}) = pop_ready_worker_completion(&mut self.events, self.now_ms)
{
match stage {
SimulationWorkerStage::Prefill => {
self.prefill_workers[worker_idx].mark_idle();
self.process_prefill_pass(
worker_idx,
completed_requests,
output_signals,
kv_events,
)?;
}
SimulationWorkerStage::Decode => {
self.decode_workers[worker_idx].mark_idle();
self.process_decode_pass(worker_idx, completed_requests, output_signals)?;
}
SimulationWorkerStage::Aggregated => {
bail!("offline disagg replay received an aggregated completion event")
}
}
changed = true;
}
Ok(changed)
}
fn apply_decode_handoffs(&mut self) -> Result<bool> {
let mut changed = false;
while let Some(uuid) = pop_ready_decode_handoff(&mut self.events, self.now_ms) {
self.enqueue_decode(uuid)?;
changed = true;
}
Ok(changed)
}
fn enqueue_decode_after_handoff(
&mut self,
uuid: Uuid,
handoff_delay_ms: Option<f64>,
) -> Result<()> {
if let Some(delay_ms) = handoff_delay_ms
&& delay_ms > 0.0
{
push_decode_handoff(
&mut self.events,
&mut self.next_event_seq,
self.now_ms + delay_ms,
uuid,
);
return Ok(());
}
self.enqueue_decode(uuid)
}
fn release_trace_arrivals(&mut self) -> Result<bool> {
let mut released_any = false;
if matches!(self.admission, AdmissionSource::Requests(_)) {
loop {
let next_ready = match &mut self.admission {
AdmissionSource::Requests(pending) => {
pop_next_trace_ready(pending, self.now_ms)
}
AdmissionSource::Workload(_) => unreachable!(),
};
let Some((request, arrival_ms)) = next_ready else {
break;
};
self.on_external_arrival(request, arrival_ms, None)?;
released_any = true;
}
return Ok(released_any);
}
let ready_requests = match &mut self.admission {
AdmissionSource::Requests(_) => unreachable!(),
AdmissionSource::Workload(driver) => driver.pop_ready(self.now_ms, usize::MAX),
};
for ready in ready_requests {
self.on_external_arrival(
ready.request,
ready.scheduled_ready_at_ms,
ready.replay_hashes,
)?;
released_any = true;
}
Ok(released_any)
}
fn top_off_concurrency(&mut self, max_in_flight: usize) -> Result<bool> {
let mut released_any = false;
if matches!(self.admission, AdmissionSource::Requests(_)) {
loop {
let cluster_in_flight = self.cluster_in_flight();
let next_ready = match &mut self.admission {
AdmissionSource::Requests(pending) => pop_next_concurrency_ready(
pending,
self.now_ms,
cluster_in_flight,
max_in_flight,
),
AdmissionSource::Workload(_) => unreachable!(),
};
let Some((request, arrival_ms)) = next_ready else {
break;
};
self.on_external_arrival(request, arrival_ms, None)?;
released_any = true;
}
return Ok(released_any);
}
let available = max_in_flight.saturating_sub(self.cluster_in_flight());
if available == 0 {
return Ok(false);
}
let ready_requests = match &mut self.admission {
AdmissionSource::Requests(_) => unreachable!(),
AdmissionSource::Workload(driver) => driver.pop_ready(self.now_ms, available),
};
for ready in ready_requests {
self.on_external_arrival(ready.request, self.now_ms, ready.replay_hashes)?;
released_any = true;
}
Ok(released_any)
}
fn drive_prefill_workers(&mut self) -> Result<bool> {
let mut changed = false;
for worker_idx in 0..self.prefill_workers.len() {
loop {
if !self.prefill_workers[worker_idx].is_ready() {
break;
}
let executed = self.prefill_workers[worker_idx].execute_hidden_pass(self.now_ms);
changed = true;
let completion_kv_events =
if executed.router_event_visibility == RouterEventVisibility::PassStart {
self.apply_prefill_router_events(executed.kv_events)?;
Vec::new()
} else {
executed.kv_events
};
if executed.end_ms == self.now_ms {
self.process_prefill_pass(
worker_idx,
executed.completed_requests,
executed.output_signals,
completion_kv_events,
)?;
continue;
}
self.prefill_workers[worker_idx].mark_busy();
push_worker_completion(
&mut self.events,
&mut self.next_event_seq,
executed.end_ms,
WorkerCompletionPayload {
stage: SimulationWorkerStage::Prefill,
worker_idx,
completed_requests: executed.completed_requests,
output_signals: executed.output_signals,
kv_events: completion_kv_events,
},
);
break;
}
}
Ok(changed)
}
fn drive_decode_workers(&mut self) -> Result<bool> {
let mut changed = false;
for worker_idx in 0..self.decode_workers.len() {
loop {
if !self.decode_workers[worker_idx].is_ready() {
break;
}
let executed = {
let (workers, collector) = (&mut self.decode_workers, &mut self.collector);
workers[worker_idx].execute_pass(collector, self.now_ms)
};
changed = true;
if executed.end_ms == self.now_ms {
self.process_decode_pass(
worker_idx,
executed.completed_requests,
executed.output_signals,
)?;
continue;
}
self.decode_workers[worker_idx].mark_busy();
push_worker_completion(
&mut self.events,
&mut self.next_event_seq,
executed.end_ms,
WorkerCompletionPayload {
stage: SimulationWorkerStage::Decode,
worker_idx,
completed_requests: executed.completed_requests,
output_signals: executed.output_signals,
kv_events: Vec::new(),
},
);
break;
}
}
Ok(changed)
}
fn drain_current_timestamp(&mut self) -> Result<()> {
loop {
let mut changed = self.apply_worker_completions()?;
changed |= self.apply_decode_handoffs()?;
changed |= match self.mode {
ReplayMode::Trace => self.release_trace_arrivals()?,
ReplayMode::Concurrency { max_in_flight } => {
self.top_off_concurrency(max_in_flight)?
}
};
changed |= self.drive_prefill_workers()?;
changed |= self.drive_decode_workers()?;
if !changed {
break;
}
}
Ok(())
}
fn finish_test_stats(&mut self) {
#[cfg(test)]
{
self.stats.request_snapshots = self
.requests
.iter()
.map(|(uuid, state)| (*uuid, state.debug_snapshot()))
.collect();
}
}
fn run(mut self) -> Result<(TraceCollector, DisaggRuntimeStats)> {
self.drain_current_timestamp()?;
while !self.is_done() {
let Some(next_timestamp_ms) = self.next_timestamp() else {
bail!(
"offline disagg replay reached a dead end with {} in-flight requests remaining",
self.cluster_in_flight()
);
};
self.now_ms = next_timestamp_ms;
self.drain_current_timestamp()?;
}
self.finish_test_stats();
Ok((self.collector, self.stats))
}
}
fn base_router_config(
args: &MockEngineArgs,
router_config: Option<KvRouterConfig>,
) -> KvRouterConfig {
let mut config = router_config.unwrap_or_default();
if let Some(policy) = args.router_queue_policy {
config.router_queue_policy = policy;
}
config
}
fn derive_prefill_router_config(
args: &MockEngineArgs,
router_config: Option<KvRouterConfig>,
) -> KvRouterConfig {
let mut config = base_router_config(args, router_config);
config.router_track_active_blocks = false;
config
}
fn derive_decode_router_config(
args: &MockEngineArgs,
router_config: Option<KvRouterConfig>,
) -> KvRouterConfig {
let mut config = base_router_config(args, router_config);
config.overlap_score_weight = 0.0;
config.router_assume_kv_reuse = false;
config.router_track_prefill_tokens = false;
config
}
pub(crate) fn simulate_trace_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let (collector, _) = DisaggRuntime::new(
&config,
router_config,
pending,
ReplayMode::Trace,
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let pending = VecDeque::from(requests);
let (collector, _) = DisaggRuntime::new(
&config,
router_config,
pending,
ReplayMode::Concurrency { max_in_flight },
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_trace_workload_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let driver = WorkloadDriver::new_trace(trace)?;
let (collector, _) = DisaggRuntime::new_workload(
&config,
router_config,
driver,
ReplayMode::Trace,
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_workload_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let driver = WorkloadDriver::new_concurrency(trace)?;
let (collector, _) = DisaggRuntime::new_workload(
&config,
router_config,
driver,
ReplayMode::Concurrency { max_in_flight },
router_mode,
)?
.run()?;
Ok(collector.finish())
}
#[cfg(test)]
fn run_trace_collect(
config: &OfflineDisaggReplayConfig,
requests: Vec<DirectRequest>,
router_config: Option<KvRouterConfig>,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> (TraceCollector, DisaggRuntimeStats) {
let pending = normalize_trace_requests(requests, arrival_speedup_ratio).unwrap();
DisaggRuntime::new(
config,
router_config,
pending,
ReplayMode::Trace,
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
fn run_concurrency_collect(
config: &OfflineDisaggReplayConfig,
requests: Vec<DirectRequest>,
router_config: Option<KvRouterConfig>,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, DisaggRuntimeStats) {
DisaggRuntime::new(
config,
router_config,
VecDeque::from(requests),
ReplayMode::Concurrency { max_in_flight },
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::common::protocols::{MockEngineArgs, WorkerType};
fn staged_args(worker_type: WorkerType, speedup_ratio: f64) -> MockEngineArgs {
MockEngineArgs::builder()
.block_size(64)
.num_gpu_blocks(256)
.max_num_batched_tokens(Some(8192))
.max_num_seqs(Some(8))
.enable_prefix_caching(true)
.enable_chunked_prefill(true)
.speedup_ratio(speedup_ratio)
.decode_speedup_ratio(speedup_ratio)
.worker_type(worker_type)
.build()
.unwrap()
}
fn disagg_config() -> OfflineDisaggReplayConfig {
OfflineDisaggReplayConfig {
prefill_args: staged_args(WorkerType::Prefill, 1000.0),
decode_args: staged_args(WorkerType::Decode, 1000.0),
num_prefill_workers: 2,
num_decode_workers: 2,
}
}
fn disagg_config_with_handoff_delay() -> OfflineDisaggReplayConfig {
let mut config = disagg_config();
config.prefill_args.kv_transfer_bandwidth = Some(1.0);
config.prefill_args.kv_bytes_per_token = Some(1_000_000);
config
}
fn router_config() -> KvRouterConfig {
KvRouterConfig {
router_queue_threshold: Some(1.25),
..KvRouterConfig::default()
}
}
fn request(
uuid: u128,
prompt_tokens: usize,
output_tokens: usize,
arrival_ms: f64,
) -> DirectRequest {
DirectRequest {
tokens: vec![1; prompt_tokens],
max_output_tokens: output_tokens,
uuid: Some(Uuid::from_u128(uuid)),
dp_rank: 0,
arrival_timestamp_ms: Some(arrival_ms),
}
}
#[test]
fn test_derive_stage_router_configs_force_required_overrides() {
let config = KvRouterConfig {
overlap_score_weight: 2.0,
router_track_active_blocks: true,
router_assume_kv_reuse: true,
router_track_prefill_tokens: true,
..KvRouterConfig::default()
};
let args = staged_args(WorkerType::Prefill, 1.0);
let prefill = derive_prefill_router_config(&args, Some(config.clone()));
let decode = derive_decode_router_config(&args, Some(config));
assert!(!prefill.router_track_active_blocks);
assert_eq!(decode.overlap_score_weight, 0.0);
assert!(!decode.router_assume_kv_reuse);
assert!(!decode.router_track_prefill_tokens);
}
#[rstest::rstest]
#[case(ReplayRouterMode::RoundRobin)]
#[case(ReplayRouterMode::KvRouter)]
fn test_trace_smoke_reports_decode_only_tokens(#[case] router_mode: ReplayRouterMode) {
let config = disagg_config();
let requests = vec![request(1, 128, 3, 5.0)];
let router_config = (router_mode == ReplayRouterMode::KvRouter).then(router_config);
let (collector, stats) =
run_trace_collect(&config, requests, router_config, 1.0, router_mode);
let snapshot = collector.snapshot(Uuid::from_u128(1)).unwrap();
let report = collector.finish();
assert_eq!(snapshot.arrival_time_ms, 0.0);
assert!(snapshot.first_admit_ms.is_some());
assert!(snapshot.first_token_ms.is_some());
assert_eq!(snapshot.output_length, 3);
assert_eq!(report.request_counts.completed_requests, 1);
assert_eq!(
stats.request_snapshots[&Uuid::from_u128(1)].phase,
DisaggPhase::Done
);
}
#[rstest::rstest]
#[case(ReplayRouterMode::RoundRobin)]
#[case(ReplayRouterMode::KvRouter)]
fn test_prefill_and_decode_use_separate_worker_pools(#[case] router_mode: ReplayRouterMode) {
let config = disagg_config();
let requests = vec![request(1, 128, 2, 0.0), request(2, 128, 2, 10.0)];
let router_config = (router_mode == ReplayRouterMode::KvRouter).then(router_config);
let (_, stats) = run_trace_collect(&config, requests, router_config, 1.0, router_mode);
for uuid in [Uuid::from_u128(1), Uuid::from_u128(2)] {
assert!(stats.prefill_assignments.contains_key(&uuid));
assert!(stats.decode_assignments.contains_key(&uuid));
}
}
#[test]
fn test_prefill_overlap_prefers_same_worker_after_handoff_delay() {
let config = disagg_config();
let requests = vec![request(1, 128, 2, 0.0), request(2, 128, 2, 100.0)];
let (_, stats) = run_trace_collect(
&config,
requests,
Some(router_config()),
1.0,
ReplayRouterMode::KvRouter,
);
assert_eq!(
stats.prefill_assignments[&Uuid::from_u128(1)],
stats.prefill_assignments[&Uuid::from_u128(2)],
);
}
#[rstest::rstest]
#[case(ReplayRouterMode::RoundRobin)]
#[case(ReplayRouterMode::KvRouter)]
fn test_concurrency_backfill_waits_for_decode_completion(
#[case] router_mode: ReplayRouterMode,
) {
let config = disagg_config();
let requests = vec![
DirectRequest {
tokens: vec![1; 128],
max_output_tokens: 3,
uuid: Some(Uuid::from_u128(1)),
dp_rank: 0,
arrival_timestamp_ms: None,
},
DirectRequest {
tokens: vec![2; 128],
max_output_tokens: 3,
uuid: Some(Uuid::from_u128(2)),
dp_rank: 0,
arrival_timestamp_ms: None,
},
];
let router_config = (router_mode == ReplayRouterMode::KvRouter).then(router_config);
let (collector, _) =
run_concurrency_collect(&config, requests, router_config, 1, router_mode);
let first = collector.snapshot(Uuid::from_u128(1)).unwrap();
let second = collector.snapshot(Uuid::from_u128(2)).unwrap();
assert_eq!(first.arrival_time_ms, 0.0);
assert_eq!(second.arrival_time_ms, first.last_token_ms.unwrap());
}
#[test]
fn test_prefill_completion_marks_and_frees_before_decode_handoff() {
let config = disagg_config();
let requests = vec![request(1, 128, 2, 0.0)];
let (_, stats) = run_trace_collect(
&config,
requests,
Some(router_config()),
1.0,
ReplayRouterMode::KvRouter,
);
assert_eq!(stats.prefill_marked_count, 1);
assert_eq!(stats.prefill_freed_count, 1);
assert_eq!(stats.decode_freed_count, 1);
}
#[test]
fn test_handoff_delay_increases_decode_visible_ttft() {
let requests = vec![request(1, 128, 2, 0.0)];
let (baseline_collector, _) = run_trace_collect(
&disagg_config(),
requests.clone(),
None,
1.0,
ReplayRouterMode::RoundRobin,
);
let (delayed_collector, _) = run_trace_collect(
&disagg_config_with_handoff_delay(),
requests,
None,
1.0,
ReplayRouterMode::RoundRobin,
);
let baseline = baseline_collector.snapshot(Uuid::from_u128(1)).unwrap();
let delayed = delayed_collector.snapshot(Uuid::from_u128(1)).unwrap();
let baseline_ttft = baseline.first_token_ms.unwrap() - baseline.arrival_time_ms;
let delayed_ttft = delayed.first_token_ms.unwrap() - delayed.arrival_time_ms;
assert!(
delayed_ttft >= baseline_ttft + 120.0,
"expected delayed TTFT to include roughly 128ms of handoff delay, baseline={baseline_ttft}, delayed={delayed_ttft}"
);
}
}
......@@ -4,15 +4,27 @@
use std::cmp::Ordering;
use crate::common::protocols::OutputSignal;
use uuid::Uuid;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum SimulationWorkerStage {
Aggregated,
Prefill,
Decode,
}
#[derive(Debug)]
pub(crate) enum SimulationEventKind {
WorkerCompletion {
stage: SimulationWorkerStage,
worker_idx: usize,
completed_requests: usize,
output_signals: Vec<OutputSignal>,
kv_events: Vec<dynamo_kv_router::protocols::RouterEvent>,
},
DecodeHandoff {
uuid: Uuid,
},
}
#[derive(Debug)]
......
......@@ -3,13 +3,16 @@
use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::loadgen::Trace;
use crate::replay::OfflineDisaggReplayConfig;
pub(crate) use crate::replay::normalize_trace_requests;
use crate::replay::{ReplayRouterMode, TraceSimulationReport};
use dynamo_kv_router::config::KvRouterConfig;
pub(crate) mod core;
pub(crate) mod disagg;
pub(crate) mod events;
pub(crate) mod multi;
pub(crate) mod runtime_utils;
pub(crate) mod single;
pub(crate) mod state;
......@@ -92,3 +95,54 @@ pub(crate) fn simulate_concurrency_workload(
)
}
}
pub(crate) fn simulate_trace_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
disagg::simulate_trace_disagg(
config,
router_config,
requests,
arrival_speedup_ratio,
router_mode,
)
}
pub(crate) fn simulate_concurrency_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
disagg::simulate_concurrency_disagg(config, router_config, requests, max_in_flight, router_mode)
}
pub(crate) fn simulate_trace_workload_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
disagg::simulate_trace_workload_disagg(config, router_config, trace, router_mode)
}
pub(crate) fn simulate_concurrency_workload_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
disagg::simulate_concurrency_workload_disagg(
config,
router_config,
trace,
max_in_flight,
router_mode,
)
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::events::{SimulationEvent, SimulationEventKind};
use super::events::{SimulationEvent, SimulationWorkerStage};
use super::normalize_trace_requests;
use super::runtime_utils::{
WorkerCompletionPayload, next_timestamp as choose_next_timestamp, pop_next_concurrency_ready,
pop_next_trace_ready, pop_ready_worker_completion, push_worker_completion,
};
#[cfg(test)]
use super::state::OfflineWorkerSnapshot;
use super::state::OfflineWorkerState;
use super::state::{AggRequestState, OfflineWorkerState};
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal};
use crate::loadgen::{ReplayRequestHashes, Trace, WorkloadDriver};
use crate::replay::router::OfflineReplayRouter;
......@@ -16,7 +20,7 @@ use crate::scheduler::RouterEventVisibility;
use anyhow::bail;
use dynamo_kv_router::config::KvRouterConfig;
use dynamo_kv_router::protocols::RouterEvent;
use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
use std::collections::{BinaryHeap, HashMap, VecDeque};
use uuid::Uuid;
#[derive(Debug, Clone, Copy)]
......@@ -62,13 +66,13 @@ struct OfflineRuntime {
next_worker_idx: usize,
next_event_seq: u64,
admission: AdmissionSource,
router_pending: HashMap<Uuid, DirectRequest>,
requests: HashMap<Uuid, AggRequestState>,
queued_requests: usize,
workers: Vec<OfflineWorkerState>,
collector: TraceCollector,
events: BinaryHeap<SimulationEvent>,
mode: ReplayMode,
router: Option<OfflineReplayRouter>,
prefill_completed: HashSet<Uuid>,
stats: OfflineRuntimeStats,
#[cfg(test)]
worker_active_requests: Vec<Vec<Uuid>>,
......@@ -135,7 +139,8 @@ impl OfflineRuntime {
next_worker_idx: 0,
next_event_seq: 0,
admission,
router_pending: HashMap::new(),
requests: HashMap::new(),
queued_requests: 0,
workers: (0..num_workers)
.map(|worker_idx| {
OfflineWorkerState::new(worker_idx, args.clone(), capture_kv_events)
......@@ -145,7 +150,6 @@ impl OfflineRuntime {
events: BinaryHeap::new(),
mode,
router,
prefill_completed: HashSet::new(),
#[cfg(test)]
stats: OfflineRuntimeStats::default(),
#[cfg(not(test))]
......@@ -162,7 +166,7 @@ impl OfflineRuntime {
.iter()
.map(OfflineWorkerState::in_flight)
.sum::<usize>()
+ self.router_pending.len()
+ self.queued_requests
}
fn record_in_flight_peak(&mut self) {
......@@ -220,9 +224,14 @@ impl OfflineRuntime {
fn dispatch_router_admissions(&mut self, admissions: Vec<(Uuid, usize)>) -> anyhow::Result<()> {
for (uuid, worker_idx) in admissions {
let request = self.router_pending.remove(&uuid).ok_or_else(|| {
anyhow::anyhow!("offline replay missing queued request state for {uuid}")
})?;
let request = self
.requests
.get_mut(&uuid)
.ok_or_else(|| {
anyhow::anyhow!("offline replay missing queued request state for {uuid}")
})?
.take_queued_request(uuid)?;
self.queued_requests = self.queued_requests.saturating_sub(1);
self.dispatch_to_worker(request, uuid, worker_idx)?;
}
Ok(())
......@@ -248,6 +257,7 @@ impl OfflineRuntime {
);
let Some(router) = self.router.as_mut() else {
self.requests.insert(uuid, AggRequestState::new_running());
let worker_idx = self.next_worker_idx;
self.next_worker_idx = (self.next_worker_idx + 1) % self.workers.len();
self.dispatch_to_worker(request, uuid, worker_idx)?;
......@@ -258,11 +268,14 @@ impl OfflineRuntime {
router.submit_request_with_hashes(&request, replay_hashes, self.now_ms)?;
self.record_router_pending();
if let Some(worker_idx) = maybe_worker_idx {
self.requests.insert(uuid, AggRequestState::new_running());
self.dispatch_to_worker(request, uuid, worker_idx)?;
return Ok(uuid);
}
self.router_pending.insert(uuid, request);
self.requests
.insert(uuid, AggRequestState::new_queued(request));
self.queued_requests += 1;
self.record_in_flight_peak();
Ok(uuid)
}
......@@ -295,21 +308,7 @@ impl OfflineRuntime {
(ReplayMode::Concurrency { .. }, AdmissionSource::Requests(_)) => None,
};
match (next_arrival_ms, next_event_ms) {
(Some(arrival_ms), Some(event_ms)) => Some(arrival_ms.min(event_ms)),
(Some(arrival_ms), None) => Some(arrival_ms),
(None, Some(event_ms)) => Some(event_ms),
(None, None) => None,
}
}
fn push_event(&mut self, at_ms: f64, kind: SimulationEventKind) {
self.events.push(SimulationEvent {
at_ms,
seq_no: self.next_event_seq,
kind,
});
self.next_event_seq += 1;
choose_next_timestamp(next_arrival_ms, next_event_ms)
}
fn apply_completed_requests(&mut self, worker_idx: usize, completed_requests: usize) {
......@@ -339,7 +338,9 @@ impl OfflineRuntime {
}
self.record_router_pending();
}
self.prefill_completed.remove(&signal.uuid);
self.requests.remove(&signal.uuid).ok_or_else(|| {
anyhow::anyhow!("offline replay missing request state for {}", signal.uuid)
})?;
if let AdmissionSource::Workload(driver) = &mut self.admission {
driver.on_complete(signal.uuid, self.now_ms)?;
}
......@@ -347,10 +348,23 @@ impl OfflineRuntime {
return Ok(());
}
if !self.prefill_completed.insert(signal.uuid) {
let already_marked = self
.requests
.get(&signal.uuid)
.ok_or_else(|| {
anyhow::anyhow!("offline replay missing request state for {}", signal.uuid)
})?
.prefill_completed();
if already_marked {
return Ok(());
}
self.requests
.get_mut(&signal.uuid)
.ok_or_else(|| {
anyhow::anyhow!("offline replay missing request state for {}", signal.uuid)
})?
.mark_prefill_completed();
if let Some(router) = self.router.as_mut() {
admissions = router.mark_prefill_completed(signal.uuid)?;
#[cfg(test)]
......@@ -395,25 +409,15 @@ impl OfflineRuntime {
fn apply_worker_completions(&mut self) -> anyhow::Result<bool> {
let mut changed = false;
loop {
let Some(event) = self.events.peek() else {
break;
};
if event.at_ms != self.now_ms {
break;
}
if !matches!(event.kind, SimulationEventKind::WorkerCompletion { .. }) {
break;
}
let event = self.events.pop().expect("event must exist after peek");
let SimulationEventKind::WorkerCompletion {
worker_idx,
completed_requests,
output_signals,
kv_events,
} = event.kind;
while let Some(WorkerCompletionPayload {
stage,
worker_idx,
completed_requests,
output_signals,
kv_events,
}) = pop_ready_worker_completion(&mut self.events, self.now_ms)
{
debug_assert_eq!(stage, SimulationWorkerStage::Aggregated);
self.workers[worker_idx].mark_idle();
self.process_completed_pass(worker_idx, completed_requests, output_signals, kv_events)?;
changed = true;
......@@ -424,39 +428,33 @@ impl OfflineRuntime {
fn release_trace_arrivals(&mut self) -> anyhow::Result<bool> {
let mut released_any = false;
let mut ready_requests = Vec::new();
match &mut self.admission {
AdmissionSource::Requests(pending) => {
while pending
.front()
.and_then(|request| request.arrival_timestamp_ms)
.is_some_and(|arrival_ms| arrival_ms <= self.now_ms)
{
let request = pending
.pop_front()
.expect("front request must exist when arrival is ready");
let arrival_ms = request
.arrival_timestamp_ms
.expect("trace replay requests must have an arrival timestamp");
ready_requests.push((request, arrival_ms, None));
}
}
AdmissionSource::Workload(driver) => {
ready_requests.extend(driver.pop_ready(self.now_ms, usize::MAX).into_iter().map(
|ready| {
(
ready.request,
ready.scheduled_ready_at_ms,
ready.replay_hashes,
)
},
));
if matches!(self.admission, AdmissionSource::Requests(_)) {
loop {
let next_ready = match &mut self.admission {
AdmissionSource::Requests(pending) => {
pop_next_trace_ready(pending, self.now_ms)
}
AdmissionSource::Workload(_) => unreachable!(),
};
let Some((request, arrival_ms)) = next_ready else {
break;
};
self.assign_request(request, arrival_ms, None)?;
released_any = true;
}
return Ok(released_any);
}
for (request, arrival_ms, replay_hashes) in ready_requests {
self.assign_request(request, arrival_ms, replay_hashes)?;
let ready_requests = match &mut self.admission {
AdmissionSource::Requests(_) => unreachable!(),
AdmissionSource::Workload(driver) => driver.pop_ready(self.now_ms, usize::MAX),
};
for ready in ready_requests {
self.assign_request(
ready.request,
ready.scheduled_ready_at_ms,
ready.replay_hashes,
)?;
released_any = true;
}
......@@ -465,33 +463,38 @@ impl OfflineRuntime {
fn top_off_concurrency(&mut self, max_in_flight: usize) -> anyhow::Result<bool> {
let mut released_any = false;
if matches!(self.admission, AdmissionSource::Requests(_)) {
loop {
let cluster_in_flight = self.cluster_in_flight();
let next_ready = match &mut self.admission {
AdmissionSource::Requests(pending) => pop_next_concurrency_ready(
pending,
self.now_ms,
cluster_in_flight,
max_in_flight,
),
AdmissionSource::Workload(_) => unreachable!(),
};
let Some((request, arrival_ms)) = next_ready else {
break;
};
self.assign_request(request, arrival_ms, None)?;
released_any = true;
}
return Ok(released_any);
}
let available = max_in_flight.saturating_sub(self.cluster_in_flight());
if available == 0 {
return Ok(false);
}
let mut ready_requests = Vec::new();
match &mut self.admission {
AdmissionSource::Requests(pending) => {
for _ in 0..available {
let Some(request) = pending.pop_front() else {
break;
};
ready_requests.push((request, None));
}
}
AdmissionSource::Workload(driver) => {
ready_requests.extend(
driver
.pop_ready(self.now_ms, available)
.into_iter()
.map(|ready| (ready.request, ready.replay_hashes)),
);
}
}
for (request, replay_hashes) in ready_requests {
self.assign_request(request, self.now_ms, replay_hashes)?;
let ready_requests = match &mut self.admission {
AdmissionSource::Requests(_) => unreachable!(),
AdmissionSource::Workload(driver) => driver.pop_ready(self.now_ms, available),
};
for ready in ready_requests {
self.assign_request(ready.request, self.now_ms, ready.replay_hashes)?;
released_any = true;
}
......@@ -531,9 +534,12 @@ impl OfflineRuntime {
}
self.workers[worker_idx].mark_busy();
self.push_event(
push_worker_completion(
&mut self.events,
&mut self.next_event_seq,
executed.end_ms,
SimulationEventKind::WorkerCompletion {
WorkerCompletionPayload {
stage: SimulationWorkerStage::Aggregated,
worker_idx,
completed_requests: executed.completed_requests,
output_signals: executed.output_signals,
......@@ -616,10 +622,19 @@ impl OfflineRuntime {
#[cfg(test)]
fn debug_snapshot(&self) -> OfflineRuntimeSnapshot {
let mut router_pending_request_ids =
self.router_pending.keys().copied().collect::<Vec<_>>();
let mut router_pending_request_ids = self
.requests
.iter()
.filter(|(_, state)| state.is_queued_at_router())
.map(|(uuid, _)| *uuid)
.collect::<Vec<_>>();
router_pending_request_ids.sort_unstable();
let mut prefill_completed = self.prefill_completed.iter().copied().collect::<Vec<_>>();
let mut prefill_completed = self
.requests
.iter()
.filter(|(_, state)| state.prefill_completed())
.map(|(uuid, _)| *uuid)
.collect::<Vec<_>>();
prefill_completed.sort_unstable();
OfflineRuntimeSnapshot {
......@@ -1483,16 +1498,16 @@ mod tests {
let request_1 = collector.snapshot(Uuid::from_u128(1)).unwrap();
let request_2 = collector.snapshot(Uuid::from_u128(2)).unwrap();
let request_3 = collector.snapshot(Uuid::from_u128(3)).unwrap();
let first_unblock_ms = request_1
.first_token_ms
.unwrap()
.min(request_2.first_token_ms.unwrap());
assert!(stats.max_router_pending > 0);
assert!(request_3.first_admit_ms.unwrap() > request_3.arrival_time_ms);
assert!(
request_3.first_admit_ms.unwrap()
< request_1
.last_token_ms
.unwrap()
.min(request_2.last_token_ms.unwrap())
);
assert_eq!(request_3.first_admit_ms.unwrap(), first_unblock_ms);
assert!(request_3.first_admit_ms.unwrap() < request_1.last_token_ms.unwrap());
assert!(request_3.first_admit_ms.unwrap() < request_2.last_token_ms.unwrap());
}
#[test]
......@@ -1561,6 +1576,72 @@ mod tests {
);
}
#[test]
fn test_multi_worker_trace_kv_router_fcfs_and_lcfs_admit_queued_requests_in_opposite_timestamp_order()
{
let requests = vec![
DirectRequest {
tokens: vec![10; 64],
max_output_tokens: 8,
uuid: Some(Uuid::from_u128(10)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
},
DirectRequest {
tokens: vec![20; 128],
max_output_tokens: 8,
uuid: Some(Uuid::from_u128(20)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
},
DirectRequest {
tokens: vec![30; 64],
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(30)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.1),
},
DirectRequest {
tokens: vec![40; 64],
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(40)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.2),
},
];
let (fcfs_collector, fcfs_stats) = run_trace_multi_collect_with_stats(
&queueing_router_args(RouterQueuePolicy::Fcfs),
requests.clone(),
2,
ReplayRouterMode::KvRouter,
);
let (lcfs_collector, lcfs_stats) = run_trace_multi_collect_with_stats(
&queueing_router_args(RouterQueuePolicy::Lcfs),
requests,
2,
ReplayRouterMode::KvRouter,
);
let fcfs_request_30 = fcfs_collector.snapshot(Uuid::from_u128(30)).unwrap();
let fcfs_request_40 = fcfs_collector.snapshot(Uuid::from_u128(40)).unwrap();
let lcfs_request_30 = lcfs_collector.snapshot(Uuid::from_u128(30)).unwrap();
let lcfs_request_40 = lcfs_collector.snapshot(Uuid::from_u128(40)).unwrap();
assert!(fcfs_stats.max_router_pending > 0);
assert!(lcfs_stats.max_router_pending > 0);
assert_eq!(
&fcfs_stats.dispatch_order[2..4],
&[Uuid::from_u128(30), Uuid::from_u128(40)]
);
assert_eq!(
&lcfs_stats.dispatch_order[2..4],
&[Uuid::from_u128(40), Uuid::from_u128(30)]
);
assert!(fcfs_request_30.first_admit_ms.unwrap() < fcfs_request_40.first_admit_ms.unwrap());
assert!(lcfs_request_40.first_admit_ms.unwrap() < lcfs_request_30.first_admit_ms.unwrap());
}
#[test]
fn test_multi_worker_concurrency_kv_router_respects_max_in_flight() {
let args = queueing_router_args(RouterQueuePolicy::Fcfs);
......@@ -1605,6 +1686,51 @@ mod tests {
assert!(stats.max_router_pending > 0);
}
#[test]
fn test_multi_worker_concurrency_kv_router_records_backfill_timing() {
let args = queueing_router_args(RouterQueuePolicy::Fcfs);
let (collector, stats) = run_concurrency_multi_collect_with_stats(
&args,
vec![
DirectRequest {
tokens: vec![1; 64],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(11)),
dp_rank: 0,
arrival_timestamp_ms: None,
},
DirectRequest {
tokens: vec![2; 64],
max_output_tokens: 4,
uuid: Some(Uuid::from_u128(22)),
dp_rank: 0,
arrival_timestamp_ms: None,
},
DirectRequest {
tokens: vec![3; 64],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(33)),
dp_rank: 0,
arrival_timestamp_ms: None,
},
],
2,
2,
ReplayRouterMode::KvRouter,
);
let request_1 = collector.snapshot(Uuid::from_u128(11)).unwrap();
let request_2 = collector.snapshot(Uuid::from_u128(22)).unwrap();
let request_3 = collector.snapshot(Uuid::from_u128(33)).unwrap();
assert_eq!(request_1.arrival_time_ms, 0.0);
assert_eq!(request_2.arrival_time_ms, 0.0);
assert_eq!(request_3.arrival_time_ms, request_1.last_token_ms.unwrap());
assert!(request_3.arrival_time_ms < request_2.last_token_ms.unwrap());
assert_eq!(request_3.first_admit_ms.unwrap(), request_3.arrival_time_ms);
assert_eq!(stats.max_in_flight_seen, 2);
}
#[test]
fn test_multi_worker_trace_single_worker_round_robin_matches_single_runtime() {
let args = replay_args(true, true);
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::{BinaryHeap, VecDeque};
use dynamo_kv_router::protocols::RouterEvent;
use super::events::{SimulationEvent, SimulationEventKind, SimulationWorkerStage};
use crate::common::protocols::{DirectRequest, OutputSignal};
#[derive(Debug)]
pub(super) struct WorkerCompletionPayload {
pub stage: SimulationWorkerStage,
pub worker_idx: usize,
pub completed_requests: usize,
pub output_signals: Vec<OutputSignal>,
pub kv_events: Vec<RouterEvent>,
}
pub(super) fn next_timestamp(
next_arrival_ms: Option<f64>,
next_event_ms: Option<f64>,
) -> Option<f64> {
match (next_arrival_ms, next_event_ms) {
(Some(arrival_ms), Some(event_ms)) => Some(arrival_ms.min(event_ms)),
(Some(arrival_ms), None) => Some(arrival_ms),
(None, Some(event_ms)) => Some(event_ms),
(None, None) => None,
}
}
pub(super) fn pop_next_trace_ready(
pending: &mut VecDeque<DirectRequest>,
now_ms: f64,
) -> Option<(DirectRequest, f64)> {
let arrival_ms = pending
.front()
.and_then(|request| request.arrival_timestamp_ms)
.filter(|arrival_ms| *arrival_ms <= now_ms)?;
let request = pending
.pop_front()
.expect("front request must exist when arrival is ready");
Some((request, arrival_ms))
}
pub(super) fn pop_next_concurrency_ready(
pending: &mut VecDeque<DirectRequest>,
now_ms: f64,
cluster_in_flight: usize,
max_in_flight: usize,
) -> Option<(DirectRequest, f64)> {
if cluster_in_flight >= max_in_flight {
return None;
}
let request = pending.pop_front()?;
Some((request, now_ms))
}
pub(super) fn push_worker_completion(
events: &mut BinaryHeap<SimulationEvent>,
next_event_seq: &mut u64,
at_ms: f64,
payload: WorkerCompletionPayload,
) {
events.push(SimulationEvent {
at_ms,
seq_no: *next_event_seq,
kind: SimulationEventKind::WorkerCompletion {
stage: payload.stage,
worker_idx: payload.worker_idx,
completed_requests: payload.completed_requests,
output_signals: payload.output_signals,
kv_events: payload.kv_events,
},
});
*next_event_seq += 1;
}
pub(super) fn pop_ready_worker_completion(
events: &mut BinaryHeap<SimulationEvent>,
now_ms: f64,
) -> Option<WorkerCompletionPayload> {
let event = events.peek()?;
if event.at_ms != now_ms {
return None;
}
let SimulationEventKind::WorkerCompletion { .. } = &event.kind else {
return None;
};
let event = events.pop().expect("event must exist after peek");
let (stage, worker_idx, completed_requests, output_signals, kv_events) = match event.kind {
SimulationEventKind::WorkerCompletion {
stage,
worker_idx,
completed_requests,
output_signals,
kv_events,
} => (
stage,
worker_idx,
completed_requests,
output_signals,
kv_events,
),
SimulationEventKind::DecodeHandoff { .. } => {
unreachable!("peeked worker completion event must match popped event")
}
};
Some(WorkerCompletionPayload {
stage,
worker_idx,
completed_requests,
output_signals,
kv_events,
})
}
pub(super) fn push_decode_handoff(
events: &mut BinaryHeap<SimulationEvent>,
next_event_seq: &mut u64,
at_ms: f64,
uuid: uuid::Uuid,
) {
events.push(SimulationEvent {
at_ms,
seq_no: *next_event_seq,
kind: SimulationEventKind::DecodeHandoff { uuid },
});
*next_event_seq += 1;
}
pub(super) fn pop_ready_decode_handoff(
events: &mut BinaryHeap<SimulationEvent>,
now_ms: f64,
) -> Option<uuid::Uuid> {
let event = events.peek()?;
if event.at_ms != now_ms {
return None;
}
let SimulationEventKind::DecodeHandoff { .. } = &event.kind else {
return None;
};
let event = events.pop().expect("event must exist after peek");
let SimulationEventKind::DecodeHandoff { uuid } = event.kind else {
unreachable!("peeked decode handoff event must match popped event");
};
Some(uuid)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::replay::offline::events::SimulationWorkerStage;
use uuid::Uuid;
fn direct_request(uuid: u128, arrival_timestamp_ms: Option<f64>) -> DirectRequest {
DirectRequest {
tokens: vec![1; 8],
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(uuid)),
dp_rank: 0,
arrival_timestamp_ms,
}
}
#[test]
fn test_next_timestamp_matches_current_choice_logic() {
assert_eq!(next_timestamp(Some(1.0), Some(2.0)), Some(1.0));
assert_eq!(next_timestamp(Some(2.0), Some(1.0)), Some(1.0));
assert_eq!(next_timestamp(Some(3.0), None), Some(3.0));
assert_eq!(next_timestamp(None, Some(4.0)), Some(4.0));
assert_eq!(next_timestamp(None, None), None);
}
#[test]
fn test_pop_next_trace_ready_releases_only_arrivals_at_or_before_now() {
let mut pending = VecDeque::from(vec![
direct_request(1, Some(1.0)),
direct_request(2, Some(1.1)),
direct_request(3, Some(2.0)),
]);
let (request_1, arrival_1) = pop_next_trace_ready(&mut pending, 1.0).unwrap();
assert_eq!(request_1.uuid, Some(Uuid::from_u128(1)));
assert_eq!(arrival_1, 1.0);
assert!(pop_next_trace_ready(&mut pending, 1.0).is_none());
let (request_2, arrival_2) = pop_next_trace_ready(&mut pending, 1.1).unwrap();
assert_eq!(request_2.uuid, Some(Uuid::from_u128(2)));
assert_eq!(arrival_2, 1.1);
assert_eq!(pending.len(), 1);
}
#[test]
fn test_pop_next_concurrency_ready_stops_at_max_in_flight() {
let mut pending = VecDeque::from(vec![direct_request(1, None), direct_request(2, None)]);
assert!(pop_next_concurrency_ready(&mut pending, 5.0, 2, 2).is_none());
let (request, arrival_ms) = pop_next_concurrency_ready(&mut pending, 5.0, 1, 2).unwrap();
assert_eq!(request.uuid, Some(Uuid::from_u128(1)));
assert_eq!(arrival_ms, 5.0);
assert_eq!(pending.len(), 1);
}
#[test]
fn test_worker_completion_helpers_preserve_same_time_sequence_ordering() {
let mut events = BinaryHeap::new();
let mut next_event_seq = 0;
push_worker_completion(
&mut events,
&mut next_event_seq,
10.0,
WorkerCompletionPayload {
stage: SimulationWorkerStage::Aggregated,
worker_idx: 7,
completed_requests: 1,
output_signals: vec![OutputSignal {
uuid: Uuid::from_u128(7),
completed: true,
handoff_delay_ms: None,
}],
kv_events: Vec::new(),
},
);
push_worker_completion(
&mut events,
&mut next_event_seq,
10.0,
WorkerCompletionPayload {
stage: SimulationWorkerStage::Aggregated,
worker_idx: 8,
completed_requests: 2,
output_signals: vec![OutputSignal {
uuid: Uuid::from_u128(8),
completed: false,
handoff_delay_ms: None,
}],
kv_events: Vec::new(),
},
);
assert!(pop_ready_worker_completion(&mut events, 9.0).is_none());
let first = pop_ready_worker_completion(&mut events, 10.0).unwrap();
let second = pop_ready_worker_completion(&mut events, 10.0).unwrap();
assert_eq!(first.stage, SimulationWorkerStage::Aggregated);
assert_eq!(first.worker_idx, 7);
assert_eq!(first.completed_requests, 1);
assert_eq!(second.stage, SimulationWorkerStage::Aggregated);
assert_eq!(second.worker_idx, 8);
assert_eq!(second.completed_requests, 2);
assert!(events.is_empty());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use anyhow::{Result, anyhow, bail};
use crate::common::protocols::DirectRequest;
use crate::common::protocols::MockEngineArgs;
use crate::replay::TraceCollector;
use crate::scheduler::{EngineCore, EnginePassResult};
use uuid::Uuid;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum AggRequestPhase {
QueuedAtRouter,
Running,
}
pub(crate) struct AggRequestState {
request: Option<DirectRequest>,
phase: AggRequestPhase,
prefill_completed: bool,
}
impl AggRequestState {
pub(crate) fn new_queued(request: DirectRequest) -> Self {
Self {
request: Some(request),
phase: AggRequestPhase::QueuedAtRouter,
prefill_completed: false,
}
}
pub(crate) fn new_running() -> Self {
Self {
request: None,
phase: AggRequestPhase::Running,
prefill_completed: false,
}
}
pub(crate) fn is_queued_at_router(&self) -> bool {
self.phase == AggRequestPhase::QueuedAtRouter
}
pub(crate) fn take_queued_request(&mut self, uuid: Uuid) -> Result<DirectRequest> {
if !self.is_queued_at_router() {
bail!("offline replay expected queued request state for {uuid}");
}
let request = self
.request
.take()
.ok_or_else(|| anyhow!("offline replay missing queued request payload for {uuid}"))?;
self.phase = AggRequestPhase::Running;
Ok(request)
}
pub(crate) fn prefill_completed(&self) -> bool {
self.prefill_completed
}
pub(crate) fn mark_prefill_completed(&mut self) {
self.prefill_completed = true;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum DisaggPhase {
QueuedPrefill,
RunningPrefill,
QueuedDecode,
RunningDecode,
Done,
}
pub(crate) struct DisaggRequestState {
original: Option<DirectRequest>,
#[cfg(test)]
arrival_ms: f64,
phase: DisaggPhase,
prefill_worker_idx: Option<usize>,
decode_worker_idx: Option<usize>,
}
#[cfg(test)]
#[derive(Debug, Clone, PartialEq)]
pub(crate) struct DisaggRequestSnapshot {
pub(crate) arrival_ms: f64,
pub(crate) phase: DisaggPhase,
pub(crate) prefill_worker_idx: Option<usize>,
pub(crate) decode_worker_idx: Option<usize>,
}
impl DisaggRequestState {
pub(crate) fn new(request: DirectRequest, arrival_ms: f64) -> Self {
#[cfg(not(test))]
let _ = arrival_ms;
Self {
original: Some(request),
#[cfg(test)]
arrival_ms,
phase: DisaggPhase::QueuedPrefill,
prefill_worker_idx: None,
decode_worker_idx: None,
}
}
pub(crate) fn is_queued_prefill(&self) -> bool {
self.phase == DisaggPhase::QueuedPrefill
}
pub(crate) fn is_queued_decode(&self) -> bool {
self.phase == DisaggPhase::QueuedDecode
}
pub(crate) fn original_request(&self) -> Result<&DirectRequest> {
self.original
.as_ref()
.ok_or_else(|| anyhow!("offline disagg replay request payload was already released"))
}
pub(crate) fn build_prefill_request(&self) -> Result<DirectRequest> {
let mut request = self.original_request()?.clone();
request.max_output_tokens = 1;
Ok(request)
}
pub(crate) fn build_decode_request(&self) -> Result<DirectRequest> {
Ok(self.original_request()?.clone())
}
pub(crate) fn start_prefill(&mut self, worker_idx: usize) {
self.phase = DisaggPhase::RunningPrefill;
self.prefill_worker_idx = Some(worker_idx);
}
pub(crate) fn queue_decode(&mut self) {
self.phase = DisaggPhase::QueuedDecode;
}
pub(crate) fn start_decode(&mut self, worker_idx: usize) {
self.phase = DisaggPhase::RunningDecode;
self.decode_worker_idx = Some(worker_idx);
}
pub(crate) fn mark_done(&mut self) {
self.phase = DisaggPhase::Done;
self.original = None;
}
#[cfg(test)]
pub(crate) fn debug_snapshot(&self) -> DisaggRequestSnapshot {
DisaggRequestSnapshot {
arrival_ms: self.arrival_ms,
phase: self.phase,
prefill_worker_idx: self.prefill_worker_idx,
decode_worker_idx: self.decode_worker_idx,
}
}
}
pub(crate) struct OfflineWorkerState {
core: EngineCore,
......@@ -91,6 +243,10 @@ impl OfflineWorkerState {
self.core.execute_pass(collector, now_ms)
}
pub(crate) fn execute_hidden_pass(&mut self, now_ms: f64) -> EnginePassResult {
self.core.execute_hidden_pass(now_ms)
}
#[cfg(test)]
pub(crate) fn debug_snapshot(&self) -> OfflineWorkerSnapshot {
OfflineWorkerSnapshot {
......
......@@ -110,6 +110,7 @@ struct PendingRequest {
token_seq: Option<Vec<SequenceHash>>,
isl_tokens: usize,
overlaps: OverlapScores,
track_prefill_tokens: bool,
expected_output_tokens: Option<u32>,
}
......@@ -130,6 +131,7 @@ impl PendingRequest {
overlaps: self.overlaps.clone(),
decode_blocks,
prefill_tokens,
track_prefill_tokens: self.track_prefill_tokens,
router_config_override: None,
update_states: true,
lora_name: None,
......@@ -258,7 +260,6 @@ impl OfflineReplayRouter {
self.drain_pending()
}
#[cfg(test)]
pub(crate) fn pending_count(&self) -> usize {
self.pending.len()
}
......@@ -371,6 +372,7 @@ impl OfflineReplayRouter {
token_seq,
isl_tokens: request.tokens.len(),
overlaps,
track_prefill_tokens: self.config.router_track_prefill_tokens,
expected_output_tokens: Some(
u32::try_from(request.max_output_tokens)
.context("max_output_tokens does not fit into u32")?,
......@@ -379,11 +381,14 @@ impl OfflineReplayRouter {
}
fn admit_request(&mut self, request: PendingRequest) -> Result<usize> {
let (decode_blocks, prefill_tokens) = self.slots.potential_blocks_and_tokens(
request.token_seq.as_deref(),
request.isl_tokens,
request.overlaps.clone(),
);
let (decode_blocks, prefill_tokens) = self
.slots
.potential_blocks_and_tokens_with_prefill_tracking(
request.token_seq.as_deref(),
request.isl_tokens,
request.overlaps.clone(),
request.track_prefill_tokens,
);
let scheduling_request = request.scheduling_request(decode_blocks, prefill_tokens);
let selection = self.selector.select_worker(
&self.workers_with_configs,
......@@ -400,6 +405,7 @@ impl OfflineReplayRouter {
token_sequence: request.token_seq,
isl: request.isl_tokens,
overlap: selection.overlap_blocks,
track_prefill_tokens: request.track_prefill_tokens,
expected_output_tokens: request.expected_output_tokens,
worker: selection.worker,
lora_name: None,
......
......@@ -138,6 +138,7 @@ impl KvReplayRouter {
args.block_size as u32,
selector,
policy,
config.router_track_prefill_tokens,
CancellationToken::new(),
"replay",
false,
......
......@@ -107,7 +107,7 @@ pub(super) fn replay_selector(config: &KvRouterConfig) -> DefaultWorkerSelector
DefaultWorkerSelector::new(Some(config.clone()), "replay")
}
pub(super) fn replay_router_config(
pub(crate) fn replay_router_config(
args: &MockEngineArgs,
router_config: Option<KvRouterConfig>,
) -> KvRouterConfig {
......
......@@ -3,9 +3,45 @@
use anyhow::{Result, bail};
use super::ReplayRouterMode;
use super::{OfflineDisaggReplayConfig, ReplayArgsMode, ReplayRouterMode};
use crate::common::protocols::{MockEngineArgs, WorkerType};
pub fn validate_replay_args_mode(
aggregated_args: Option<&MockEngineArgs>,
prefill_args: Option<&MockEngineArgs>,
decode_args: Option<&MockEngineArgs>,
num_workers: usize,
num_prefill_workers: usize,
num_decode_workers: usize,
) -> Result<ReplayArgsMode> {
if aggregated_args.is_some() && (prefill_args.is_some() || decode_args.is_some()) {
bail!("extra_engine_args cannot be combined with prefill_engine_args/decode_engine_args");
}
match (aggregated_args, prefill_args, decode_args) {
(Some(_), None, None) | (None, None, None) => {
if num_prefill_workers != 1 || num_decode_workers != 1 {
bail!(
"num_prefill_workers and num_decode_workers are only used for disagg replay; use num_workers for aggregated replay"
);
}
Ok(ReplayArgsMode::Aggregated)
}
(None, Some(_), Some(_)) => {
if num_workers != 1 {
bail!(
"num_workers is only used for aggregated replay; use num_prefill_workers and num_decode_workers for disagg replay"
);
}
Ok(ReplayArgsMode::Disagg)
}
(None, Some(_), None) | (None, None, Some(_)) => {
bail!("prefill_engine_args and decode_engine_args must be provided together")
}
(Some(_), Some(_), _) | (Some(_), _, Some(_)) => unreachable!(),
}
}
fn validate_replay_args(args: &MockEngineArgs, num_workers: usize, mode: &str) -> Result<()> {
if num_workers == 0 {
bail!("{mode} requires num_workers >= 1");
......@@ -75,3 +111,63 @@ pub(super) fn validate_online_concurrency_args(
validate_replay_args(args, num_workers, "online replay")
}
fn validate_disagg_args(config: &OfflineDisaggReplayConfig, mode: &str) -> Result<()> {
if config.num_prefill_workers == 0 {
bail!("{mode} requires num_prefill_workers >= 1");
}
if config.num_decode_workers == 0 {
bail!("{mode} requires num_decode_workers >= 1");
}
if config.prefill_args.worker_type != WorkerType::Prefill {
bail!(
"{mode} requires prefill_engine_args.worker_type=prefill, got {:?}",
config.prefill_args.worker_type,
);
}
if config.decode_args.worker_type != WorkerType::Decode {
bail!(
"{mode} requires decode_engine_args.worker_type=decode, got {:?}",
config.decode_args.worker_type,
);
}
if config.prefill_args.dp_size != 1 {
bail!(
"{mode} only supports prefill data_parallel_size=1, got {}",
config.prefill_args.dp_size,
);
}
if config.decode_args.dp_size != 1 {
bail!(
"{mode} only supports decode data_parallel_size=1, got {}",
config.decode_args.dp_size,
);
}
if config.prefill_args.block_size != config.decode_args.block_size {
bail!(
"{mode} requires matching prefill/decode block_size, got {} and {}",
config.prefill_args.block_size,
config.decode_args.block_size,
);
}
Ok(())
}
pub(super) fn validate_offline_disagg_replay_args(
config: &OfflineDisaggReplayConfig,
_router_mode: ReplayRouterMode,
) -> Result<()> {
validate_disagg_args(config, "trace replay")
}
pub(super) fn validate_offline_disagg_concurrency_args(
config: &OfflineDisaggReplayConfig,
max_in_flight: usize,
_router_mode: ReplayRouterMode,
) -> Result<()> {
if max_in_flight == 0 {
bail!("concurrency replay requires max_in_flight >= 1");
}
validate_disagg_args(config, "concurrency replay")
}
......@@ -89,6 +89,13 @@ impl EngineCore {
Self::Sglang(core) => core.execute_pass(collector, now_ms),
}
}
pub(crate) fn execute_hidden_pass(&mut self, now_ms: f64) -> EnginePassResult {
match self {
Self::Vllm(core) => core.execute_hidden_pass(now_ms),
Self::Sglang(core) => core.execute_hidden_pass(now_ms),
}
}
}
#[derive(Clone)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment