disagg.rs 41.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::collections::{BinaryHeap, HashMap, VecDeque};

use anyhow::{Result, anyhow, bail};
use dynamo_kv_router::config::KvRouterConfig;
use dynamo_kv_router::protocols::RouterEvent;
use uuid::Uuid;

11
12
13
14
15
pub(super) use super::components::ReplayMode;
use super::components::{
    AdmissionQueue, EngineComponent, EngineEffects, EnginePassMode, OfflineReplayRouter,
    ScheduledWorkerCompletion, WorkerAdmission,
};
16
use super::events::{SimulationEvent, SimulationWorkerStage};
17
use super::progress::ReplayProgress;
18
use super::runtime_utils::{
19
    next_timestamp as choose_next_timestamp, pop_ready_decode_handoff, pop_ready_worker_completion,
20
21
22
23
    push_decode_handoff, push_worker_completion,
};
#[cfg(test)]
use super::state::DisaggRequestSnapshot;
24
use super::state::{DisaggPhase, DisaggRequestState};
25
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal};
26
use crate::loadgen::{ReplayRequestHashes, WorkloadDriver};
27
28
29
use crate::replay::{
    OfflineDisaggReplayConfig, ReplayPrefillLoadEstimator, ReplayRouterMode, TraceCollector,
};
30
use crate::scheduler::AdmissionEvent;
31

32
33
34
35
36
37
38
39
40
41
#[cfg(test)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum DisaggTransition {
    PrefillMarkCompleted { uuid: Uuid },
    PrefillFree { uuid: Uuid },
    DecodeHandoffQueued { uuid: Uuid },
    DecodeEnqueued { uuid: Uuid },
    DecodeFree { uuid: Uuid },
    RequestMarkedDone { uuid: Uuid },
    WorkloadCompleted { uuid: Uuid },
42
43
44
45
}

#[cfg(test)]
#[derive(Debug, Default, Clone, PartialEq)]
46
pub(super) struct DisaggRuntimeStats {
47
48
49
50
51
    request_snapshots: HashMap<Uuid, DisaggRequestSnapshot>,
    prefill_assignments: HashMap<Uuid, usize>,
    decode_assignments: HashMap<Uuid, usize>,
    handoff_ms: HashMap<Uuid, f64>,
    prefill_marked_count: usize,
52
53
54
55
    prefill_router_freed_count: usize,
    decode_router_freed_count: usize,
    max_prefill_router_pending_count: usize,
    max_decode_router_pending_count: usize,
56
    transition_log: Vec<DisaggTransition>,
57
58
59
60
}

#[cfg(not(test))]
#[derive(Debug, Default, Clone, PartialEq, Eq)]
61
pub(super) struct DisaggRuntimeStats;
62

63
pub(super) struct DisaggRuntime {
64
65
66
67
    now_ms: f64,
    next_prefill_worker_idx: usize,
    next_decode_worker_idx: usize,
    next_event_seq: u64,
68
69
70
    admission: AdmissionQueue,
    prefill_engine: EngineComponent,
    decode_engine: EngineComponent,
71
72
73
74
75
    prefill_router: Option<OfflineReplayRouter>,
    decode_router: Option<OfflineReplayRouter>,
    requests: HashMap<Uuid, DisaggRequestState>,
    collector: TraceCollector,
    events: BinaryHeap<SimulationEvent>,
76
    progress: ReplayProgress,
77
78
79
80
    stats: DisaggRuntimeStats,
}

impl DisaggRuntime {
81
82
    /// Create a disaggregated offline runtime seeded from an explicit request queue.
    pub(super) fn new(
83
84
        config: &OfflineDisaggReplayConfig,
        router_config: Option<KvRouterConfig>,
85
        prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
86
87
88
89
90
91
92
        pending: VecDeque<DirectRequest>,
        mode: ReplayMode,
        router_mode: ReplayRouterMode,
    ) -> Result<Self> {
        Self::new_with_source(
            config,
            router_config,
93
94
            prefill_load_estimator,
            AdmissionQueue::new_requests(pending, mode),
95
96
97
98
            router_mode,
        )
    }

99
100
    /// Create a disaggregated offline runtime whose admissions come from a workload driver.
    pub(super) fn new_workload(
101
102
        config: &OfflineDisaggReplayConfig,
        router_config: Option<KvRouterConfig>,
103
        prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
104
105
106
107
108
109
110
        driver: WorkloadDriver,
        mode: ReplayMode,
        router_mode: ReplayRouterMode,
    ) -> Result<Self> {
        Self::new_with_source(
            config,
            router_config,
111
112
            prefill_load_estimator,
            AdmissionQueue::new_workload(driver, mode),
113
114
115
116
            router_mode,
        )
    }

117
    /// Shared constructor for both raw-request and workload-driven admissions.
118
119
120
    fn new_with_source(
        config: &OfflineDisaggReplayConfig,
        router_config: Option<KvRouterConfig>,
121
122
        prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
        admission: AdmissionQueue,
123
124
        router_mode: ReplayRouterMode,
    ) -> Result<Self> {
125
        let progress = ReplayProgress::new(admission.total_requests(), "offline disagg replay");
126
127
128
129
130
131
132
133
134
135
136
        let (prefill_router, decode_router) = match router_mode {
            ReplayRouterMode::RoundRobin => (None, None),
            ReplayRouterMode::KvRouter => {
                let prefill_router_config =
                    derive_prefill_router_config(&config.prefill_args, router_config.clone());
                let decode_router_config =
                    derive_decode_router_config(&config.decode_args, router_config);
                (
                    Some(OfflineReplayRouter::new(
                        &config.prefill_args,
                        Some(prefill_router_config),
137
                        prefill_load_estimator,
138
139
140
141
142
                        config.num_prefill_workers,
                    )?),
                    Some(OfflineReplayRouter::new(
                        &config.decode_args,
                        Some(decode_router_config),
143
                        None,
144
145
146
147
148
149
                        config.num_decode_workers,
                    )?),
                )
            }
        };

150
151
152
153
        let prefill_engine = EngineComponent::new(
            SimulationWorkerStage::Prefill,
            EnginePassMode::Hidden,
            (0..config.num_prefill_workers)
154
                .map(|worker_idx| {
155
                    super::state::OfflineWorkerState::new(
156
157
158
159
160
161
                        worker_idx,
                        config.prefill_args.clone(),
                        prefill_router.is_some(),
                    )
                })
                .collect(),
162
163
164
165
166
        );
        let decode_engine = EngineComponent::new(
            SimulationWorkerStage::Decode,
            EnginePassMode::Visible,
            (0..config.num_decode_workers)
167
                .map(|worker_idx| {
168
169
170
171
172
                    super::state::OfflineWorkerState::new(
                        worker_idx,
                        config.decode_args.clone(),
                        false,
                    )
173
174
                })
                .collect(),
175
176
177
178
179
180
181
182
183
184
        );

        Ok(Self {
            now_ms: 0.0,
            next_prefill_worker_idx: 0,
            next_decode_worker_idx: 0,
            next_event_seq: 0,
            admission,
            prefill_engine,
            decode_engine,
185
186
187
188
189
            prefill_router,
            decode_router,
            requests: HashMap::new(),
            collector: TraceCollector::default(),
            events: BinaryHeap::new(),
190
            progress,
191
192
193
194
195
196
197
            #[cfg(test)]
            stats: DisaggRuntimeStats::default(),
            #[cfg(not(test))]
            stats: DisaggRuntimeStats,
        })
    }

198
    /// Count all requests consuming cluster capacity across prefill, decode, and router queues.
199
    fn cluster_in_flight(&self) -> usize {
200
201
        self.prefill_engine.in_flight()
            + self.decode_engine.in_flight()
202
203
204
205
206
207
208
209
210
211
            + self
                .prefill_router
                .as_ref()
                .map_or(0, OfflineReplayRouter::pending_count)
            + self
                .decode_router
                .as_ref()
                .map_or(0, OfflineReplayRouter::pending_count)
    }

212
    /// Pick the next prefill worker in round-robin order.
213
214
215
    fn next_prefill_worker(&mut self) -> usize {
        let worker_idx = self.next_prefill_worker_idx;
        self.next_prefill_worker_idx =
216
            (self.next_prefill_worker_idx + 1) % self.prefill_engine.worker_count();
217
218
219
        worker_idx
    }

220
    /// Pick the next decode worker in round-robin order.
221
222
    fn next_decode_worker(&mut self) -> usize {
        let worker_idx = self.next_decode_worker_idx;
223
224
        self.next_decode_worker_idx =
            (self.next_decode_worker_idx + 1) % self.decode_engine.worker_count();
225
226
227
        worker_idx
    }

228
    /// Track the peak number of requests parked in each stage router.
229
230
231
    fn record_router_pending(&mut self) {
        #[cfg(test)]
        {
232
233
234
235
236
237
238
239
240
241
242
243
            self.stats.max_prefill_router_pending_count =
                self.stats.max_prefill_router_pending_count.max(
                    self.prefill_router
                        .as_ref()
                        .map_or(0, OfflineReplayRouter::pending_count),
                );
            self.stats.max_decode_router_pending_count =
                self.stats.max_decode_router_pending_count.max(
                    self.decode_router
                        .as_ref()
                        .map_or(0, OfflineReplayRouter::pending_count),
                );
244
245
246
        }
    }

247
    /// Borrow immutable request state with a structured missing-request error.
248
249
250
251
252
253
    fn state(&self, uuid: Uuid) -> Result<&DisaggRequestState> {
        self.requests
            .get(&uuid)
            .ok_or_else(|| anyhow!("offline disagg replay missing request state for {uuid}"))
    }

254
    /// Borrow mutable request state with a structured missing-request error.
255
256
257
258
259
260
    fn state_mut(&mut self, uuid: Uuid) -> Result<&mut DisaggRequestState> {
        self.requests
            .get_mut(&uuid)
            .ok_or_else(|| anyhow!("offline disagg replay missing request state for {uuid}"))
    }

261
    /// Dispatch a request's prefill stage onto a specific prefill worker.
262
263
    fn dispatch_prefill(&mut self, uuid: Uuid, worker_idx: usize) -> Result<()> {
        let request = self.state(uuid)?.build_prefill_request()?;
264
        self.prefill_engine.dispatch(worker_idx, request)?;
265
266
267
268
269
270
271
272
        self.state_mut(uuid)?.start_prefill(worker_idx);
        #[cfg(test)]
        {
            self.stats.prefill_assignments.insert(uuid, worker_idx);
        }
        Ok(())
    }

273
    /// Dispatch a request's decode stage onto a specific decode worker.
274
    fn dispatch_decode(&mut self, uuid: Uuid, worker_idx: usize) -> Result<()> {
275
276
        let request = self.state(uuid)?.original_request()?.clone();
        self.decode_engine.dispatch(worker_idx, request)?;
277
278
279
280
281
282
283
284
        self.state_mut(uuid)?.start_decode(worker_idx);
        #[cfg(test)]
        {
            self.stats.decode_assignments.insert(uuid, worker_idx);
        }
        Ok(())
    }

285
    /// Turn prefill router admissions into concrete worker dispatches.
286
287
288
    fn dispatch_prefill_admissions(&mut self, admissions: Vec<WorkerAdmission>) -> Result<()> {
        for WorkerAdmission { uuid, worker_idx } in admissions {
            if self.state(uuid)?.phase != DisaggPhase::QueuedPrefill {
289
290
291
292
293
294
295
                bail!("offline disagg replay expected queued prefill request for {uuid}");
            }
            self.dispatch_prefill(uuid, worker_idx)?;
        }
        Ok(())
    }

296
    /// Turn decode router admissions into concrete worker dispatches.
297
298
299
    fn dispatch_decode_admissions(&mut self, admissions: Vec<WorkerAdmission>) -> Result<()> {
        for WorkerAdmission { uuid, worker_idx } in admissions {
            if self.state(uuid)?.phase != DisaggPhase::QueuedDecode {
300
301
302
303
304
305
306
                bail!("offline disagg replay expected queued decode request for {uuid}");
            }
            self.dispatch_decode(uuid, worker_idx)?;
        }
        Ok(())
    }

307
308
309
    /// Queue or dispatch a request into decode, depending on whether a decode router is active.
    fn enqueue_decode(&mut self, uuid: Uuid) -> Result<()> {
        if self.decode_router.is_none() {
310
311
312
313
314
315
316
            #[cfg(test)]
            {
                self.stats
                    .transition_log
                    .push(DisaggTransition::DecodeEnqueued { uuid });
                self.stats.handoff_ms.insert(uuid, self.now_ms);
            }
317
318
319
            let worker_idx = self.next_decode_worker();
            self.dispatch_decode(uuid, worker_idx)?;
            return Ok(());
320
        }
321
322
        let request = self.state(uuid)?.original_request()?.clone();
        self.state_mut(uuid)?.queue_decode();
323
324
        #[cfg(test)]
        {
325
326
327
            self.stats
                .transition_log
                .push(DisaggTransition::DecodeEnqueued { uuid });
328
329
            self.stats.handoff_ms.insert(uuid, self.now_ms);
        }
330
331
332
333
334
335
336
337
        let admissions = self
            .decode_router
            .as_mut()
            .expect("decode router presence checked above")
            .on_request_arrival(&request, None, self.now_ms)?
            .admissions;
        self.record_router_pending();
        self.dispatch_decode_admissions(admissions)?;
338
339
340
        Ok(())
    }

341
    /// Admit one external request into prefill-side state, collector state, and optional router.
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
    fn on_external_arrival(
        &mut self,
        mut request: DirectRequest,
        arrival_time_ms: f64,
        replay_hashes: Option<ReplayRequestHashes>,
    ) -> Result<Uuid> {
        let uuid = request.uuid.unwrap_or_else(Uuid::new_v4);
        request.uuid = Some(uuid);
        request.arrival_timestamp_ms = Some(arrival_time_ms);

        self.collector.on_arrival(
            uuid,
            arrival_time_ms,
            request.tokens.len(),
            request.max_output_tokens,
        );

359
        let queued_request = request.clone();
360
361
        self.requests
            .insert(uuid, DisaggRequestState::new(request, arrival_time_ms));
362
        if self.prefill_router.is_none() {
363
364
365
            let worker_idx = self.next_prefill_worker();
            self.dispatch_prefill(uuid, worker_idx)?;
            return Ok(uuid);
366
        }
367
368
369
370
371
372
373
374
        let admissions = self
            .prefill_router
            .as_mut()
            .expect("prefill router presence checked above")
            .on_request_arrival(&queued_request, replay_hashes, self.now_ms)?
            .admissions;
        self.record_router_pending();
        self.dispatch_prefill_admissions(admissions)?;
375
376
377
        Ok(uuid)
    }

378
    /// Return true once both stages, both routers, and all admissions are fully drained.
379
380
381
    fn is_done(&self) -> bool {
        self.events.is_empty()
            && self.cluster_in_flight() == 0
382
383
384
            && self.admission.is_drained()
            && self.prefill_engine.is_drained()
            && self.decode_engine.is_drained()
385
386
    }

387
    /// Pick the next logical timestamp from arrivals, worker completions, or decode handoffs.
388
389
    fn next_timestamp(&mut self) -> Option<f64> {
        let next_event_ms = self.events.peek().map(|event| event.at_ms);
390
391
392
393
        choose_next_timestamp(
            self.admission.next_ready_time_ms(self.cluster_in_flight()),
            next_event_ms,
        )
394
395
    }

396
    /// Apply prefill-side KV router events at the scheduler-selected visibility phase.
397
398
399
400
    fn apply_prefill_router_events(&mut self, events: Vec<RouterEvent>) -> Result<()> {
        let Some(prefill_router) = self.prefill_router.as_mut() else {
            return Ok(());
        };
401
402
403
        let effects = prefill_router.on_kv_events(events)?;
        if !effects.admissions.is_empty() {
            bail!("offline disagg replay prefill KV events must not admit requests");
404
405
406
407
        }
        Ok(())
    }

408
    /// Process one prefill output signal, including router updates and decode handoff scheduling.
409
410
411
412
413
414
415
416
    fn process_prefill_signal(&mut self, signal: OutputSignal) -> Result<()> {
        if !signal.completed {
            return Ok(());
        }

        if self.prefill_router.is_some() {
            let prefill_complete_admissions = {
                let prefill_router = self.prefill_router.as_mut().expect("router checked above");
417
418
419
                prefill_router
                    .on_prefill_completed(signal.uuid, self.now_ms)?
                    .admissions
420
421
422
423
            };
            #[cfg(test)]
            {
                self.stats.prefill_marked_count += 1;
424
425
426
                self.stats
                    .transition_log
                    .push(DisaggTransition::PrefillMarkCompleted { uuid: signal.uuid });
427
428
429
430
431
432
            }
            self.record_router_pending();
            self.dispatch_prefill_admissions(prefill_complete_admissions)?;

            let admissions = {
                let prefill_router = self.prefill_router.as_mut().expect("router checked above");
433
434
435
                prefill_router
                    .on_request_completed(signal.uuid, self.now_ms)?
                    .admissions
436
437
438
            };
            #[cfg(test)]
            {
439
                self.stats.prefill_router_freed_count += 1;
440
441
442
                self.stats
                    .transition_log
                    .push(DisaggTransition::PrefillFree { uuid: signal.uuid });
443
444
445
446
447
448
449
450
            }
            self.record_router_pending();
            self.dispatch_prefill_admissions(admissions)?;
        }

        self.enqueue_decode_after_handoff(signal.uuid, signal.handoff_delay_ms)
    }

451
    /// Process one decode output signal, including decode router frees and request completion.
452
453
454
455
456
457
    fn process_decode_signal(&mut self, signal: OutputSignal) -> Result<()> {
        if !signal.completed {
            return Ok(());
        }

        let admissions = if let Some(decode_router) = self.decode_router.as_mut() {
458
459
460
            let admissions = decode_router
                .on_request_completed(signal.uuid, self.now_ms)?
                .admissions;
461
462
            #[cfg(test)]
            {
463
                self.stats.decode_router_freed_count += 1;
464
465
466
                self.stats
                    .transition_log
                    .push(DisaggTransition::DecodeFree { uuid: signal.uuid });
467
468
469
470
471
472
            }
            admissions
        } else {
            Vec::new()
        };
        self.record_router_pending();
473
474
475
476
477
478
479
480
        self.admission
            .on_request_completed(signal.uuid, self.now_ms)?;
        self.progress.inc_completed();
        #[cfg(test)]
        if self.admission.is_workload() {
            self.stats
                .transition_log
                .push(DisaggTransition::WorkloadCompleted { uuid: signal.uuid });
481
482
        }
        self.state_mut(signal.uuid)?.mark_done();
483
484
485
486
487
488
        #[cfg(test)]
        {
            self.stats
                .transition_log
                .push(DisaggTransition::RequestMarkedDone { uuid: signal.uuid });
        }
489
490
491
492
        self.dispatch_decode_admissions(admissions)?;
        Ok(())
    }

493
    /// Apply the side effects of a finished prefill pass.
494
495
    fn process_prefill_pass(
        &mut self,
496
497
        _worker_idx: usize,
        _completed_requests: usize,
498
499
500
501
502
503
504
505
506
507
        output_signals: Vec<OutputSignal>,
        kv_events: Vec<RouterEvent>,
    ) -> Result<()> {
        self.apply_prefill_router_events(kv_events)?;
        for signal in output_signals {
            self.process_prefill_signal(signal)?;
        }
        Ok(())
    }

508
    /// Apply the side effects of a finished decode pass.
509
510
    fn process_decode_pass(
        &mut self,
511
512
        _worker_idx: usize,
        _completed_requests: usize,
513
514
515
516
517
518
519
520
        output_signals: Vec<OutputSignal>,
    ) -> Result<()> {
        for signal in output_signals {
            self.process_decode_signal(signal)?;
        }
        Ok(())
    }

521
    /// Drain all worker-completion events scheduled for the current logical timestamp.
522
523
    fn apply_worker_completions(&mut self) -> Result<bool> {
        let mut changed = false;
524
525
        while let Some(payload) = pop_ready_worker_completion(&mut self.events, self.now_ms) {
            match payload.stage {
526
                SimulationWorkerStage::Prefill => {
527
                    let payload = self.prefill_engine.on_scheduled_completion(payload)?;
528
                    self.process_prefill_pass(
529
530
531
532
                        payload.worker_idx,
                        payload.completed_requests,
                        payload.output_signals,
                        payload.kv_events,
533
534
535
                    )?;
                }
                SimulationWorkerStage::Decode => {
536
537
538
539
540
541
                    let payload = self.decode_engine.on_scheduled_completion(payload)?;
                    self.process_decode_pass(
                        payload.worker_idx,
                        payload.completed_requests,
                        payload.output_signals,
                    )?;
542
543
544
545
546
547
548
549
550
551
                }
                SimulationWorkerStage::Aggregated => {
                    bail!("offline disagg replay received an aggregated completion event")
                }
            }
            changed = true;
        }
        Ok(changed)
    }

552
    /// Drain all delayed decode handoff events scheduled for the current logical timestamp.
553
554
555
556
557
558
559
560
561
    fn apply_decode_handoffs(&mut self) -> Result<bool> {
        let mut changed = false;
        while let Some(uuid) = pop_ready_decode_handoff(&mut self.events, self.now_ms) {
            self.enqueue_decode(uuid)?;
            changed = true;
        }
        Ok(changed)
    }

562
    /// Either enqueue decode immediately or schedule a delayed handoff event on the event heap.
563
564
565
566
567
    fn enqueue_decode_after_handoff(
        &mut self,
        uuid: Uuid,
        handoff_delay_ms: Option<f64>,
    ) -> Result<()> {
568
569
570
571
        let Some(delay_ms) = handoff_delay_ms else {
            return self.enqueue_decode(uuid);
        };
        if delay_ms > 0.0 {
572
573
574
575
576
577
            push_decode_handoff(
                &mut self.events,
                &mut self.next_event_seq,
                self.now_ms + delay_ms,
                uuid,
            );
578
579
580
581
            #[cfg(test)]
            self.stats
                .transition_log
                .push(DisaggTransition::DecodeHandoffQueued { uuid });
582
583
584
585
586
            return Ok(());
        }
        self.enqueue_decode(uuid)
    }

587
588
    /// Release every admission made ready by the shared admission queue.
    fn release_ready_arrivals(&mut self) -> Result<bool> {
589
        let mut released_any = false;
590
591
592
593
594
        for ready in self
            .admission
            .drain_ready(self.now_ms, self.cluster_in_flight())?
        {
            self.on_external_arrival(ready.request, ready.arrival_time_ms, ready.replay_hashes)?;
595
596
597
598
599
            released_any = true;
        }
        Ok(released_any)
    }

600
    /// Start passes on every idle prefill worker that can make progress at the current timestamp.
601
602
    fn drive_prefill_workers(&mut self) -> Result<bool> {
        let mut changed = false;
603
604
605
606
        loop {
            let effects = self.prefill_engine.drive_ready(self.now_ms, None)?;
            if effects.is_empty() {
                return Ok(changed);
607
            }
608
609
            changed = true;
            self.handle_prefill_engine_effects(effects)?;
610
611
612
        }
    }

613
    /// Start passes on every idle decode worker that can make progress at the current timestamp.
614
615
    fn drive_decode_workers(&mut self) -> Result<bool> {
        let mut changed = false;
616
617
618
619
620
621
622
623
624
625
626
        loop {
            let effects = self
                .decode_engine
                .drive_ready(self.now_ms, Some(&mut self.collector))?;
            if effects.is_empty() {
                return Ok(changed);
            }
            changed = true;
            self.handle_decode_engine_effects(effects)?;
        }
    }
627

628
    fn handle_prefill_engine_effects(&mut self, effects: EngineEffects) -> Result<()> {
629
        self.record_prefill_admissions(effects.admissions);
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
        self.apply_prefill_router_events(effects.pass_start_kv_events)?;
        for payload in effects.immediate_completions {
            let payload = self.prefill_engine.on_scheduled_completion(payload)?;
            self.process_prefill_pass(
                payload.worker_idx,
                payload.completed_requests,
                payload.output_signals,
                payload.kv_events,
            )?;
        }
        for ScheduledWorkerCompletion { at_ms, payload } in effects.scheduled_completions {
            push_worker_completion(&mut self.events, &mut self.next_event_seq, at_ms, payload);
        }
        Ok(())
    }
645

646
647
648
649
650
651
652
    fn record_prefill_admissions(&mut self, admissions: Vec<AdmissionEvent>) {
        for admission in admissions {
            self.collector
                .on_admit(admission.uuid, self.now_ms, admission.reused_input_tokens);
        }
    }

653
654
655
656
657
658
659
660
    fn handle_decode_engine_effects(&mut self, effects: EngineEffects) -> Result<()> {
        for payload in effects.immediate_completions {
            let payload = self.decode_engine.on_scheduled_completion(payload)?;
            self.process_decode_pass(
                payload.worker_idx,
                payload.completed_requests,
                payload.output_signals,
            )?;
661
        }
662
663
664
665
        for ScheduledWorkerCompletion { at_ms, payload } in effects.scheduled_completions {
            push_worker_completion(&mut self.events, &mut self.next_event_seq, at_ms, payload);
        }
        Ok(())
666
667
    }

668
    /// Repeatedly process all work that becomes possible without advancing logical time.
669
670
671
672
    fn drain_current_timestamp(&mut self) -> Result<()> {
        loop {
            let mut changed = self.apply_worker_completions()?;
            changed |= self.apply_decode_handoffs()?;
673
            changed |= self.release_ready_arrivals()?;
674
675
676
677
678
679
680
681
682
683
            changed |= self.drive_prefill_workers()?;
            changed |= self.drive_decode_workers()?;

            if !changed {
                break;
            }
        }
        Ok(())
    }

684
    /// Finalize test-only request snapshots before returning.
685
686
687
688
689
690
691
692
693
694
695
    fn finish_test_stats(&mut self) {
        #[cfg(test)]
        {
            self.stats.request_snapshots = self
                .requests
                .iter()
                .map(|(uuid, state)| (*uuid, state.debug_snapshot()))
                .collect();
        }
    }

696
697
    /// Run the staged offline replay until both prefill and decode pipelines are drained.
    pub(super) fn run(mut self) -> Result<(TraceCollector, DisaggRuntimeStats)> {
698
699
700
701
702
703
704
705
706
707
708
709
710
        self.drain_current_timestamp()?;

        while !self.is_done() {
            let Some(next_timestamp_ms) = self.next_timestamp() else {
                bail!(
                    "offline disagg replay reached a dead end with {} in-flight requests remaining",
                    self.cluster_in_flight()
                );
            };
            self.now_ms = next_timestamp_ms;
            self.drain_current_timestamp()?;
        }

711
        self.progress.finish();
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
        self.finish_test_stats();
        Ok((self.collector, self.stats))
    }
}

fn base_router_config(
    args: &MockEngineArgs,
    router_config: Option<KvRouterConfig>,
) -> KvRouterConfig {
    let mut config = router_config.unwrap_or_default();
    if let Some(policy) = args.router_queue_policy {
        config.router_queue_policy = policy;
    }
    config
}

fn derive_prefill_router_config(
    args: &MockEngineArgs,
    router_config: Option<KvRouterConfig>,
) -> KvRouterConfig {
    let mut config = base_router_config(args, router_config);
    config.router_track_active_blocks = false;
    config
}

fn derive_decode_router_config(
    args: &MockEngineArgs,
    router_config: Option<KvRouterConfig>,
) -> KvRouterConfig {
    let mut config = base_router_config(args, router_config);
    config.overlap_score_weight = 0.0;
    config.router_assume_kv_reuse = false;
    config.router_track_prefill_tokens = false;
745
    config.router_prefill_load_model = dynamo_kv_router::config::RouterPrefillLoadModel::None;
746
747
748
749
750
    config
}

#[cfg(test)]
mod tests {
751
752
753
754
    use super::super::entrypoints::{
        run_concurrency_collect, run_concurrency_workload_collect, run_trace_collect,
        run_trace_workload_collect,
    };
755
    use super::*;
756
757
    use crate::common::protocols::{EngineType, MockEngineArgs, SglangArgs, WorkerType};
    use crate::loadgen::{SessionTrace, Trace, TurnTrace};
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773

    fn staged_args(worker_type: WorkerType, speedup_ratio: f64) -> MockEngineArgs {
        MockEngineArgs::builder()
            .block_size(64)
            .num_gpu_blocks(256)
            .max_num_batched_tokens(Some(8192))
            .max_num_seqs(Some(8))
            .enable_prefix_caching(true)
            .enable_chunked_prefill(true)
            .speedup_ratio(speedup_ratio)
            .decode_speedup_ratio(speedup_ratio)
            .worker_type(worker_type)
            .build()
            .unwrap()
    }

774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
    fn sglang_staged_args(worker_type: WorkerType, speedup_ratio: f64) -> MockEngineArgs {
        MockEngineArgs::builder()
            .engine_type(EngineType::Sglang)
            .block_size(64)
            .num_gpu_blocks(512)
            .max_num_batched_tokens(Some(8192))
            .max_num_seqs(Some(8))
            .enable_prefix_caching(true)
            .enable_chunked_prefill(true)
            .speedup_ratio(speedup_ratio)
            .decode_speedup_ratio(speedup_ratio)
            .worker_type(worker_type)
            .sglang(Some(SglangArgs {
                page_size: Some(64),
                ..Default::default()
            }))
            .build()
            .unwrap()
    }

794
795
796
797
798
799
800
801
802
    fn disagg_config() -> OfflineDisaggReplayConfig {
        OfflineDisaggReplayConfig {
            prefill_args: staged_args(WorkerType::Prefill, 1000.0),
            decode_args: staged_args(WorkerType::Decode, 1000.0),
            num_prefill_workers: 2,
            num_decode_workers: 2,
        }
    }

803
804
805
806
807
808
809
810
811
    fn sglang_disagg_config() -> OfflineDisaggReplayConfig {
        OfflineDisaggReplayConfig {
            prefill_args: sglang_staged_args(WorkerType::Prefill, 1000.0),
            decode_args: sglang_staged_args(WorkerType::Decode, 1000.0),
            num_prefill_workers: 2,
            num_decode_workers: 2,
        }
    }

812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
    fn disagg_config_with_handoff_delay() -> OfflineDisaggReplayConfig {
        let mut config = disagg_config();
        config.prefill_args.kv_transfer_bandwidth = Some(1.0);
        config.prefill_args.kv_bytes_per_token = Some(1_000_000);
        config
    }

    fn router_config() -> KvRouterConfig {
        KvRouterConfig {
            router_queue_threshold: Some(1.25),
            ..KvRouterConfig::default()
        }
    }

    fn request(
        uuid: u128,
        prompt_tokens: usize,
        output_tokens: usize,
        arrival_ms: f64,
    ) -> DirectRequest {
        DirectRequest {
            tokens: vec![1; prompt_tokens],
            max_output_tokens: output_tokens,
            uuid: Some(Uuid::from_u128(uuid)),
            dp_rank: 0,
            arrival_timestamp_ms: Some(arrival_ms),
        }
    }

841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
    fn multiturn_trace() -> Trace {
        Trace {
            block_size: 64,
            sessions: vec![
                SessionTrace {
                    session_id: "session-a".to_string(),
                    first_arrival_timestamp_ms: Some(0.0),
                    turns: vec![
                        TurnTrace {
                            input_length: 64,
                            max_output_tokens: 2,
                            hash_ids: vec![11],
                            delay_after_previous_ms: 0.0,
                        },
                        TurnTrace {
                            input_length: 192,
                            max_output_tokens: 2,
                            hash_ids: vec![21, 22, 23],
                            delay_after_previous_ms: 10.0,
                        },
                    ],
                },
                SessionTrace {
                    session_id: "session-b".to_string(),
                    first_arrival_timestamp_ms: Some(5.0),
                    turns: vec![TurnTrace {
                        input_length: 128,
                        max_output_tokens: 2,
                        hash_ids: vec![31, 32],
                        delay_after_previous_ms: 0.0,
                    }],
                },
            ],
        }
    }

    fn transition_index(transitions: &[DisaggTransition], needle: DisaggTransition) -> usize {
        transitions
            .iter()
            .position(|transition| *transition == needle)
            .unwrap()
    }

884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
    #[test]
    fn test_derive_stage_router_configs_force_required_overrides() {
        let config = KvRouterConfig {
            overlap_score_weight: 2.0,
            router_track_active_blocks: true,
            router_assume_kv_reuse: true,
            router_track_prefill_tokens: true,
            ..KvRouterConfig::default()
        };
        let args = staged_args(WorkerType::Prefill, 1.0);
        let prefill = derive_prefill_router_config(&args, Some(config.clone()));
        let decode = derive_decode_router_config(&args, Some(config));

        assert!(!prefill.router_track_active_blocks);
        assert_eq!(decode.overlap_score_weight, 0.0);
        assert!(!decode.router_assume_kv_reuse);
        assert!(!decode.router_track_prefill_tokens);
    }

    #[rstest::rstest]
    #[case(ReplayRouterMode::RoundRobin)]
    #[case(ReplayRouterMode::KvRouter)]
    fn test_trace_smoke_reports_decode_only_tokens(#[case] router_mode: ReplayRouterMode) {
        let config = disagg_config();
        let requests = vec![request(1, 128, 3, 5.0)];

        let router_config = (router_mode == ReplayRouterMode::KvRouter).then(router_config);
        let (collector, stats) =
            run_trace_collect(&config, requests, router_config, 1.0, router_mode);
        let snapshot = collector.snapshot(Uuid::from_u128(1)).unwrap();
        let report = collector.finish();

        assert_eq!(snapshot.arrival_time_ms, 0.0);
        assert!(snapshot.first_admit_ms.is_some());
        assert!(snapshot.first_token_ms.is_some());
        assert_eq!(snapshot.output_length, 3);
        assert_eq!(report.request_counts.completed_requests, 1);
921
        assert_eq!(report.request_counts.total_output_tokens, 3);
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
        assert_eq!(
            stats.request_snapshots[&Uuid::from_u128(1)].phase,
            DisaggPhase::Done
        );
    }

    #[rstest::rstest]
    #[case(ReplayRouterMode::RoundRobin)]
    #[case(ReplayRouterMode::KvRouter)]
    fn test_prefill_and_decode_use_separate_worker_pools(#[case] router_mode: ReplayRouterMode) {
        let config = disagg_config();
        let requests = vec![request(1, 128, 2, 0.0), request(2, 128, 2, 10.0)];

        let router_config = (router_mode == ReplayRouterMode::KvRouter).then(router_config);
        let (_, stats) = run_trace_collect(&config, requests, router_config, 1.0, router_mode);

        for uuid in [Uuid::from_u128(1), Uuid::from_u128(2)] {
            assert!(stats.prefill_assignments.contains_key(&uuid));
            assert!(stats.decode_assignments.contains_key(&uuid));
941
942
943
944
945
946
947
948
949
            assert_eq!(stats.request_snapshots[&uuid].phase, DisaggPhase::Done);
            assert_eq!(
                stats.request_snapshots[&uuid].prefill_worker_idx,
                Some(stats.prefill_assignments[&uuid])
            );
            assert_eq!(
                stats.request_snapshots[&uuid].decode_worker_idx,
                Some(stats.decode_assignments[&uuid])
            );
950
951
952
953
954
955
956
        }
    }

    #[test]
    fn test_prefill_overlap_prefers_same_worker_after_handoff_delay() {
        let requests = vec![request(1, 128, 2, 0.0), request(2, 128, 2, 100.0)];

957
958
959
960
961
962
963
964
965
        let cases = [(disagg_config(), true), (sglang_disagg_config(), false)];
        for (config, expect_same_worker) in cases {
            let (_, stats) = run_trace_collect(
                &config,
                requests.clone(),
                Some(router_config()),
                1.0,
                ReplayRouterMode::KvRouter,
            );
966

967
968
969
970
971
972
973
974
975
976
977
978
979
            if expect_same_worker {
                assert_eq!(
                    stats.prefill_assignments[&Uuid::from_u128(1)],
                    stats.prefill_assignments[&Uuid::from_u128(2)],
                );
            } else {
                for uuid in [Uuid::from_u128(1), Uuid::from_u128(2)] {
                    assert!(stats.prefill_assignments.contains_key(&uuid));
                    assert!(stats.decode_assignments.contains_key(&uuid));
                    assert_eq!(stats.request_snapshots[&uuid].phase, DisaggPhase::Done);
                }
            }
        }
980
981
    }

982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
    #[test]
    fn test_hidden_prefill_reports_reused_tokens_even_when_decode_prefix_caching_is_disabled() {
        let mut config = disagg_config();
        config.num_prefill_workers = 1;
        config.num_decode_workers = 1;
        config.decode_args.enable_prefix_caching = false;

        let requests = vec![request(1, 128, 2, 0.0), request(2, 128, 2, 100.0)];
        let (collector, _) = run_trace_collect(
            &config,
            requests,
            Some(router_config()),
            1.0,
            ReplayRouterMode::KvRouter,
        );

        let request_2 = collector.snapshot(Uuid::from_u128(2)).unwrap();
        let report = collector.finish();

        assert!(request_2.reused_input_tokens > 0);
        assert!(report.prefix_cache_reused_ratio > 0.0);
    }

1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
    #[rstest::rstest]
    #[case(ReplayRouterMode::RoundRobin)]
    #[case(ReplayRouterMode::KvRouter)]
    fn test_concurrency_backfill_waits_for_decode_completion(
        #[case] router_mode: ReplayRouterMode,
    ) {
        let config = disagg_config();
        let requests = vec![
            DirectRequest {
                tokens: vec![1; 128],
                max_output_tokens: 3,
                uuid: Some(Uuid::from_u128(1)),
                dp_rank: 0,
                arrival_timestamp_ms: None,
            },
            DirectRequest {
                tokens: vec![2; 128],
                max_output_tokens: 3,
                uuid: Some(Uuid::from_u128(2)),
                dp_rank: 0,
                arrival_timestamp_ms: None,
            },
        ];

        let router_config = (router_mode == ReplayRouterMode::KvRouter).then(router_config);
1030
        let (collector, stats) =
1031
1032
1033
1034
1035
1036
            run_concurrency_collect(&config, requests, router_config, 1, router_mode);
        let first = collector.snapshot(Uuid::from_u128(1)).unwrap();
        let second = collector.snapshot(Uuid::from_u128(2)).unwrap();

        assert_eq!(first.arrival_time_ms, 0.0);
        assert_eq!(second.arrival_time_ms, first.last_token_ms.unwrap());
1037
1038
1039
1040
1041
1042
1043
1044
        assert_eq!(
            stats.request_snapshots[&Uuid::from_u128(1)].phase,
            DisaggPhase::Done
        );
        assert_eq!(
            stats.request_snapshots[&Uuid::from_u128(2)].phase,
            DisaggPhase::Done
        );
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
    }

    #[test]
    fn test_prefill_completion_marks_and_frees_before_decode_handoff() {
        let config = disagg_config();
        let requests = vec![request(1, 128, 2, 0.0)];

        let (_, stats) = run_trace_collect(
            &config,
            requests,
            Some(router_config()),
            1.0,
            ReplayRouterMode::KvRouter,
        );

        assert_eq!(stats.prefill_marked_count, 1);
1061
1062
        assert_eq!(stats.prefill_router_freed_count, 1);
        assert_eq!(stats.decode_router_freed_count, 1);
1063
1064
1065
1066
1067
1068
1069
1070
        let transitions = &stats.transition_log;
        let uuid = Uuid::from_u128(1);
        let mark_idx =
            transition_index(transitions, DisaggTransition::PrefillMarkCompleted { uuid });
        let free_idx = transition_index(transitions, DisaggTransition::PrefillFree { uuid });
        let enqueue_idx = transition_index(transitions, DisaggTransition::DecodeEnqueued { uuid });
        assert!(mark_idx < free_idx);
        assert!(free_idx < enqueue_idx);
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
    }

    #[test]
    fn test_handoff_delay_increases_decode_visible_ttft() {
        let requests = vec![request(1, 128, 2, 0.0)];

        let (baseline_collector, _) = run_trace_collect(
            &disagg_config(),
            requests.clone(),
            None,
            1.0,
            ReplayRouterMode::RoundRobin,
        );
1084
        let (delayed_collector, delayed_stats) = run_trace_collect(
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
            &disagg_config_with_handoff_delay(),
            requests,
            None,
            1.0,
            ReplayRouterMode::RoundRobin,
        );

        let baseline = baseline_collector.snapshot(Uuid::from_u128(1)).unwrap();
        let delayed = delayed_collector.snapshot(Uuid::from_u128(1)).unwrap();
        let baseline_ttft = baseline.first_token_ms.unwrap() - baseline.arrival_time_ms;
        let delayed_ttft = delayed.first_token_ms.unwrap() - delayed.arrival_time_ms;

        assert!(
            delayed_ttft >= baseline_ttft + 120.0,
            "expected delayed TTFT to include roughly 128ms of handoff delay, baseline={baseline_ttft}, delayed={delayed_ttft}"
        );
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
        let uuid = Uuid::from_u128(1);
        let queued_idx = transition_index(
            &delayed_stats.transition_log,
            DisaggTransition::DecodeHandoffQueued { uuid },
        );
        let enqueued_idx = transition_index(
            &delayed_stats.transition_log,
            DisaggTransition::DecodeEnqueued { uuid },
        );
        assert!(queued_idx < enqueued_idx);
        assert!(delayed_stats.handoff_ms[&uuid] >= 120.0);
    }

    #[test]
    fn test_trace_workload_follow_up_turn_arrives_after_completion_plus_delay() {
        let (collector, _) = run_trace_workload_collect(
            &disagg_config(),
            multiturn_trace(),
            None,
            ReplayRouterMode::RoundRobin,
        );
        let snapshots = collector.snapshots();
        let first_turn = snapshots
            .iter()
            .find(|snapshot| snapshot.input_length == 64)
            .unwrap();
        let second_turn = snapshots
            .iter()
            .find(|snapshot| snapshot.input_length == 192)
            .unwrap();
        let session_b = snapshots
            .iter()
            .find(|snapshot| snapshot.input_length == 128)
            .unwrap();

        assert_eq!(first_turn.arrival_time_ms, 0.0);
        assert_eq!(session_b.arrival_time_ms, 5.0);
        assert!(
            second_turn.arrival_time_ms >= first_turn.last_token_ms.unwrap() + 10.0,
            "follow-up turn should unlock after completion plus delay"
        );
    }

    #[test]
    fn test_concurrency_workload_delayed_follow_up_does_not_bypass_other_ready_sessions() {
        let (collector, _) = run_concurrency_workload_collect(
            &disagg_config(),
            multiturn_trace(),
            None,
            1,
            ReplayRouterMode::RoundRobin,
        );
        let mut input_lengths = collector
            .snapshots()
            .into_iter()
            .map(|snapshot| (snapshot.arrival_time_ms, snapshot.input_length))
            .collect::<Vec<_>>();
        input_lengths.sort_by(|left, right| left.0.total_cmp(&right.0));

        assert_eq!(
            input_lengths
                .into_iter()
                .map(|(_, input_length)| input_length)
                .collect::<Vec<_>>(),
            vec![64, 128, 192]
        );
1167
1168
    }
}