Unverified Commit 02b1c58a authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(mocker): add offline disagg replay (#7617)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 4b8826b3
...@@ -23,7 +23,7 @@ use dynamo_mocker::common::protocols::{ ...@@ -23,7 +23,7 @@ use dynamo_mocker::common::protocols::{
DirectRequest, KvCacheEventSink, KvEventPublishers, MockEngineArgs, OutputSignal, RawKvEvent, DirectRequest, KvCacheEventSink, KvEventPublishers, MockEngineArgs, OutputSignal, RawKvEvent,
RawKvEventSink, RawKvEventSink,
}; };
use dynamo_mocker::common::utils::{compute_kv_transfer_delay, sleep_precise}; use dynamo_mocker::common::utils::sleep_precise;
use dynamo_mocker::engine::create_engine; use dynamo_mocker::engine::create_engine;
use dynamo_mocker::scheduler::SchedulerHandle; use dynamo_mocker::scheduler::SchedulerHandle;
use dynamo_runtime::DistributedRuntime; use dynamo_runtime::DistributedRuntime;
...@@ -645,14 +645,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -645,14 +645,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
let bootstrap_server = self.bootstrap_server.clone(); let bootstrap_server = self.bootstrap_server.clone();
let reasoning = self.engine_args.reasoning.clone(); let reasoning = self.engine_args.reasoning.clone();
// Compute KV transfer delay for prefill workers.
// Simulates the time to transfer KV cache from prefill to decode worker.
let kv_transfer_delay = if is_prefill {
compute_kv_transfer_delay(&self.engine_args, request.token_ids.len())
} else {
None
};
// Spawn a task to handle the complex async logic // Spawn a task to handle the complex async logic
tokio::spawn(async move { tokio::spawn(async move {
let mut token_count = 0; let mut token_count = 0;
...@@ -693,17 +685,15 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -693,17 +685,15 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
if signal.completed { if signal.completed {
let _ = stream_tx.send(output); let _ = stream_tx.send(output);
// Simulate KV transfer delay before prefill's first (and only) token. // Prefill-to-decode handoff delay is emitted by the shared mocker core.
// This models the time to transfer KV cache to the decode worker. if is_prefill
if token_count == 1 && let Some(delay_ms) = signal.handoff_delay_ms
&& let Some(delay) = kv_transfer_delay
{ {
sleep_precise(delay).await; sleep_precise(Duration::from_secs_f64(delay_ms / 1000.0)).await;
} }
// Prefill: after first token, mark room complete (unblocks decode) // Prefill: after first token, mark room complete (unblocks decode)
if is_prefill if is_prefill
&& token_count == 1
&& let (Some(server), Some(room_id)) = (bootstrap_server.get(), bootstrap_room) && let (Some(server), Some(room_id)) = (bootstrap_server.get(), bootstrap_room)
{ {
server.complete_room(room_id); server.complete_room(room_id);
......
...@@ -139,6 +139,8 @@ impl PrefillCost { ...@@ -139,6 +139,8 @@ impl PrefillCost {
pub struct OutputSignal { pub struct OutputSignal {
pub uuid: Uuid, pub uuid: Uuid,
pub completed: bool, pub completed: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub handoff_delay_ms: Option<f64>,
} }
/// Preemption policy for evicting decode requests under memory pressure /// Preemption policy for evicting decode requests under memory pressure
...@@ -286,6 +288,10 @@ pub struct MockEngineArgs { ...@@ -286,6 +288,10 @@ pub struct MockEngineArgs {
#[builder(default = "WorkerType::Aggregated")] #[builder(default = "WorkerType::Aggregated")]
pub worker_type: WorkerType, pub worker_type: WorkerType,
/// Original planner profile NPZ path used to materialize `perf_model`.
#[builder(default = "None")]
pub planner_profile_data: Option<PathBuf>,
/// Performance model for timing predictions (not serialized, loaded from planner_profile_data) /// Performance model for timing predictions (not serialized, loaded from planner_profile_data)
#[serde(skip)] #[serde(skip)]
#[builder(default = "Arc::new(PerfModel::default())")] #[builder(default = "Arc::new(PerfModel::default())")]
...@@ -691,6 +697,7 @@ impl MockEngineArgs { ...@@ -691,6 +697,7 @@ impl MockEngineArgs {
&& let Some(path_str) = path_str.as_str() && let Some(path_str) = path_str.as_str()
{ {
let npz_path = PathBuf::from(path_str); let npz_path = PathBuf::from(path_str);
builder = builder.planner_profile_data(Some(npz_path.clone()));
match PerfModel::from_npz(&npz_path) { match PerfModel::from_npz(&npz_path) {
Ok(model) => { Ok(model) => {
tracing::info!("Successfully loaded performance model from: {:?}", npz_path); tracing::info!("Successfully loaded performance model from: {:?}", npz_path);
......
...@@ -3,32 +3,59 @@ ...@@ -3,32 +3,59 @@
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use crate::common::protocols::MockEngineArgs; use crate::common::protocols::{MockEngineArgs, WorkerType};
/// Compute the KV transfer delay duration for a given number of input tokens. /// Compute the modeled handoff delay after a prefill worker emits its terminal token.
/// ///
/// Returns `None` if KV transfer simulation is disabled (bandwidth is 0 or not configured). /// NOTE: this intentionally does not model the internal prefill TTFT itself accurately, and the
pub fn compute_kv_transfer_delay( /// exact prefill/decode boundary is backend dependent. For now we only care about decode-visible
args: &MockEngineArgs, /// TTFT, which is what the client observes, so modeling the delay as prefill-to-decode handoff is
/// good enough.
pub fn compute_prefill_handoff_delay_ms(
worker_type: WorkerType,
completed: bool,
num_input_tokens: usize, num_input_tokens: usize,
) -> Option<Duration> { kv_transfer_bandwidth: Option<f64>,
match (args.kv_transfer_bandwidth, args.kv_bytes_per_token) { kv_bytes_per_token: Option<usize>,
) -> Option<f64> {
if worker_type != WorkerType::Prefill || !completed {
return None;
}
match (kv_transfer_bandwidth, kv_bytes_per_token) {
(Some(bw), Some(bpt)) if bw > 0.0 => { (Some(bw), Some(bpt)) if bw > 0.0 => {
let kv_bytes = num_input_tokens as f64 * bpt as f64; let kv_bytes = num_input_tokens as f64 * bpt as f64;
let delay = Duration::from_secs_f64(kv_bytes / (bw * 1e9)); let delay_ms = kv_bytes / (bw * 1e9) * 1000.0;
tracing::debug!( tracing::debug!(
num_input_tokens, num_input_tokens,
kv_bytes, kv_bytes,
bandwidth_gb_s = bw, bandwidth_gb_s = bw,
delay_ms = format!("{:.2}", delay.as_secs_f64() * 1000.0), delay_ms = format!("{delay_ms:.2}"),
"KV transfer delay for prefill" "KV handoff delay for prefill completion"
); );
Some(delay) Some(delay_ms)
} }
_ => None, _ => None,
} }
} }
/// Compute the KV transfer delay duration for a given number of input tokens.
///
/// Returns `None` if KV transfer simulation is disabled (bandwidth is 0 or not configured).
pub fn compute_kv_transfer_delay(
args: &MockEngineArgs,
num_input_tokens: usize,
) -> Option<Duration> {
compute_prefill_handoff_delay_ms(
args.worker_type,
true,
num_input_tokens,
args.kv_transfer_bandwidth,
args.kv_bytes_per_token,
)
.map(|delay_ms| Duration::from_secs_f64(delay_ms / 1000.0))
}
/// Sleep for the specified duration using timerfd on Linux for precision. /// Sleep for the specified duration using timerfd on Linux for precision.
pub async fn sleep_precise(duration: Duration) { pub async fn sleep_precise(duration: Duration) {
sleep_until_precise(Instant::now() + duration).await; sleep_until_precise(Instant::now() + duration).await;
...@@ -53,3 +80,42 @@ pub async fn sleep_until_precise(deadline: Instant) { ...@@ -53,3 +80,42 @@ pub async fn sleep_until_precise(deadline: Instant) {
tokio::time::sleep_until(tokio::time::Instant::from_std(deadline)).await; tokio::time::sleep_until(tokio::time::Instant::from_std(deadline)).await;
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prefill_handoff_delay_only_applies_to_completed_prefill() {
let delay_ms = compute_prefill_handoff_delay_ms(
WorkerType::Prefill,
true,
128,
Some(1.0),
Some(1_000_000),
)
.expect("prefill completion should produce a handoff delay");
assert!((delay_ms - 128.0).abs() < 1e-9);
assert!(
compute_prefill_handoff_delay_ms(
WorkerType::Prefill,
false,
128,
Some(1.0),
Some(1_000_000),
)
.is_none()
);
assert!(
compute_prefill_handoff_delay_ms(
WorkerType::Decode,
true,
128,
Some(1.0),
Some(1_000_000),
)
.is_none()
);
}
}
...@@ -7,7 +7,8 @@ use std::collections::{BinaryHeap, HashMap}; ...@@ -7,7 +7,8 @@ use std::collections::{BinaryHeap, HashMap};
use anyhow::{Result, anyhow, bail}; use anyhow::{Result, anyhow, bail};
use uuid::Uuid; use uuid::Uuid;
use super::types::{ReadyTurn, Trace, TurnTrace}; use super::types::{ReadyTurn, ReplayRequestHashes, Trace};
use crate::common::protocols::DirectRequest;
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum DriverMode { enum DriverMode {
...@@ -18,12 +19,20 @@ enum DriverMode { ...@@ -18,12 +19,20 @@ enum DriverMode {
#[derive(Debug)] #[derive(Debug)]
struct SessionRuntime { struct SessionRuntime {
session_id: String, session_id: String,
turns: Vec<TurnTrace>, turns: Vec<TurnRuntime>,
next_turn_index: usize, next_turn_index: usize,
next_ready_at_ms: Option<f64>, next_ready_at_ms: Option<f64>,
in_flight: Option<Uuid>, in_flight: Option<Uuid>,
} }
#[derive(Debug)]
struct TurnRuntime {
tokens: Vec<u32>,
max_output_tokens: usize,
delay_after_previous_ms: f64,
replay_hashes: ReplayRequestHashes,
}
#[derive(Debug)] #[derive(Debug)]
struct InFlightTurn { struct InFlightTurn {
session_index: usize, session_index: usize,
...@@ -66,7 +75,6 @@ impl PartialOrd for ReadySession { ...@@ -66,7 +75,6 @@ impl PartialOrd for ReadySession {
#[derive(Debug)] #[derive(Debug)]
pub struct WorkloadDriver { pub struct WorkloadDriver {
mode: DriverMode, mode: DriverMode,
block_size: usize,
sessions: Vec<SessionRuntime>, sessions: Vec<SessionRuntime>,
in_flight: HashMap<Uuid, InFlightTurn>, in_flight: HashMap<Uuid, InFlightTurn>,
ready_sessions: BinaryHeap<ReadySession>, ready_sessions: BinaryHeap<ReadySession>,
...@@ -82,20 +90,36 @@ impl WorkloadDriver { ...@@ -82,20 +90,36 @@ impl WorkloadDriver {
} }
fn new(trace: Trace, mode: DriverMode) -> Result<Self> { fn new(trace: Trace, mode: DriverMode) -> Result<Self> {
let block_size = trace.block_size;
let sessions: Vec<SessionRuntime> = trace let sessions: Vec<SessionRuntime> = trace
.sessions .sessions
.into_iter() .into_iter()
.map(|session| SessionRuntime { .map(|session| -> Result<SessionRuntime> {
session_id: session.session_id, let next_ready_at_ms = Some(match mode {
turns: session.turns,
next_turn_index: 0,
next_ready_at_ms: Some(match mode {
DriverMode::Trace => session.first_arrival_timestamp_ms.unwrap_or(0.0), DriverMode::Trace => session.first_arrival_timestamp_ms.unwrap_or(0.0),
DriverMode::Concurrency => 0.0, DriverMode::Concurrency => 0.0,
}), });
in_flight: None, let turns = session
.turns
.into_iter()
.map(|turn| -> Result<TurnRuntime> {
Ok(TurnRuntime {
tokens: turn.synthesize_tokens(block_size)?,
max_output_tokens: turn.max_output_tokens,
delay_after_previous_ms: turn.delay_after_previous_ms,
replay_hashes: turn.to_replay_hashes(block_size)?,
})
})
.collect::<Result<Vec<_>>>()?;
Ok(SessionRuntime {
session_id: session.session_id,
turns,
next_turn_index: 0,
next_ready_at_ms,
in_flight: None,
})
}) })
.collect(); .collect::<Result<Vec<_>>>()?;
let ready_sessions = sessions let ready_sessions = sessions
.iter() .iter()
...@@ -111,7 +135,6 @@ impl WorkloadDriver { ...@@ -111,7 +135,6 @@ impl WorkloadDriver {
Ok(Self { Ok(Self {
mode, mode,
block_size: trace.block_size,
sessions, sessions,
in_flight: HashMap::new(), in_flight: HashMap::new(),
ready_sessions, ready_sessions,
...@@ -146,16 +169,18 @@ impl WorkloadDriver { ...@@ -146,16 +169,18 @@ impl WorkloadDriver {
.next_ready_at_ms .next_ready_at_ms
.expect("ready session must have a timestamp"); .expect("ready session must have a timestamp");
let request_uuid = Uuid::new_v4(); let request_uuid = Uuid::new_v4();
let replay_hashes = session.turns[turn_index] let turn = &session.turns[turn_index];
.to_replay_hashes(self.block_size)
.expect("validated trace should always synthesize replay hashes");
let arrival_timestamp_ms = match self.mode { let arrival_timestamp_ms = match self.mode {
DriverMode::Trace => Some(scheduled_ready_at_ms), DriverMode::Trace => Some(scheduled_ready_at_ms),
DriverMode::Concurrency => None, DriverMode::Concurrency => None,
}; };
let request = session.turns[turn_index] let request = DirectRequest {
.to_direct_request(self.block_size, request_uuid, arrival_timestamp_ms) tokens: turn.tokens.clone(),
.expect("validated trace should always synthesize into a direct request"); max_output_tokens: turn.max_output_tokens,
uuid: Some(request_uuid),
dp_rank: 0,
arrival_timestamp_ms,
};
session.in_flight = Some(request_uuid); session.in_flight = Some(request_uuid);
session.next_ready_at_ms = None; session.next_ready_at_ms = None;
self.in_flight.insert( self.in_flight.insert(
...@@ -170,7 +195,7 @@ impl WorkloadDriver { ...@@ -170,7 +195,7 @@ impl WorkloadDriver {
session_id: session.session_id.clone(), session_id: session.session_id.clone(),
turn_index, turn_index,
scheduled_ready_at_ms, scheduled_ready_at_ms,
replay_hashes: Some(replay_hashes), replay_hashes: Some(turn.replay_hashes.clone()),
request, request,
}); });
} }
......
...@@ -59,12 +59,7 @@ impl TurnTrace { ...@@ -59,12 +59,7 @@ impl TurnTrace {
Ok(()) Ok(())
} }
pub fn to_direct_request( pub(crate) fn synthesize_tokens(&self, block_size: usize) -> Result<Vec<u32>> {
&self,
block_size: usize,
request_uuid: Uuid,
arrival_timestamp_ms: Option<f64>,
) -> Result<DirectRequest> {
self.validate_block_size_and_capacity(block_size)?; self.validate_block_size_and_capacity(block_size)?;
let mut tokens = Vec::with_capacity(self.input_length); let mut tokens = Vec::with_capacity(self.input_length);
...@@ -85,6 +80,16 @@ impl TurnTrace { ...@@ -85,6 +80,16 @@ impl TurnTrace {
); );
} }
Ok(tokens)
}
pub fn to_direct_request(
&self,
block_size: usize,
request_uuid: Uuid,
arrival_timestamp_ms: Option<f64>,
) -> Result<DirectRequest> {
let tokens = self.synthesize_tokens(block_size)?;
Ok(DirectRequest { Ok(DirectRequest {
tokens, tokens,
max_output_tokens: self.max_output_tokens, max_output_tokens: self.max_output_tokens,
......
...@@ -309,6 +309,7 @@ impl TraceCollector { ...@@ -309,6 +309,7 @@ impl TraceCollector {
} }
let duration_s = (duration_ms / 1000.0).max(1e-9); let duration_s = (duration_ms / 1000.0).max(1e-9);
let itl_distribution = build_distribution_stats(&itls);
TraceSimulationReport { TraceSimulationReport {
request_counts: TraceRequestCounts { request_counts: TraceRequestCounts {
num_requests: requests.len(), num_requests: requests.len(),
...@@ -335,8 +336,8 @@ impl TraceCollector { ...@@ -335,8 +336,8 @@ impl TraceCollector {
ttst: build_distribution_stats(&ttsts), ttst: build_distribution_stats(&ttsts),
tpot: build_distribution_stats(&tpots), tpot: build_distribution_stats(&tpots),
itl: TraceInterTokenLatencyStats { itl: TraceInterTokenLatencyStats {
distribution: build_distribution_stats(&itls), max_ms: itl_distribution.max_ms,
max_ms: max_value(&itls), distribution: itl_distribution,
}, },
e2e: build_distribution_stats(&e2e_latencies), e2e: build_distribution_stats(&e2e_latencies),
output_token_throughput_per_user: build_distribution_stats( output_token_throughput_per_user: build_distribution_stats(
...@@ -386,39 +387,41 @@ fn mean(values: &[f64]) -> f64 { ...@@ -386,39 +387,41 @@ fn mean(values: &[f64]) -> f64 {
} }
} }
fn max_value(values: &[f64]) -> f64 {
values.iter().copied().reduce(f64::max).unwrap_or(0.0)
}
fn build_distribution_stats(values: &[f64]) -> TraceDistributionStats { fn build_distribution_stats(values: &[f64]) -> TraceDistributionStats {
if values.is_empty() {
return TraceDistributionStats {
mean_ms: 0.0,
min_ms: 0.0,
max_ms: 0.0,
median_ms: 0.0,
p75_ms: 0.0,
p90_ms: 0.0,
p95_ms: 0.0,
p99_ms: 0.0,
std_ms: 0.0,
};
}
let mut sorted = values.to_vec();
sorted.sort_by(|left, right| left.total_cmp(right));
TraceDistributionStats { TraceDistributionStats {
mean_ms: mean(values), mean_ms: mean(values),
min_ms: min_value(values), min_ms: sorted[0],
max_ms: max_value(values), max_ms: *sorted.last().expect("sorted values must be non-empty"),
median_ms: percentile(values, 50.0), median_ms: percentile_sorted(&sorted, 50.0),
p75_ms: percentile(values, 75.0), p75_ms: percentile_sorted(&sorted, 75.0),
p90_ms: percentile(values, 90.0), p90_ms: percentile_sorted(&sorted, 90.0),
p95_ms: percentile(values, 95.0), p95_ms: percentile_sorted(&sorted, 95.0),
p99_ms: percentile(values, 99.0), p99_ms: percentile_sorted(&sorted, 99.0),
std_ms: std_dev(values), std_ms: std_dev(values),
} }
} }
fn percentile(values: &[f64], percentile: f64) -> f64 { fn percentile_sorted(sorted: &[f64], percentile: f64) -> f64 {
if values.is_empty() {
return 0.0;
}
let mut sorted = values.to_vec();
sorted.sort_by(|left, right| left.total_cmp(right));
let rank = ((sorted.len() - 1) as f64 * percentile / 100.0).round() as usize; let rank = ((sorted.len() - 1) as f64 * percentile / 100.0).round() as usize;
sorted[rank.min(sorted.len() - 1)] sorted[rank.min(sorted.len() - 1)]
} }
fn min_value(values: &[f64]) -> f64 {
values.iter().copied().reduce(f64::min).unwrap_or(0.0)
}
fn std_dev(values: &[f64]) -> f64 { fn std_dev(values: &[f64]) -> f64 {
if values.is_empty() { if values.is_empty() {
return 0.0; return 0.0;
......
...@@ -9,10 +9,11 @@ use dynamo_kv_router::config::KvRouterConfig; ...@@ -9,10 +9,11 @@ use dynamo_kv_router::config::KvRouterConfig;
use super::online; use super::online;
use super::validate::{ use super::validate::{
validate_offline_concurrency_args, validate_offline_replay_args, validate_offline_concurrency_args, validate_offline_disagg_concurrency_args,
validate_offline_disagg_replay_args, validate_offline_replay_args,
validate_online_concurrency_args, validate_online_replay_args, validate_online_concurrency_args, validate_online_replay_args,
}; };
use super::{ReplayRouterMode, TraceSimulationReport}; use super::{OfflineDisaggReplayConfig, ReplayRouterMode, TraceSimulationReport};
use crate::common::protocols::{DirectRequest, MockEngineArgs}; use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::loadgen::Trace; use crate::loadgen::Trace;
...@@ -56,6 +57,28 @@ pub fn simulate_trace_file_with_router_mode( ...@@ -56,6 +57,28 @@ pub fn simulate_trace_file_with_router_mode(
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0)) Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
} }
pub fn simulate_trace_file_disagg_with_router_mode(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace_path: &Path,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let config = config.normalized()?;
validate_offline_disagg_replay_args(&config, router_mode)?;
let trace = Trace::from_mooncake(trace_path, config.prefill_args.block_size)?
.normalize_session_starts()?
.speed_up_timing(arrival_speedup_ratio)?;
let started_at = Instant::now();
let report = crate::replay::offline::simulate_trace_workload_disagg(
config,
router_config,
trace,
router_mode,
)?;
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
}
pub fn simulate_trace_live_file( pub fn simulate_trace_live_file(
args: MockEngineArgs, args: MockEngineArgs,
trace_path: &Path, trace_path: &Path,
...@@ -130,6 +153,30 @@ pub fn simulate_trace_requests_with_router_mode( ...@@ -130,6 +153,30 @@ pub fn simulate_trace_requests_with_router_mode(
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0)) Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
} }
pub fn simulate_trace_requests_disagg_with_router_mode(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let config = config.normalized()?;
validate_offline_disagg_replay_args(&config, router_mode)?;
if requests.is_empty() {
bail!("trace replay requires at least one request");
}
let started_at = Instant::now();
let report = crate::replay::offline::simulate_trace_disagg(
config,
router_config,
requests,
arrival_speedup_ratio,
router_mode,
)?;
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
}
pub fn simulate_trace_live_requests( pub fn simulate_trace_live_requests(
args: MockEngineArgs, args: MockEngineArgs,
requests: Vec<DirectRequest>, requests: Vec<DirectRequest>,
...@@ -209,6 +256,27 @@ pub fn simulate_concurrency_file_with_router_mode( ...@@ -209,6 +256,27 @@ pub fn simulate_concurrency_file_with_router_mode(
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0)) Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
} }
pub fn simulate_concurrency_file_disagg_with_router_mode(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace_path: &Path,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let config = config.normalized()?;
validate_offline_disagg_concurrency_args(&config, max_in_flight, router_mode)?;
let trace = Trace::from_mooncake(trace_path, config.prefill_args.block_size)?;
let started_at = Instant::now();
let report = simulate_concurrency_workload_disagg_with_router_mode(
config,
router_config,
trace,
max_in_flight,
router_mode,
)?;
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
}
pub fn simulate_concurrency_live_file( pub fn simulate_concurrency_live_file(
args: MockEngineArgs, args: MockEngineArgs,
trace_path: &Path, trace_path: &Path,
...@@ -326,6 +394,28 @@ pub fn simulate_concurrency_requests_with_router_mode( ...@@ -326,6 +394,28 @@ pub fn simulate_concurrency_requests_with_router_mode(
) )
} }
pub fn simulate_concurrency_requests_disagg_with_router_mode(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let config = config.normalized()?;
validate_offline_disagg_concurrency_args(&config, max_in_flight, router_mode)?;
if requests.is_empty() {
bail!("concurrency replay requires at least one request");
}
crate::replay::offline::simulate_concurrency_disagg(
config,
router_config,
requests,
max_in_flight,
router_mode,
)
}
pub fn simulate_trace_workload( pub fn simulate_trace_workload(
args: MockEngineArgs, args: MockEngineArgs,
trace: Trace, trace: Trace,
...@@ -360,6 +450,24 @@ pub fn simulate_trace_workload_with_router_mode( ...@@ -360,6 +450,24 @@ pub fn simulate_trace_workload_with_router_mode(
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0)) Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
} }
pub fn simulate_trace_workload_disagg_with_router_mode(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let config = config.normalized()?;
validate_offline_disagg_replay_args(&config, router_mode)?;
let started_at = Instant::now();
let report = crate::replay::offline::simulate_trace_workload_disagg(
config,
router_config,
trace,
router_mode,
)?;
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
}
pub fn simulate_trace_live_workload( pub fn simulate_trace_live_workload(
args: MockEngineArgs, args: MockEngineArgs,
trace: Trace, trace: Trace,
...@@ -422,6 +530,24 @@ pub fn simulate_concurrency_workload_with_router_mode( ...@@ -422,6 +530,24 @@ pub fn simulate_concurrency_workload_with_router_mode(
) )
} }
pub fn simulate_concurrency_workload_disagg_with_router_mode(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let config = config.normalized()?;
validate_offline_disagg_concurrency_args(&config, max_in_flight, router_mode)?;
crate::replay::offline::simulate_concurrency_workload_disagg(
config,
router_config,
trace,
max_in_flight,
router_mode,
)
}
pub fn simulate_concurrency_live_workload( pub fn simulate_concurrency_live_workload(
args: MockEngineArgs, args: MockEngineArgs,
trace: Trace, trace: Trace,
......
...@@ -10,7 +10,7 @@ mod validate; ...@@ -10,7 +10,7 @@ mod validate;
use std::collections::VecDeque; use std::collections::VecDeque;
use crate::common::protocols::DirectRequest; use crate::common::protocols::{DirectRequest, MockEngineArgs};
pub(crate) use collector::TraceCollector; pub(crate) use collector::TraceCollector;
#[cfg(test)] #[cfg(test)]
...@@ -25,20 +25,50 @@ pub enum ReplayRouterMode { ...@@ -25,20 +25,50 @@ pub enum ReplayRouterMode {
KvRouter, KvRouter,
} }
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ReplayArgsMode {
Aggregated,
Disagg,
}
#[derive(Clone, Debug)]
pub struct OfflineDisaggReplayConfig {
pub prefill_args: MockEngineArgs,
pub decode_args: MockEngineArgs,
pub num_prefill_workers: usize,
pub num_decode_workers: usize,
}
impl OfflineDisaggReplayConfig {
pub fn normalized(self) -> anyhow::Result<Self> {
Ok(Self {
prefill_args: self.prefill_args.normalized()?,
decode_args: self.decode_args.normalized()?,
num_prefill_workers: self.num_prefill_workers,
num_decode_workers: self.num_decode_workers,
})
}
}
pub use entrypoints::{ pub use entrypoints::{
simulate_concurrency_file, simulate_concurrency_file_with_router_mode, simulate_concurrency_file, simulate_concurrency_file_disagg_with_router_mode,
simulate_concurrency_live_file, simulate_concurrency_live_file_with_router_mode, simulate_concurrency_file_with_router_mode, simulate_concurrency_live_file,
simulate_concurrency_live_requests, simulate_concurrency_live_requests_with_router_mode, simulate_concurrency_live_file_with_router_mode, simulate_concurrency_live_requests,
simulate_concurrency_live_workload, simulate_concurrency_live_workload_with_router_mode, simulate_concurrency_live_requests_with_router_mode, simulate_concurrency_live_workload,
simulate_concurrency_requests, simulate_concurrency_requests_with_router_mode, simulate_concurrency_live_workload_with_router_mode, simulate_concurrency_requests,
simulate_concurrency_workload, simulate_concurrency_workload_with_router_mode, simulate_concurrency_requests_disagg_with_router_mode,
simulate_trace_file, simulate_trace_file_with_router_mode, simulate_trace_live_file, simulate_concurrency_requests_with_router_mode, simulate_concurrency_workload,
simulate_trace_live_file_with_router_mode, simulate_trace_live_requests, simulate_concurrency_workload_disagg_with_router_mode,
simulate_trace_live_requests_with_router_mode, simulate_trace_live_workload, simulate_concurrency_workload_with_router_mode, simulate_trace_file,
simulate_trace_live_workload_with_router_mode, simulate_trace_requests, simulate_trace_file_disagg_with_router_mode, simulate_trace_file_with_router_mode,
simulate_trace_live_file, simulate_trace_live_file_with_router_mode,
simulate_trace_live_requests, simulate_trace_live_requests_with_router_mode,
simulate_trace_live_workload, simulate_trace_live_workload_with_router_mode,
simulate_trace_requests, simulate_trace_requests_disagg_with_router_mode,
simulate_trace_requests_with_router_mode, simulate_trace_workload, simulate_trace_requests_with_router_mode, simulate_trace_workload,
simulate_trace_workload_with_router_mode, simulate_trace_workload_disagg_with_router_mode, simulate_trace_workload_with_router_mode,
}; };
pub use validate::validate_replay_args_mode;
pub(crate) fn normalize_trace_requests( pub(crate) fn normalize_trace_requests(
mut requests: Vec<DirectRequest>, mut requests: Vec<DirectRequest>,
......
...@@ -15,10 +15,11 @@ The public replay entrypoints live one level up in `lib/mocker/src/replay/entryp ...@@ -15,10 +15,11 @@ The public replay entrypoints live one level up in `lib/mocker/src/replay/entryp
Offline replay starts in `lib/mocker/src/replay/offline/mod.rs`. Offline replay starts in `lib/mocker/src/replay/offline/mod.rs`.
`offline/mod.rs` chooses between two implementations: `offline/mod.rs` chooses between three implementations:
- `lib/mocker/src/replay/offline/single.rs` for the special case `num_workers == 1` with the vLLM engine - `lib/mocker/src/replay/offline/single.rs` for the special case `num_workers == 1` with the vLLM engine
- `lib/mocker/src/replay/offline/multi.rs` for everything else, including multi-worker replay and `kv_router` replay - `lib/mocker/src/replay/offline/multi.rs` for everything else, including multi-worker replay and `kv_router` replay
- `lib/mocker/src/replay/offline/disagg.rs` for offline disaggregated prefill/decode replay
## File Map ## File Map
...@@ -28,6 +29,8 @@ Offline replay starts in `lib/mocker/src/replay/offline/mod.rs`. ...@@ -28,6 +29,8 @@ Offline replay starts in `lib/mocker/src/replay/offline/mod.rs`.
Minimal replay loop for one vLLM worker. Minimal replay loop for one vLLM worker.
- `lib/mocker/src/replay/offline/multi.rs` - `lib/mocker/src/replay/offline/multi.rs`
General offline cluster simulator for multi-worker replay and KV-router replay. General offline cluster simulator for multi-worker replay and KV-router replay.
- `lib/mocker/src/replay/offline/disagg.rs`
Offline two-stage replay harness with separate prefill and decode pools.
- `lib/mocker/src/replay/offline/state.rs` - `lib/mocker/src/replay/offline/state.rs`
Per-worker wrapper around `EngineCore`, including optional KV event capture. Per-worker wrapper around `EngineCore`, including optional KV event capture.
- `lib/mocker/src/replay/offline/events.rs` - `lib/mocker/src/replay/offline/events.rs`
...@@ -114,7 +117,7 @@ So offline replay is not a toy simulator. It reuses the real per-pass mocker sch ...@@ -114,7 +117,7 @@ So offline replay is not a toy simulator. It reuses the real per-pass mocker sch
## Completion Event Queue ## Completion Event Queue
The multi-worker harness uses `SimulationEvent` from `lib/mocker/src/replay/offline/events.rs` as a min-time priority queue implemented with `BinaryHeap`. The multi-worker and disagg harnesses use `SimulationEvent` from `lib/mocker/src/replay/offline/events.rs` as a min-time priority queue implemented with `BinaryHeap`.
Right now the only scheduled event type is: Right now the only scheduled event type is:
...@@ -122,6 +125,7 @@ Right now the only scheduled event type is: ...@@ -122,6 +125,7 @@ Right now the only scheduled event type is:
That event carries: That event carries:
- worker `stage` (`aggregated`, `prefill`, or `decode`)
- `worker_idx` - `worker_idx`
- `completed_requests` - `completed_requests`
- `output_signals` - `output_signals`
...@@ -167,7 +171,7 @@ flowchart LR ...@@ -167,7 +171,7 @@ flowchart LR
M --> G M --> G
``` ```
### Why KV events are captured only here ### Why KV events are captured only where needed
When offline replay uses `kv_router`, workers are created with KV event capture enabled via: When offline replay uses `kv_router`, workers are created with KV event capture enabled via:
...@@ -177,6 +181,37 @@ When offline replay uses `kv_router`, workers are created with KV event capture ...@@ -177,6 +181,37 @@ When offline replay uses `kv_router`, workers are created with KV event capture
That causes each pass to return router-visible `kv_events`, which the harness applies synchronously to the offline router indexer after the pass completes. That causes each pass to return router-visible `kv_events`, which the harness applies synchronously to the offline router indexer after the pass completes.
In round-robin mode, this capture is skipped because nothing consumes those events. In round-robin mode, this capture is skipped because nothing consumes those events.
In offline disagg replay, only the prefill workers capture and publish KV events; the decode workers
run with capture disabled because the decode router is overlap-blind and does not consume router
events.
## Disaggregated Harness
The disaggregated runtime in `lib/mocker/src/replay/offline/disagg.rs` models two distinct stages:
- a prefill router and prefill worker pool
- a decode router and decode worker pool
It keeps one logical clock and one completion-event heap, but request ownership moves through a
two-stage state machine instead of the aggregated single-pool lifecycle.
The prefill router is derived from the main router config with `router_track_active_blocks = false`.
The decode router is derived with:
- overlap disabled
- `assume_kv_reuse = false`
- `track_prefill_tokens = false`
The prefill stage runs a hidden synthetic one-token bootstrap request. When prefill completes, the
harness:
1. applies any prefill KV events
2. marks prefill complete in the prefill router
3. frees prefill router state
4. enqueues the original request into decode at the same logical timestamp
Decode then runs with normal collector visibility. The public replay report remains decode-only, so
TTFT includes prefill queueing and prefill compute.
## Trace vs Concurrency Modes ## Trace vs Concurrency Modes
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::{BinaryHeap, HashMap, VecDeque};
use anyhow::{Result, anyhow, bail};
use dynamo_kv_router::config::KvRouterConfig;
use dynamo_kv_router::protocols::RouterEvent;
use uuid::Uuid;
use super::events::{SimulationEvent, SimulationWorkerStage};
use super::normalize_trace_requests;
use super::runtime_utils::{
WorkerCompletionPayload, next_timestamp as choose_next_timestamp, pop_next_concurrency_ready,
pop_next_trace_ready, pop_ready_decode_handoff, pop_ready_worker_completion,
push_decode_handoff, push_worker_completion,
};
#[cfg(test)]
use super::state::DisaggPhase;
#[cfg(test)]
use super::state::DisaggRequestSnapshot;
use super::state::{DisaggRequestState, OfflineWorkerState};
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal};
use crate::loadgen::{ReplayRequestHashes, Trace, WorkloadDriver};
use crate::replay::router::OfflineReplayRouter;
use crate::replay::{
OfflineDisaggReplayConfig, ReplayRouterMode, TraceCollector, TraceSimulationReport,
};
use crate::scheduler::RouterEventVisibility;
#[derive(Debug, Clone, Copy)]
enum ReplayMode {
Trace,
Concurrency { max_in_flight: usize },
}
enum AdmissionSource {
Requests(VecDeque<DirectRequest>),
Workload(WorkloadDriver),
}
#[cfg(test)]
#[derive(Debug, Default, Clone, PartialEq)]
struct DisaggRuntimeStats {
request_snapshots: HashMap<Uuid, DisaggRequestSnapshot>,
prefill_assignments: HashMap<Uuid, usize>,
decode_assignments: HashMap<Uuid, usize>,
handoff_ms: HashMap<Uuid, f64>,
prefill_marked_count: usize,
prefill_freed_count: usize,
decode_freed_count: usize,
max_prefill_router_pending: usize,
max_decode_router_pending: usize,
}
#[cfg(not(test))]
#[derive(Debug, Default, Clone, PartialEq, Eq)]
struct DisaggRuntimeStats;
struct DisaggRuntime {
now_ms: f64,
next_prefill_worker_idx: usize,
next_decode_worker_idx: usize,
next_event_seq: u64,
admission: AdmissionSource,
prefill_workers: Vec<OfflineWorkerState>,
decode_workers: Vec<OfflineWorkerState>,
prefill_router: Option<OfflineReplayRouter>,
decode_router: Option<OfflineReplayRouter>,
requests: HashMap<Uuid, DisaggRequestState>,
collector: TraceCollector,
events: BinaryHeap<SimulationEvent>,
mode: ReplayMode,
stats: DisaggRuntimeStats,
}
impl DisaggRuntime {
fn new(
config: &OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
pending: VecDeque<DirectRequest>,
mode: ReplayMode,
router_mode: ReplayRouterMode,
) -> Result<Self> {
Self::new_with_source(
config,
router_config,
AdmissionSource::Requests(pending),
mode,
router_mode,
)
}
fn new_workload(
config: &OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
driver: WorkloadDriver,
mode: ReplayMode,
router_mode: ReplayRouterMode,
) -> Result<Self> {
Self::new_with_source(
config,
router_config,
AdmissionSource::Workload(driver),
mode,
router_mode,
)
}
fn new_with_source(
config: &OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
admission: AdmissionSource,
mode: ReplayMode,
router_mode: ReplayRouterMode,
) -> Result<Self> {
let (prefill_router, decode_router) = match router_mode {
ReplayRouterMode::RoundRobin => (None, None),
ReplayRouterMode::KvRouter => {
let prefill_router_config =
derive_prefill_router_config(&config.prefill_args, router_config.clone());
let decode_router_config =
derive_decode_router_config(&config.decode_args, router_config);
(
Some(OfflineReplayRouter::new(
&config.prefill_args,
Some(prefill_router_config),
config.num_prefill_workers,
)?),
Some(OfflineReplayRouter::new(
&config.decode_args,
Some(decode_router_config),
config.num_decode_workers,
)?),
)
}
};
Ok(Self {
now_ms: 0.0,
next_prefill_worker_idx: 0,
next_decode_worker_idx: 0,
next_event_seq: 0,
admission,
prefill_workers: (0..config.num_prefill_workers)
.map(|worker_idx| {
OfflineWorkerState::new(
worker_idx,
config.prefill_args.clone(),
prefill_router.is_some(),
)
})
.collect(),
decode_workers: (0..config.num_decode_workers)
.map(|worker_idx| {
OfflineWorkerState::new(worker_idx, config.decode_args.clone(), false)
})
.collect(),
prefill_router,
decode_router,
requests: HashMap::new(),
collector: TraceCollector::default(),
events: BinaryHeap::new(),
mode,
#[cfg(test)]
stats: DisaggRuntimeStats::default(),
#[cfg(not(test))]
stats: DisaggRuntimeStats,
})
}
fn cluster_in_flight(&self) -> usize {
self.prefill_workers
.iter()
.map(OfflineWorkerState::in_flight)
.sum::<usize>()
+ self
.decode_workers
.iter()
.map(OfflineWorkerState::in_flight)
.sum::<usize>()
+ self
.prefill_router
.as_ref()
.map_or(0, OfflineReplayRouter::pending_count)
+ self
.decode_router
.as_ref()
.map_or(0, OfflineReplayRouter::pending_count)
}
fn next_prefill_worker(&mut self) -> usize {
let worker_idx = self.next_prefill_worker_idx;
self.next_prefill_worker_idx =
(self.next_prefill_worker_idx + 1) % self.prefill_workers.len();
worker_idx
}
fn next_decode_worker(&mut self) -> usize {
let worker_idx = self.next_decode_worker_idx;
self.next_decode_worker_idx = (self.next_decode_worker_idx + 1) % self.decode_workers.len();
worker_idx
}
fn record_router_pending(&mut self) {
#[cfg(test)]
{
self.stats.max_prefill_router_pending = self.stats.max_prefill_router_pending.max(
self.prefill_router
.as_ref()
.map_or(0, OfflineReplayRouter::pending_count),
);
self.stats.max_decode_router_pending = self.stats.max_decode_router_pending.max(
self.decode_router
.as_ref()
.map_or(0, OfflineReplayRouter::pending_count),
);
}
}
fn validate_worker_idx(&self, stage: SimulationWorkerStage, worker_idx: usize) -> Result<()> {
let worker_count = match stage {
SimulationWorkerStage::Prefill => self.prefill_workers.len(),
SimulationWorkerStage::Decode => self.decode_workers.len(),
SimulationWorkerStage::Aggregated => unreachable!("aggregated stage is not used"),
};
if worker_idx >= worker_count {
bail!("offline disagg replay selected unknown {stage:?} worker index {worker_idx}");
}
Ok(())
}
fn state(&self, uuid: Uuid) -> Result<&DisaggRequestState> {
self.requests
.get(&uuid)
.ok_or_else(|| anyhow!("offline disagg replay missing request state for {uuid}"))
}
fn state_mut(&mut self, uuid: Uuid) -> Result<&mut DisaggRequestState> {
self.requests
.get_mut(&uuid)
.ok_or_else(|| anyhow!("offline disagg replay missing request state for {uuid}"))
}
fn dispatch_prefill(&mut self, uuid: Uuid, worker_idx: usize) -> Result<()> {
self.validate_worker_idx(SimulationWorkerStage::Prefill, worker_idx)?;
let request = self.state(uuid)?.build_prefill_request()?;
self.prefill_workers[worker_idx].receive_request(request);
self.state_mut(uuid)?.start_prefill(worker_idx);
#[cfg(test)]
{
self.stats.prefill_assignments.insert(uuid, worker_idx);
}
Ok(())
}
fn dispatch_decode(&mut self, uuid: Uuid, worker_idx: usize) -> Result<()> {
self.validate_worker_idx(SimulationWorkerStage::Decode, worker_idx)?;
let request = self.state(uuid)?.build_decode_request()?;
self.decode_workers[worker_idx].receive_request(request);
self.state_mut(uuid)?.start_decode(worker_idx);
#[cfg(test)]
{
self.stats.decode_assignments.insert(uuid, worker_idx);
}
Ok(())
}
fn dispatch_prefill_admissions(&mut self, admissions: Vec<(Uuid, usize)>) -> Result<()> {
for (uuid, worker_idx) in admissions {
if !self.state(uuid)?.is_queued_prefill() {
bail!("offline disagg replay expected queued prefill request for {uuid}");
}
self.dispatch_prefill(uuid, worker_idx)?;
}
Ok(())
}
fn dispatch_decode_admissions(&mut self, admissions: Vec<(Uuid, usize)>) -> Result<()> {
for (uuid, worker_idx) in admissions {
if !self.state(uuid)?.is_queued_decode() {
bail!("offline disagg replay expected queued decode request for {uuid}");
}
self.dispatch_decode(uuid, worker_idx)?;
}
Ok(())
}
fn enqueue_decode(&mut self, uuid: Uuid) -> Result<()> {
let Some(decode_router) = self.decode_router.as_mut() else {
let worker_idx = self.next_decode_worker();
self.dispatch_decode(uuid, worker_idx)?;
return Ok(());
};
let maybe_worker_idx = {
let requests = &self.requests;
let request = requests
.get(&uuid)
.ok_or_else(|| anyhow!("offline disagg replay missing request state for {uuid}"))?
.original_request()?;
decode_router.submit_request_with_hashes(request, None, self.now_ms)?
};
self.record_router_pending();
#[cfg(test)]
{
self.stats.handoff_ms.insert(uuid, self.now_ms);
}
if let Some(worker_idx) = maybe_worker_idx {
self.dispatch_decode(uuid, worker_idx)?;
return Ok(());
}
self.state_mut(uuid)?.queue_decode();
Ok(())
}
fn on_external_arrival(
&mut self,
mut request: DirectRequest,
arrival_time_ms: f64,
replay_hashes: Option<ReplayRequestHashes>,
) -> Result<Uuid> {
let uuid = request.uuid.unwrap_or_else(Uuid::new_v4);
request.uuid = Some(uuid);
request.arrival_timestamp_ms = Some(arrival_time_ms);
self.collector.on_arrival(
uuid,
arrival_time_ms,
request.tokens.len(),
request.max_output_tokens,
);
self.requests
.insert(uuid, DisaggRequestState::new(request, arrival_time_ms));
let Some(prefill_router) = self.prefill_router.as_mut() else {
let worker_idx = self.next_prefill_worker();
self.dispatch_prefill(uuid, worker_idx)?;
return Ok(uuid);
};
let maybe_worker_idx = {
let requests = &self.requests;
let request = requests
.get(&uuid)
.ok_or_else(|| anyhow!("offline disagg replay missing request state for {uuid}"))?
.original_request()?;
prefill_router.submit_request_with_hashes(request, replay_hashes, self.now_ms)?
};
self.record_router_pending();
if let Some(worker_idx) = maybe_worker_idx {
self.dispatch_prefill(uuid, worker_idx)?;
}
Ok(uuid)
}
fn is_done(&self) -> bool {
self.events.is_empty()
&& self.cluster_in_flight() == 0
&& match &self.admission {
AdmissionSource::Requests(pending) => pending.is_empty(),
AdmissionSource::Workload(driver) => driver.is_drained(),
}
&& self
.prefill_workers
.iter()
.all(OfflineWorkerState::is_drained)
&& self
.decode_workers
.iter()
.all(OfflineWorkerState::is_drained)
}
fn next_timestamp(&mut self) -> Option<f64> {
let next_event_ms = self.events.peek().map(|event| event.at_ms);
let cluster_in_flight = self.cluster_in_flight();
let next_arrival_ms = match (&self.mode, &mut self.admission) {
(ReplayMode::Trace, AdmissionSource::Requests(pending)) => pending
.front()
.and_then(|request| request.arrival_timestamp_ms),
(ReplayMode::Trace, AdmissionSource::Workload(driver)) => driver.next_ready_time_ms(),
(ReplayMode::Concurrency { max_in_flight }, AdmissionSource::Workload(driver)) => {
if cluster_in_flight < *max_in_flight {
driver.next_ready_time_ms()
} else {
None
}
}
(ReplayMode::Concurrency { .. }, AdmissionSource::Requests(_)) => None,
};
choose_next_timestamp(next_arrival_ms, next_event_ms)
}
fn apply_prefill_router_events(&mut self, events: Vec<RouterEvent>) -> Result<()> {
let Some(prefill_router) = self.prefill_router.as_mut() else {
return Ok(());
};
for event in events {
prefill_router.apply_event(event)?;
}
Ok(())
}
fn process_prefill_signal(&mut self, signal: OutputSignal) -> Result<()> {
if !signal.completed {
return Ok(());
}
if self.prefill_router.is_some() {
let prefill_complete_admissions = {
let prefill_router = self.prefill_router.as_mut().expect("router checked above");
prefill_router.mark_prefill_completed(signal.uuid)?
};
#[cfg(test)]
{
self.stats.prefill_marked_count += 1;
}
self.record_router_pending();
self.dispatch_prefill_admissions(prefill_complete_admissions)?;
let admissions = {
let prefill_router = self.prefill_router.as_mut().expect("router checked above");
prefill_router.free(signal.uuid)?
};
#[cfg(test)]
{
self.stats.prefill_freed_count += 1;
}
self.record_router_pending();
self.dispatch_prefill_admissions(admissions)?;
}
self.enqueue_decode_after_handoff(signal.uuid, signal.handoff_delay_ms)
}
fn process_decode_signal(&mut self, signal: OutputSignal) -> Result<()> {
if !signal.completed {
return Ok(());
}
let admissions = if let Some(decode_router) = self.decode_router.as_mut() {
let admissions = decode_router.free(signal.uuid)?;
#[cfg(test)]
{
self.stats.decode_freed_count += 1;
}
admissions
} else {
Vec::new()
};
self.record_router_pending();
if let AdmissionSource::Workload(driver) = &mut self.admission {
driver.on_complete(signal.uuid, self.now_ms)?;
}
self.state_mut(signal.uuid)?.mark_done();
self.dispatch_decode_admissions(admissions)?;
Ok(())
}
fn process_prefill_pass(
&mut self,
worker_idx: usize,
completed_requests: usize,
output_signals: Vec<OutputSignal>,
kv_events: Vec<RouterEvent>,
) -> Result<()> {
self.prefill_workers[worker_idx].mark_completed(completed_requests);
self.apply_prefill_router_events(kv_events)?;
for signal in output_signals {
self.process_prefill_signal(signal)?;
}
Ok(())
}
fn process_decode_pass(
&mut self,
worker_idx: usize,
completed_requests: usize,
output_signals: Vec<OutputSignal>,
) -> Result<()> {
self.decode_workers[worker_idx].mark_completed(completed_requests);
for signal in output_signals {
self.process_decode_signal(signal)?;
}
Ok(())
}
fn apply_worker_completions(&mut self) -> Result<bool> {
let mut changed = false;
while let Some(WorkerCompletionPayload {
stage,
worker_idx,
completed_requests,
output_signals,
kv_events,
}) = pop_ready_worker_completion(&mut self.events, self.now_ms)
{
match stage {
SimulationWorkerStage::Prefill => {
self.prefill_workers[worker_idx].mark_idle();
self.process_prefill_pass(
worker_idx,
completed_requests,
output_signals,
kv_events,
)?;
}
SimulationWorkerStage::Decode => {
self.decode_workers[worker_idx].mark_idle();
self.process_decode_pass(worker_idx, completed_requests, output_signals)?;
}
SimulationWorkerStage::Aggregated => {
bail!("offline disagg replay received an aggregated completion event")
}
}
changed = true;
}
Ok(changed)
}
fn apply_decode_handoffs(&mut self) -> Result<bool> {
let mut changed = false;
while let Some(uuid) = pop_ready_decode_handoff(&mut self.events, self.now_ms) {
self.enqueue_decode(uuid)?;
changed = true;
}
Ok(changed)
}
fn enqueue_decode_after_handoff(
&mut self,
uuid: Uuid,
handoff_delay_ms: Option<f64>,
) -> Result<()> {
if let Some(delay_ms) = handoff_delay_ms
&& delay_ms > 0.0
{
push_decode_handoff(
&mut self.events,
&mut self.next_event_seq,
self.now_ms + delay_ms,
uuid,
);
return Ok(());
}
self.enqueue_decode(uuid)
}
fn release_trace_arrivals(&mut self) -> Result<bool> {
let mut released_any = false;
if matches!(self.admission, AdmissionSource::Requests(_)) {
loop {
let next_ready = match &mut self.admission {
AdmissionSource::Requests(pending) => {
pop_next_trace_ready(pending, self.now_ms)
}
AdmissionSource::Workload(_) => unreachable!(),
};
let Some((request, arrival_ms)) = next_ready else {
break;
};
self.on_external_arrival(request, arrival_ms, None)?;
released_any = true;
}
return Ok(released_any);
}
let ready_requests = match &mut self.admission {
AdmissionSource::Requests(_) => unreachable!(),
AdmissionSource::Workload(driver) => driver.pop_ready(self.now_ms, usize::MAX),
};
for ready in ready_requests {
self.on_external_arrival(
ready.request,
ready.scheduled_ready_at_ms,
ready.replay_hashes,
)?;
released_any = true;
}
Ok(released_any)
}
fn top_off_concurrency(&mut self, max_in_flight: usize) -> Result<bool> {
let mut released_any = false;
if matches!(self.admission, AdmissionSource::Requests(_)) {
loop {
let cluster_in_flight = self.cluster_in_flight();
let next_ready = match &mut self.admission {
AdmissionSource::Requests(pending) => pop_next_concurrency_ready(
pending,
self.now_ms,
cluster_in_flight,
max_in_flight,
),
AdmissionSource::Workload(_) => unreachable!(),
};
let Some((request, arrival_ms)) = next_ready else {
break;
};
self.on_external_arrival(request, arrival_ms, None)?;
released_any = true;
}
return Ok(released_any);
}
let available = max_in_flight.saturating_sub(self.cluster_in_flight());
if available == 0 {
return Ok(false);
}
let ready_requests = match &mut self.admission {
AdmissionSource::Requests(_) => unreachable!(),
AdmissionSource::Workload(driver) => driver.pop_ready(self.now_ms, available),
};
for ready in ready_requests {
self.on_external_arrival(ready.request, self.now_ms, ready.replay_hashes)?;
released_any = true;
}
Ok(released_any)
}
fn drive_prefill_workers(&mut self) -> Result<bool> {
let mut changed = false;
for worker_idx in 0..self.prefill_workers.len() {
loop {
if !self.prefill_workers[worker_idx].is_ready() {
break;
}
let executed = self.prefill_workers[worker_idx].execute_hidden_pass(self.now_ms);
changed = true;
let completion_kv_events =
if executed.router_event_visibility == RouterEventVisibility::PassStart {
self.apply_prefill_router_events(executed.kv_events)?;
Vec::new()
} else {
executed.kv_events
};
if executed.end_ms == self.now_ms {
self.process_prefill_pass(
worker_idx,
executed.completed_requests,
executed.output_signals,
completion_kv_events,
)?;
continue;
}
self.prefill_workers[worker_idx].mark_busy();
push_worker_completion(
&mut self.events,
&mut self.next_event_seq,
executed.end_ms,
WorkerCompletionPayload {
stage: SimulationWorkerStage::Prefill,
worker_idx,
completed_requests: executed.completed_requests,
output_signals: executed.output_signals,
kv_events: completion_kv_events,
},
);
break;
}
}
Ok(changed)
}
fn drive_decode_workers(&mut self) -> Result<bool> {
let mut changed = false;
for worker_idx in 0..self.decode_workers.len() {
loop {
if !self.decode_workers[worker_idx].is_ready() {
break;
}
let executed = {
let (workers, collector) = (&mut self.decode_workers, &mut self.collector);
workers[worker_idx].execute_pass(collector, self.now_ms)
};
changed = true;
if executed.end_ms == self.now_ms {
self.process_decode_pass(
worker_idx,
executed.completed_requests,
executed.output_signals,
)?;
continue;
}
self.decode_workers[worker_idx].mark_busy();
push_worker_completion(
&mut self.events,
&mut self.next_event_seq,
executed.end_ms,
WorkerCompletionPayload {
stage: SimulationWorkerStage::Decode,
worker_idx,
completed_requests: executed.completed_requests,
output_signals: executed.output_signals,
kv_events: Vec::new(),
},
);
break;
}
}
Ok(changed)
}
fn drain_current_timestamp(&mut self) -> Result<()> {
loop {
let mut changed = self.apply_worker_completions()?;
changed |= self.apply_decode_handoffs()?;
changed |= match self.mode {
ReplayMode::Trace => self.release_trace_arrivals()?,
ReplayMode::Concurrency { max_in_flight } => {
self.top_off_concurrency(max_in_flight)?
}
};
changed |= self.drive_prefill_workers()?;
changed |= self.drive_decode_workers()?;
if !changed {
break;
}
}
Ok(())
}
fn finish_test_stats(&mut self) {
#[cfg(test)]
{
self.stats.request_snapshots = self
.requests
.iter()
.map(|(uuid, state)| (*uuid, state.debug_snapshot()))
.collect();
}
}
fn run(mut self) -> Result<(TraceCollector, DisaggRuntimeStats)> {
self.drain_current_timestamp()?;
while !self.is_done() {
let Some(next_timestamp_ms) = self.next_timestamp() else {
bail!(
"offline disagg replay reached a dead end with {} in-flight requests remaining",
self.cluster_in_flight()
);
};
self.now_ms = next_timestamp_ms;
self.drain_current_timestamp()?;
}
self.finish_test_stats();
Ok((self.collector, self.stats))
}
}
fn base_router_config(
args: &MockEngineArgs,
router_config: Option<KvRouterConfig>,
) -> KvRouterConfig {
let mut config = router_config.unwrap_or_default();
if let Some(policy) = args.router_queue_policy {
config.router_queue_policy = policy;
}
config
}
fn derive_prefill_router_config(
args: &MockEngineArgs,
router_config: Option<KvRouterConfig>,
) -> KvRouterConfig {
let mut config = base_router_config(args, router_config);
config.router_track_active_blocks = false;
config
}
fn derive_decode_router_config(
args: &MockEngineArgs,
router_config: Option<KvRouterConfig>,
) -> KvRouterConfig {
let mut config = base_router_config(args, router_config);
config.overlap_score_weight = 0.0;
config.router_assume_kv_reuse = false;
config.router_track_prefill_tokens = false;
config
}
pub(crate) fn simulate_trace_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let (collector, _) = DisaggRuntime::new(
&config,
router_config,
pending,
ReplayMode::Trace,
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let pending = VecDeque::from(requests);
let (collector, _) = DisaggRuntime::new(
&config,
router_config,
pending,
ReplayMode::Concurrency { max_in_flight },
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_trace_workload_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let driver = WorkloadDriver::new_trace(trace)?;
let (collector, _) = DisaggRuntime::new_workload(
&config,
router_config,
driver,
ReplayMode::Trace,
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_workload_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let driver = WorkloadDriver::new_concurrency(trace)?;
let (collector, _) = DisaggRuntime::new_workload(
&config,
router_config,
driver,
ReplayMode::Concurrency { max_in_flight },
router_mode,
)?
.run()?;
Ok(collector.finish())
}
#[cfg(test)]
fn run_trace_collect(
config: &OfflineDisaggReplayConfig,
requests: Vec<DirectRequest>,
router_config: Option<KvRouterConfig>,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> (TraceCollector, DisaggRuntimeStats) {
let pending = normalize_trace_requests(requests, arrival_speedup_ratio).unwrap();
DisaggRuntime::new(
config,
router_config,
pending,
ReplayMode::Trace,
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
fn run_concurrency_collect(
config: &OfflineDisaggReplayConfig,
requests: Vec<DirectRequest>,
router_config: Option<KvRouterConfig>,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, DisaggRuntimeStats) {
DisaggRuntime::new(
config,
router_config,
VecDeque::from(requests),
ReplayMode::Concurrency { max_in_flight },
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::common::protocols::{MockEngineArgs, WorkerType};
fn staged_args(worker_type: WorkerType, speedup_ratio: f64) -> MockEngineArgs {
MockEngineArgs::builder()
.block_size(64)
.num_gpu_blocks(256)
.max_num_batched_tokens(Some(8192))
.max_num_seqs(Some(8))
.enable_prefix_caching(true)
.enable_chunked_prefill(true)
.speedup_ratio(speedup_ratio)
.decode_speedup_ratio(speedup_ratio)
.worker_type(worker_type)
.build()
.unwrap()
}
fn disagg_config() -> OfflineDisaggReplayConfig {
OfflineDisaggReplayConfig {
prefill_args: staged_args(WorkerType::Prefill, 1000.0),
decode_args: staged_args(WorkerType::Decode, 1000.0),
num_prefill_workers: 2,
num_decode_workers: 2,
}
}
fn disagg_config_with_handoff_delay() -> OfflineDisaggReplayConfig {
let mut config = disagg_config();
config.prefill_args.kv_transfer_bandwidth = Some(1.0);
config.prefill_args.kv_bytes_per_token = Some(1_000_000);
config
}
fn router_config() -> KvRouterConfig {
KvRouterConfig {
router_queue_threshold: Some(1.25),
..KvRouterConfig::default()
}
}
fn request(
uuid: u128,
prompt_tokens: usize,
output_tokens: usize,
arrival_ms: f64,
) -> DirectRequest {
DirectRequest {
tokens: vec![1; prompt_tokens],
max_output_tokens: output_tokens,
uuid: Some(Uuid::from_u128(uuid)),
dp_rank: 0,
arrival_timestamp_ms: Some(arrival_ms),
}
}
#[test]
fn test_derive_stage_router_configs_force_required_overrides() {
let config = KvRouterConfig {
overlap_score_weight: 2.0,
router_track_active_blocks: true,
router_assume_kv_reuse: true,
router_track_prefill_tokens: true,
..KvRouterConfig::default()
};
let args = staged_args(WorkerType::Prefill, 1.0);
let prefill = derive_prefill_router_config(&args, Some(config.clone()));
let decode = derive_decode_router_config(&args, Some(config));
assert!(!prefill.router_track_active_blocks);
assert_eq!(decode.overlap_score_weight, 0.0);
assert!(!decode.router_assume_kv_reuse);
assert!(!decode.router_track_prefill_tokens);
}
#[rstest::rstest]
#[case(ReplayRouterMode::RoundRobin)]
#[case(ReplayRouterMode::KvRouter)]
fn test_trace_smoke_reports_decode_only_tokens(#[case] router_mode: ReplayRouterMode) {
let config = disagg_config();
let requests = vec![request(1, 128, 3, 5.0)];
let router_config = (router_mode == ReplayRouterMode::KvRouter).then(router_config);
let (collector, stats) =
run_trace_collect(&config, requests, router_config, 1.0, router_mode);
let snapshot = collector.snapshot(Uuid::from_u128(1)).unwrap();
let report = collector.finish();
assert_eq!(snapshot.arrival_time_ms, 0.0);
assert!(snapshot.first_admit_ms.is_some());
assert!(snapshot.first_token_ms.is_some());
assert_eq!(snapshot.output_length, 3);
assert_eq!(report.request_counts.completed_requests, 1);
assert_eq!(
stats.request_snapshots[&Uuid::from_u128(1)].phase,
DisaggPhase::Done
);
}
#[rstest::rstest]
#[case(ReplayRouterMode::RoundRobin)]
#[case(ReplayRouterMode::KvRouter)]
fn test_prefill_and_decode_use_separate_worker_pools(#[case] router_mode: ReplayRouterMode) {
let config = disagg_config();
let requests = vec![request(1, 128, 2, 0.0), request(2, 128, 2, 10.0)];
let router_config = (router_mode == ReplayRouterMode::KvRouter).then(router_config);
let (_, stats) = run_trace_collect(&config, requests, router_config, 1.0, router_mode);
for uuid in [Uuid::from_u128(1), Uuid::from_u128(2)] {
assert!(stats.prefill_assignments.contains_key(&uuid));
assert!(stats.decode_assignments.contains_key(&uuid));
}
}
#[test]
fn test_prefill_overlap_prefers_same_worker_after_handoff_delay() {
let config = disagg_config();
let requests = vec![request(1, 128, 2, 0.0), request(2, 128, 2, 100.0)];
let (_, stats) = run_trace_collect(
&config,
requests,
Some(router_config()),
1.0,
ReplayRouterMode::KvRouter,
);
assert_eq!(
stats.prefill_assignments[&Uuid::from_u128(1)],
stats.prefill_assignments[&Uuid::from_u128(2)],
);
}
#[rstest::rstest]
#[case(ReplayRouterMode::RoundRobin)]
#[case(ReplayRouterMode::KvRouter)]
fn test_concurrency_backfill_waits_for_decode_completion(
#[case] router_mode: ReplayRouterMode,
) {
let config = disagg_config();
let requests = vec![
DirectRequest {
tokens: vec![1; 128],
max_output_tokens: 3,
uuid: Some(Uuid::from_u128(1)),
dp_rank: 0,
arrival_timestamp_ms: None,
},
DirectRequest {
tokens: vec![2; 128],
max_output_tokens: 3,
uuid: Some(Uuid::from_u128(2)),
dp_rank: 0,
arrival_timestamp_ms: None,
},
];
let router_config = (router_mode == ReplayRouterMode::KvRouter).then(router_config);
let (collector, _) =
run_concurrency_collect(&config, requests, router_config, 1, router_mode);
let first = collector.snapshot(Uuid::from_u128(1)).unwrap();
let second = collector.snapshot(Uuid::from_u128(2)).unwrap();
assert_eq!(first.arrival_time_ms, 0.0);
assert_eq!(second.arrival_time_ms, first.last_token_ms.unwrap());
}
#[test]
fn test_prefill_completion_marks_and_frees_before_decode_handoff() {
let config = disagg_config();
let requests = vec![request(1, 128, 2, 0.0)];
let (_, stats) = run_trace_collect(
&config,
requests,
Some(router_config()),
1.0,
ReplayRouterMode::KvRouter,
);
assert_eq!(stats.prefill_marked_count, 1);
assert_eq!(stats.prefill_freed_count, 1);
assert_eq!(stats.decode_freed_count, 1);
}
#[test]
fn test_handoff_delay_increases_decode_visible_ttft() {
let requests = vec![request(1, 128, 2, 0.0)];
let (baseline_collector, _) = run_trace_collect(
&disagg_config(),
requests.clone(),
None,
1.0,
ReplayRouterMode::RoundRobin,
);
let (delayed_collector, _) = run_trace_collect(
&disagg_config_with_handoff_delay(),
requests,
None,
1.0,
ReplayRouterMode::RoundRobin,
);
let baseline = baseline_collector.snapshot(Uuid::from_u128(1)).unwrap();
let delayed = delayed_collector.snapshot(Uuid::from_u128(1)).unwrap();
let baseline_ttft = baseline.first_token_ms.unwrap() - baseline.arrival_time_ms;
let delayed_ttft = delayed.first_token_ms.unwrap() - delayed.arrival_time_ms;
assert!(
delayed_ttft >= baseline_ttft + 120.0,
"expected delayed TTFT to include roughly 128ms of handoff delay, baseline={baseline_ttft}, delayed={delayed_ttft}"
);
}
}
...@@ -4,15 +4,27 @@ ...@@ -4,15 +4,27 @@
use std::cmp::Ordering; use std::cmp::Ordering;
use crate::common::protocols::OutputSignal; use crate::common::protocols::OutputSignal;
use uuid::Uuid;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum SimulationWorkerStage {
Aggregated,
Prefill,
Decode,
}
#[derive(Debug)] #[derive(Debug)]
pub(crate) enum SimulationEventKind { pub(crate) enum SimulationEventKind {
WorkerCompletion { WorkerCompletion {
stage: SimulationWorkerStage,
worker_idx: usize, worker_idx: usize,
completed_requests: usize, completed_requests: usize,
output_signals: Vec<OutputSignal>, output_signals: Vec<OutputSignal>,
kv_events: Vec<dynamo_kv_router::protocols::RouterEvent>, kv_events: Vec<dynamo_kv_router::protocols::RouterEvent>,
}, },
DecodeHandoff {
uuid: Uuid,
},
} }
#[derive(Debug)] #[derive(Debug)]
......
...@@ -3,13 +3,16 @@ ...@@ -3,13 +3,16 @@
use crate::common::protocols::{DirectRequest, MockEngineArgs}; use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::loadgen::Trace; use crate::loadgen::Trace;
use crate::replay::OfflineDisaggReplayConfig;
pub(crate) use crate::replay::normalize_trace_requests; pub(crate) use crate::replay::normalize_trace_requests;
use crate::replay::{ReplayRouterMode, TraceSimulationReport}; use crate::replay::{ReplayRouterMode, TraceSimulationReport};
use dynamo_kv_router::config::KvRouterConfig; use dynamo_kv_router::config::KvRouterConfig;
pub(crate) mod core; pub(crate) mod core;
pub(crate) mod disagg;
pub(crate) mod events; pub(crate) mod events;
pub(crate) mod multi; pub(crate) mod multi;
pub(crate) mod runtime_utils;
pub(crate) mod single; pub(crate) mod single;
pub(crate) mod state; pub(crate) mod state;
...@@ -92,3 +95,54 @@ pub(crate) fn simulate_concurrency_workload( ...@@ -92,3 +95,54 @@ pub(crate) fn simulate_concurrency_workload(
) )
} }
} }
pub(crate) fn simulate_trace_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
disagg::simulate_trace_disagg(
config,
router_config,
requests,
arrival_speedup_ratio,
router_mode,
)
}
pub(crate) fn simulate_concurrency_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
disagg::simulate_concurrency_disagg(config, router_config, requests, max_in_flight, router_mode)
}
pub(crate) fn simulate_trace_workload_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
disagg::simulate_trace_workload_disagg(config, router_config, trace, router_mode)
}
pub(crate) fn simulate_concurrency_workload_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
disagg::simulate_concurrency_workload_disagg(
config,
router_config,
trace,
max_in_flight,
router_mode,
)
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::events::{SimulationEvent, SimulationEventKind}; use super::events::{SimulationEvent, SimulationWorkerStage};
use super::normalize_trace_requests; use super::normalize_trace_requests;
use super::runtime_utils::{
WorkerCompletionPayload, next_timestamp as choose_next_timestamp, pop_next_concurrency_ready,
pop_next_trace_ready, pop_ready_worker_completion, push_worker_completion,
};
#[cfg(test)] #[cfg(test)]
use super::state::OfflineWorkerSnapshot; use super::state::OfflineWorkerSnapshot;
use super::state::OfflineWorkerState; use super::state::{AggRequestState, OfflineWorkerState};
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal}; use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal};
use crate::loadgen::{ReplayRequestHashes, Trace, WorkloadDriver}; use crate::loadgen::{ReplayRequestHashes, Trace, WorkloadDriver};
use crate::replay::router::OfflineReplayRouter; use crate::replay::router::OfflineReplayRouter;
...@@ -16,7 +20,7 @@ use crate::scheduler::RouterEventVisibility; ...@@ -16,7 +20,7 @@ use crate::scheduler::RouterEventVisibility;
use anyhow::bail; use anyhow::bail;
use dynamo_kv_router::config::KvRouterConfig; use dynamo_kv_router::config::KvRouterConfig;
use dynamo_kv_router::protocols::RouterEvent; use dynamo_kv_router::protocols::RouterEvent;
use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque}; use std::collections::{BinaryHeap, HashMap, VecDeque};
use uuid::Uuid; use uuid::Uuid;
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
...@@ -62,13 +66,13 @@ struct OfflineRuntime { ...@@ -62,13 +66,13 @@ struct OfflineRuntime {
next_worker_idx: usize, next_worker_idx: usize,
next_event_seq: u64, next_event_seq: u64,
admission: AdmissionSource, admission: AdmissionSource,
router_pending: HashMap<Uuid, DirectRequest>, requests: HashMap<Uuid, AggRequestState>,
queued_requests: usize,
workers: Vec<OfflineWorkerState>, workers: Vec<OfflineWorkerState>,
collector: TraceCollector, collector: TraceCollector,
events: BinaryHeap<SimulationEvent>, events: BinaryHeap<SimulationEvent>,
mode: ReplayMode, mode: ReplayMode,
router: Option<OfflineReplayRouter>, router: Option<OfflineReplayRouter>,
prefill_completed: HashSet<Uuid>,
stats: OfflineRuntimeStats, stats: OfflineRuntimeStats,
#[cfg(test)] #[cfg(test)]
worker_active_requests: Vec<Vec<Uuid>>, worker_active_requests: Vec<Vec<Uuid>>,
...@@ -135,7 +139,8 @@ impl OfflineRuntime { ...@@ -135,7 +139,8 @@ impl OfflineRuntime {
next_worker_idx: 0, next_worker_idx: 0,
next_event_seq: 0, next_event_seq: 0,
admission, admission,
router_pending: HashMap::new(), requests: HashMap::new(),
queued_requests: 0,
workers: (0..num_workers) workers: (0..num_workers)
.map(|worker_idx| { .map(|worker_idx| {
OfflineWorkerState::new(worker_idx, args.clone(), capture_kv_events) OfflineWorkerState::new(worker_idx, args.clone(), capture_kv_events)
...@@ -145,7 +150,6 @@ impl OfflineRuntime { ...@@ -145,7 +150,6 @@ impl OfflineRuntime {
events: BinaryHeap::new(), events: BinaryHeap::new(),
mode, mode,
router, router,
prefill_completed: HashSet::new(),
#[cfg(test)] #[cfg(test)]
stats: OfflineRuntimeStats::default(), stats: OfflineRuntimeStats::default(),
#[cfg(not(test))] #[cfg(not(test))]
...@@ -162,7 +166,7 @@ impl OfflineRuntime { ...@@ -162,7 +166,7 @@ impl OfflineRuntime {
.iter() .iter()
.map(OfflineWorkerState::in_flight) .map(OfflineWorkerState::in_flight)
.sum::<usize>() .sum::<usize>()
+ self.router_pending.len() + self.queued_requests
} }
fn record_in_flight_peak(&mut self) { fn record_in_flight_peak(&mut self) {
...@@ -220,9 +224,14 @@ impl OfflineRuntime { ...@@ -220,9 +224,14 @@ impl OfflineRuntime {
fn dispatch_router_admissions(&mut self, admissions: Vec<(Uuid, usize)>) -> anyhow::Result<()> { fn dispatch_router_admissions(&mut self, admissions: Vec<(Uuid, usize)>) -> anyhow::Result<()> {
for (uuid, worker_idx) in admissions { for (uuid, worker_idx) in admissions {
let request = self.router_pending.remove(&uuid).ok_or_else(|| { let request = self
anyhow::anyhow!("offline replay missing queued request state for {uuid}") .requests
})?; .get_mut(&uuid)
.ok_or_else(|| {
anyhow::anyhow!("offline replay missing queued request state for {uuid}")
})?
.take_queued_request(uuid)?;
self.queued_requests = self.queued_requests.saturating_sub(1);
self.dispatch_to_worker(request, uuid, worker_idx)?; self.dispatch_to_worker(request, uuid, worker_idx)?;
} }
Ok(()) Ok(())
...@@ -248,6 +257,7 @@ impl OfflineRuntime { ...@@ -248,6 +257,7 @@ impl OfflineRuntime {
); );
let Some(router) = self.router.as_mut() else { let Some(router) = self.router.as_mut() else {
self.requests.insert(uuid, AggRequestState::new_running());
let worker_idx = self.next_worker_idx; let worker_idx = self.next_worker_idx;
self.next_worker_idx = (self.next_worker_idx + 1) % self.workers.len(); self.next_worker_idx = (self.next_worker_idx + 1) % self.workers.len();
self.dispatch_to_worker(request, uuid, worker_idx)?; self.dispatch_to_worker(request, uuid, worker_idx)?;
...@@ -258,11 +268,14 @@ impl OfflineRuntime { ...@@ -258,11 +268,14 @@ impl OfflineRuntime {
router.submit_request_with_hashes(&request, replay_hashes, self.now_ms)?; router.submit_request_with_hashes(&request, replay_hashes, self.now_ms)?;
self.record_router_pending(); self.record_router_pending();
if let Some(worker_idx) = maybe_worker_idx { if let Some(worker_idx) = maybe_worker_idx {
self.requests.insert(uuid, AggRequestState::new_running());
self.dispatch_to_worker(request, uuid, worker_idx)?; self.dispatch_to_worker(request, uuid, worker_idx)?;
return Ok(uuid); return Ok(uuid);
} }
self.router_pending.insert(uuid, request); self.requests
.insert(uuid, AggRequestState::new_queued(request));
self.queued_requests += 1;
self.record_in_flight_peak(); self.record_in_flight_peak();
Ok(uuid) Ok(uuid)
} }
...@@ -295,21 +308,7 @@ impl OfflineRuntime { ...@@ -295,21 +308,7 @@ impl OfflineRuntime {
(ReplayMode::Concurrency { .. }, AdmissionSource::Requests(_)) => None, (ReplayMode::Concurrency { .. }, AdmissionSource::Requests(_)) => None,
}; };
match (next_arrival_ms, next_event_ms) { choose_next_timestamp(next_arrival_ms, next_event_ms)
(Some(arrival_ms), Some(event_ms)) => Some(arrival_ms.min(event_ms)),
(Some(arrival_ms), None) => Some(arrival_ms),
(None, Some(event_ms)) => Some(event_ms),
(None, None) => None,
}
}
fn push_event(&mut self, at_ms: f64, kind: SimulationEventKind) {
self.events.push(SimulationEvent {
at_ms,
seq_no: self.next_event_seq,
kind,
});
self.next_event_seq += 1;
} }
fn apply_completed_requests(&mut self, worker_idx: usize, completed_requests: usize) { fn apply_completed_requests(&mut self, worker_idx: usize, completed_requests: usize) {
...@@ -339,7 +338,9 @@ impl OfflineRuntime { ...@@ -339,7 +338,9 @@ impl OfflineRuntime {
} }
self.record_router_pending(); self.record_router_pending();
} }
self.prefill_completed.remove(&signal.uuid); self.requests.remove(&signal.uuid).ok_or_else(|| {
anyhow::anyhow!("offline replay missing request state for {}", signal.uuid)
})?;
if let AdmissionSource::Workload(driver) = &mut self.admission { if let AdmissionSource::Workload(driver) = &mut self.admission {
driver.on_complete(signal.uuid, self.now_ms)?; driver.on_complete(signal.uuid, self.now_ms)?;
} }
...@@ -347,10 +348,23 @@ impl OfflineRuntime { ...@@ -347,10 +348,23 @@ impl OfflineRuntime {
return Ok(()); return Ok(());
} }
if !self.prefill_completed.insert(signal.uuid) { let already_marked = self
.requests
.get(&signal.uuid)
.ok_or_else(|| {
anyhow::anyhow!("offline replay missing request state for {}", signal.uuid)
})?
.prefill_completed();
if already_marked {
return Ok(()); return Ok(());
} }
self.requests
.get_mut(&signal.uuid)
.ok_or_else(|| {
anyhow::anyhow!("offline replay missing request state for {}", signal.uuid)
})?
.mark_prefill_completed();
if let Some(router) = self.router.as_mut() { if let Some(router) = self.router.as_mut() {
admissions = router.mark_prefill_completed(signal.uuid)?; admissions = router.mark_prefill_completed(signal.uuid)?;
#[cfg(test)] #[cfg(test)]
...@@ -395,25 +409,15 @@ impl OfflineRuntime { ...@@ -395,25 +409,15 @@ impl OfflineRuntime {
fn apply_worker_completions(&mut self) -> anyhow::Result<bool> { fn apply_worker_completions(&mut self) -> anyhow::Result<bool> {
let mut changed = false; let mut changed = false;
loop { while let Some(WorkerCompletionPayload {
let Some(event) = self.events.peek() else { stage,
break; worker_idx,
}; completed_requests,
if event.at_ms != self.now_ms { output_signals,
break; kv_events,
} }) = pop_ready_worker_completion(&mut self.events, self.now_ms)
if !matches!(event.kind, SimulationEventKind::WorkerCompletion { .. }) { {
break; debug_assert_eq!(stage, SimulationWorkerStage::Aggregated);
}
let event = self.events.pop().expect("event must exist after peek");
let SimulationEventKind::WorkerCompletion {
worker_idx,
completed_requests,
output_signals,
kv_events,
} = event.kind;
self.workers[worker_idx].mark_idle(); self.workers[worker_idx].mark_idle();
self.process_completed_pass(worker_idx, completed_requests, output_signals, kv_events)?; self.process_completed_pass(worker_idx, completed_requests, output_signals, kv_events)?;
changed = true; changed = true;
...@@ -424,39 +428,33 @@ impl OfflineRuntime { ...@@ -424,39 +428,33 @@ impl OfflineRuntime {
fn release_trace_arrivals(&mut self) -> anyhow::Result<bool> { fn release_trace_arrivals(&mut self) -> anyhow::Result<bool> {
let mut released_any = false; let mut released_any = false;
let mut ready_requests = Vec::new(); if matches!(self.admission, AdmissionSource::Requests(_)) {
loop {
match &mut self.admission { let next_ready = match &mut self.admission {
AdmissionSource::Requests(pending) => { AdmissionSource::Requests(pending) => {
while pending pop_next_trace_ready(pending, self.now_ms)
.front() }
.and_then(|request| request.arrival_timestamp_ms) AdmissionSource::Workload(_) => unreachable!(),
.is_some_and(|arrival_ms| arrival_ms <= self.now_ms) };
{ let Some((request, arrival_ms)) = next_ready else {
let request = pending break;
.pop_front() };
.expect("front request must exist when arrival is ready"); self.assign_request(request, arrival_ms, None)?;
let arrival_ms = request released_any = true;
.arrival_timestamp_ms
.expect("trace replay requests must have an arrival timestamp");
ready_requests.push((request, arrival_ms, None));
}
}
AdmissionSource::Workload(driver) => {
ready_requests.extend(driver.pop_ready(self.now_ms, usize::MAX).into_iter().map(
|ready| {
(
ready.request,
ready.scheduled_ready_at_ms,
ready.replay_hashes,
)
},
));
} }
return Ok(released_any);
} }
for (request, arrival_ms, replay_hashes) in ready_requests { let ready_requests = match &mut self.admission {
self.assign_request(request, arrival_ms, replay_hashes)?; AdmissionSource::Requests(_) => unreachable!(),
AdmissionSource::Workload(driver) => driver.pop_ready(self.now_ms, usize::MAX),
};
for ready in ready_requests {
self.assign_request(
ready.request,
ready.scheduled_ready_at_ms,
ready.replay_hashes,
)?;
released_any = true; released_any = true;
} }
...@@ -465,33 +463,38 @@ impl OfflineRuntime { ...@@ -465,33 +463,38 @@ impl OfflineRuntime {
fn top_off_concurrency(&mut self, max_in_flight: usize) -> anyhow::Result<bool> { fn top_off_concurrency(&mut self, max_in_flight: usize) -> anyhow::Result<bool> {
let mut released_any = false; let mut released_any = false;
if matches!(self.admission, AdmissionSource::Requests(_)) {
loop {
let cluster_in_flight = self.cluster_in_flight();
let next_ready = match &mut self.admission {
AdmissionSource::Requests(pending) => pop_next_concurrency_ready(
pending,
self.now_ms,
cluster_in_flight,
max_in_flight,
),
AdmissionSource::Workload(_) => unreachable!(),
};
let Some((request, arrival_ms)) = next_ready else {
break;
};
self.assign_request(request, arrival_ms, None)?;
released_any = true;
}
return Ok(released_any);
}
let available = max_in_flight.saturating_sub(self.cluster_in_flight()); let available = max_in_flight.saturating_sub(self.cluster_in_flight());
if available == 0 { if available == 0 {
return Ok(false); return Ok(false);
} }
let mut ready_requests = Vec::new(); let ready_requests = match &mut self.admission {
match &mut self.admission { AdmissionSource::Requests(_) => unreachable!(),
AdmissionSource::Requests(pending) => { AdmissionSource::Workload(driver) => driver.pop_ready(self.now_ms, available),
for _ in 0..available { };
let Some(request) = pending.pop_front() else { for ready in ready_requests {
break; self.assign_request(ready.request, self.now_ms, ready.replay_hashes)?;
};
ready_requests.push((request, None));
}
}
AdmissionSource::Workload(driver) => {
ready_requests.extend(
driver
.pop_ready(self.now_ms, available)
.into_iter()
.map(|ready| (ready.request, ready.replay_hashes)),
);
}
}
for (request, replay_hashes) in ready_requests {
self.assign_request(request, self.now_ms, replay_hashes)?;
released_any = true; released_any = true;
} }
...@@ -531,9 +534,12 @@ impl OfflineRuntime { ...@@ -531,9 +534,12 @@ impl OfflineRuntime {
} }
self.workers[worker_idx].mark_busy(); self.workers[worker_idx].mark_busy();
self.push_event( push_worker_completion(
&mut self.events,
&mut self.next_event_seq,
executed.end_ms, executed.end_ms,
SimulationEventKind::WorkerCompletion { WorkerCompletionPayload {
stage: SimulationWorkerStage::Aggregated,
worker_idx, worker_idx,
completed_requests: executed.completed_requests, completed_requests: executed.completed_requests,
output_signals: executed.output_signals, output_signals: executed.output_signals,
...@@ -616,10 +622,19 @@ impl OfflineRuntime { ...@@ -616,10 +622,19 @@ impl OfflineRuntime {
#[cfg(test)] #[cfg(test)]
fn debug_snapshot(&self) -> OfflineRuntimeSnapshot { fn debug_snapshot(&self) -> OfflineRuntimeSnapshot {
let mut router_pending_request_ids = let mut router_pending_request_ids = self
self.router_pending.keys().copied().collect::<Vec<_>>(); .requests
.iter()
.filter(|(_, state)| state.is_queued_at_router())
.map(|(uuid, _)| *uuid)
.collect::<Vec<_>>();
router_pending_request_ids.sort_unstable(); router_pending_request_ids.sort_unstable();
let mut prefill_completed = self.prefill_completed.iter().copied().collect::<Vec<_>>(); let mut prefill_completed = self
.requests
.iter()
.filter(|(_, state)| state.prefill_completed())
.map(|(uuid, _)| *uuid)
.collect::<Vec<_>>();
prefill_completed.sort_unstable(); prefill_completed.sort_unstable();
OfflineRuntimeSnapshot { OfflineRuntimeSnapshot {
...@@ -1483,16 +1498,16 @@ mod tests { ...@@ -1483,16 +1498,16 @@ mod tests {
let request_1 = collector.snapshot(Uuid::from_u128(1)).unwrap(); let request_1 = collector.snapshot(Uuid::from_u128(1)).unwrap();
let request_2 = collector.snapshot(Uuid::from_u128(2)).unwrap(); let request_2 = collector.snapshot(Uuid::from_u128(2)).unwrap();
let request_3 = collector.snapshot(Uuid::from_u128(3)).unwrap(); let request_3 = collector.snapshot(Uuid::from_u128(3)).unwrap();
let first_unblock_ms = request_1
.first_token_ms
.unwrap()
.min(request_2.first_token_ms.unwrap());
assert!(stats.max_router_pending > 0); assert!(stats.max_router_pending > 0);
assert!(request_3.first_admit_ms.unwrap() > request_3.arrival_time_ms); assert!(request_3.first_admit_ms.unwrap() > request_3.arrival_time_ms);
assert!( assert_eq!(request_3.first_admit_ms.unwrap(), first_unblock_ms);
request_3.first_admit_ms.unwrap() assert!(request_3.first_admit_ms.unwrap() < request_1.last_token_ms.unwrap());
< request_1 assert!(request_3.first_admit_ms.unwrap() < request_2.last_token_ms.unwrap());
.last_token_ms
.unwrap()
.min(request_2.last_token_ms.unwrap())
);
} }
#[test] #[test]
...@@ -1561,6 +1576,72 @@ mod tests { ...@@ -1561,6 +1576,72 @@ mod tests {
); );
} }
#[test]
fn test_multi_worker_trace_kv_router_fcfs_and_lcfs_admit_queued_requests_in_opposite_timestamp_order()
{
let requests = vec![
DirectRequest {
tokens: vec![10; 64],
max_output_tokens: 8,
uuid: Some(Uuid::from_u128(10)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
},
DirectRequest {
tokens: vec![20; 128],
max_output_tokens: 8,
uuid: Some(Uuid::from_u128(20)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
},
DirectRequest {
tokens: vec![30; 64],
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(30)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.1),
},
DirectRequest {
tokens: vec![40; 64],
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(40)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.2),
},
];
let (fcfs_collector, fcfs_stats) = run_trace_multi_collect_with_stats(
&queueing_router_args(RouterQueuePolicy::Fcfs),
requests.clone(),
2,
ReplayRouterMode::KvRouter,
);
let (lcfs_collector, lcfs_stats) = run_trace_multi_collect_with_stats(
&queueing_router_args(RouterQueuePolicy::Lcfs),
requests,
2,
ReplayRouterMode::KvRouter,
);
let fcfs_request_30 = fcfs_collector.snapshot(Uuid::from_u128(30)).unwrap();
let fcfs_request_40 = fcfs_collector.snapshot(Uuid::from_u128(40)).unwrap();
let lcfs_request_30 = lcfs_collector.snapshot(Uuid::from_u128(30)).unwrap();
let lcfs_request_40 = lcfs_collector.snapshot(Uuid::from_u128(40)).unwrap();
assert!(fcfs_stats.max_router_pending > 0);
assert!(lcfs_stats.max_router_pending > 0);
assert_eq!(
&fcfs_stats.dispatch_order[2..4],
&[Uuid::from_u128(30), Uuid::from_u128(40)]
);
assert_eq!(
&lcfs_stats.dispatch_order[2..4],
&[Uuid::from_u128(40), Uuid::from_u128(30)]
);
assert!(fcfs_request_30.first_admit_ms.unwrap() < fcfs_request_40.first_admit_ms.unwrap());
assert!(lcfs_request_40.first_admit_ms.unwrap() < lcfs_request_30.first_admit_ms.unwrap());
}
#[test] #[test]
fn test_multi_worker_concurrency_kv_router_respects_max_in_flight() { fn test_multi_worker_concurrency_kv_router_respects_max_in_flight() {
let args = queueing_router_args(RouterQueuePolicy::Fcfs); let args = queueing_router_args(RouterQueuePolicy::Fcfs);
...@@ -1605,6 +1686,51 @@ mod tests { ...@@ -1605,6 +1686,51 @@ mod tests {
assert!(stats.max_router_pending > 0); assert!(stats.max_router_pending > 0);
} }
#[test]
fn test_multi_worker_concurrency_kv_router_records_backfill_timing() {
let args = queueing_router_args(RouterQueuePolicy::Fcfs);
let (collector, stats) = run_concurrency_multi_collect_with_stats(
&args,
vec![
DirectRequest {
tokens: vec![1; 64],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(11)),
dp_rank: 0,
arrival_timestamp_ms: None,
},
DirectRequest {
tokens: vec![2; 64],
max_output_tokens: 4,
uuid: Some(Uuid::from_u128(22)),
dp_rank: 0,
arrival_timestamp_ms: None,
},
DirectRequest {
tokens: vec![3; 64],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(33)),
dp_rank: 0,
arrival_timestamp_ms: None,
},
],
2,
2,
ReplayRouterMode::KvRouter,
);
let request_1 = collector.snapshot(Uuid::from_u128(11)).unwrap();
let request_2 = collector.snapshot(Uuid::from_u128(22)).unwrap();
let request_3 = collector.snapshot(Uuid::from_u128(33)).unwrap();
assert_eq!(request_1.arrival_time_ms, 0.0);
assert_eq!(request_2.arrival_time_ms, 0.0);
assert_eq!(request_3.arrival_time_ms, request_1.last_token_ms.unwrap());
assert!(request_3.arrival_time_ms < request_2.last_token_ms.unwrap());
assert_eq!(request_3.first_admit_ms.unwrap(), request_3.arrival_time_ms);
assert_eq!(stats.max_in_flight_seen, 2);
}
#[test] #[test]
fn test_multi_worker_trace_single_worker_round_robin_matches_single_runtime() { fn test_multi_worker_trace_single_worker_round_robin_matches_single_runtime() {
let args = replay_args(true, true); let args = replay_args(true, true);
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::{BinaryHeap, VecDeque};
use dynamo_kv_router::protocols::RouterEvent;
use super::events::{SimulationEvent, SimulationEventKind, SimulationWorkerStage};
use crate::common::protocols::{DirectRequest, OutputSignal};
#[derive(Debug)]
pub(super) struct WorkerCompletionPayload {
pub stage: SimulationWorkerStage,
pub worker_idx: usize,
pub completed_requests: usize,
pub output_signals: Vec<OutputSignal>,
pub kv_events: Vec<RouterEvent>,
}
pub(super) fn next_timestamp(
next_arrival_ms: Option<f64>,
next_event_ms: Option<f64>,
) -> Option<f64> {
match (next_arrival_ms, next_event_ms) {
(Some(arrival_ms), Some(event_ms)) => Some(arrival_ms.min(event_ms)),
(Some(arrival_ms), None) => Some(arrival_ms),
(None, Some(event_ms)) => Some(event_ms),
(None, None) => None,
}
}
pub(super) fn pop_next_trace_ready(
pending: &mut VecDeque<DirectRequest>,
now_ms: f64,
) -> Option<(DirectRequest, f64)> {
let arrival_ms = pending
.front()
.and_then(|request| request.arrival_timestamp_ms)
.filter(|arrival_ms| *arrival_ms <= now_ms)?;
let request = pending
.pop_front()
.expect("front request must exist when arrival is ready");
Some((request, arrival_ms))
}
pub(super) fn pop_next_concurrency_ready(
pending: &mut VecDeque<DirectRequest>,
now_ms: f64,
cluster_in_flight: usize,
max_in_flight: usize,
) -> Option<(DirectRequest, f64)> {
if cluster_in_flight >= max_in_flight {
return None;
}
let request = pending.pop_front()?;
Some((request, now_ms))
}
pub(super) fn push_worker_completion(
events: &mut BinaryHeap<SimulationEvent>,
next_event_seq: &mut u64,
at_ms: f64,
payload: WorkerCompletionPayload,
) {
events.push(SimulationEvent {
at_ms,
seq_no: *next_event_seq,
kind: SimulationEventKind::WorkerCompletion {
stage: payload.stage,
worker_idx: payload.worker_idx,
completed_requests: payload.completed_requests,
output_signals: payload.output_signals,
kv_events: payload.kv_events,
},
});
*next_event_seq += 1;
}
pub(super) fn pop_ready_worker_completion(
events: &mut BinaryHeap<SimulationEvent>,
now_ms: f64,
) -> Option<WorkerCompletionPayload> {
let event = events.peek()?;
if event.at_ms != now_ms {
return None;
}
let SimulationEventKind::WorkerCompletion { .. } = &event.kind else {
return None;
};
let event = events.pop().expect("event must exist after peek");
let (stage, worker_idx, completed_requests, output_signals, kv_events) = match event.kind {
SimulationEventKind::WorkerCompletion {
stage,
worker_idx,
completed_requests,
output_signals,
kv_events,
} => (
stage,
worker_idx,
completed_requests,
output_signals,
kv_events,
),
SimulationEventKind::DecodeHandoff { .. } => {
unreachable!("peeked worker completion event must match popped event")
}
};
Some(WorkerCompletionPayload {
stage,
worker_idx,
completed_requests,
output_signals,
kv_events,
})
}
pub(super) fn push_decode_handoff(
events: &mut BinaryHeap<SimulationEvent>,
next_event_seq: &mut u64,
at_ms: f64,
uuid: uuid::Uuid,
) {
events.push(SimulationEvent {
at_ms,
seq_no: *next_event_seq,
kind: SimulationEventKind::DecodeHandoff { uuid },
});
*next_event_seq += 1;
}
pub(super) fn pop_ready_decode_handoff(
events: &mut BinaryHeap<SimulationEvent>,
now_ms: f64,
) -> Option<uuid::Uuid> {
let event = events.peek()?;
if event.at_ms != now_ms {
return None;
}
let SimulationEventKind::DecodeHandoff { .. } = &event.kind else {
return None;
};
let event = events.pop().expect("event must exist after peek");
let SimulationEventKind::DecodeHandoff { uuid } = event.kind else {
unreachable!("peeked decode handoff event must match popped event");
};
Some(uuid)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::replay::offline::events::SimulationWorkerStage;
use uuid::Uuid;
fn direct_request(uuid: u128, arrival_timestamp_ms: Option<f64>) -> DirectRequest {
DirectRequest {
tokens: vec![1; 8],
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(uuid)),
dp_rank: 0,
arrival_timestamp_ms,
}
}
#[test]
fn test_next_timestamp_matches_current_choice_logic() {
assert_eq!(next_timestamp(Some(1.0), Some(2.0)), Some(1.0));
assert_eq!(next_timestamp(Some(2.0), Some(1.0)), Some(1.0));
assert_eq!(next_timestamp(Some(3.0), None), Some(3.0));
assert_eq!(next_timestamp(None, Some(4.0)), Some(4.0));
assert_eq!(next_timestamp(None, None), None);
}
#[test]
fn test_pop_next_trace_ready_releases_only_arrivals_at_or_before_now() {
let mut pending = VecDeque::from(vec![
direct_request(1, Some(1.0)),
direct_request(2, Some(1.1)),
direct_request(3, Some(2.0)),
]);
let (request_1, arrival_1) = pop_next_trace_ready(&mut pending, 1.0).unwrap();
assert_eq!(request_1.uuid, Some(Uuid::from_u128(1)));
assert_eq!(arrival_1, 1.0);
assert!(pop_next_trace_ready(&mut pending, 1.0).is_none());
let (request_2, arrival_2) = pop_next_trace_ready(&mut pending, 1.1).unwrap();
assert_eq!(request_2.uuid, Some(Uuid::from_u128(2)));
assert_eq!(arrival_2, 1.1);
assert_eq!(pending.len(), 1);
}
#[test]
fn test_pop_next_concurrency_ready_stops_at_max_in_flight() {
let mut pending = VecDeque::from(vec![direct_request(1, None), direct_request(2, None)]);
assert!(pop_next_concurrency_ready(&mut pending, 5.0, 2, 2).is_none());
let (request, arrival_ms) = pop_next_concurrency_ready(&mut pending, 5.0, 1, 2).unwrap();
assert_eq!(request.uuid, Some(Uuid::from_u128(1)));
assert_eq!(arrival_ms, 5.0);
assert_eq!(pending.len(), 1);
}
#[test]
fn test_worker_completion_helpers_preserve_same_time_sequence_ordering() {
let mut events = BinaryHeap::new();
let mut next_event_seq = 0;
push_worker_completion(
&mut events,
&mut next_event_seq,
10.0,
WorkerCompletionPayload {
stage: SimulationWorkerStage::Aggregated,
worker_idx: 7,
completed_requests: 1,
output_signals: vec![OutputSignal {
uuid: Uuid::from_u128(7),
completed: true,
handoff_delay_ms: None,
}],
kv_events: Vec::new(),
},
);
push_worker_completion(
&mut events,
&mut next_event_seq,
10.0,
WorkerCompletionPayload {
stage: SimulationWorkerStage::Aggregated,
worker_idx: 8,
completed_requests: 2,
output_signals: vec![OutputSignal {
uuid: Uuid::from_u128(8),
completed: false,
handoff_delay_ms: None,
}],
kv_events: Vec::new(),
},
);
assert!(pop_ready_worker_completion(&mut events, 9.0).is_none());
let first = pop_ready_worker_completion(&mut events, 10.0).unwrap();
let second = pop_ready_worker_completion(&mut events, 10.0).unwrap();
assert_eq!(first.stage, SimulationWorkerStage::Aggregated);
assert_eq!(first.worker_idx, 7);
assert_eq!(first.completed_requests, 1);
assert_eq!(second.stage, SimulationWorkerStage::Aggregated);
assert_eq!(second.worker_idx, 8);
assert_eq!(second.completed_requests, 2);
assert!(events.is_empty());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use anyhow::{Result, anyhow, bail};
use crate::common::protocols::DirectRequest; use crate::common::protocols::DirectRequest;
use crate::common::protocols::MockEngineArgs; use crate::common::protocols::MockEngineArgs;
use crate::replay::TraceCollector; use crate::replay::TraceCollector;
use crate::scheduler::{EngineCore, EnginePassResult}; use crate::scheduler::{EngineCore, EnginePassResult};
use uuid::Uuid;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum AggRequestPhase {
QueuedAtRouter,
Running,
}
pub(crate) struct AggRequestState {
request: Option<DirectRequest>,
phase: AggRequestPhase,
prefill_completed: bool,
}
impl AggRequestState {
pub(crate) fn new_queued(request: DirectRequest) -> Self {
Self {
request: Some(request),
phase: AggRequestPhase::QueuedAtRouter,
prefill_completed: false,
}
}
pub(crate) fn new_running() -> Self {
Self {
request: None,
phase: AggRequestPhase::Running,
prefill_completed: false,
}
}
pub(crate) fn is_queued_at_router(&self) -> bool {
self.phase == AggRequestPhase::QueuedAtRouter
}
pub(crate) fn take_queued_request(&mut self, uuid: Uuid) -> Result<DirectRequest> {
if !self.is_queued_at_router() {
bail!("offline replay expected queued request state for {uuid}");
}
let request = self
.request
.take()
.ok_or_else(|| anyhow!("offline replay missing queued request payload for {uuid}"))?;
self.phase = AggRequestPhase::Running;
Ok(request)
}
pub(crate) fn prefill_completed(&self) -> bool {
self.prefill_completed
}
pub(crate) fn mark_prefill_completed(&mut self) {
self.prefill_completed = true;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum DisaggPhase {
QueuedPrefill,
RunningPrefill,
QueuedDecode,
RunningDecode,
Done,
}
pub(crate) struct DisaggRequestState {
original: Option<DirectRequest>,
#[cfg(test)]
arrival_ms: f64,
phase: DisaggPhase,
prefill_worker_idx: Option<usize>,
decode_worker_idx: Option<usize>,
}
#[cfg(test)]
#[derive(Debug, Clone, PartialEq)]
pub(crate) struct DisaggRequestSnapshot {
pub(crate) arrival_ms: f64,
pub(crate) phase: DisaggPhase,
pub(crate) prefill_worker_idx: Option<usize>,
pub(crate) decode_worker_idx: Option<usize>,
}
impl DisaggRequestState {
pub(crate) fn new(request: DirectRequest, arrival_ms: f64) -> Self {
#[cfg(not(test))]
let _ = arrival_ms;
Self {
original: Some(request),
#[cfg(test)]
arrival_ms,
phase: DisaggPhase::QueuedPrefill,
prefill_worker_idx: None,
decode_worker_idx: None,
}
}
pub(crate) fn is_queued_prefill(&self) -> bool {
self.phase == DisaggPhase::QueuedPrefill
}
pub(crate) fn is_queued_decode(&self) -> bool {
self.phase == DisaggPhase::QueuedDecode
}
pub(crate) fn original_request(&self) -> Result<&DirectRequest> {
self.original
.as_ref()
.ok_or_else(|| anyhow!("offline disagg replay request payload was already released"))
}
pub(crate) fn build_prefill_request(&self) -> Result<DirectRequest> {
let mut request = self.original_request()?.clone();
request.max_output_tokens = 1;
Ok(request)
}
pub(crate) fn build_decode_request(&self) -> Result<DirectRequest> {
Ok(self.original_request()?.clone())
}
pub(crate) fn start_prefill(&mut self, worker_idx: usize) {
self.phase = DisaggPhase::RunningPrefill;
self.prefill_worker_idx = Some(worker_idx);
}
pub(crate) fn queue_decode(&mut self) {
self.phase = DisaggPhase::QueuedDecode;
}
pub(crate) fn start_decode(&mut self, worker_idx: usize) {
self.phase = DisaggPhase::RunningDecode;
self.decode_worker_idx = Some(worker_idx);
}
pub(crate) fn mark_done(&mut self) {
self.phase = DisaggPhase::Done;
self.original = None;
}
#[cfg(test)]
pub(crate) fn debug_snapshot(&self) -> DisaggRequestSnapshot {
DisaggRequestSnapshot {
arrival_ms: self.arrival_ms,
phase: self.phase,
prefill_worker_idx: self.prefill_worker_idx,
decode_worker_idx: self.decode_worker_idx,
}
}
}
pub(crate) struct OfflineWorkerState { pub(crate) struct OfflineWorkerState {
core: EngineCore, core: EngineCore,
...@@ -91,6 +243,10 @@ impl OfflineWorkerState { ...@@ -91,6 +243,10 @@ impl OfflineWorkerState {
self.core.execute_pass(collector, now_ms) self.core.execute_pass(collector, now_ms)
} }
pub(crate) fn execute_hidden_pass(&mut self, now_ms: f64) -> EnginePassResult {
self.core.execute_hidden_pass(now_ms)
}
#[cfg(test)] #[cfg(test)]
pub(crate) fn debug_snapshot(&self) -> OfflineWorkerSnapshot { pub(crate) fn debug_snapshot(&self) -> OfflineWorkerSnapshot {
OfflineWorkerSnapshot { OfflineWorkerSnapshot {
......
...@@ -110,6 +110,7 @@ struct PendingRequest { ...@@ -110,6 +110,7 @@ struct PendingRequest {
token_seq: Option<Vec<SequenceHash>>, token_seq: Option<Vec<SequenceHash>>,
isl_tokens: usize, isl_tokens: usize,
overlaps: OverlapScores, overlaps: OverlapScores,
track_prefill_tokens: bool,
expected_output_tokens: Option<u32>, expected_output_tokens: Option<u32>,
} }
...@@ -130,6 +131,7 @@ impl PendingRequest { ...@@ -130,6 +131,7 @@ impl PendingRequest {
overlaps: self.overlaps.clone(), overlaps: self.overlaps.clone(),
decode_blocks, decode_blocks,
prefill_tokens, prefill_tokens,
track_prefill_tokens: self.track_prefill_tokens,
router_config_override: None, router_config_override: None,
update_states: true, update_states: true,
lora_name: None, lora_name: None,
...@@ -258,7 +260,6 @@ impl OfflineReplayRouter { ...@@ -258,7 +260,6 @@ impl OfflineReplayRouter {
self.drain_pending() self.drain_pending()
} }
#[cfg(test)]
pub(crate) fn pending_count(&self) -> usize { pub(crate) fn pending_count(&self) -> usize {
self.pending.len() self.pending.len()
} }
...@@ -371,6 +372,7 @@ impl OfflineReplayRouter { ...@@ -371,6 +372,7 @@ impl OfflineReplayRouter {
token_seq, token_seq,
isl_tokens: request.tokens.len(), isl_tokens: request.tokens.len(),
overlaps, overlaps,
track_prefill_tokens: self.config.router_track_prefill_tokens,
expected_output_tokens: Some( expected_output_tokens: Some(
u32::try_from(request.max_output_tokens) u32::try_from(request.max_output_tokens)
.context("max_output_tokens does not fit into u32")?, .context("max_output_tokens does not fit into u32")?,
...@@ -379,11 +381,14 @@ impl OfflineReplayRouter { ...@@ -379,11 +381,14 @@ impl OfflineReplayRouter {
} }
fn admit_request(&mut self, request: PendingRequest) -> Result<usize> { fn admit_request(&mut self, request: PendingRequest) -> Result<usize> {
let (decode_blocks, prefill_tokens) = self.slots.potential_blocks_and_tokens( let (decode_blocks, prefill_tokens) = self
request.token_seq.as_deref(), .slots
request.isl_tokens, .potential_blocks_and_tokens_with_prefill_tracking(
request.overlaps.clone(), request.token_seq.as_deref(),
); request.isl_tokens,
request.overlaps.clone(),
request.track_prefill_tokens,
);
let scheduling_request = request.scheduling_request(decode_blocks, prefill_tokens); let scheduling_request = request.scheduling_request(decode_blocks, prefill_tokens);
let selection = self.selector.select_worker( let selection = self.selector.select_worker(
&self.workers_with_configs, &self.workers_with_configs,
...@@ -400,6 +405,7 @@ impl OfflineReplayRouter { ...@@ -400,6 +405,7 @@ impl OfflineReplayRouter {
token_sequence: request.token_seq, token_sequence: request.token_seq,
isl: request.isl_tokens, isl: request.isl_tokens,
overlap: selection.overlap_blocks, overlap: selection.overlap_blocks,
track_prefill_tokens: request.track_prefill_tokens,
expected_output_tokens: request.expected_output_tokens, expected_output_tokens: request.expected_output_tokens,
worker: selection.worker, worker: selection.worker,
lora_name: None, lora_name: None,
......
...@@ -138,6 +138,7 @@ impl KvReplayRouter { ...@@ -138,6 +138,7 @@ impl KvReplayRouter {
args.block_size as u32, args.block_size as u32,
selector, selector,
policy, policy,
config.router_track_prefill_tokens,
CancellationToken::new(), CancellationToken::new(),
"replay", "replay",
false, false,
......
...@@ -107,7 +107,7 @@ pub(super) fn replay_selector(config: &KvRouterConfig) -> DefaultWorkerSelector ...@@ -107,7 +107,7 @@ pub(super) fn replay_selector(config: &KvRouterConfig) -> DefaultWorkerSelector
DefaultWorkerSelector::new(Some(config.clone()), "replay") DefaultWorkerSelector::new(Some(config.clone()), "replay")
} }
pub(super) fn replay_router_config( pub(crate) fn replay_router_config(
args: &MockEngineArgs, args: &MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
) -> KvRouterConfig { ) -> KvRouterConfig {
......
...@@ -3,9 +3,45 @@ ...@@ -3,9 +3,45 @@
use anyhow::{Result, bail}; use anyhow::{Result, bail};
use super::ReplayRouterMode; use super::{OfflineDisaggReplayConfig, ReplayArgsMode, ReplayRouterMode};
use crate::common::protocols::{MockEngineArgs, WorkerType}; use crate::common::protocols::{MockEngineArgs, WorkerType};
pub fn validate_replay_args_mode(
aggregated_args: Option<&MockEngineArgs>,
prefill_args: Option<&MockEngineArgs>,
decode_args: Option<&MockEngineArgs>,
num_workers: usize,
num_prefill_workers: usize,
num_decode_workers: usize,
) -> Result<ReplayArgsMode> {
if aggregated_args.is_some() && (prefill_args.is_some() || decode_args.is_some()) {
bail!("extra_engine_args cannot be combined with prefill_engine_args/decode_engine_args");
}
match (aggregated_args, prefill_args, decode_args) {
(Some(_), None, None) | (None, None, None) => {
if num_prefill_workers != 1 || num_decode_workers != 1 {
bail!(
"num_prefill_workers and num_decode_workers are only used for disagg replay; use num_workers for aggregated replay"
);
}
Ok(ReplayArgsMode::Aggregated)
}
(None, Some(_), Some(_)) => {
if num_workers != 1 {
bail!(
"num_workers is only used for aggregated replay; use num_prefill_workers and num_decode_workers for disagg replay"
);
}
Ok(ReplayArgsMode::Disagg)
}
(None, Some(_), None) | (None, None, Some(_)) => {
bail!("prefill_engine_args and decode_engine_args must be provided together")
}
(Some(_), Some(_), _) | (Some(_), _, Some(_)) => unreachable!(),
}
}
fn validate_replay_args(args: &MockEngineArgs, num_workers: usize, mode: &str) -> Result<()> { fn validate_replay_args(args: &MockEngineArgs, num_workers: usize, mode: &str) -> Result<()> {
if num_workers == 0 { if num_workers == 0 {
bail!("{mode} requires num_workers >= 1"); bail!("{mode} requires num_workers >= 1");
...@@ -75,3 +111,63 @@ pub(super) fn validate_online_concurrency_args( ...@@ -75,3 +111,63 @@ pub(super) fn validate_online_concurrency_args(
validate_replay_args(args, num_workers, "online replay") validate_replay_args(args, num_workers, "online replay")
} }
fn validate_disagg_args(config: &OfflineDisaggReplayConfig, mode: &str) -> Result<()> {
if config.num_prefill_workers == 0 {
bail!("{mode} requires num_prefill_workers >= 1");
}
if config.num_decode_workers == 0 {
bail!("{mode} requires num_decode_workers >= 1");
}
if config.prefill_args.worker_type != WorkerType::Prefill {
bail!(
"{mode} requires prefill_engine_args.worker_type=prefill, got {:?}",
config.prefill_args.worker_type,
);
}
if config.decode_args.worker_type != WorkerType::Decode {
bail!(
"{mode} requires decode_engine_args.worker_type=decode, got {:?}",
config.decode_args.worker_type,
);
}
if config.prefill_args.dp_size != 1 {
bail!(
"{mode} only supports prefill data_parallel_size=1, got {}",
config.prefill_args.dp_size,
);
}
if config.decode_args.dp_size != 1 {
bail!(
"{mode} only supports decode data_parallel_size=1, got {}",
config.decode_args.dp_size,
);
}
if config.prefill_args.block_size != config.decode_args.block_size {
bail!(
"{mode} requires matching prefill/decode block_size, got {} and {}",
config.prefill_args.block_size,
config.decode_args.block_size,
);
}
Ok(())
}
pub(super) fn validate_offline_disagg_replay_args(
config: &OfflineDisaggReplayConfig,
_router_mode: ReplayRouterMode,
) -> Result<()> {
validate_disagg_args(config, "trace replay")
}
pub(super) fn validate_offline_disagg_concurrency_args(
config: &OfflineDisaggReplayConfig,
max_in_flight: usize,
_router_mode: ReplayRouterMode,
) -> Result<()> {
if max_in_flight == 0 {
bail!("concurrency replay requires max_in_flight >= 1");
}
validate_disagg_args(config, "concurrency replay")
}
...@@ -89,6 +89,13 @@ impl EngineCore { ...@@ -89,6 +89,13 @@ impl EngineCore {
Self::Sglang(core) => core.execute_pass(collector, now_ms), Self::Sglang(core) => core.execute_pass(collector, now_ms),
} }
} }
pub(crate) fn execute_hidden_pass(&mut self, now_ms: f64) -> EnginePassResult {
match self {
Self::Vllm(core) => core.execute_hidden_pass(now_ms),
Self::Sglang(core) => core.execute_hidden_pass(now_ms),
}
}
} }
#[derive(Clone)] #[derive(Clone)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment