Unverified Commit 0bbba988 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat(replay): simulate non-zero worker startup time in offline replay (#8231)


Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 6633be54
...@@ -7,7 +7,8 @@ pub(super) use super::components::ReplayMode; ...@@ -7,7 +7,8 @@ pub(super) use super::components::ReplayMode;
use super::events::{SimulationEvent, SimulationWorkerStage}; use super::events::{SimulationEvent, SimulationWorkerStage};
use super::progress::ReplayProgress; use super::progress::ReplayProgress;
use super::runtime_utils::{ use super::runtime_utils::{
next_timestamp as choose_next_timestamp, pop_ready_worker_completion, push_worker_completion, next_timestamp as choose_next_timestamp, pop_ready_worker_completion, pop_ready_worker_ready,
push_worker_completion, push_worker_ready,
}; };
#[cfg(test)] #[cfg(test)]
use super::state::AggRequestPhase; use super::state::AggRequestPhase;
...@@ -313,7 +314,7 @@ impl AggRuntime { ...@@ -313,7 +314,7 @@ impl AggRuntime {
Ok(uuid) Ok(uuid)
} }
/// Return true once no workers, router queues, or admissions remain. /// Return true once no events, workers, router queues, or admissions remain.
fn is_done(&self) -> bool { fn is_done(&self) -> bool {
self.events.is_empty() self.events.is_empty()
&& self.cluster_in_flight() == 0 && self.cluster_in_flight() == 0
...@@ -321,6 +322,25 @@ impl AggRuntime { ...@@ -321,6 +322,25 @@ impl AggRuntime {
&& self.engine.is_drained() && self.engine.is_drained()
} }
/// Return true once the request workload is complete, even if `WorkerReady`
/// events remain in the queue. Used by `advance_to` so the planner adapter
/// can terminate when there is no more work — lingering startup events for
/// workers that will never receive requests should not block completion.
fn is_workload_done(&self) -> bool {
self.cluster_in_flight() == 0
&& self.admission.is_drained()
&& self.engine.is_drained()
&& self.only_worker_ready_events_remain()
}
/// True if the event heap is empty or contains only `WorkerReady` events.
fn only_worker_ready_events_remain(&self) -> bool {
use super::events::SimulationEventKind;
self.events
.iter()
.all(|e| matches!(e.kind, SimulationEventKind::WorkerReady { .. }))
}
/// Pick the next logical timestamp from either arrivals or scheduled worker completions. /// Pick the next logical timestamp from either arrivals or scheduled worker completions.
fn next_timestamp(&mut self) -> Option<f64> { fn next_timestamp(&mut self) -> Option<f64> {
let next_event_ms = self.events.peek().map(|event| event.at_ms); let next_event_ms = self.events.peek().map(|event| event.at_ms);
...@@ -496,10 +516,32 @@ impl AggRuntime { ...@@ -496,10 +516,32 @@ impl AggRuntime {
Ok(()) Ok(())
} }
/// Activate workers whose startup period has elapsed at the current timestamp.
fn apply_worker_ready_events(&mut self) -> anyhow::Result<bool> {
let mut changed = false;
while let Some((stage, worker_id)) = pop_ready_worker_ready(&mut self.events, self.now_ms) {
debug_assert_eq!(stage, SimulationWorkerStage::Aggregated);
if self.engine.mark_worker_ready(worker_id) {
if let Some(router) = self.router.as_mut() {
router.add_worker(worker_id)?;
// Drain any requests that were queued while all workers
// were busy — the new worker may have capacity for them.
let effects = router.try_drain_pending(self.now_ms)?;
self.dispatch_router_admissions(effects.admissions)?;
}
changed = true;
}
// If mark_worker_ready returned false the worker was cancelled
// during startup (scale-down) — the stale event is silently ignored.
}
Ok(changed)
}
/// Repeatedly process all work that becomes possible without advancing logical time. /// Repeatedly process all work that becomes possible without advancing logical time.
fn drain_current_timestamp(&mut self) -> anyhow::Result<()> { fn drain_current_timestamp(&mut self) -> anyhow::Result<()> {
loop { loop {
let mut changed = self.apply_worker_completions()?; let mut changed = self.apply_worker_completions()?;
changed |= self.apply_worker_ready_events()?;
changed |= self.release_ready_arrivals()?; changed |= self.release_ready_arrivals()?;
changed |= self.drive_ready_workers()?; changed |= self.drive_ready_workers()?;
...@@ -516,7 +558,8 @@ impl AggRuntime { ...@@ -516,7 +558,8 @@ impl AggRuntime {
// ------------------------------------------------------------------ // ------------------------------------------------------------------
/// Advance the simulation up to `until_ms` simulated time, then pause. /// Advance the simulation up to `until_ms` simulated time, then pause.
/// Returns `true` if the replay is done (no more work). /// Returns `true` if the request workload is done — pending `WorkerReady`
/// events do not block completion since there is no work for those workers.
pub(in crate::replay) fn advance_to(&mut self, until_ms: f64) -> anyhow::Result<bool> { pub(in crate::replay) fn advance_to(&mut self, until_ms: f64) -> anyhow::Result<bool> {
self.drain_current_timestamp()?; self.drain_current_timestamp()?;
...@@ -536,7 +579,7 @@ impl AggRuntime { ...@@ -536,7 +579,7 @@ impl AggRuntime {
self.drain_current_timestamp()?; self.drain_current_timestamp()?;
} }
Ok(self.is_done()) Ok(self.is_workload_done())
} }
/// Current simulated time in milliseconds. /// Current simulated time in milliseconds.
...@@ -565,19 +608,42 @@ impl AggRuntime { ...@@ -565,19 +608,42 @@ impl AggRuntime {
} }
/// Apply a scaling decision: set the target number of workers. /// Apply a scaling decision: set the target number of workers.
/// Scale-up is immediate; scale-down removes the worker from the router ///
/// immediately (so no new requests land on it) and lets it drain in-flight /// Scale-up: if `startup_time` is configured, new workers enter a startup
/// work in the engine. /// phase and a `WorkerReady` event is scheduled. They become active (and
/// are registered with the router) only when that event fires. Without
/// `startup_time`, workers are available immediately.
///
/// Scale-down: the worker is removed from the router immediately (so no
/// new requests land on it) and drains in-flight work in the engine.
pub(in crate::replay) fn apply_scaling(&mut self, target_workers: usize) -> anyhow::Result<()> { pub(in crate::replay) fn apply_scaling(&mut self, target_workers: usize) -> anyhow::Result<()> {
let (added, newly_marked) = self.engine.apply_target_count(target_workers); let (added, newly_marked) = self.engine.apply_target_count(target_workers);
#[cfg(test)] #[cfg(test)]
if let Some(new_len) = added.iter().max().map(|id| id + 1) { if let Some(new_len) = added.iter().max().map(|id| id + 1) {
self.worker_active_requests.resize(new_len, Vec::new()); self.worker_active_requests.resize(new_len, Vec::new());
} }
let admissions = if let Some(router) = self.router.as_mut() { let startup_delay_ms = self.engine.startup_time_ms();
for id in added {
router.add_worker(id)?; for &id in &added {
match startup_delay_ms {
Some(delay) => {
push_worker_ready(
&mut self.events,
&mut self.next_event_seq,
self.now_ms + delay,
SimulationWorkerStage::Aggregated,
id,
);
}
None => {
if let Some(router) = self.router.as_mut() {
router.add_worker(id)?;
}
}
} }
}
let admissions = if let Some(router) = self.router.as_mut() {
for id in newly_marked { for id in newly_marked {
router.remove_worker(id)?; router.remove_worker(id)?;
} }
...@@ -1830,4 +1896,155 @@ mod tests { ...@@ -1830,4 +1896,155 @@ mod tests {
); );
} }
} }
// ---- startup delay tests ----
fn startup_args(startup_time_s: f64) -> MockEngineArgs {
MockEngineArgs::builder()
.block_size(64)
.num_gpu_blocks(256)
.max_num_batched_tokens(Some(8192))
.max_num_seqs(Some(8))
.enable_prefix_caching(true)
.enable_chunked_prefill(true)
.speedup_ratio(1000.0)
.startup_time(Some(startup_time_s))
.build()
.unwrap()
}
fn simple_requests(n: usize, arrival_interval_ms: f64) -> VecDeque<DirectRequest> {
(0..n)
.map(|i| DirectRequest {
tokens: vec![1; 64],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(i as u128 + 1)),
dp_rank: 0,
arrival_timestamp_ms: Some(i as f64 * arrival_interval_ms),
})
.collect()
}
#[test]
fn test_apply_scaling_with_startup_delay_defers_activation() {
// Use enough requests spread over a long enough window that the
// workload is still in-flight when the startup delay elapses.
let args = startup_args(5.0); // 5-second startup delay
let requests = simple_requests(20, 1000.0); // arrivals at 0, 1s, 2s, ... 19s
let mut rt = AggRuntime::new(
&args,
None,
None,
requests,
1,
ReplayMode::Trace,
ReplayRouterMode::RoundRobin,
)
.unwrap();
// Advance to t=500ms — first request dispatched to worker 0.
rt.advance_to(500.0).unwrap();
assert_eq!(rt.active_worker_count(), 1);
assert_eq!(rt.total_worker_count(), 1);
// Scale up to 2 workers. The WorkerReady event is scheduled at
// now_ms + 5000ms.
rt.apply_scaling(2).unwrap();
let scale_time = rt.now_ms();
let expected_ready_ms = scale_time + 5000.0;
assert_eq!(rt.active_worker_count(), 1); // new worker still starting
assert_eq!(rt.total_worker_count(), 2);
// Advance to just before the worker is ready.
rt.advance_to(expected_ready_ms - 1.0).unwrap();
assert_eq!(rt.active_worker_count(), 1); // still starting
// Advance past the startup time.
rt.advance_to(expected_ready_ms).unwrap();
assert_eq!(rt.active_worker_count(), 2); // now active
assert_eq!(rt.total_worker_count(), 2);
}
#[test]
fn test_apply_scaling_without_startup_is_immediate() {
let args = fast_router_args(); // no startup_time
let requests = simple_requests(4, 100.0);
let mut rt = AggRuntime::new(
&args,
None,
None,
requests,
1,
ReplayMode::Trace,
ReplayRouterMode::RoundRobin,
)
.unwrap();
rt.advance_to(50.0).unwrap();
rt.apply_scaling(2).unwrap();
// Without startup delay, new worker is immediately active.
assert_eq!(rt.active_worker_count(), 2);
assert_eq!(rt.total_worker_count(), 2);
}
#[test]
fn test_startup_cancel_ignores_stale_event() {
let args = startup_args(5.0);
let requests = simple_requests(20, 1000.0); // long enough to span startup
let mut rt = AggRuntime::new(
&args,
None,
None,
requests,
2,
ReplayMode::Trace,
ReplayRouterMode::RoundRobin,
)
.unwrap();
// Scale up to 4 (2 new workers starting).
rt.apply_scaling(4).unwrap();
assert_eq!(rt.active_worker_count(), 2);
assert_eq!(rt.total_worker_count(), 4);
// Immediately scale back to 2 — should cancel both startup workers.
rt.apply_scaling(2).unwrap();
assert_eq!(rt.active_worker_count(), 2);
assert_eq!(rt.total_worker_count(), 2);
// Advance past the original startup time. No crash, counts unchanged.
rt.advance_to(6000.0).unwrap();
assert_eq!(rt.active_worker_count(), 2);
assert_eq!(rt.total_worker_count(), 2);
}
#[test]
fn test_advance_to_reports_done_when_workload_finishes_before_startup() {
// Short trace (4 requests at 0-300ms) with a long startup delay.
// The workload finishes well before the startup delay elapses.
let args = startup_args(30.0); // 30s startup
let requests = simple_requests(4, 100.0); // all done by ~400ms
let mut rt = AggRuntime::new(
&args,
None,
None,
requests,
1,
ReplayMode::Trace,
ReplayRouterMode::RoundRobin,
)
.unwrap();
// Scale up before requests arrive.
rt.apply_scaling(2).unwrap();
assert_eq!(rt.active_worker_count(), 1);
// Advance well past all request completions but before startup.
let done = rt.advance_to(10_000.0).unwrap();
// Workload is done even though the WorkerReady event is at ~30000ms.
assert!(
done,
"advance_to should report done when workload is complete"
);
}
} }
...@@ -24,6 +24,8 @@ pub(in crate::replay::offline) struct EngineComponent { ...@@ -24,6 +24,8 @@ pub(in crate::replay::offline) struct EngineComponent {
next_id: usize, next_id: usize,
/// Workers marked for removal — skipped by round-robin, removed when drained. /// Workers marked for removal — skipped by round-robin, removed when drained.
pending_removal: BTreeSet<usize>, pending_removal: BTreeSet<usize>,
/// Workers still starting up — excluded from active set until ready.
pending_startup: BTreeSet<usize>,
/// Engine args used to construct new workers during scale-up. /// Engine args used to construct new workers during scale-up.
args: MockEngineArgs, args: MockEngineArgs,
/// Whether new workers should capture KV events (true when a router is present). /// Whether new workers should capture KV events (true when a router is present).
...@@ -44,6 +46,7 @@ impl EngineComponent { ...@@ -44,6 +46,7 @@ impl EngineComponent {
workers: map, workers: map,
next_id: count, next_id: count,
pending_removal: BTreeSet::new(), pending_removal: BTreeSet::new(),
pending_startup: BTreeSet::new(),
args: MockEngineArgs::default(), args: MockEngineArgs::default(),
capture_kv_events: false, capture_kv_events: false,
} }
...@@ -100,22 +103,48 @@ impl EngineComponent { ...@@ -100,22 +103,48 @@ impl EngineComponent {
/// router immediately. Newly marked workers should be removed from the /// router immediately. Newly marked workers should be removed from the
/// router right away to prevent new requests from landing on them, even /// router right away to prevent new requests from landing on them, even
/// though the workers themselves remain in the engine until fully drained. /// though the workers themselves remain in the engine until fully drained.
///
/// The effective count is `active + pending_startup` — workers that will
/// be active once all startups complete. On scale-down, pending startup
/// workers are cancelled first (cheapest: no in-flight work, no router
/// registration), then active workers are marked for removal.
pub(in crate::replay::offline) fn apply_target_count( pub(in crate::replay::offline) fn apply_target_count(
&mut self, &mut self,
target: usize, target: usize,
) -> (Vec<usize>, Vec<usize>) { ) -> (Vec<usize>, Vec<usize>) {
let active_ids = self.active_worker_ids(); let active_ids = self.active_worker_ids();
let current = active_ids.len(); let effective = active_ids.len() + self.pending_startup.len();
let mut added = Vec::new(); let mut added = Vec::new();
let mut newly_marked = Vec::new(); let mut newly_marked = Vec::new();
if target > current { if target > effective {
for _ in 0..(target - current) { let has_startup_delay = self.startup_time_ms().is_some();
added.push(self.add_worker()); for _ in 0..(target - effective) {
let id = self.add_worker();
if has_startup_delay {
self.pending_startup.insert(id);
}
added.push(id);
}
} else if target < effective {
let excess = effective - target;
// Cancel pending startup workers first (reverse order = highest IDs).
let to_cancel: Vec<usize> = self
.pending_startup
.iter()
.copied()
.rev()
.take(excess)
.collect();
for &id in &to_cancel {
self.pending_startup.remove(&id);
self.workers.remove(&id);
} }
} else if target < current {
let excess = current - target; // Mark active workers for removal if more excess remains.
for &id in active_ids.iter().rev().take(excess) { let remaining = excess - to_cancel.len();
for &id in active_ids.iter().rev().take(remaining) {
self.mark_for_removal(id); self.mark_for_removal(id);
newly_marked.push(id); newly_marked.push(id);
} }
...@@ -126,15 +155,31 @@ impl EngineComponent { ...@@ -126,15 +155,31 @@ impl EngineComponent {
(added, newly_marked) (added, newly_marked)
} }
/// Return stable IDs of all active (non-pending-removal) workers. /// Return stable IDs of all active workers — excludes both pending removal
/// and pending startup.
pub(in crate::replay::offline) fn active_worker_ids(&self) -> Vec<usize> { pub(in crate::replay::offline) fn active_worker_ids(&self) -> Vec<usize> {
self.workers self.workers
.keys() .keys()
.filter(|id| !self.pending_removal.contains(id)) .filter(|id| !self.pending_removal.contains(id) && !self.pending_startup.contains(id))
.copied() .copied()
.collect() .collect()
} }
/// Return the configured startup delay in milliseconds, if any.
pub(in crate::replay::offline) fn startup_time_ms(&self) -> Option<f64> {
self.args
.startup_time
.filter(|&s| s > 0.0)
.map(|s| s * 1000.0)
}
/// Mark a pending-startup worker as ready. Returns `true` if the worker
/// was actually pending startup (and is now active), `false` if the worker
/// was already cancelled or unknown (stale event).
pub(in crate::replay::offline) fn mark_worker_ready(&mut self, worker_id: usize) -> bool {
self.pending_startup.remove(&worker_id) && self.workers.contains_key(&worker_id)
}
pub(in crate::replay::offline) fn dispatch( pub(in crate::replay::offline) fn dispatch(
&mut self, &mut self,
worker_id: usize, worker_id: usize,
...@@ -269,3 +314,120 @@ impl EngineComponent { ...@@ -269,3 +314,120 @@ impl EngineComponent {
.collect() .collect()
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use crate::common::protocols::MockEngineArgs;
fn engine_with_startup(num_workers: usize, startup_time: Option<f64>) -> EngineComponent {
let args = MockEngineArgs {
startup_time,
..MockEngineArgs::default()
};
let workers: Vec<_> = (0..num_workers)
.map(|i| OfflineWorkerState::new(i, args.clone(), false))
.collect();
let mut engine = EngineComponent::new(
SimulationWorkerStage::Aggregated,
EnginePassMode::Visible,
workers,
);
engine.set_scaling_args(args, false);
engine
}
#[test]
fn test_apply_target_count_scale_up_with_startup() {
let mut engine = engine_with_startup(2, Some(5.0));
let (added, newly_marked) = engine.apply_target_count(4);
assert_eq!(added.len(), 2);
assert!(newly_marked.is_empty());
// New workers are in pending_startup.
assert_eq!(engine.active_worker_ids().len(), 2);
assert_eq!(engine.worker_count(), 4);
}
#[test]
fn test_apply_target_count_scale_up_without_startup() {
let mut engine = engine_with_startup(2, None);
let (added, newly_marked) = engine.apply_target_count(4);
assert_eq!(added.len(), 2);
assert!(newly_marked.is_empty());
// Without startup delay, workers are immediately active.
assert_eq!(engine.active_worker_ids().len(), 4);
assert_eq!(engine.worker_count(), 4);
}
#[test]
fn test_scale_down_cancels_startup_before_active() {
let mut engine = engine_with_startup(2, Some(5.0));
// Scale up to 4 — adds 2 in pending_startup.
engine.apply_target_count(4);
assert_eq!(engine.active_worker_ids().len(), 2);
assert_eq!(engine.worker_count(), 4);
// Scale down to 3 — should cancel 1 startup worker, not mark any active.
let (_added, newly_marked) = engine.apply_target_count(3);
assert!(newly_marked.is_empty());
assert_eq!(engine.active_worker_ids().len(), 2);
assert_eq!(engine.worker_count(), 3); // 2 active + 1 still starting
// Scale down to 2 — should cancel the remaining startup worker.
let (_added, newly_marked) = engine.apply_target_count(2);
assert!(newly_marked.is_empty());
assert_eq!(engine.active_worker_ids().len(), 2);
assert_eq!(engine.worker_count(), 2);
}
#[test]
fn test_scale_down_past_startup_marks_active() {
let mut engine = engine_with_startup(3, Some(5.0));
// Scale up to 5 — adds 2 in pending_startup.
engine.apply_target_count(5);
// Scale down to 1 — should cancel 2 startup, mark 2 active.
let (_added, newly_marked) = engine.apply_target_count(1);
assert_eq!(newly_marked.len(), 2);
assert_eq!(engine.active_worker_ids().len(), 1);
}
#[test]
fn test_mark_worker_ready_activates_pending() {
let mut engine = engine_with_startup(1, Some(5.0));
let (added, _) = engine.apply_target_count(2);
let new_id = added[0];
assert_eq!(engine.active_worker_ids().len(), 1);
assert!(engine.mark_worker_ready(new_id));
assert_eq!(engine.active_worker_ids().len(), 2);
}
#[test]
fn test_mark_worker_ready_returns_false_for_cancelled() {
let mut engine = engine_with_startup(1, Some(5.0));
let (added, _) = engine.apply_target_count(2);
let new_id = added[0];
// Cancel by scaling back down.
engine.apply_target_count(1);
// Worker was removed from pending_startup and workers map.
assert!(!engine.mark_worker_ready(new_id));
}
#[test]
fn test_startup_time_ms_conversion() {
let engine = engine_with_startup(1, Some(5.0));
assert_eq!(engine.startup_time_ms(), Some(5000.0));
let engine = engine_with_startup(1, None);
assert_eq!(engine.startup_time_ms(), None);
let engine = engine_with_startup(1, Some(0.0));
assert_eq!(engine.startup_time_ms(), None); // 0 treated as no delay
}
}
...@@ -305,6 +305,19 @@ impl OfflineReplayRouter { ...@@ -305,6 +305,19 @@ impl OfflineReplayRouter {
}) })
} }
/// Drain queued requests that can now be admitted (e.g. after a new worker
/// becomes available).
pub(crate) fn try_drain_pending(&mut self, now_ms: f64) -> Result<RouterEffects> {
let decay_now = self.decay_now(now_ms);
Ok(RouterEffects {
admissions: self
.drain_pending(decay_now)?
.into_iter()
.map(|(uuid, worker_idx)| WorkerAdmission { uuid, worker_idx })
.collect(),
})
}
pub(crate) fn pending_count(&self) -> usize { pub(crate) fn pending_count(&self) -> usize {
self.pending.len() self.pending.len()
} }
......
...@@ -17,7 +17,7 @@ use super::events::{SimulationEvent, SimulationWorkerStage}; ...@@ -17,7 +17,7 @@ use super::events::{SimulationEvent, SimulationWorkerStage};
use super::progress::ReplayProgress; use super::progress::ReplayProgress;
use super::runtime_utils::{ use super::runtime_utils::{
next_timestamp as choose_next_timestamp, pop_ready_decode_handoff, pop_ready_worker_completion, next_timestamp as choose_next_timestamp, pop_ready_decode_handoff, pop_ready_worker_completion,
push_decode_handoff, push_worker_completion, pop_ready_worker_ready, push_decode_handoff, push_worker_completion, push_worker_ready,
}; };
#[cfg(test)] #[cfg(test)]
use super::state::DisaggRequestSnapshot; use super::state::DisaggRequestSnapshot;
...@@ -402,6 +402,24 @@ impl DisaggRuntime { ...@@ -402,6 +402,24 @@ impl DisaggRuntime {
&& self.decode_engine.is_drained() && self.decode_engine.is_drained()
} }
/// Return true once the request workload is complete, even if `WorkerReady`
/// events remain in the queue.
fn is_workload_done(&self) -> bool {
self.cluster_in_flight() == 0
&& self.admission.is_drained()
&& self.prefill_engine.is_drained()
&& self.decode_engine.is_drained()
&& self.only_worker_ready_events_remain()
}
/// True if the event heap is empty or contains only `WorkerReady` events.
fn only_worker_ready_events_remain(&self) -> bool {
use super::events::SimulationEventKind;
self.events
.iter()
.all(|e| matches!(e.kind, SimulationEventKind::WorkerReady { .. }))
}
/// Pick the next logical timestamp from arrivals, worker completions, or decode handoffs. /// Pick the next logical timestamp from arrivals, worker completions, or decode handoffs.
fn next_timestamp(&mut self) -> Option<f64> { fn next_timestamp(&mut self) -> Option<f64> {
let next_event_ms = self.events.peek().map(|event| event.at_ms); let next_event_ms = self.events.peek().map(|event| event.at_ms);
...@@ -690,10 +708,44 @@ impl DisaggRuntime { ...@@ -690,10 +708,44 @@ impl DisaggRuntime {
Ok(()) Ok(())
} }
/// Activate workers whose startup period has elapsed at the current timestamp.
fn apply_worker_ready_events(&mut self) -> Result<bool> {
let mut changed = false;
while let Some((stage, worker_id)) = pop_ready_worker_ready(&mut self.events, self.now_ms) {
match stage {
SimulationWorkerStage::Prefill => {
if self.prefill_engine.mark_worker_ready(worker_id) {
if let Some(router) = self.prefill_router.as_mut() {
router.add_worker(worker_id)?;
let effects = router.try_drain_pending(self.now_ms)?;
self.dispatch_prefill_admissions(effects.admissions)?;
}
changed = true;
}
}
SimulationWorkerStage::Decode => {
if self.decode_engine.mark_worker_ready(worker_id) {
if let Some(router) = self.decode_router.as_mut() {
router.add_worker(worker_id)?;
let effects = router.try_drain_pending(self.now_ms)?;
self.dispatch_decode_admissions(effects.admissions)?;
}
changed = true;
}
}
SimulationWorkerStage::Aggregated => {
unreachable!("disagg replay should not receive aggregated worker ready events")
}
}
}
Ok(changed)
}
/// Repeatedly process all work that becomes possible without advancing logical time. /// Repeatedly process all work that becomes possible without advancing logical time.
fn drain_current_timestamp(&mut self) -> Result<()> { fn drain_current_timestamp(&mut self) -> Result<()> {
loop { loop {
let mut changed = self.apply_worker_completions()?; let mut changed = self.apply_worker_completions()?;
changed |= self.apply_worker_ready_events()?;
changed |= self.apply_decode_handoffs()?; changed |= self.apply_decode_handoffs()?;
changed |= self.release_ready_arrivals()?; changed |= self.release_ready_arrivals()?;
changed |= self.drive_prefill_workers()?; changed |= self.drive_prefill_workers()?;
...@@ -723,7 +775,8 @@ impl DisaggRuntime { ...@@ -723,7 +775,8 @@ impl DisaggRuntime {
// ------------------------------------------------------------------ // ------------------------------------------------------------------
/// Advance the simulation up to `until_ms` simulated time, then pause. /// Advance the simulation up to `until_ms` simulated time, then pause.
/// Returns `true` if the replay is done (no more work). /// Returns `true` if the request workload is done — pending `WorkerReady`
/// events do not block completion since there is no work for those workers.
pub(in crate::replay) fn advance_to(&mut self, until_ms: f64) -> Result<bool> { pub(in crate::replay) fn advance_to(&mut self, until_ms: f64) -> Result<bool> {
self.drain_current_timestamp()?; self.drain_current_timestamp()?;
...@@ -743,7 +796,7 @@ impl DisaggRuntime { ...@@ -743,7 +796,7 @@ impl DisaggRuntime {
self.drain_current_timestamp()?; self.drain_current_timestamp()?;
} }
Ok(self.is_done()) Ok(self.is_workload_done())
} }
/// Current simulated time in milliseconds. /// Current simulated time in milliseconds.
...@@ -783,18 +836,41 @@ impl DisaggRuntime { ...@@ -783,18 +836,41 @@ impl DisaggRuntime {
} }
/// Apply a scaling decision with separate prefill and decode targets. /// Apply a scaling decision with separate prefill and decode targets.
/// Newly marked workers are removed from the router immediately so no ///
/// new requests land on them while they drain in-flight work. /// Scale-up: if `startup_time` is configured on the respective engine args,
/// new workers enter a startup phase and a `WorkerReady` event is scheduled.
/// They become active (and are registered with the router) only when that
/// event fires. Without `startup_time`, workers are available immediately.
///
/// Scale-down: the worker is removed from the router immediately so no
/// new requests land on it while it drains in-flight work.
pub(in crate::replay) fn apply_scaling( pub(in crate::replay) fn apply_scaling(
&mut self, &mut self,
target_prefill: usize, target_prefill: usize,
target_decode: usize, target_decode: usize,
) -> Result<()> { ) -> Result<()> {
// -- prefill --
let (added, newly_marked) = self.prefill_engine.apply_target_count(target_prefill); let (added, newly_marked) = self.prefill_engine.apply_target_count(target_prefill);
let prefill_admissions = if let Some(router) = self.prefill_router.as_mut() { let prefill_delay = self.prefill_engine.startup_time_ms();
for id in added { for &id in &added {
router.add_worker(id)?; match prefill_delay {
Some(delay) => {
push_worker_ready(
&mut self.events,
&mut self.next_event_seq,
self.now_ms + delay,
SimulationWorkerStage::Prefill,
id,
);
}
None => {
if let Some(router) = self.prefill_router.as_mut() {
router.add_worker(id)?;
}
}
} }
}
let prefill_admissions = if let Some(router) = self.prefill_router.as_mut() {
for id in newly_marked { for id in newly_marked {
router.remove_worker(id)?; router.remove_worker(id)?;
} }
...@@ -802,11 +878,29 @@ impl DisaggRuntime { ...@@ -802,11 +878,29 @@ impl DisaggRuntime {
} else { } else {
Vec::new() Vec::new()
}; };
// -- decode --
let (added, newly_marked) = self.decode_engine.apply_target_count(target_decode); let (added, newly_marked) = self.decode_engine.apply_target_count(target_decode);
let decode_admissions = if let Some(router) = self.decode_router.as_mut() { let decode_delay = self.decode_engine.startup_time_ms();
for id in added { for &id in &added {
router.add_worker(id)?; match decode_delay {
Some(delay) => {
push_worker_ready(
&mut self.events,
&mut self.next_event_seq,
self.now_ms + delay,
SimulationWorkerStage::Decode,
id,
);
}
None => {
if let Some(router) = self.decode_router.as_mut() {
router.add_worker(id)?;
}
}
} }
}
let decode_admissions = if let Some(router) = self.decode_router.as_mut() {
for id in newly_marked { for id in newly_marked {
router.remove_worker(id)?; router.remove_worker(id)?;
} }
......
...@@ -25,6 +25,10 @@ pub(crate) enum SimulationEventKind { ...@@ -25,6 +25,10 @@ pub(crate) enum SimulationEventKind {
DecodeHandoff { DecodeHandoff {
uuid: Uuid, uuid: Uuid,
}, },
WorkerReady {
stage: SimulationWorkerStage,
worker_id: usize,
},
} }
#[derive(Debug)] #[derive(Debug)]
......
...@@ -108,7 +108,7 @@ pub(super) fn pop_ready_worker_completion( ...@@ -108,7 +108,7 @@ pub(super) fn pop_ready_worker_completion(
output_signals, output_signals,
kv_events, kv_events,
), ),
SimulationEventKind::DecodeHandoff { .. } => { SimulationEventKind::DecodeHandoff { .. } | SimulationEventKind::WorkerReady { .. } => {
unreachable!("peeked worker completion event must match popped event") unreachable!("peeked worker completion event must match popped event")
} }
}; };
...@@ -153,6 +153,39 @@ pub(super) fn pop_ready_decode_handoff( ...@@ -153,6 +153,39 @@ pub(super) fn pop_ready_decode_handoff(
Some(uuid) Some(uuid)
} }
pub(super) fn push_worker_ready(
events: &mut BinaryHeap<SimulationEvent>,
next_event_seq: &mut u64,
at_ms: f64,
stage: SimulationWorkerStage,
worker_id: usize,
) {
events.push(SimulationEvent {
at_ms,
seq_no: *next_event_seq,
kind: SimulationEventKind::WorkerReady { stage, worker_id },
});
*next_event_seq += 1;
}
pub(super) fn pop_ready_worker_ready(
events: &mut BinaryHeap<SimulationEvent>,
now_ms: f64,
) -> Option<(SimulationWorkerStage, usize)> {
let event = events.peek()?;
if event.at_ms != now_ms {
return None;
}
let SimulationEventKind::WorkerReady { .. } = &event.kind else {
return None;
};
let event = events.pop().expect("event must exist after peek");
let SimulationEventKind::WorkerReady { stage, worker_id } = event.kind else {
unreachable!("peeked worker ready event must match popped event");
};
Some((stage, worker_id))
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
...@@ -260,4 +293,84 @@ mod tests { ...@@ -260,4 +293,84 @@ mod tests {
assert_eq!(second.completed_requests, 2); assert_eq!(second.completed_requests, 2);
assert!(events.is_empty()); assert!(events.is_empty());
} }
#[test]
fn test_worker_ready_push_pop_round_trip() {
let mut events = BinaryHeap::new();
let mut next_event_seq = 0;
push_worker_ready(
&mut events,
&mut next_event_seq,
100.0,
SimulationWorkerStage::Aggregated,
3,
);
// Not ready before the scheduled time.
assert!(pop_ready_worker_ready(&mut events, 99.0).is_none());
let (stage, worker_id) = pop_ready_worker_ready(&mut events, 100.0).unwrap();
assert_eq!(stage, SimulationWorkerStage::Aggregated);
assert_eq!(worker_id, 3);
assert!(events.is_empty());
}
#[test]
fn test_worker_ready_does_not_interfere_with_completion_pop() {
let mut events = BinaryHeap::new();
let mut next_event_seq = 0;
push_worker_ready(
&mut events,
&mut next_event_seq,
10.0,
SimulationWorkerStage::Aggregated,
1,
);
// pop_ready_worker_completion must return None (wrong event kind).
assert!(pop_ready_worker_completion(&mut events, 10.0).is_none());
// The event should still be in the heap.
assert_eq!(events.len(), 1);
// pop_ready_worker_ready should succeed.
assert!(pop_ready_worker_ready(&mut events, 10.0).is_some());
}
#[test]
fn test_worker_ready_interleaved_with_completion() {
let mut events = BinaryHeap::new();
let mut next_event_seq = 0;
push_worker_completion(
&mut events,
&mut next_event_seq,
10.0,
WorkerCompletionPayload {
stage: SimulationWorkerStage::Aggregated,
worker_idx: 0,
completed_requests: 1,
output_signals: Vec::new(),
kv_events: Vec::new(),
},
);
push_worker_ready(
&mut events,
&mut next_event_seq,
10.0,
SimulationWorkerStage::Aggregated,
5,
);
// The completion was pushed first (lower seq_no) so it pops first.
let completion = pop_ready_worker_completion(&mut events, 10.0).unwrap();
assert_eq!(completion.worker_idx, 0);
// Now the ready event is at the front.
assert!(pop_ready_worker_completion(&mut events, 10.0).is_none());
let (stage, worker_id) = pop_ready_worker_ready(&mut events, 10.0).unwrap();
assert_eq!(stage, SimulationWorkerStage::Aggregated);
assert_eq!(worker_id, 5);
assert!(events.is_empty());
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment