Unverified Commit b7fe46b1 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(mocker): add multi-worker replay and router startup fixes (#7553)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 82794761
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
mod collector;
mod entrypoints;
mod loader;
pub(crate) mod offline;
mod online;
mod router;
mod validate;
use std::collections::VecDeque;
use crate::common::protocols::DirectRequest;
pub(crate) use collector::TraceCollector;
#[cfg(test)]
pub(crate) use collector::TraceRequestStatsSnapshot;
pub use collector::{
TraceDistributionStats, TraceInterTokenLatencyStats, TraceLatencyStats, TraceRequestCounts,
TraceSimulationReport, TraceThroughputStats,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ReplayRouterMode {
RoundRobin,
KvRouter,
}
pub use entrypoints::{
simulate_concurrency_file, simulate_concurrency_file_with_router_mode,
simulate_concurrency_live_file, simulate_concurrency_live_file_with_router_mode,
simulate_concurrency_live_requests, simulate_concurrency_live_requests_with_router_mode,
simulate_concurrency_requests, simulate_concurrency_requests_with_router_mode,
simulate_trace_file, simulate_trace_file_with_router_mode, simulate_trace_live_file,
simulate_trace_live_file_with_router_mode, simulate_trace_live_requests,
simulate_trace_live_requests_with_router_mode, simulate_trace_requests,
simulate_trace_requests_with_router_mode,
};
pub(crate) fn normalize_trace_requests(
mut requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
) -> anyhow::Result<VecDeque<DirectRequest>> {
if !arrival_speedup_ratio.is_finite() || arrival_speedup_ratio <= 0.0 {
anyhow::bail!(
"arrival_speedup_ratio must be a finite positive number, got {arrival_speedup_ratio}"
);
}
requests.sort_by(|left, right| {
let left_ts = left
.arrival_timestamp_ms
.expect("trace replay requests must have an arrival timestamp");
let right_ts = right
.arrival_timestamp_ms
.expect("trace replay requests must have an arrival timestamp");
left_ts.total_cmp(&right_ts)
});
let first_arrival_ms = requests
.first()
.and_then(|request| request.arrival_timestamp_ms)
.ok_or_else(|| anyhow::anyhow!("trace replay requires at least one timestamped request"))?;
Ok(VecDeque::from(
requests
.into_iter()
.map(|mut request| {
let arrival_timestamp_ms = request
.arrival_timestamp_ms
.expect("trace replay requests must have an arrival timestamp")
- first_arrival_ms;
let arrival_timestamp_ms = arrival_timestamp_ms / arrival_speedup_ratio;
request.arrival_timestamp_ms = Some(arrival_timestamp_ms);
request
})
.collect::<Vec<_>>(),
))
}
#[cfg(test)]
mod tests {
use super::*;
use uuid::Uuid;
#[test]
fn test_replay_itl_uses_per_token_gaps() {
let mut collector = TraceCollector::default();
let uuid = Uuid::from_u128(11);
collector.on_arrival(uuid, 0.0, 4, 4);
collector.on_admit(uuid, 0.0, 0);
collector.on_token(uuid, 10.0);
collector.on_token(uuid, 11.0);
collector.on_token(uuid, 12.0);
collector.on_token(uuid, 110.0);
let report = collector.finish();
assert!((report.latency.tpot.mean_ms - (100.0 / 3.0)).abs() < 1e-9);
assert!((report.latency.itl.distribution.mean_ms - (100.0 / 3.0)).abs() < 1e-9);
assert_eq!(report.latency.itl.distribution.median_ms, 1.0);
assert_eq!(report.latency.itl.distribution.p75_ms, 98.0);
assert_eq!(report.latency.itl.distribution.p90_ms, 98.0);
assert_eq!(report.latency.itl.distribution.p95_ms, 98.0);
assert_eq!(report.latency.itl.max_ms, 98.0);
assert_eq!(report.latency.ttst.min_ms, 1.0);
assert_eq!(report.latency.ttst.max_ms, 1.0);
assert_eq!(
report.latency.output_token_throughput_per_user.min_ms,
1000.0 / 98.0
);
assert_eq!(
report.latency.output_token_throughput_per_user.max_ms,
1000.0
);
}
#[test]
fn test_normalize_trace_requests_applies_arrival_speedup_ratio() {
let requests = vec![
DirectRequest {
tokens: vec![1; 4],
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(1)),
dp_rank: 0,
arrival_timestamp_ms: Some(100.0),
},
DirectRequest {
tokens: vec![2; 4],
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(2)),
dp_rank: 0,
arrival_timestamp_ms: Some(200.0),
},
];
let normalized = normalize_trace_requests(requests, 10.0).unwrap();
let arrivals = normalized
.into_iter()
.map(|request| request.arrival_timestamp_ms.unwrap())
.collect::<Vec<_>>();
assert_eq!(arrivals, vec![0.0, 10.0]);
}
}
# Offline Replay Harness
This directory contains the in-process offline replay harness used by `dynamo_mocker::replay`.
The goal is to simulate trace execution without spinning up async runtimes, network planes, or real worker tasks. Instead, the harness advances a logical clock, steps mock engine cores directly, and records request/token timing into `TraceCollector` in `lib/mocker/src/replay/collector.rs`.
## Where It Sits
The public replay entrypoints live one level up in `lib/mocker/src/replay/entrypoints.rs`. They:
- normalize `MockEngineArgs`
- load or accept `DirectRequest`s
- validate replay arguments
- dispatch to offline or online replay
Offline replay starts in `lib/mocker/src/replay/offline/mod.rs`.
`offline/mod.rs` chooses between two implementations:
- `lib/mocker/src/replay/offline/single.rs` for the special case `num_workers == 1` with the vLLM engine
- `lib/mocker/src/replay/offline/multi.rs` for everything else, including multi-worker replay and `kv_router` replay
## File Map
- `lib/mocker/src/replay/offline/mod.rs`
Chooses single-worker fast path vs multi-worker harness.
- `lib/mocker/src/replay/offline/single.rs`
Minimal replay loop for one vLLM worker.
- `lib/mocker/src/replay/offline/multi.rs`
General offline cluster simulator for multi-worker replay and KV-router replay.
- `lib/mocker/src/replay/offline/state.rs`
Per-worker wrapper around `EngineCore`, including optional KV event capture.
- `lib/mocker/src/replay/offline/events.rs`
Priority-queue event type used by the multi-worker harness.
- `lib/mocker/src/replay/offline/core.rs`
Small `ReplayWorkerCore` wrapper used by the single-worker path.
## Single-Worker Fast Path
The single-worker path is intentionally simple and only used when:
- `num_workers == 1`
- engine type is `vllm`
That path avoids the cluster event queue and router machinery entirely.
```mermaid
flowchart TD
A["single.rs::SingleRuntime"] --> B["pending requests"]
B --> C{"mode"}
C -->|trace| D["enqueue arrivals whose arrival_timestamp_ms <= current_time_ms"]
C -->|concurrency| E["enqueue until max_in_flight"]
D --> F["ReplayWorkerCore::execute_pass"]
E --> F
F --> G["update current_time_ms = pass.end_ms"]
G --> H["TraceCollector records arrivals/tokens/completions"]
H --> I{"done?"}
I -->|no| C
I -->|yes| J["TraceCollector::finish"]
```
Important details:
- Trace mode uses `normalize_trace_requests` in `lib/mocker/src/replay/mod.rs` so the first request starts at `0 ms`, then applies `arrival_speedup_ratio`.
- Concurrency mode ignores original arrival spacing and keeps the worker filled up to `max_in_flight`.
- The worker itself is still the real mocker engine core; only the scheduling loop is simplified.
## Multi-Worker Harness
The general harness lives in `lib/mocker/src/replay/offline/multi.rs`. It models a cluster with:
- a logical clock `now_ms`
- a pending request queue
- one [`OfflineWorkerState`](/Users/peabrane/Documents/codes/dynamo/lib/mocker/src/replay/offline/state.rs) per worker
- a binary heap of future completion events
- an optional synchronous offline router
### Main Loop
The harness is event-driven. It does not sleep. Instead, `OfflineRuntime` repeatedly:
1. picks the next meaningful timestamp
2. advances `now_ms`
3. applies any worker completion events scheduled for that time
4. admits newly available requests, either from trace arrivals or concurrency backfill
5. starts passes on workers that are ready to run
6. pushes new `WorkerCompletion` events back into the binary heap
It only advances `now_ms` to the next meaningful timestamp:
- next request arrival
- next worker completion event
### Worker Model
Each worker is represented by `OfflineWorkerState` in `lib/mocker/src/replay/offline/state.rs`:
- wraps an `EngineCore`
- tracks whether a pass is currently in progress
- tracks in-flight request count separately from engine internals
- optionally enables KV event capture when replay is running with `kv_router` mode
The pass execution itself still comes from the real scheduler core:
- `VllmCore::execute_pass(...)`
- `SglangCore::execute_pass(...)`
So offline replay is not a toy simulator. It reuses the real per-pass mocker scheduling logic, but drives it deterministically.
## Completion Event Queue
The multi-worker harness uses `SimulationEvent` from `lib/mocker/src/replay/offline/events.rs` as a min-time priority queue implemented with `BinaryHeap`.
Right now the only scheduled event type is:
- `WorkerCompletion`
That event carries:
- `worker_idx`
- `completed_requests`
- `output_signals`
- router-visible `kv_events`
Those are emitted after a worker pass is executed and then applied later when the harness clock reaches `pass.end_ms`.
## Router Integration
Offline replay can run in:
- `round_robin`
- `kv_router`
The router implementation for offline mode lives in `lib/mocker/src/replay/router/offline.rs`.
This router is synchronous and in-process:
- no async worker tasks
- no event plane
- no background indexer thread
Instead it maintains:
- a local radix tree indexer
- local `ActiveSequencesMultiWorker` state
- a pending queue for queued requests
```mermaid
flowchart LR
A["request arrives"] --> B{"router mode"}
B -->|round_robin| C["assign next worker"]
B -->|kv_router| D["OfflineReplayRouter::submit_request"]
D --> E["sync index lookup + scheduling policy"]
E --> F{"admit now?"}
F -->|yes| G["dispatch to worker"]
F -->|no| H["store in router_pending"]
I["worker pass emits RouterEvent + OutputSignal"] --> J["OfflineRuntime::process_completed_pass"]
J --> K["apply router events to sync indexer"]
J --> L["mark_prefill_completed / free"]
L --> M["drain queued admissions"]
M --> G
```
### Why KV events are captured only here
When offline replay uses `kv_router`, workers are created with KV event capture enabled via:
- `VllmCore::new_with_kv_capture` in `lib/mocker/src/scheduler/vllm/core.rs`
- `SglangCore::new_with_kv_capture` in `lib/mocker/src/scheduler/sglang/core.rs`
That causes each pass to return router-visible `kv_events`, which the harness applies synchronously to the offline router indexer after the pass completes.
In round-robin mode, this capture is skipped because nothing consumes those events.
## Trace vs Concurrency Modes
Both single and multi harnesses support two admission modes:
- Trace mode
- respects input arrival timestamps
- timestamps are normalized so the first request starts at `0 ms`
- `arrival_speedup_ratio` compresses or stretches inter-arrival gaps
- Concurrency mode
- ignores original spacing
- keeps up to `max_in_flight` requests resident in the cluster
- stamps synthetic arrival times as requests are admitted
This split is why `lib/mocker/src/replay/offline/mod.rs` exposes both:
- `simulate_trace(...)`
- `simulate_concurrency(...)`
## Metrics Collection
Both harnesses emit request timing into `TraceCollector` in `lib/mocker/src/replay/collector.rs`:
- arrival
- admission
- token emission
- completion
The harness itself does not compute final throughput/latency metrics incrementally. It records events, then `TraceCollector::finish()` derives the final `TraceSimulationReport` from `lib/mocker/src/replay/collector.rs`.
## Mental Model
The easiest way to think about offline replay is:
1. Reuse the real mocker scheduling pass logic.
2. Replace wall-clock async execution with a deterministic logical clock.
3. Optionally replace networked router behavior with a synchronous in-process router model.
4. Record the same request lifecycle timings into `TraceCollector`.
That keeps the harness fast, reproducible, and close to the real scheduler behavior without needing to boot a live runtime.
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::common::protocols::MockEngineArgs;
use crate::replay::TraceCollector;
use crate::scheduler::{EngineCore, EnginePassResult, SglangCore, VllmCore};
pub(crate) struct ReplayWorkerCore {
core: EngineCore,
}
impl ReplayWorkerCore {
pub(crate) fn new(args: MockEngineArgs) -> Self {
let core = match args.engine_type {
crate::common::protocols::EngineType::Vllm => EngineCore::Vllm(VllmCore::new(args)),
crate::common::protocols::EngineType::Sglang => {
EngineCore::Sglang(SglangCore::new(args))
}
};
Self { core }
}
pub(crate) fn is_empty(&self) -> bool {
self.core.is_empty()
}
pub(crate) fn receive(
&mut self,
request: crate::common::protocols::DirectRequest,
) -> uuid::Uuid {
self.core.receive(request)
}
pub(crate) fn num_requests(&self) -> usize {
self.core.num_requests()
}
pub(crate) fn execute_pass(
&mut self,
collector: &mut TraceCollector,
now_ms: f64,
) -> EnginePassResult {
self.core.execute_pass(collector, now_ms)
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::cmp::Ordering;
use crate::common::protocols::OutputSignal;
#[derive(Debug)]
pub(crate) enum SimulationEventKind {
WorkerCompletion {
worker_idx: usize,
completed_requests: usize,
output_signals: Vec<OutputSignal>,
kv_events: Vec<dynamo_kv_router::protocols::RouterEvent>,
},
}
#[derive(Debug)]
pub(crate) struct SimulationEvent {
pub(crate) at_ms: f64,
pub(crate) seq_no: u64,
pub(crate) kind: SimulationEventKind,
}
impl SimulationEvent {
fn kind_priority(&self) -> u8 {
0
}
}
impl PartialEq for SimulationEvent {
fn eq(&self, other: &Self) -> bool {
self.at_ms.to_bits() == other.at_ms.to_bits()
&& self.seq_no == other.seq_no
&& self.kind_priority() == other.kind_priority()
}
}
impl Eq for SimulationEvent {}
impl PartialOrd for SimulationEvent {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for SimulationEvent {
fn cmp(&self, other: &Self) -> Ordering {
other
.at_ms
.partial_cmp(&self.at_ms)
.unwrap_or(Ordering::Equal)
.then_with(|| self.kind_priority().cmp(&other.kind_priority()))
.then_with(|| other.seq_no.cmp(&self.seq_no))
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::common::protocols::{DirectRequest, MockEngineArgs};
pub(crate) use crate::replay::normalize_trace_requests;
use crate::replay::{ReplayRouterMode, TraceSimulationReport};
use dynamo_kv_router::config::KvRouterConfig;
pub(crate) mod core;
pub(crate) mod events;
pub(crate) mod multi;
pub(crate) mod single;
pub(crate) mod state;
pub(crate) fn simulate_trace(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == crate::common::protocols::EngineType::Vllm {
single::simulate_trace_single(args, requests, arrival_speedup_ratio)
} else {
multi::simulate_trace_multi(
args,
router_config,
requests,
num_workers,
arrival_speedup_ratio,
router_mode,
)
}
}
pub(crate) fn simulate_concurrency(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == crate::common::protocols::EngineType::Vllm {
single::simulate_concurrency_single(args, requests, max_in_flight)
} else {
multi::simulate_concurrency_multi(
args,
router_config,
requests,
max_in_flight,
num_workers,
router_mode,
)
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::events::{SimulationEvent, SimulationEventKind};
use super::normalize_trace_requests;
use super::state::OfflineWorkerState;
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal};
use crate::replay::router::OfflineReplayRouter;
use crate::replay::{ReplayRouterMode, TraceCollector, TraceSimulationReport};
use crate::scheduler::RouterEventVisibility;
use anyhow::bail;
use dynamo_kv_router::config::KvRouterConfig;
use dynamo_kv_router::protocols::RouterEvent;
use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
use uuid::Uuid;
#[derive(Debug, Clone, Copy)]
enum ReplayMode {
Trace,
Concurrency { max_in_flight: usize },
}
#[cfg(test)]
#[derive(Debug, Default, Clone, PartialEq, Eq)]
struct OfflineRuntimeStats {
dispatch_history: Vec<usize>,
dispatch_order: Vec<Uuid>,
assigned_worker_by_uuid: HashMap<Uuid, usize>,
max_in_flight_seen: usize,
prefill_marked_count: usize,
freed_count: usize,
max_router_pending: usize,
}
#[cfg(not(test))]
#[derive(Debug, Default, Clone, PartialEq, Eq)]
struct OfflineRuntimeStats;
struct OfflineRuntime {
now_ms: f64,
next_worker_idx: usize,
next_event_seq: u64,
pending: VecDeque<DirectRequest>,
router_pending: HashMap<Uuid, DirectRequest>,
workers: Vec<OfflineWorkerState>,
collector: TraceCollector,
events: BinaryHeap<SimulationEvent>,
mode: ReplayMode,
router: Option<OfflineReplayRouter>,
prefill_completed: HashSet<Uuid>,
stats: OfflineRuntimeStats,
}
impl OfflineRuntime {
fn new(
args: &MockEngineArgs,
router_config: Option<KvRouterConfig>,
pending: VecDeque<DirectRequest>,
num_workers: usize,
mode: ReplayMode,
router_mode: ReplayRouterMode,
) -> anyhow::Result<Self> {
let args = args.clone().normalized()?;
let router = match router_mode {
ReplayRouterMode::RoundRobin => None,
ReplayRouterMode::KvRouter => {
Some(OfflineReplayRouter::new(&args, router_config, num_workers)?)
}
};
let capture_kv_events = router.is_some();
Ok(Self {
now_ms: 0.0,
next_worker_idx: 0,
next_event_seq: 0,
pending,
router_pending: HashMap::new(),
workers: (0..num_workers)
.map(|worker_idx| {
OfflineWorkerState::new(worker_idx, args.clone(), capture_kv_events)
})
.collect(),
collector: TraceCollector::default(),
events: BinaryHeap::new(),
mode,
router,
prefill_completed: HashSet::new(),
#[cfg(test)]
stats: OfflineRuntimeStats::default(),
#[cfg(not(test))]
stats: OfflineRuntimeStats,
})
}
fn cluster_in_flight(&self) -> usize {
self.workers
.iter()
.map(OfflineWorkerState::in_flight)
.sum::<usize>()
+ self.router_pending.len()
}
fn record_in_flight_peak(&mut self) {
#[cfg(test)]
{
self.stats.max_in_flight_seen =
self.stats.max_in_flight_seen.max(self.cluster_in_flight());
}
}
fn record_router_pending(&mut self) {
#[cfg(test)]
let Some(router) = self.router.as_ref() else {
return;
};
#[cfg(test)]
{
self.stats.max_router_pending =
self.stats.max_router_pending.max(router.pending_count());
}
}
fn record_dispatch(&mut self, _uuid: Uuid, _worker_idx: usize) {
#[cfg(test)]
{
self.stats.dispatch_history.push(_worker_idx);
self.stats.dispatch_order.push(_uuid);
self.stats
.assigned_worker_by_uuid
.insert(_uuid, _worker_idx);
}
self.record_in_flight_peak();
}
fn validate_worker_idx(&self, worker_idx: usize) -> anyhow::Result<()> {
if worker_idx >= self.workers.len() {
bail!("offline replay selected unknown worker index {worker_idx}");
}
Ok(())
}
fn dispatch_to_worker(
&mut self,
request: DirectRequest,
uuid: Uuid,
worker_idx: usize,
) -> anyhow::Result<()> {
self.validate_worker_idx(worker_idx)?;
self.workers[worker_idx].receive_request(request);
self.record_dispatch(uuid, worker_idx);
Ok(())
}
fn dispatch_router_admissions(&mut self, admissions: Vec<(Uuid, usize)>) -> anyhow::Result<()> {
for (uuid, worker_idx) in admissions {
let request = self.router_pending.remove(&uuid).ok_or_else(|| {
anyhow::anyhow!("offline replay missing queued request state for {uuid}")
})?;
self.dispatch_to_worker(request, uuid, worker_idx)?;
}
Ok(())
}
fn assign_request(
&mut self,
mut request: DirectRequest,
arrival_time_ms: f64,
) -> anyhow::Result<Uuid> {
let uuid = request.uuid.unwrap_or_else(Uuid::new_v4);
request.uuid = Some(uuid);
if matches!(self.mode, ReplayMode::Concurrency { .. }) {
request.arrival_timestamp_ms = Some(arrival_time_ms);
}
self.collector.on_arrival(
uuid,
arrival_time_ms,
request.tokens.len(),
request.max_output_tokens,
);
let Some(router) = self.router.as_mut() else {
let worker_idx = self.next_worker_idx;
self.next_worker_idx = (self.next_worker_idx + 1) % self.workers.len();
self.dispatch_to_worker(request, uuid, worker_idx)?;
return Ok(uuid);
};
let maybe_worker_idx = router.submit_request(&request, self.now_ms)?;
self.record_router_pending();
if let Some(worker_idx) = maybe_worker_idx {
self.dispatch_to_worker(request, uuid, worker_idx)?;
return Ok(uuid);
}
self.router_pending.insert(uuid, request);
self.record_in_flight_peak();
Ok(uuid)
}
fn is_done(&self) -> bool {
self.pending.is_empty()
&& self.events.is_empty()
&& self.cluster_in_flight() == 0
&& self.workers.iter().all(OfflineWorkerState::is_drained)
}
fn next_timestamp(&self) -> Option<f64> {
let next_event_ms = self.events.peek().map(|event| event.at_ms);
let next_arrival_ms = match self.mode {
ReplayMode::Trace => self
.pending
.front()
.and_then(|request| request.arrival_timestamp_ms),
ReplayMode::Concurrency { .. } => None,
};
match (next_arrival_ms, next_event_ms) {
(Some(arrival_ms), Some(event_ms)) => Some(arrival_ms.min(event_ms)),
(Some(arrival_ms), None) => Some(arrival_ms),
(None, Some(event_ms)) => Some(event_ms),
(None, None) => None,
}
}
fn push_event(&mut self, at_ms: f64, kind: SimulationEventKind) {
self.events.push(SimulationEvent {
at_ms,
seq_no: self.next_event_seq,
kind,
});
self.next_event_seq += 1;
}
fn apply_completed_requests(&mut self, worker_idx: usize, completed_requests: usize) {
self.workers[worker_idx].mark_completed(completed_requests);
}
fn apply_router_events(&mut self, events: Vec<RouterEvent>) -> anyhow::Result<()> {
let Some(router) = self.router.as_mut() else {
return Ok(());
};
for event in events {
router.apply_event(event)?;
}
Ok(())
}
fn process_output_signal(&mut self, signal: OutputSignal) -> anyhow::Result<()> {
let mut admissions = Vec::new();
if signal.completed {
if let Some(router) = self.router.as_mut() {
admissions = router.free(signal.uuid)?;
#[cfg(test)]
{
self.stats.freed_count += 1;
}
self.record_router_pending();
}
self.prefill_completed.remove(&signal.uuid);
self.dispatch_router_admissions(admissions)?;
return Ok(());
}
if !self.prefill_completed.insert(signal.uuid) {
return Ok(());
}
if let Some(router) = self.router.as_mut() {
admissions = router.mark_prefill_completed(signal.uuid)?;
#[cfg(test)]
{
self.stats.prefill_marked_count += 1;
}
self.record_router_pending();
}
self.dispatch_router_admissions(admissions)?;
Ok(())
}
fn process_completed_pass(
&mut self,
worker_idx: usize,
completed_requests: usize,
output_signals: Vec<OutputSignal>,
kv_events: Vec<RouterEvent>,
) -> anyhow::Result<()> {
self.apply_completed_requests(worker_idx, completed_requests);
self.apply_router_events(kv_events)?;
for signal in output_signals {
self.process_output_signal(signal)?;
}
Ok(())
}
fn apply_worker_completions(&mut self) -> anyhow::Result<bool> {
let mut changed = false;
loop {
let Some(event) = self.events.peek() else {
break;
};
if event.at_ms != self.now_ms {
break;
}
if !matches!(event.kind, SimulationEventKind::WorkerCompletion { .. }) {
break;
}
let event = self.events.pop().expect("event must exist after peek");
let SimulationEventKind::WorkerCompletion {
worker_idx,
completed_requests,
output_signals,
kv_events,
} = event.kind;
self.workers[worker_idx].mark_idle();
self.process_completed_pass(worker_idx, completed_requests, output_signals, kv_events)?;
changed = true;
}
Ok(changed)
}
fn release_trace_arrivals(&mut self) -> anyhow::Result<bool> {
let mut released_any = false;
while self
.pending
.front()
.and_then(|request| request.arrival_timestamp_ms)
.is_some_and(|arrival_ms| arrival_ms <= self.now_ms)
{
let request = self
.pending
.pop_front()
.expect("front request must exist when arrival is ready");
let arrival_ms = request
.arrival_timestamp_ms
.expect("trace replay requests must have an arrival timestamp");
self.assign_request(request, arrival_ms)?;
released_any = true;
}
Ok(released_any)
}
fn top_off_concurrency(&mut self, max_in_flight: usize) -> anyhow::Result<bool> {
let mut released_any = false;
while self.cluster_in_flight() < max_in_flight {
let Some(request) = self.pending.pop_front() else {
break;
};
self.assign_request(request, self.now_ms)?;
released_any = true;
}
Ok(released_any)
}
fn drive_ready_workers(&mut self) -> anyhow::Result<bool> {
let mut changed = false;
for worker_idx in 0..self.workers.len() {
loop {
if !self.workers[worker_idx].is_ready() {
break;
}
let executed = {
let (workers, collector) = (&mut self.workers, &mut self.collector);
workers[worker_idx].execute_pass(collector, self.now_ms)
};
changed = true;
let completion_kv_events =
if executed.router_event_visibility == RouterEventVisibility::PassStart {
self.apply_router_events(executed.kv_events)?;
Vec::new()
} else {
executed.kv_events
};
if executed.end_ms == self.now_ms {
self.process_completed_pass(
worker_idx,
executed.completed_requests,
executed.output_signals,
completion_kv_events,
)?;
continue;
}
self.workers[worker_idx].mark_busy();
self.push_event(
executed.end_ms,
SimulationEventKind::WorkerCompletion {
worker_idx,
completed_requests: executed.completed_requests,
output_signals: executed.output_signals,
kv_events: completion_kv_events,
},
);
break;
}
}
Ok(changed)
}
fn drain_current_timestamp(&mut self) -> anyhow::Result<()> {
loop {
let mut changed = self.apply_worker_completions()?;
changed |= match self.mode {
ReplayMode::Trace => self.release_trace_arrivals()?,
ReplayMode::Concurrency { max_in_flight } => {
self.top_off_concurrency(max_in_flight)?
}
};
changed |= self.drive_ready_workers()?;
if !changed {
break;
}
}
Ok(())
}
fn run(mut self) -> anyhow::Result<(TraceCollector, OfflineRuntimeStats)> {
self.drain_current_timestamp()?;
while !self.is_done() {
let Some(next_timestamp_ms) = self.next_timestamp() else {
bail!(
"offline replay reached a dead end with {} in-flight requests remaining",
self.cluster_in_flight()
);
};
self.now_ms = next_timestamp_ms;
self.drain_current_timestamp()?;
}
if let Some(router) = self.router.as_mut() {
router.shutdown();
}
Ok((self.collector, self.stats))
}
}
pub(crate) fn simulate_trace_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let (collector, _) = OfflineRuntime::new(
&args,
router_config,
pending,
num_workers,
ReplayMode::Trace,
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = VecDeque::from(requests);
let (collector, _) = OfflineRuntime::new(
&args,
router_config,
pending,
num_workers,
ReplayMode::Concurrency { max_in_flight },
router_mode,
)?
.run()?;
Ok(collector.finish())
}
#[cfg(test)]
fn run_trace_multi_collect_with_stats(
args: &MockEngineArgs,
requests: Vec<DirectRequest>,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, OfflineRuntimeStats) {
let pending = normalize_trace_requests(requests, 1.0).unwrap();
OfflineRuntime::new(
args,
None,
pending,
num_workers,
ReplayMode::Trace,
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
fn run_concurrency_multi_collect_with_stats(
args: &MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, OfflineRuntimeStats) {
OfflineRuntime::new(
args,
None,
VecDeque::from(requests),
num_workers,
ReplayMode::Concurrency { max_in_flight },
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
mod tests {
use super::super::single::{run_concurrency_single_collect, run_trace_single_collect};
use super::*;
use crate::common::protocols::{EngineType, SglangArgs};
use dynamo_kv_router::config::RouterQueuePolicy;
fn replay_args(enable_prefix_caching: bool, enable_chunked_prefill: bool) -> MockEngineArgs {
MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(32)
.max_num_batched_tokens(Some(8))
.max_num_seqs(Some(2))
.enable_prefix_caching(enable_prefix_caching)
.enable_chunked_prefill(enable_chunked_prefill)
.speedup_ratio(0.0)
.build()
.unwrap()
}
fn fast_router_args() -> MockEngineArgs {
MockEngineArgs::builder()
.block_size(64)
.num_gpu_blocks(256)
.max_num_batched_tokens(Some(8192))
.max_num_seqs(Some(8))
.enable_prefix_caching(true)
.enable_chunked_prefill(true)
.speedup_ratio(1000.0)
.build()
.unwrap()
}
fn queueing_router_args(policy: RouterQueuePolicy) -> MockEngineArgs {
MockEngineArgs::builder()
.block_size(64)
.num_gpu_blocks(256)
.max_num_batched_tokens(Some(8))
.max_num_seqs(Some(8))
.enable_prefix_caching(true)
.enable_chunked_prefill(true)
.speedup_ratio(10.0)
.router_queue_policy(Some(policy))
.build()
.unwrap()
}
fn sglang_replay_args() -> MockEngineArgs {
MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.num_gpu_blocks(512)
.speedup_ratio(1000.0)
.sglang(Some(SglangArgs {
page_size: Some(2),
..Default::default()
}))
.build()
.unwrap()
}
#[test]
fn test_multi_worker_trace_round_robin_assigns_same_timestamp_requests_deterministically() {
let args = replay_args(false, true);
let (collector, _) = run_trace_multi_collect_with_stats(
&args,
vec![
DirectRequest {
tokens: vec![1, 1, 1, 1, 2, 2, 2, 2],
max_output_tokens: 4,
uuid: Some(Uuid::from_u128(11)),
dp_rank: 0,
arrival_timestamp_ms: Some(100.0),
},
DirectRequest {
tokens: vec![3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(22)),
dp_rank: 0,
arrival_timestamp_ms: Some(100.0),
},
DirectRequest {
tokens: vec![5, 5, 5, 5, 6, 6, 6, 6],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(33)),
dp_rank: 0,
arrival_timestamp_ms: Some(101.0),
},
DirectRequest {
tokens: vec![7, 7, 7, 7, 8, 8, 8, 8],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(44)),
dp_rank: 0,
arrival_timestamp_ms: Some(101.0),
},
],
2,
ReplayRouterMode::RoundRobin,
);
let request_1 = collector.snapshot(Uuid::from_u128(11)).unwrap();
let request_2 = collector.snapshot(Uuid::from_u128(22)).unwrap();
let request_3 = collector.snapshot(Uuid::from_u128(33)).unwrap();
let request_4 = collector.snapshot(Uuid::from_u128(44)).unwrap();
let report = collector.finish();
assert_eq!(request_1.arrival_time_ms, 0.0);
assert_eq!(request_2.arrival_time_ms, 0.0);
assert_eq!(request_3.arrival_time_ms, 1.0);
assert_eq!(request_4.arrival_time_ms, 1.0);
assert!(request_3.first_admit_ms.unwrap() >= request_1.first_token_ms.unwrap());
assert!(request_4.first_admit_ms.unwrap() >= request_2.first_token_ms.unwrap());
assert!(request_3.first_admit_ms.unwrap() < request_4.first_admit_ms.unwrap());
assert_eq!(report.request_counts.completed_requests, 4);
assert_eq!(report.request_counts.total_input_tokens, 40);
assert_eq!(report.request_counts.total_output_tokens, 10);
}
#[test]
fn test_multi_worker_trace_round_robin_records_dispatch_history() {
let args = replay_args(false, true);
let (_, stats) = run_trace_multi_collect_with_stats(
&args,
vec![
DirectRequest {
tokens: vec![1; 8],
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(1)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
},
DirectRequest {
tokens: vec![2; 8],
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(2)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
},
DirectRequest {
tokens: vec![3; 8],
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(3)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
},
DirectRequest {
tokens: vec![4; 8],
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(4)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
},
DirectRequest {
tokens: vec![5; 8],
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(5)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
},
],
4,
ReplayRouterMode::RoundRobin,
);
assert_eq!(stats.dispatch_history, vec![0, 1, 2, 3, 0]);
}
#[test]
fn test_offline_trace_replay_sglang_single_worker_completes() {
let args = sglang_replay_args();
let (collector, stats) = run_trace_multi_collect_with_stats(
&args,
vec![
DirectRequest {
tokens: vec![1; 64],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(901)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
},
DirectRequest {
tokens: vec![2; 64],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(902)),
dp_rank: 0,
arrival_timestamp_ms: Some(5.0),
},
],
1,
ReplayRouterMode::RoundRobin,
);
let report = collector.finish();
assert_eq!(report.request_counts.completed_requests, 2);
assert_eq!(report.request_counts.total_output_tokens, 4);
assert_eq!(stats.dispatch_history, vec![0, 0]);
}
#[test]
fn test_offline_trace_replay_sglang_kv_router_smoke() {
let args = sglang_replay_args();
let (collector, stats) = run_trace_multi_collect_with_stats(
&args,
vec![
DirectRequest {
tokens: vec![7; 64],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(911)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
},
DirectRequest {
tokens: vec![7; 64],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(912)),
dp_rank: 0,
arrival_timestamp_ms: Some(500.0),
},
],
2,
ReplayRouterMode::KvRouter,
);
let report = collector.finish();
assert_eq!(report.request_counts.completed_requests, 2);
assert_eq!(stats.dispatch_history.len(), 2);
}
#[test]
fn test_multi_worker_concurrency_uses_worker_in_flight_for_cap_checks() {
let args = replay_args(false, false);
let (collector, _) = run_concurrency_multi_collect_with_stats(
&args,
vec![
DirectRequest {
tokens: vec![1, 1, 1, 1, 2, 2, 2, 2],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(11)),
dp_rank: 0,
arrival_timestamp_ms: Some(900.0),
},
DirectRequest {
tokens: vec![3, 3, 3, 3, 4, 4, 4, 4],
max_output_tokens: 4,
uuid: Some(Uuid::from_u128(22)),
dp_rank: 0,
arrival_timestamp_ms: Some(1000.0),
},
DirectRequest {
tokens: vec![5, 5, 5, 5, 6, 6, 6, 6],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(33)),
dp_rank: 0,
arrival_timestamp_ms: Some(100.0),
},
],
2,
2,
ReplayRouterMode::RoundRobin,
);
let request_1 = collector.snapshot(Uuid::from_u128(11)).unwrap();
let request_2 = collector.snapshot(Uuid::from_u128(22)).unwrap();
let request_3 = collector.snapshot(Uuid::from_u128(33)).unwrap();
let report = collector.finish();
assert_eq!(request_1.arrival_time_ms, 0.0);
assert_eq!(request_2.arrival_time_ms, 0.0);
assert_eq!(request_3.arrival_time_ms, request_1.last_token_ms.unwrap());
assert!(request_3.arrival_time_ms < request_2.last_token_ms.unwrap());
assert_eq!(request_3.first_admit_ms.unwrap(), request_3.arrival_time_ms);
assert_eq!(report.request_counts.completed_requests, 3);
assert_eq!(report.request_counts.total_input_tokens, 24);
assert_eq!(report.request_counts.total_output_tokens, 8);
}
#[test]
fn test_multi_worker_trace_kv_router_prefers_cached_workers_after_delay() {
let args = fast_router_args();
let (_, stats) = run_trace_multi_collect_with_stats(
&args,
vec![
DirectRequest {
tokens: vec![11; 64],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(11)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
},
DirectRequest {
tokens: vec![22; 64],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(22)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
},
DirectRequest {
tokens: vec![11; 64],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(33)),
dp_rank: 0,
arrival_timestamp_ms: Some(2.0),
},
DirectRequest {
tokens: vec![22; 64],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(44)),
dp_rank: 0,
arrival_timestamp_ms: Some(2.0),
},
],
2,
ReplayRouterMode::KvRouter,
);
let worker_a1 = stats.assigned_worker_by_uuid[&Uuid::from_u128(11)];
let worker_b1 = stats.assigned_worker_by_uuid[&Uuid::from_u128(22)];
let worker_a2 = stats.assigned_worker_by_uuid[&Uuid::from_u128(33)];
let worker_b2 = stats.assigned_worker_by_uuid[&Uuid::from_u128(44)];
assert_ne!(worker_a1, worker_b1);
assert_eq!(worker_a1, worker_a2);
assert_eq!(worker_b1, worker_b2);
}
#[test]
fn test_multi_worker_trace_kv_router_marks_prefill_and_free_correctly() {
let args = fast_router_args();
let (_, stats) = run_trace_multi_collect_with_stats(
&args,
vec![
DirectRequest {
tokens: vec![9; 64],
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(9)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
},
DirectRequest {
tokens: vec![8; 64],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(8)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
},
],
2,
ReplayRouterMode::KvRouter,
);
assert_eq!(stats.prefill_marked_count, 1);
assert_eq!(stats.freed_count, 2);
assert_eq!(stats.max_router_pending, 0);
}
#[test]
fn test_multi_worker_trace_kv_router_queues_until_prefill_completion() {
let args = queueing_router_args(RouterQueuePolicy::Fcfs);
let (collector, stats) = run_trace_multi_collect_with_stats(
&args,
vec![
DirectRequest {
tokens: vec![1; 64],
max_output_tokens: 8,
uuid: Some(Uuid::from_u128(1)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
},
DirectRequest {
tokens: vec![2; 64],
max_output_tokens: 8,
uuid: Some(Uuid::from_u128(2)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
},
DirectRequest {
tokens: vec![3; 64],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(3)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.1),
},
],
2,
ReplayRouterMode::KvRouter,
);
let request_1 = collector.snapshot(Uuid::from_u128(1)).unwrap();
let request_2 = collector.snapshot(Uuid::from_u128(2)).unwrap();
let request_3 = collector.snapshot(Uuid::from_u128(3)).unwrap();
assert!(stats.max_router_pending > 0);
assert!(request_3.first_admit_ms.unwrap() > request_3.arrival_time_ms);
assert!(
request_3.first_admit_ms.unwrap()
< request_1
.last_token_ms
.unwrap()
.min(request_2.last_token_ms.unwrap())
);
}
#[test]
fn test_multi_worker_trace_kv_router_fcfs_and_lcfs_dispatch_in_opposite_queue_order() {
let requests = vec![
DirectRequest {
tokens: vec![10; 64],
max_output_tokens: 8,
uuid: Some(Uuid::from_u128(10)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
},
DirectRequest {
tokens: vec![20; 64],
max_output_tokens: 8,
uuid: Some(Uuid::from_u128(20)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
},
DirectRequest {
tokens: vec![30; 64],
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(30)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.1),
},
DirectRequest {
tokens: vec![40; 64],
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(40)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.2),
},
];
let (_, fcfs_stats) = run_trace_multi_collect_with_stats(
&queueing_router_args(RouterQueuePolicy::Fcfs),
requests.clone(),
2,
ReplayRouterMode::KvRouter,
);
let (_, lcfs_stats) = run_trace_multi_collect_with_stats(
&queueing_router_args(RouterQueuePolicy::Lcfs),
requests,
2,
ReplayRouterMode::KvRouter,
);
assert!(fcfs_stats.max_router_pending > 0);
assert!(lcfs_stats.max_router_pending > 0);
assert_eq!(
&fcfs_stats.dispatch_order[..2],
&[Uuid::from_u128(10), Uuid::from_u128(20)]
);
assert_eq!(
&lcfs_stats.dispatch_order[..2],
&[Uuid::from_u128(10), Uuid::from_u128(20)]
);
assert_eq!(
&fcfs_stats.dispatch_order[2..4],
&[Uuid::from_u128(30), Uuid::from_u128(40)]
);
assert_eq!(
&lcfs_stats.dispatch_order[2..4],
&[Uuid::from_u128(40), Uuid::from_u128(30)]
);
}
#[test]
fn test_multi_worker_concurrency_kv_router_respects_max_in_flight() {
let args = queueing_router_args(RouterQueuePolicy::Fcfs);
let (_, stats) = run_concurrency_multi_collect_with_stats(
&args,
vec![
DirectRequest {
tokens: vec![1; 64],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(1)),
dp_rank: 0,
arrival_timestamp_ms: None,
},
DirectRequest {
tokens: vec![2; 64],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(2)),
dp_rank: 0,
arrival_timestamp_ms: None,
},
DirectRequest {
tokens: vec![1; 64],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(3)),
dp_rank: 0,
arrival_timestamp_ms: None,
},
DirectRequest {
tokens: vec![2; 64],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(4)),
dp_rank: 0,
arrival_timestamp_ms: None,
},
],
3,
2,
ReplayRouterMode::KvRouter,
);
assert_eq!(stats.max_in_flight_seen, 3);
assert!(stats.max_router_pending > 0);
}
#[test]
fn test_multi_worker_trace_single_worker_round_robin_matches_single_runtime() {
let args = replay_args(true, true);
let requests = vec![
DirectRequest {
tokens: vec![1, 1, 1, 1, 2, 2, 2, 2],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(11)),
dp_rank: 0,
arrival_timestamp_ms: Some(100.0),
},
DirectRequest {
tokens: vec![1, 1, 1, 1, 2, 2, 2, 2],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(22)),
dp_rank: 0,
arrival_timestamp_ms: Some(101.0),
},
DirectRequest {
tokens: vec![9, 9, 9, 9, 8, 8, 8, 8],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(33)),
dp_rank: 0,
arrival_timestamp_ms: Some(500.0),
},
];
let single = run_trace_single_collect(args.clone(), requests.clone(), 1.0);
let (multi, stats) =
run_trace_multi_collect_with_stats(&args, requests, 1, ReplayRouterMode::RoundRobin);
assert_eq!(stats.dispatch_history, vec![0, 0, 0]);
for uuid in [11_u128, 22, 33] {
assert_eq!(
multi.snapshot(Uuid::from_u128(uuid)),
single.snapshot(Uuid::from_u128(uuid))
);
}
assert_eq!(multi.finish().request_counts.completed_requests, 3);
assert_eq!(single.finish().request_counts.completed_requests, 3);
}
#[test]
fn test_multi_worker_trace_single_worker_kv_router_matches_single_runtime() {
let args = replay_args(true, true);
let requests = vec![
DirectRequest {
tokens: vec![1, 1, 1, 1, 2, 2, 2, 2],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(11)),
dp_rank: 0,
arrival_timestamp_ms: Some(100.0),
},
DirectRequest {
tokens: vec![1, 1, 1, 1, 2, 2, 2, 2],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(22)),
dp_rank: 0,
arrival_timestamp_ms: Some(101.0),
},
DirectRequest {
tokens: vec![9, 9, 9, 9, 8, 8, 8, 8],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(33)),
dp_rank: 0,
arrival_timestamp_ms: Some(500.0),
},
];
let single = run_trace_single_collect(args.clone(), requests.clone(), 1.0);
let (multi, stats) =
run_trace_multi_collect_with_stats(&args, requests, 1, ReplayRouterMode::KvRouter);
assert_eq!(stats.dispatch_history, vec![0, 0, 0]);
assert_eq!(stats.max_router_pending, 0);
for uuid in [11_u128, 22, 33] {
assert_eq!(
multi.snapshot(Uuid::from_u128(uuid)),
single.snapshot(Uuid::from_u128(uuid))
);
}
assert_eq!(multi.finish().request_counts.completed_requests, 3);
assert_eq!(single.finish().request_counts.completed_requests, 3);
}
#[test]
fn test_multi_worker_concurrency_single_worker_round_robin_matches_single_runtime() {
let args = replay_args(true, true);
let requests = vec![
DirectRequest {
tokens: vec![1, 1, 1, 1, 2, 2, 2, 2],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(11)),
dp_rank: 0,
arrival_timestamp_ms: Some(900.0),
},
DirectRequest {
tokens: vec![3, 3, 3, 3, 4, 4, 4, 4],
max_output_tokens: 4,
uuid: Some(Uuid::from_u128(22)),
dp_rank: 0,
arrival_timestamp_ms: Some(1000.0),
},
DirectRequest {
tokens: vec![5, 5, 5, 5, 6, 6, 6, 6],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(33)),
dp_rank: 0,
arrival_timestamp_ms: Some(100.0),
},
];
let single = run_concurrency_single_collect(args.clone(), requests.clone(), 2);
let (multi, stats) = run_concurrency_multi_collect_with_stats(
&args,
requests,
2,
1,
ReplayRouterMode::RoundRobin,
);
assert_eq!(stats.dispatch_history, vec![0, 0, 0]);
for uuid in [11_u128, 22, 33] {
assert_eq!(
multi.snapshot(Uuid::from_u128(uuid)),
single.snapshot(Uuid::from_u128(uuid))
);
}
}
#[test]
fn test_multi_worker_concurrency_single_worker_kv_router_matches_single_runtime() {
let args = replay_args(true, true);
let requests = vec![
DirectRequest {
tokens: vec![1, 1, 1, 1, 2, 2, 2, 2],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(11)),
dp_rank: 0,
arrival_timestamp_ms: Some(900.0),
},
DirectRequest {
tokens: vec![3, 3, 3, 3, 4, 4, 4, 4],
max_output_tokens: 4,
uuid: Some(Uuid::from_u128(22)),
dp_rank: 0,
arrival_timestamp_ms: Some(1000.0),
},
DirectRequest {
tokens: vec![5, 5, 5, 5, 6, 6, 6, 6],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(33)),
dp_rank: 0,
arrival_timestamp_ms: Some(100.0),
},
];
let single = run_concurrency_single_collect(args.clone(), requests.clone(), 2);
let (multi, stats) = run_concurrency_multi_collect_with_stats(
&args,
requests,
2,
1,
ReplayRouterMode::KvRouter,
);
assert_eq!(stats.dispatch_history, vec![0, 0, 0]);
assert_eq!(stats.max_router_pending, 0);
for uuid in [11_u128, 22, 33] {
assert_eq!(
multi.snapshot(Uuid::from_u128(uuid)),
single.snapshot(Uuid::from_u128(uuid))
);
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::core::ReplayWorkerCore;
use super::normalize_trace_requests;
use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::replay::{TraceCollector, TraceSimulationReport};
use anyhow::bail;
use std::collections::VecDeque;
use uuid::Uuid;
#[derive(Debug, Clone, Copy)]
enum SingleReplayMode {
Trace,
Concurrency { max_in_flight: usize },
}
struct SingleRuntime {
current_time_ms: f64,
pending: VecDeque<DirectRequest>,
worker: ReplayWorkerCore,
collector: TraceCollector,
mode: SingleReplayMode,
}
impl SingleRuntime {
fn new(args: MockEngineArgs, pending: VecDeque<DirectRequest>, mode: SingleReplayMode) -> Self {
Self {
current_time_ms: 0.0,
pending,
worker: ReplayWorkerCore::new(args),
collector: TraceCollector::default(),
mode,
}
}
fn enqueue_trace_arrivals(&mut self) {
loop {
let Some(next_arrival_ms) = self
.pending
.front()
.and_then(|request| request.arrival_timestamp_ms)
else {
break;
};
if next_arrival_ms > self.current_time_ms {
break;
}
let request = self
.pending
.pop_front()
.expect("front request must exist when arrival is available");
let arrival_ms = request
.arrival_timestamp_ms
.expect("trace replay requests must have an arrival timestamp");
self.record_arrival(request, arrival_ms);
}
}
fn enqueue_concurrency_arrivals(&mut self, max_in_flight: usize) {
while self.worker.num_requests() < max_in_flight {
let Some(mut request) = self.pending.pop_front() else {
break;
};
request.arrival_timestamp_ms = Some(self.current_time_ms);
self.record_arrival(request, self.current_time_ms);
}
}
fn record_arrival(&mut self, request: DirectRequest, arrival_ms: f64) -> Uuid {
let input_length = request.tokens.len();
let output_length = request.max_output_tokens;
let uuid = self.worker.receive(request);
self.collector
.on_arrival(uuid, arrival_ms, input_length, output_length);
uuid
}
fn is_done(&self) -> bool {
self.pending.is_empty() && self.worker.is_empty()
}
fn advance_to_next_trace_arrival(&mut self) -> anyhow::Result<()> {
let Some(next_arrival_ms) = self
.pending
.front()
.and_then(|request| request.arrival_timestamp_ms)
else {
bail!("trace replay reached an idle state without a pending arrival");
};
self.current_time_ms = next_arrival_ms;
Ok(())
}
fn drive_worker(&mut self, admit_arrivals_between_steps: bool) {
let pass = self
.worker
.execute_pass(&mut self.collector, self.current_time_ms);
self.current_time_ms = pass.end_ms;
if admit_arrivals_between_steps {
self.enqueue_trace_arrivals();
}
}
fn run(mut self) -> anyhow::Result<TraceCollector> {
while !self.is_done() {
match self.mode {
SingleReplayMode::Trace => {
self.enqueue_trace_arrivals();
if self.worker.is_empty() {
self.advance_to_next_trace_arrival()?;
self.enqueue_trace_arrivals();
continue;
}
self.drive_worker(true);
}
SingleReplayMode::Concurrency { max_in_flight } => {
self.enqueue_concurrency_arrivals(max_in_flight);
if self.worker.is_empty() {
break;
}
self.drive_worker(false);
}
}
}
Ok(self.collector)
}
}
pub(crate) fn simulate_trace_single(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let collector = SingleRuntime::new(args, pending, SingleReplayMode::Trace).run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_single(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = VecDeque::from(requests);
let collector = SingleRuntime::new(
args,
pending,
SingleReplayMode::Concurrency { max_in_flight },
)
.run()?;
Ok(collector.finish())
}
#[cfg(test)]
pub(super) fn run_trace_single_collect(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
) -> TraceCollector {
let pending = normalize_trace_requests(requests, arrival_speedup_ratio).unwrap();
SingleRuntime::new(args, pending, SingleReplayMode::Trace)
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_concurrency_single_collect(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
) -> TraceCollector {
SingleRuntime::new(
args,
VecDeque::from(requests),
SingleReplayMode::Concurrency { max_in_flight },
)
.run()
.unwrap()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::replay::{TraceRequestStatsSnapshot, TraceSimulationReport};
use rstest::rstest;
use std::collections::{HashMap, VecDeque};
use uuid::Uuid;
#[derive(Debug)]
struct ManualReplayResult {
report: TraceSimulationReport,
snapshots: HashMap<Uuid, TraceRequestStatsSnapshot>,
idle_jump_ms: f64,
first_decode_end_ms: f64,
}
#[derive(Debug)]
struct ManualConcurrencyResult {
report: TraceSimulationReport,
snapshots: HashMap<Uuid, TraceRequestStatsSnapshot>,
}
fn enqueue_trace_arrivals_manual(
pending: &mut VecDeque<DirectRequest>,
worker: &mut ReplayWorkerCore,
collector: &mut TraceCollector,
current_time_ms: f64,
) {
loop {
let Some(next_arrival_ms) = pending
.front()
.and_then(|request| request.arrival_timestamp_ms)
else {
break;
};
if next_arrival_ms > current_time_ms {
break;
}
let request = pending
.pop_front()
.expect("front request must exist when arrival is available");
let arrival_ms = request
.arrival_timestamp_ms
.expect("trace replay requests must have an arrival timestamp");
let input_length = request.tokens.len();
let output_length = request.max_output_tokens;
let uuid = worker.receive(request);
collector.on_arrival(uuid, arrival_ms, input_length, output_length);
}
}
fn enqueue_concurrency_arrivals_manual(
pending: &mut VecDeque<DirectRequest>,
worker: &mut ReplayWorkerCore,
collector: &mut TraceCollector,
current_time_ms: f64,
max_in_flight: usize,
) {
while worker.num_requests() < max_in_flight {
let Some(mut request) = pending.pop_front() else {
break;
};
request.arrival_timestamp_ms = Some(current_time_ms);
let input_length = request.tokens.len();
let output_length = request.max_output_tokens;
let uuid = worker.receive(request);
collector.on_arrival(uuid, current_time_ms, input_length, output_length);
}
}
fn replay_args(enable_prefix_caching: bool, enable_chunked_prefill: bool) -> MockEngineArgs {
MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(32)
.max_num_batched_tokens(Some(8))
.max_num_seqs(Some(2))
.enable_prefix_caching(enable_prefix_caching)
.enable_chunked_prefill(enable_chunked_prefill)
.speedup_ratio(0.0)
.build()
.unwrap()
}
fn replay_fixture() -> Vec<DirectRequest> {
vec![
DirectRequest {
tokens: vec![1, 1, 1, 1, 2, 2, 2, 2],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(11)),
dp_rank: 0,
arrival_timestamp_ms: Some(100.0),
},
DirectRequest {
tokens: vec![1, 1, 1, 1, 2, 2, 2, 2],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(22)),
dp_rank: 0,
arrival_timestamp_ms: Some(101.0),
},
DirectRequest {
tokens: vec![9, 9, 9, 9, 8, 8, 8, 8],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(33)),
dp_rank: 0,
arrival_timestamp_ms: Some(500.0),
},
]
}
fn run_trace_manually(
args: &MockEngineArgs,
requests: Vec<DirectRequest>,
) -> ManualReplayResult {
let mut requests = requests;
requests.sort_by(|left, right| {
let left_ts = left.arrival_timestamp_ms.unwrap();
let right_ts = right.arrival_timestamp_ms.unwrap();
left_ts.total_cmp(&right_ts)
});
let first_arrival_ms = requests.first().unwrap().arrival_timestamp_ms.unwrap();
let mut pending = VecDeque::from(
requests
.into_iter()
.map(|mut request| {
request.arrival_timestamp_ms =
Some(request.arrival_timestamp_ms.unwrap() - first_arrival_ms);
request
})
.collect::<Vec<_>>(),
);
let mut worker = ReplayWorkerCore::new(args.clone());
let mut collector = TraceCollector::default();
let mut current_time_ms = 0.0;
let mut idle_jump_ms = 0.0;
let mut first_decode_end_ms = 0.0;
while !pending.is_empty() || !worker.is_empty() {
enqueue_trace_arrivals_manual(
&mut pending,
&mut worker,
&mut collector,
current_time_ms,
);
if worker.is_empty() {
let next_arrival_ms = pending.front().unwrap().arrival_timestamp_ms.unwrap();
current_time_ms = next_arrival_ms;
if idle_jump_ms == 0.0 && current_time_ms > 0.0 {
idle_jump_ms = current_time_ms;
}
enqueue_trace_arrivals_manual(
&mut pending,
&mut worker,
&mut collector,
current_time_ms,
);
continue;
}
let pass = worker.execute_pass(&mut collector, current_time_ms);
if first_decode_end_ms == 0.0 && !pass.output_signals.is_empty() {
first_decode_end_ms = pass.end_ms;
}
current_time_ms = pass.end_ms;
enqueue_trace_arrivals_manual(
&mut pending,
&mut worker,
&mut collector,
current_time_ms,
);
}
let snapshots = [
Uuid::from_u128(11),
Uuid::from_u128(22),
Uuid::from_u128(33),
]
.into_iter()
.map(|uuid| (uuid, collector.snapshot(uuid).unwrap()))
.collect();
ManualReplayResult {
report: collector.finish(),
snapshots,
idle_jump_ms,
first_decode_end_ms,
}
}
fn run_concurrency_manually(
args: &MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
) -> ManualConcurrencyResult {
let mut pending = VecDeque::from(requests);
let mut worker = ReplayWorkerCore::new(args.clone());
let mut collector = TraceCollector::default();
let mut current_time_ms = 0.0;
while !pending.is_empty() || !worker.is_empty() {
enqueue_concurrency_arrivals_manual(
&mut pending,
&mut worker,
&mut collector,
current_time_ms,
max_in_flight,
);
if worker.is_empty() {
break;
}
let pass = worker.execute_pass(&mut collector, current_time_ms);
current_time_ms = pass.end_ms;
}
let snapshots = [
Uuid::from_u128(11),
Uuid::from_u128(22),
Uuid::from_u128(33),
]
.into_iter()
.map(|uuid| (uuid, collector.snapshot(uuid).unwrap()))
.collect();
ManualConcurrencyResult {
report: collector.finish(),
snapshots,
}
}
fn assert_report_close(left: &TraceSimulationReport, right: &TraceSimulationReport) {
let epsilon = 1e-9;
assert_eq!(
left.request_counts.num_requests,
right.request_counts.num_requests
);
assert_eq!(
left.request_counts.completed_requests,
right.request_counts.completed_requests
);
assert_eq!(
left.request_counts.total_input_tokens,
right.request_counts.total_input_tokens
);
assert_eq!(
left.request_counts.total_output_tokens,
right.request_counts.total_output_tokens
);
assert!((left.throughput.duration_ms - right.throughput.duration_ms).abs() <= epsilon);
assert!(
(left.throughput.request_throughput_rps - right.throughput.request_throughput_rps)
.abs()
<= epsilon
);
assert!(
(left.throughput.input_throughput_tok_s - right.throughput.input_throughput_tok_s)
.abs()
<= epsilon
);
assert!(
(left.throughput.output_throughput_tok_s - right.throughput.output_throughput_tok_s)
.abs()
<= epsilon
);
assert!(
(left.throughput.total_throughput_tok_s - right.throughput.total_throughput_tok_s)
.abs()
<= epsilon
);
assert!(
(left.prefix_cache_reused_ratio - right.prefix_cache_reused_ratio).abs() <= epsilon
);
assert!((left.latency.ttft.mean_ms - right.latency.ttft.mean_ms).abs() <= epsilon);
assert!((left.latency.ttft.min_ms - right.latency.ttft.min_ms).abs() <= epsilon);
assert!((left.latency.ttft.max_ms - right.latency.ttft.max_ms).abs() <= epsilon);
assert!((left.latency.ttft.median_ms - right.latency.ttft.median_ms).abs() <= epsilon);
assert!((left.latency.ttft.p75_ms - right.latency.ttft.p75_ms).abs() <= epsilon);
assert!((left.latency.ttft.p90_ms - right.latency.ttft.p90_ms).abs() <= epsilon);
assert!((left.latency.ttft.p95_ms - right.latency.ttft.p95_ms).abs() <= epsilon);
assert!((left.latency.ttft.p99_ms - right.latency.ttft.p99_ms).abs() <= epsilon);
assert!((left.latency.ttft.std_ms - right.latency.ttft.std_ms).abs() <= epsilon);
assert!((left.latency.ttst.mean_ms - right.latency.ttst.mean_ms).abs() <= epsilon);
assert!((left.latency.ttst.min_ms - right.latency.ttst.min_ms).abs() <= epsilon);
assert!((left.latency.ttst.max_ms - right.latency.ttst.max_ms).abs() <= epsilon);
assert!((left.latency.ttst.median_ms - right.latency.ttst.median_ms).abs() <= epsilon);
assert!((left.latency.ttst.p75_ms - right.latency.ttst.p75_ms).abs() <= epsilon);
assert!((left.latency.ttst.p90_ms - right.latency.ttst.p90_ms).abs() <= epsilon);
assert!((left.latency.ttst.p95_ms - right.latency.ttst.p95_ms).abs() <= epsilon);
assert!((left.latency.ttst.p99_ms - right.latency.ttst.p99_ms).abs() <= epsilon);
assert!((left.latency.ttst.std_ms - right.latency.ttst.std_ms).abs() <= epsilon);
assert!((left.latency.tpot.mean_ms - right.latency.tpot.mean_ms).abs() <= epsilon);
assert!((left.latency.tpot.min_ms - right.latency.tpot.min_ms).abs() <= epsilon);
assert!((left.latency.tpot.max_ms - right.latency.tpot.max_ms).abs() <= epsilon);
assert!((left.latency.tpot.median_ms - right.latency.tpot.median_ms).abs() <= epsilon);
assert!((left.latency.tpot.p75_ms - right.latency.tpot.p75_ms).abs() <= epsilon);
assert!((left.latency.tpot.p90_ms - right.latency.tpot.p90_ms).abs() <= epsilon);
assert!((left.latency.tpot.p95_ms - right.latency.tpot.p95_ms).abs() <= epsilon);
assert!((left.latency.tpot.p99_ms - right.latency.tpot.p99_ms).abs() <= epsilon);
assert!((left.latency.tpot.std_ms - right.latency.tpot.std_ms).abs() <= epsilon);
assert!(
(left.latency.itl.distribution.mean_ms - right.latency.itl.distribution.mean_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.min_ms - right.latency.itl.distribution.min_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.max_ms - right.latency.itl.distribution.max_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.median_ms - right.latency.itl.distribution.median_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.p75_ms - right.latency.itl.distribution.p75_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.p90_ms - right.latency.itl.distribution.p90_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.p95_ms - right.latency.itl.distribution.p95_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.p99_ms - right.latency.itl.distribution.p99_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.std_ms - right.latency.itl.distribution.std_ms).abs()
<= epsilon
);
assert!((left.latency.itl.max_ms - right.latency.itl.max_ms).abs() <= epsilon);
assert!((left.latency.e2e.mean_ms - right.latency.e2e.mean_ms).abs() <= epsilon);
assert!((left.latency.e2e.min_ms - right.latency.e2e.min_ms).abs() <= epsilon);
assert!((left.latency.e2e.max_ms - right.latency.e2e.max_ms).abs() <= epsilon);
assert!((left.latency.e2e.median_ms - right.latency.e2e.median_ms).abs() <= epsilon);
assert!((left.latency.e2e.p75_ms - right.latency.e2e.p75_ms).abs() <= epsilon);
assert!((left.latency.e2e.p90_ms - right.latency.e2e.p90_ms).abs() <= epsilon);
assert!((left.latency.e2e.p95_ms - right.latency.e2e.p95_ms).abs() <= epsilon);
assert!((left.latency.e2e.p99_ms - right.latency.e2e.p99_ms).abs() <= epsilon);
assert!((left.latency.e2e.std_ms - right.latency.e2e.std_ms).abs() <= epsilon);
assert!(
(left.latency.output_token_throughput_per_user.mean_ms
- right.latency.output_token_throughput_per_user.mean_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.min_ms
- right.latency.output_token_throughput_per_user.min_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.max_ms
- right.latency.output_token_throughput_per_user.max_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.median_ms
- right.latency.output_token_throughput_per_user.median_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.p75_ms
- right.latency.output_token_throughput_per_user.p75_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.p90_ms
- right.latency.output_token_throughput_per_user.p90_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.p95_ms
- right.latency.output_token_throughput_per_user.p95_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.p99_ms
- right.latency.output_token_throughput_per_user.p99_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.std_ms
- right.latency.output_token_throughput_per_user.std_ms)
.abs()
<= epsilon
);
}
#[rstest]
#[case(false, false)]
#[case(false, true)]
#[case(true, false)]
#[case(true, true)]
fn test_trace_replay_matches_manual_steps(
#[case] enable_prefix_caching: bool,
#[case] enable_chunked_prefill: bool,
) {
let args = replay_args(enable_prefix_caching, enable_chunked_prefill);
let manual = run_trace_manually(&args, replay_fixture());
let replay_report = simulate_trace_single(args, replay_fixture(), 1.0).unwrap();
let request_1 = manual.snapshots.get(&Uuid::from_u128(11)).unwrap();
let request_2 = manual.snapshots.get(&Uuid::from_u128(22)).unwrap();
let request_3 = manual.snapshots.get(&Uuid::from_u128(33)).unwrap();
assert_eq!(request_1.arrival_time_ms, 0.0);
assert_eq!(request_2.arrival_time_ms, 1.0);
assert_eq!(request_3.arrival_time_ms, 400.0);
assert_eq!(manual.idle_jump_ms, 400.0);
assert_eq!(
request_1.first_token_ms.unwrap(),
manual.first_decode_end_ms
);
assert!(request_2.first_admit_ms.unwrap() >= request_2.arrival_time_ms);
assert!(request_3.first_admit_ms.unwrap() >= request_3.arrival_time_ms);
assert!(manual.report.latency.e2e.mean_ms >= manual.report.latency.ttft.mean_ms);
if enable_prefix_caching {
assert!(request_2.reused_input_tokens > 0);
assert!(manual.report.prefix_cache_reused_ratio > 0.0);
} else {
assert_eq!(request_2.reused_input_tokens, 0);
assert_eq!(manual.report.prefix_cache_reused_ratio, 0.0);
}
assert_report_close(&replay_report, &manual.report);
}
#[test]
fn test_concurrency_replay_matches_manual_steps() {
let args = replay_args(false, false);
let requests = vec![
DirectRequest {
tokens: vec![1, 2, 3, 4, 5, 6, 7, 8],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(11)),
dp_rank: 0,
arrival_timestamp_ms: Some(900.0),
},
DirectRequest {
tokens: vec![1, 2, 3, 4, 5, 9, 10, 11],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(22)),
dp_rank: 0,
arrival_timestamp_ms: Some(1000.0),
},
DirectRequest {
tokens: vec![12, 13, 14, 15, 16, 17, 18, 19],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(33)),
dp_rank: 0,
arrival_timestamp_ms: Some(100.0),
},
];
let manual = run_concurrency_manually(&args, requests.clone(), 2);
let replay_report = simulate_concurrency_single(args, requests, 2).unwrap();
let request_1 = manual.snapshots.get(&Uuid::from_u128(11)).unwrap();
let request_2 = manual.snapshots.get(&Uuid::from_u128(22)).unwrap();
let request_3 = manual.snapshots.get(&Uuid::from_u128(33)).unwrap();
assert_eq!(request_1.arrival_time_ms, 0.0);
assert_eq!(request_2.arrival_time_ms, 0.0);
assert_eq!(request_3.arrival_time_ms, request_1.last_token_ms.unwrap());
assert!(request_3.arrival_time_ms < request_2.last_token_ms.unwrap());
assert_eq!(manual.report.request_counts.completed_requests, 3);
assert_eq!(manual.report.request_counts.total_input_tokens, 24);
assert_eq!(manual.report.request_counts.total_output_tokens, 6);
assert_report_close(&replay_report, &manual.report);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::common::protocols::DirectRequest;
use crate::common::protocols::MockEngineArgs;
use crate::replay::TraceCollector;
use crate::scheduler::{EngineCore, EnginePassResult};
pub(crate) struct OfflineWorkerState {
core: EngineCore,
busy: bool,
in_flight: usize,
}
impl OfflineWorkerState {
pub(crate) fn new(worker_idx: usize, args: MockEngineArgs, capture_kv_events: bool) -> Self {
let core = match args.engine_type {
crate::common::protocols::EngineType::Vllm => {
if capture_kv_events {
EngineCore::Vllm(crate::scheduler::VllmCore::new_with_kv_capture(
args,
worker_idx as u64,
))
} else {
EngineCore::Vllm(crate::scheduler::VllmCore::new(args))
}
}
crate::common::protocols::EngineType::Sglang => {
if capture_kv_events {
EngineCore::Sglang(crate::scheduler::SglangCore::new_with_kv_capture(
args,
worker_idx as u64,
))
} else {
EngineCore::Sglang(crate::scheduler::SglangCore::new(args))
}
}
};
Self {
core,
busy: false,
in_flight: 0,
}
}
pub(crate) fn in_flight(&self) -> usize {
debug_assert!(self.in_flight >= self.core.num_requests());
self.in_flight
}
pub(crate) fn receive_request(&mut self, request: DirectRequest) {
self.in_flight += 1;
self.core.receive(request);
}
pub(crate) fn mark_completed(&mut self, completed_requests: usize) {
self.in_flight = self.in_flight.saturating_sub(completed_requests);
}
pub(crate) fn mark_busy(&mut self) {
self.busy = true;
}
pub(crate) fn mark_idle(&mut self) {
self.busy = false;
}
pub(crate) fn is_ready(&self) -> bool {
!self.busy && !self.core.is_empty()
}
pub(crate) fn is_drained(&self) -> bool {
self.in_flight == 0 && !self.busy && self.core.is_empty()
}
pub(crate) fn execute_pass(
&mut self,
collector: &mut TraceCollector,
now_ms: f64,
) -> EnginePassResult {
self.core.execute_pass(collector, now_ms)
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
mod runtime;
pub(crate) use runtime::{simulate_concurrency_requests, simulate_trace_requests};
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use anyhow::{Result, anyhow, bail};
use dashmap::DashMap;
use dynamo_kv_router::config::KvRouterConfig;
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, mpsc};
use tokio::task::JoinSet;
use tokio::time::Instant;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal};
use crate::replay::router::ReplayRouter;
use crate::replay::{
ReplayRouterMode, TraceCollector, TraceSimulationReport, normalize_trace_requests,
};
use crate::scheduler::{AdmissionEvent, EngineScheduler, SchedulerHandle};
#[derive(Clone, Copy, Debug)]
enum LiveReplayMode {
Trace,
Concurrency { max_in_flight: usize },
}
#[derive(Debug, Default, PartialEq, Eq)]
pub(super) struct LiveRuntimeStats {
pub(super) dispatch_history: Vec<usize>,
pub(super) max_in_flight_seen: usize,
pub(super) prefill_marked_count: usize,
pub(super) freed_count: usize,
}
#[derive(Default)]
struct SharedLiveRuntimeStats {
dispatch_history: Mutex<Vec<usize>>,
current_in_flight: AtomicUsize,
max_in_flight_seen: AtomicUsize,
prefill_marked_count: AtomicUsize,
freed_count: AtomicUsize,
}
impl SharedLiveRuntimeStats {
fn record_dispatch(&self, worker_idx: usize) {
self.dispatch_history.lock().unwrap().push(worker_idx);
let current = self.current_in_flight.fetch_add(1, Ordering::AcqRel) + 1;
self.max_in_flight_seen.fetch_max(current, Ordering::AcqRel);
}
fn record_completion(&self) {
self.current_in_flight.fetch_sub(1, Ordering::AcqRel);
}
fn record_prefill_marked(&self) {
self.prefill_marked_count.fetch_add(1, Ordering::AcqRel);
}
fn record_freed(&self) {
self.freed_count.fetch_add(1, Ordering::AcqRel);
}
fn snapshot(&self) -> LiveRuntimeStats {
LiveRuntimeStats {
dispatch_history: self.dispatch_history.lock().unwrap().clone(),
max_in_flight_seen: self.max_in_flight_seen.load(Ordering::Acquire),
prefill_marked_count: self.prefill_marked_count.load(Ordering::Acquire),
freed_count: self.freed_count.load(Ordering::Acquire),
}
}
}
#[derive(Default)]
struct RequestState {
first_token_seen: AtomicBool,
completed_seen: AtomicBool,
completion_notify: Notify,
}
impl RequestState {
fn mark_first_token_once(&self) -> bool {
!self.first_token_seen.swap(true, Ordering::AcqRel)
}
fn mark_completed_once(&self) -> bool {
!self.completed_seen.swap(true, Ordering::AcqRel)
}
fn notify_completion(&self) {
self.completion_notify.notify_waiters();
}
async fn wait_for_completion(&self) {
loop {
let notified = self.completion_notify.notified();
if self.completed_seen.load(Ordering::Acquire) {
return;
}
notified.await;
}
}
}
#[derive(Clone, Copy)]
struct ArrivalEvent {
uuid: Uuid,
at_ms: f64,
input_tokens: usize,
output_tokens: usize,
}
type RequestRegistry = Arc<DashMap<Uuid, Arc<RequestState>>>;
async fn run_demux(
start: Instant,
mut arrival_rx: mpsc::UnboundedReceiver<ArrivalEvent>,
mut admission_rx: mpsc::UnboundedReceiver<AdmissionEvent>,
mut output_rx: mpsc::UnboundedReceiver<OutputSignal>,
requests: RequestRegistry,
router: Arc<ReplayRouter>,
stats: Arc<SharedLiveRuntimeStats>,
) -> TraceSimulationReport {
let mut collector = TraceCollector::default();
let mut arrivals_open = true;
let mut admissions_open = true;
let mut outputs_open = true;
loop {
if !arrivals_open && !admissions_open && !outputs_open {
break;
}
tokio::select! {
biased;
arrival = arrival_rx.recv(), if arrivals_open => {
match arrival {
Some(arrival) => collector.on_arrival(
arrival.uuid,
arrival.at_ms,
arrival.input_tokens,
arrival.output_tokens,
),
None => arrivals_open = false,
}
}
admission = admission_rx.recv(), if admissions_open => {
match admission {
Some(admission) => {
let now_ms = start.elapsed().as_secs_f64() * 1000.0;
collector.on_admit(admission.uuid, now_ms, admission.reused_input_tokens);
}
None => admissions_open = false,
}
}
output = output_rx.recv(), if outputs_open => {
match output {
Some(output) => {
let now_ms = start.elapsed().as_secs_f64() * 1000.0;
collector.on_token(output.uuid, now_ms);
if let Some(state) = requests.get(&output.uuid) {
if state.mark_first_token_once() {
match router.on_first_token(output.uuid).await {
Ok(true) => stats.record_prefill_marked(),
Ok(false) => {}
Err(error) => tracing::warn!(
uuid = %output.uuid,
error = %error,
"online replay failed to mark prefill completed"
),
}
}
if output.completed && state.mark_completed_once() {
match router.on_complete(output.uuid).await {
Ok(true) => stats.record_freed(),
Ok(false) => {}
Err(error) => tracing::warn!(
uuid = %output.uuid,
error = %error,
"online replay failed to free completed request"
),
}
state.notify_completion();
}
}
}
None => outputs_open = false,
}
}
}
}
let wall_time_ms = start.elapsed().as_secs_f64() * 1000.0;
collector.finish().with_wall_time_ms(wall_time_ms)
}
struct LiveRuntime {
pending: VecDeque<DirectRequest>,
senders: Arc<[mpsc::UnboundedSender<DirectRequest>]>,
schedulers: Vec<EngineScheduler>,
output_rx: mpsc::UnboundedReceiver<OutputSignal>,
admission_rx: mpsc::UnboundedReceiver<AdmissionEvent>,
cancel_token: CancellationToken,
start: Instant,
mode: LiveReplayMode,
router: Arc<ReplayRouter>,
}
fn now_ms(start: Instant) -> f64 {
start.elapsed().as_secs_f64() * 1000.0
}
fn request_uuid(request: &DirectRequest) -> Result<Uuid> {
request
.uuid
.ok_or_else(|| anyhow!("online replay requires requests to have stable UUIDs"))
}
fn record_arrival(
arrival_tx: &mpsc::UnboundedSender<ArrivalEvent>,
request: &DirectRequest,
arrival_at_ms: f64,
) -> Result<Uuid> {
let uuid = request_uuid(request)?;
let input_tokens = request.tokens.len();
let output_tokens = request.max_output_tokens;
arrival_tx
.send(ArrivalEvent {
uuid,
at_ms: arrival_at_ms,
input_tokens,
output_tokens,
})
.map_err(|_| anyhow!("online replay arrival channel closed"))?;
Ok(uuid)
}
#[derive(Clone)]
struct RequestTaskContext {
senders: Arc<[mpsc::UnboundedSender<DirectRequest>]>,
router: Arc<ReplayRouter>,
requests: RequestRegistry,
stats: Arc<SharedLiveRuntimeStats>,
}
async fn run_request_task(
ctx: RequestTaskContext,
request: DirectRequest,
permit: Option<OwnedSemaphorePermit>,
) -> Result<()> {
let uuid = request_uuid(&request)?;
let worker_idx = ctx
.router
.select_worker(&request, ctx.senders.len())
.await?;
if worker_idx >= ctx.senders.len() {
bail!("online replay selected unknown worker index {worker_idx}");
}
let state = Arc::new(RequestState::default());
ctx.requests.insert(uuid, Arc::clone(&state));
if let Err(error) = ctx.senders[worker_idx].send(request) {
ctx.requests.remove(&uuid);
return Err(anyhow!(
"online replay failed to dispatch request to worker {worker_idx}: {error}"
));
}
ctx.stats.record_dispatch(worker_idx);
state.wait_for_completion().await;
ctx.stats.record_completion();
ctx.requests.remove(&uuid);
drop(permit);
Ok(())
}
impl LiveRuntime {
fn new(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
pending: VecDeque<DirectRequest>,
num_workers: usize,
mode: LiveReplayMode,
router_mode: ReplayRouterMode,
) -> Result<Self> {
if pending.is_empty() {
bail!("online replay requires at least one request");
}
let cancel_token = CancellationToken::new();
let (output_tx, output_rx) = mpsc::unbounded_channel();
let (admission_tx, admission_rx) = mpsc::unbounded_channel();
let router = Arc::new(ReplayRouter::new(
router_mode,
&args,
router_config,
num_workers,
));
let mut schedulers = Vec::with_capacity(num_workers);
let mut senders = Vec::with_capacity(num_workers);
for worker_idx in 0..num_workers {
let scheduler = EngineScheduler::new_with_admission(
args.clone(),
0,
Some(output_tx.clone()),
router.sink(worker_idx as _),
Some(cancel_token.clone()),
Some(admission_tx.clone()),
);
senders.push(scheduler.request_sender());
schedulers.push(scheduler);
}
drop(output_tx);
drop(admission_tx);
Ok(Self {
pending,
senders: Arc::from(senders),
schedulers,
output_rx,
admission_rx,
cancel_token,
start: Instant::now(),
mode,
router,
})
}
async fn run(mut self) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let requests = Arc::new(DashMap::with_capacity(self.pending.len()));
let stats = Arc::new(SharedLiveRuntimeStats::default());
let (arrival_tx, arrival_rx) = mpsc::unbounded_channel();
let demux_requests = Arc::clone(&requests);
let start = self.start;
let router = Arc::clone(&self.router);
let senders = Arc::clone(&self.senders);
let output_rx = self.output_rx;
let admission_rx = self.admission_rx;
let demux_stats = Arc::clone(&stats);
let demux_router = Arc::clone(&router);
let demux_task = tokio::spawn(async move {
run_demux(
start,
arrival_rx,
admission_rx,
output_rx,
demux_requests,
demux_router,
demux_stats,
)
.await
});
let mut tasks = JoinSet::new();
let task_ctx = RequestTaskContext {
senders,
router: Arc::clone(&self.router),
requests: Arc::clone(&requests),
stats: Arc::clone(&stats),
};
match self.mode {
LiveReplayMode::Trace => {
while let Some(request) = self.pending.pop_front() {
let arrival_ms = request.arrival_timestamp_ms.unwrap_or(0.0);
let deadline =
start + tokio::time::Duration::from_secs_f64(arrival_ms / 1000.0);
tokio::time::sleep_until(deadline).await;
record_arrival(&arrival_tx, &request, arrival_ms)?;
tasks.spawn(run_request_task(task_ctx.clone(), request, None));
}
}
LiveReplayMode::Concurrency { max_in_flight } => {
let semaphore = Arc::new(Semaphore::new(max_in_flight));
while let Some(request) = self.pending.pop_front() {
let permit = semaphore
.clone()
.acquire_owned()
.await
.map_err(|_| anyhow!("online replay concurrency semaphore closed"))?;
record_arrival(&arrival_tx, &request, now_ms(start))?;
tasks.spawn(run_request_task(task_ctx.clone(), request, Some(permit)));
}
}
}
while let Some(result) = tasks.join_next().await {
result.map_err(|e| anyhow!("online replay request task failed: {e}"))??;
}
drop(arrival_tx);
self.cancel_token.cancel();
self.schedulers.clear();
let report = demux_task
.await
.map_err(|e| anyhow!("online replay demux task failed: {e}"))?;
router.shutdown().await?;
Ok((report, stats.snapshot()))
}
}
fn run_live_runtime(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
pending: VecDeque<DirectRequest>,
num_workers: usize,
mode: LiveReplayMode,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.map_err(|e| anyhow!("failed to create online replay runtime: {e}"))?;
runtime.block_on(async move {
LiveRuntime::new(args, router_config, pending, num_workers, mode, router_mode)?
.run()
.await
})
}
pub(crate) fn simulate_trace_requests(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let (report, _) = run_live_runtime(
args,
router_config,
pending,
num_workers,
LiveReplayMode::Trace,
router_mode,
)?;
Ok(report)
}
pub(crate) fn simulate_concurrency_requests(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
if requests.is_empty() {
bail!("online concurrency replay requires at least one request");
}
let pending = VecDeque::from(requests);
let (report, _) = run_live_runtime(
args,
router_config,
pending,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)?;
Ok(report)
}
#[cfg(test)]
fn simulate_trace_requests_with_stats(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
run_live_runtime(
args,
None,
pending,
num_workers,
LiveReplayMode::Trace,
router_mode,
)
}
#[cfg(test)]
fn simulate_concurrency_requests_with_stats(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let pending = VecDeque::from(requests);
run_live_runtime(
args,
None,
pending,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::common::protocols::{DirectRequest, EngineType, SglangArgs};
fn replay_args() -> MockEngineArgs {
MockEngineArgs::builder()
.speedup_ratio(1000.0)
.block_size(64)
.build()
.unwrap()
}
fn sglang_replay_args() -> MockEngineArgs {
MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.num_gpu_blocks(512)
.speedup_ratio(1000.0)
.sglang(Some(SglangArgs {
page_size: Some(2),
..Default::default()
}))
.build()
.unwrap()
}
fn request(uuid: u128, token: u32, arrival_timestamp_ms: Option<f64>) -> DirectRequest {
DirectRequest {
tokens: vec![token; 64],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(uuid)),
dp_rank: 0,
arrival_timestamp_ms,
}
}
#[test]
fn test_online_trace_replay_single_worker_completes() {
let args = replay_args();
let requests = vec![request(1, 11, Some(0.0)), request(2, 22, Some(1.0))];
let report =
simulate_trace_requests(args, None, requests, 1, 1.0, ReplayRouterMode::RoundRobin)
.unwrap();
assert_eq!(report.request_counts.num_requests, 2);
assert_eq!(report.request_counts.completed_requests, 2);
assert_eq!(report.request_counts.total_output_tokens, 4);
assert!(report.throughput.wall_time_ms >= 0.0);
}
#[tokio::test]
async fn test_record_arrival_uses_caller_arrival_timestamp() {
let (arrival_tx, mut arrival_rx) = mpsc::unbounded_channel();
let uuid = Uuid::from_u128(999);
let arrival_at_ms = 123.0;
let request = request(999, 42, Some(arrival_at_ms));
let recorded_uuid = record_arrival(&arrival_tx, &request, arrival_at_ms).unwrap();
let arrival = arrival_rx.recv().await.unwrap();
assert_eq!(recorded_uuid, uuid);
assert_eq!(arrival.uuid, uuid);
assert_eq!(arrival.at_ms, arrival_at_ms);
}
#[tokio::test]
async fn test_trace_arrivals_are_not_blocked_by_queued_router_selection() {
let args = MockEngineArgs::builder()
.speedup_ratio(1000.0)
.block_size(64)
.max_num_seqs(Some(1))
.max_num_batched_tokens(Some(8))
.build()
.unwrap();
let start = Instant::now();
let router = Arc::new(ReplayRouter::new(
ReplayRouterMode::KvRouter,
&args,
None,
1,
));
let senders: Arc<[mpsc::UnboundedSender<DirectRequest>]> =
Arc::from(vec![mpsc::unbounded_channel::<DirectRequest>().0]);
let requests = Arc::new(DashMap::new());
let stats = Arc::new(SharedLiveRuntimeStats::default());
let (arrival_tx, mut arrival_rx) = mpsc::unbounded_channel();
let task_ctx = RequestTaskContext {
senders,
router: Arc::clone(&router),
requests,
stats,
};
let mut tasks = JoinSet::new();
let mut pending = VecDeque::from(vec![
request(1, 11, Some(0.0)),
request(2, 22, Some(1.0)),
request(3, 33, Some(2.0)),
]);
while let Some(request) = pending.pop_front() {
let arrival_ms = request.arrival_timestamp_ms.unwrap_or(0.0);
let deadline = start + tokio::time::Duration::from_secs_f64(arrival_ms / 1000.0);
tokio::time::sleep_until(deadline).await;
record_arrival(&arrival_tx, &request, arrival_ms).unwrap();
tasks.spawn(run_request_task(task_ctx.clone(), request, None));
}
let first = tokio::time::timeout(tokio::time::Duration::from_millis(50), arrival_rx.recv())
.await
.unwrap()
.unwrap();
let second =
tokio::time::timeout(tokio::time::Duration::from_millis(50), arrival_rx.recv())
.await
.unwrap()
.unwrap();
let third = tokio::time::timeout(tokio::time::Duration::from_millis(50), arrival_rx.recv())
.await
.unwrap()
.unwrap();
assert_eq!(first.uuid, Uuid::from_u128(1));
assert_eq!(second.uuid, Uuid::from_u128(2));
assert_eq!(third.uuid, Uuid::from_u128(3));
assert_eq!(first.at_ms, 0.0);
assert_eq!(second.at_ms, 1.0);
assert_eq!(third.at_ms, 2.0);
tasks.abort_all();
router.shutdown().await.unwrap();
}
#[test]
fn test_online_trace_replay_uses_round_robin_dispatch() {
let args = replay_args();
let requests = vec![
request(1, 1, Some(0.0)),
request(2, 2, Some(100.0)),
request(3, 3, Some(200.0)),
request(4, 4, Some(300.0)),
request(5, 5, Some(400.0)),
];
let (_, stats) = simulate_trace_requests_with_stats(
args,
requests,
3,
1.0,
ReplayRouterMode::RoundRobin,
)
.unwrap();
assert_eq!(stats.dispatch_history, vec![0, 1, 2, 0, 1]);
}
#[test]
fn test_online_concurrency_replay_respects_max_in_flight() {
let args = replay_args();
let requests = vec![
request(1, 10, None),
request(2, 20, None),
request(3, 30, None),
request(4, 40, None),
];
let (report, stats) = simulate_concurrency_requests_with_stats(
args,
requests,
2,
2,
ReplayRouterMode::RoundRobin,
)
.unwrap();
assert_eq!(report.request_counts.completed_requests, 4);
assert_eq!(stats.max_in_flight_seen, 2);
}
#[test]
fn test_online_trace_replay_populates_admit_reuse_stats() {
let args = replay_args();
let requests = vec![request(1, 77, Some(0.0)), request(2, 77, Some(5.0))];
let report =
simulate_trace_requests(args, None, requests, 1, 1.0, ReplayRouterMode::RoundRobin)
.unwrap();
assert_eq!(report.request_counts.completed_requests, 2);
assert!(report.prefix_cache_reused_ratio > 0.0);
}
#[test]
fn test_online_trace_replay_kv_router_prefers_cached_worker() {
let args = replay_args();
let requests = vec![request(1, 88, Some(0.0)), request(2, 88, Some(500.0))];
let (_, stats) =
simulate_trace_requests_with_stats(args, requests, 2, 1.0, ReplayRouterMode::KvRouter)
.unwrap();
assert_eq!(stats.dispatch_history.len(), 2);
assert_eq!(stats.dispatch_history[0], stats.dispatch_history[1]);
}
#[test]
fn test_online_trace_replay_sglang_single_worker_completes() {
let args = sglang_replay_args();
let requests = vec![request(101, 7, Some(0.0)), request(102, 8, Some(1.0))];
let report =
simulate_trace_requests(args, None, requests, 1, 1.0, ReplayRouterMode::RoundRobin)
.unwrap();
assert_eq!(report.request_counts.completed_requests, 2);
assert_eq!(report.request_counts.total_output_tokens, 4);
}
#[test]
fn test_online_trace_replay_sglang_kv_router_smoke() {
let args = sglang_replay_args();
let requests = vec![request(111, 9, Some(0.0)), request(112, 9, Some(500.0))];
let (report, stats) =
simulate_trace_requests_with_stats(args, requests, 2, 1.0, ReplayRouterMode::KvRouter)
.unwrap();
assert_eq!(report.request_counts.completed_requests, 2);
assert_eq!(stats.dispatch_history.len(), 2);
}
#[test]
fn test_online_concurrency_replay_kv_router_respects_max_in_flight() {
let args = replay_args();
let requests = vec![
request(1, 10, None),
request(2, 20, None),
request(3, 10, None),
request(4, 20, None),
];
let (report, stats) = simulate_concurrency_requests_with_stats(
args,
requests,
2,
2,
ReplayRouterMode::KvRouter,
)
.unwrap();
assert_eq!(report.request_counts.completed_requests, 4);
assert_eq!(stats.max_in_flight_seen, 2);
}
#[test]
fn test_online_trace_replay_kv_router_marks_prefill_and_free_once() {
let args = replay_args();
let requests = vec![DirectRequest {
tokens: vec![9; 64],
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(9)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
}];
let (_, stats) =
simulate_trace_requests_with_stats(args, requests, 1, 1.0, ReplayRouterMode::KvRouter)
.unwrap();
assert_eq!(stats.prefill_marked_count, 1);
assert_eq!(stats.freed_count, 1);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
mod offline;
mod online;
mod shared;
pub(crate) use offline::OfflineReplayRouter;
pub(crate) use online::ReplayRouter;
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap};
use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, Result, anyhow};
use dynamo_kv_router::config::KvRouterConfig;
use dynamo_kv_router::protocols::{
OverlapScores, RouterEvent, WorkerConfigLike, WorkerId, WorkerWithDpRank,
compute_block_hash_for_seq,
};
use dynamo_kv_router::queue::DEFAULT_MAX_BATCHED_TOKENS;
use dynamo_kv_router::{
ActiveSequencesMultiWorker, DefaultWorkerSelector, RadixTree, RouterSchedulingPolicy,
SchedulingPolicy, SchedulingRequest, SequenceRequest, WorkerSelector,
};
use dynamo_tokens::SequenceHash;
use uuid::Uuid;
use super::shared::{
ReplayNoopPublisher, ReplayWorkerConfig, replay_policy, replay_router_config, replay_selector,
replay_slots, replay_workers_with_configs,
};
use crate::common::protocols::DirectRequest;
use crate::common::protocols::MockEngineArgs;
type ReplayQueueKey = <RouterSchedulingPolicy as SchedulingPolicy>::Key;
struct SyncReplayIndexer {
block_size: u32,
tree: RadixTree,
}
impl SyncReplayIndexer {
fn new(block_size: u32) -> Self {
Self {
block_size,
tree: RadixTree::new(),
}
}
fn find_matches_for_request(&self, tokens: &[u32], lora_name: Option<&str>) -> OverlapScores {
let sequence = compute_block_hash_for_seq(tokens, self.block_size, None, lora_name);
self.tree.find_matches(sequence, false)
}
fn apply_event(&mut self, event: RouterEvent) -> Result<()> {
self.tree.apply_event(event).map_err(Into::into)
}
}
struct PendingRequest {
uuid: Uuid,
token_seq: Option<Vec<SequenceHash>>,
isl_tokens: usize,
overlaps: OverlapScores,
expected_output_tokens: Option<u32>,
}
impl PendingRequest {
fn request_id(&self) -> String {
self.uuid.to_string()
}
fn scheduling_request(
&self,
decode_blocks: HashMap<WorkerWithDpRank, usize>,
prefill_tokens: HashMap<WorkerWithDpRank, usize>,
) -> SchedulingRequest {
SchedulingRequest {
maybe_request_id: Some(self.request_id()),
token_seq: self.token_seq.clone(),
isl_tokens: self.isl_tokens,
overlaps: self.overlaps.clone(),
decode_blocks,
prefill_tokens,
router_config_override: None,
update_states: true,
lora_name: None,
priority_jump: 0.0,
expected_output_tokens: self.expected_output_tokens,
allowed_worker_ids: None,
resp_tx: None,
}
}
}
struct QueueEntry {
key: ReplayQueueKey,
_enqueue_time_ms: f64,
enqueue_seq: u64,
request: PendingRequest,
}
impl Eq for QueueEntry {}
impl PartialEq for QueueEntry {
fn eq(&self, other: &Self) -> bool {
self.key == other.key && self.enqueue_seq == other.enqueue_seq
}
}
impl Ord for QueueEntry {
fn cmp(&self, other: &Self) -> Ordering {
self.key
.cmp(&other.key)
.then_with(|| other.enqueue_seq.cmp(&self.enqueue_seq))
}
}
impl PartialOrd for QueueEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
pub(crate) struct OfflineReplayRouter {
config: KvRouterConfig,
block_size: u32,
runtime: tokio::runtime::Runtime,
queue_threshold: Option<f64>,
workers_with_configs: HashMap<WorkerId, ReplayWorkerConfig>,
slots: Arc<ActiveSequencesMultiWorker<ReplayNoopPublisher>>,
selector: DefaultWorkerSelector,
policy: RouterSchedulingPolicy,
pending: BinaryHeap<QueueEntry>,
next_enqueue_seq: u64,
indexer: SyncReplayIndexer,
}
impl OfflineReplayRouter {
pub(crate) fn new(
args: &MockEngineArgs,
router_config: Option<KvRouterConfig>,
num_workers: usize,
) -> Result<Self> {
let config = replay_router_config(args, router_config);
let workers_with_configs = replay_workers_with_configs(args, num_workers);
let slots = replay_slots(args, &workers_with_configs);
let selector = replay_selector(&config);
let policy = replay_policy(&config, args);
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| anyhow!("failed to create offline replay router runtime: {e}"))?;
let queue_threshold = if num_workers > 1 {
config.router_queue_threshold
} else {
None
};
Ok(Self {
config,
block_size: args.block_size as u32,
runtime,
queue_threshold,
workers_with_configs,
slots,
selector,
policy,
pending: BinaryHeap::new(),
next_enqueue_seq: 0,
indexer: SyncReplayIndexer::new(args.block_size as u32),
})
}
pub(crate) fn submit_request(
&mut self,
request: &DirectRequest,
now_ms: f64,
) -> Result<Option<usize>> {
let pending = self.build_pending_request(request)?;
let should_queue = self
.queue_threshold
.is_some_and(|threshold| self.all_workers_busy(threshold));
if should_queue {
let key = self.enqueue_key(now_ms, &pending);
self.pending.push(QueueEntry {
key,
_enqueue_time_ms: now_ms,
enqueue_seq: self.next_enqueue_seq,
request: pending,
});
self.next_enqueue_seq += 1;
return Ok(None);
}
self.admit_request(pending).map(Some)
}
pub(crate) fn apply_event(&mut self, event: RouterEvent) -> Result<()> {
self.indexer.apply_event(event)
}
pub(crate) fn mark_prefill_completed(&mut self, uuid: Uuid) -> Result<Vec<(Uuid, usize)>> {
self.runtime
.block_on(self.slots.mark_prefill_completed(&uuid.to_string()))
.map_err(anyhow::Error::from)?;
self.drain_pending()
}
pub(crate) fn free(&mut self, uuid: Uuid) -> Result<Vec<(Uuid, usize)>> {
self.runtime
.block_on(self.slots.free(&uuid.to_string()))
.map_err(anyhow::Error::from)?;
self.drain_pending()
}
#[cfg(test)]
pub(crate) fn pending_count(&self) -> usize {
self.pending.len()
}
pub(crate) fn shutdown(&mut self) {}
fn enqueue_key(&self, now_ms: f64, request: &PendingRequest) -> ReplayQueueKey {
let arrival_offset = Duration::from_secs_f64((now_ms.max(0.0)) / 1000.0);
self.policy.enqueue_key(
arrival_offset,
&request.scheduling_request(HashMap::new(), HashMap::new()),
)
}
fn build_pending_request(&self, request: &DirectRequest) -> Result<PendingRequest> {
let uuid = request
.uuid
.ok_or_else(|| anyhow!("offline replay requires requests to have stable UUIDs"))?;
let overlaps = self.indexer.find_matches_for_request(&request.tokens, None);
let token_seq = self.config.compute_seq_hashes_for_tracking(
&request.tokens,
self.block_size,
None,
None,
);
Ok(PendingRequest {
uuid,
token_seq,
isl_tokens: request.tokens.len(),
overlaps,
expected_output_tokens: Some(
u32::try_from(request.max_output_tokens)
.context("max_output_tokens does not fit into u32")?,
),
})
}
fn admit_request(&mut self, request: PendingRequest) -> Result<usize> {
let (decode_blocks, prefill_tokens) = self.slots.potential_blocks_and_tokens(
request.token_seq.as_deref(),
request.isl_tokens,
request.overlaps.clone(),
);
let scheduling_request = request.scheduling_request(decode_blocks, prefill_tokens);
let selection = self.selector.select_worker(
&self.workers_with_configs,
&scheduling_request,
self.block_size,
)?;
let worker_idx = usize::try_from(selection.worker.worker_id)
.map_err(|_| anyhow!("selected worker id does not fit into usize"))?;
let request_id = request.request_id();
self.runtime
.block_on(self.slots.add_request(SequenceRequest {
request_id,
token_sequence: request.token_seq,
isl: request.isl_tokens,
overlap: selection.overlap_blocks,
expected_output_tokens: request.expected_output_tokens,
worker: selection.worker,
lora_name: None,
}))
.map_err(anyhow::Error::from)?;
Ok(worker_idx)
}
fn drain_pending(&mut self) -> Result<Vec<(Uuid, usize)>> {
let Some(threshold) = self.queue_threshold else {
return Ok(Vec::new());
};
let mut admissions = Vec::new();
while !self.all_workers_busy(threshold) {
let Some(QueueEntry { request, .. }) = self.pending.pop() else {
break;
};
let uuid = request.uuid;
let worker_idx = self.admit_request(request)?;
admissions.push((uuid, worker_idx));
}
Ok(admissions)
}
fn all_workers_busy(&self, threshold: f64) -> bool {
let mut checked_any = false;
let any_worker_not_busy = self
.slots
.any_worker_matches_active_tokens(|worker, tokens| {
let Some(config) = self.workers_with_configs.get(&worker.worker_id) else {
return false;
};
checked_any = true;
let max_batched = config
.max_num_batched_tokens()
.unwrap_or(DEFAULT_MAX_BATCHED_TOKENS);
(tokens as f64) <= threshold * (max_batched as f64)
});
checked_any && !any_worker_not_busy
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
use anyhow::{Context, Result, anyhow};
use dynamo_kv_router::ConcurrentRadixTree;
use dynamo_kv_router::config::KvRouterConfig;
use dynamo_kv_router::indexer::{
KvIndexer, KvIndexerInterface, KvIndexerMetrics, ThreadPoolIndexer,
};
use dynamo_kv_router::protocols::{OverlapScores, RouterEvent, WorkerId};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use super::shared::{
ReplayScheduler, replay_policy, replay_router_config, replay_selector, replay_slots,
replay_workers_with_configs,
};
use crate::common::protocols::{
DirectRequest, KvCacheEventSink, KvEventPublishers, MockEngineArgs,
};
use crate::replay::ReplayRouterMode;
#[derive(Clone)]
enum ReplayIndexer {
Single(KvIndexer),
Concurrent(Arc<ThreadPoolIndexer<ConcurrentRadixTree>>),
}
impl ReplayIndexer {
async fn apply_event(&self, event: RouterEvent) {
match self {
Self::Single(indexer) => indexer.apply_event(event).await,
Self::Concurrent(indexer) => indexer.apply_event(event).await,
}
}
async fn find_matches_for_request(
&self,
tokens: &[u32],
lora_name: Option<&str>,
) -> Result<OverlapScores> {
match self {
Self::Single(indexer) => indexer
.find_matches_for_request(tokens, lora_name)
.await
.map_err(Into::into),
Self::Concurrent(indexer) => indexer
.find_matches_for_request(tokens, lora_name)
.await
.map_err(Into::into),
}
}
async fn flush(&self) -> usize {
match self {
Self::Single(indexer) => indexer.flush().await,
Self::Concurrent(indexer) => KvIndexerInterface::flush(indexer.as_ref()).await,
}
}
}
fn create_replay_indexer(block_size: u32, num_threads: usize) -> ReplayIndexer {
if num_threads > 1 {
return ReplayIndexer::Concurrent(Arc::new(ThreadPoolIndexer::new(
ConcurrentRadixTree::new(),
num_threads,
block_size,
)));
}
ReplayIndexer::Single(KvIndexer::new_with_frequency(
CancellationToken::new(),
None,
block_size,
Arc::new(KvIndexerMetrics::new_unregistered()),
None,
))
}
#[derive(Clone)]
struct ReplayKvEventSink {
worker_id: WorkerId,
event_tx: mpsc::UnboundedSender<RouterEvent>,
}
impl KvCacheEventSink for ReplayKvEventSink {
fn publish(&self, event: dynamo_kv_router::protocols::KvCacheEvent) -> anyhow::Result<()> {
self.event_tx
.send(RouterEvent::new(self.worker_id, event))
.map_err(|_| anyhow!("replay router event channel closed"))
}
}
#[derive(Default)]
pub(crate) struct RoundRobinRouter {
next_worker_idx: AtomicUsize,
}
impl RoundRobinRouter {
fn select_worker(&self, num_workers: usize) -> usize {
self.next_worker_idx.fetch_add(1, Ordering::AcqRel) % num_workers
}
}
pub(crate) struct KvReplayRouter {
config: KvRouterConfig,
block_size: u32,
scheduler: Arc<ReplayScheduler>,
event_tx: Mutex<Option<mpsc::UnboundedSender<RouterEvent>>>,
event_task: Mutex<Option<tokio::task::JoinHandle<()>>>,
indexer: ReplayIndexer,
}
impl KvReplayRouter {
fn new(
args: &MockEngineArgs,
router_config: Option<KvRouterConfig>,
num_workers: usize,
) -> Self {
let config = replay_router_config(args, router_config);
let indexer =
create_replay_indexer(args.block_size as u32, config.router_event_threads as usize);
let workers_with_configs = replay_workers_with_configs(args, num_workers);
let slots = replay_slots(args, &workers_with_configs);
let (_worker_config_tx, worker_config_rx) =
tokio::sync::watch::channel(workers_with_configs);
let selector = replay_selector(&config);
let policy = replay_policy(&config, args);
let scheduler = Arc::new(dynamo_kv_router::LocalScheduler::new(
slots,
worker_config_rx,
config.router_queue_threshold,
args.block_size as u32,
selector,
policy,
CancellationToken::new(),
"replay",
false,
));
let (event_tx, mut event_rx) = mpsc::unbounded_channel();
let indexer_clone = indexer.clone();
let event_task = tokio::spawn(async move {
while let Some(event) = event_rx.recv().await {
indexer_clone.apply_event(event).await;
}
let _ = indexer_clone.flush().await;
});
Self {
config,
block_size: args.block_size as u32,
scheduler,
event_tx: Mutex::new(Some(event_tx)),
event_task: Mutex::new(Some(event_task)),
indexer,
}
}
fn sink(&self, worker_id: WorkerId) -> Arc<dyn KvCacheEventSink> {
let event_tx = self
.event_tx
.lock()
.unwrap()
.as_ref()
.expect("router event channel should exist while runtime is active")
.clone();
Arc::new(ReplayKvEventSink {
worker_id,
event_tx,
})
}
async fn select_worker(&self, request: &DirectRequest) -> Result<usize> {
let uuid = request
.uuid
.ok_or_else(|| anyhow!("online replay requires requests to have stable UUIDs"))?;
let overlaps = self
.indexer
.find_matches_for_request(&request.tokens, None)
.await?;
let token_seq = self.config.compute_seq_hashes_for_tracking(
&request.tokens,
self.block_size,
None,
None,
);
let response = self
.scheduler
.schedule(
Some(uuid.to_string()),
request.tokens.len(),
token_seq,
overlaps,
None,
true,
None,
0.0,
Some(
u32::try_from(request.max_output_tokens)
.context("max_output_tokens does not fit into u32")?,
),
None,
)
.await?;
usize::try_from(response.best_worker.worker_id)
.map_err(|_| anyhow!("selected worker id does not fit into usize"))
}
async fn mark_prefill_completed(&self, uuid: Uuid) -> Result<()> {
self.scheduler
.mark_prefill_completed(&uuid.to_string())
.await
.map_err(anyhow::Error::from)
}
async fn free(&self, uuid: Uuid) -> Result<()> {
self.scheduler
.free(&uuid.to_string())
.await
.map_err(anyhow::Error::from)
}
async fn shutdown(&self) -> Result<()> {
self.event_tx.lock().unwrap().take();
let Some(event_task) = self.event_task.lock().unwrap().take() else {
return Ok(());
};
event_task
.await
.map_err(|e| anyhow!("replay router event task failed: {e}"))?;
Ok(())
}
}
#[expect(
clippy::large_enum_variant,
reason = "ReplayRouter is long-lived and the KV router variant is intentional"
)]
pub(crate) enum ReplayRouter {
RoundRobin(RoundRobinRouter),
Kv(KvReplayRouter),
}
impl ReplayRouter {
pub(crate) fn new(
mode: ReplayRouterMode,
args: &MockEngineArgs,
router_config: Option<KvRouterConfig>,
num_workers: usize,
) -> Self {
match mode {
ReplayRouterMode::RoundRobin => Self::RoundRobin(RoundRobinRouter::default()),
ReplayRouterMode::KvRouter => {
Self::Kv(KvReplayRouter::new(args, router_config, num_workers))
}
}
}
pub(crate) fn sink(&self, worker_id: WorkerId) -> KvEventPublishers {
match self {
Self::RoundRobin(_) => KvEventPublishers::default(),
Self::Kv(router) => KvEventPublishers::new(Some(router.sink(worker_id)), None),
}
}
pub(crate) async fn select_worker(
&self,
request: &DirectRequest,
num_workers: usize,
) -> Result<usize> {
match self {
Self::RoundRobin(router) => Ok(router.select_worker(num_workers)),
Self::Kv(router) => router.select_worker(request).await,
}
}
pub(crate) async fn on_first_token(&self, uuid: Uuid) -> Result<bool> {
match self {
Self::RoundRobin(_) => Ok(false),
Self::Kv(router) => {
router.mark_prefill_completed(uuid).await?;
Ok(true)
}
}
}
pub(crate) async fn on_complete(&self, uuid: Uuid) -> Result<bool> {
match self {
Self::RoundRobin(_) => Ok(false),
Self::Kv(router) => {
router.free(uuid).await?;
Ok(true)
}
}
}
pub(crate) async fn shutdown(&self) -> Result<()> {
match self {
Self::RoundRobin(_) => Ok(()),
Self::Kv(router) => router.shutdown().await,
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::HashMap;
use std::future;
use std::sync::Arc;
use dynamo_kv_router::config::KvRouterConfig;
use dynamo_kv_router::protocols::{
ActiveLoad, ActiveSequenceEvent, WorkerConfigLike, WorkerId, WorkerWithDpRank,
};
use dynamo_kv_router::scheduling::queue::DEFAULT_MAX_BATCHED_TOKENS;
use dynamo_kv_router::{
ActiveSequencesMultiWorker, DefaultWorkerSelector, LocalScheduler, RouterSchedulingPolicy,
SequencePublisher,
};
use crate::common::protocols::MockEngineArgs;
#[derive(Clone, Copy, Debug, Default)]
pub(super) struct ReplayNoopPublisher;
impl SequencePublisher for ReplayNoopPublisher {
fn publish_event(
&self,
_event: &ActiveSequenceEvent,
) -> impl future::Future<Output = anyhow::Result<()>> + Send {
future::ready(Ok(()))
}
fn publish_load(&self, _load: ActiveLoad) {}
fn observe_load(&self, _: &WorkerWithDpRank, _: &str, _: usize, _: usize) {}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub(super) struct ReplayWorkerConfig {
pub(super) max_num_batched_tokens: u64,
pub(super) total_kv_blocks: u64,
}
impl WorkerConfigLike for ReplayWorkerConfig {
fn data_parallel_start_rank(&self) -> u32 {
0
}
fn data_parallel_size(&self) -> u32 {
1
}
fn max_num_batched_tokens(&self) -> Option<u64> {
Some(self.max_num_batched_tokens)
}
fn total_kv_blocks(&self) -> Option<u64> {
Some(self.total_kv_blocks)
}
}
pub(super) type ReplayScheduler = LocalScheduler<
ReplayNoopPublisher,
ReplayWorkerConfig,
RouterSchedulingPolicy,
DefaultWorkerSelector,
>;
fn replay_worker_config(args: &MockEngineArgs) -> ReplayWorkerConfig {
ReplayWorkerConfig {
max_num_batched_tokens: args
.max_num_batched_tokens
.map(|tokens| tokens as u64)
.unwrap_or(DEFAULT_MAX_BATCHED_TOKENS),
total_kv_blocks: args.num_gpu_blocks as u64,
}
}
pub(super) fn replay_workers_with_configs(
args: &MockEngineArgs,
num_workers: usize,
) -> HashMap<WorkerId, ReplayWorkerConfig> {
let worker_config = replay_worker_config(args);
(0..num_workers)
.map(|worker_idx| (worker_idx as WorkerId, worker_config.clone()))
.collect()
}
pub(super) fn replay_slots(
args: &MockEngineArgs,
workers_with_configs: &HashMap<WorkerId, ReplayWorkerConfig>,
) -> Arc<ActiveSequencesMultiWorker<ReplayNoopPublisher>> {
let dp_range = workers_with_configs
.keys()
.copied()
.map(|worker_id| (worker_id, (0, 1)))
.collect();
Arc::new(ActiveSequencesMultiWorker::new(
ReplayNoopPublisher,
args.block_size,
dp_range,
false,
0,
"replay",
))
}
pub(super) fn replay_selector(config: &KvRouterConfig) -> DefaultWorkerSelector {
DefaultWorkerSelector::new(Some(config.clone()), "replay")
}
pub(super) fn replay_router_config(
args: &MockEngineArgs,
router_config: Option<KvRouterConfig>,
) -> KvRouterConfig {
let mut config = router_config.unwrap_or_default();
if let Some(policy) = args.router_queue_policy {
config.router_queue_policy = policy;
}
config
}
pub(super) fn replay_policy(
config: &KvRouterConfig,
args: &MockEngineArgs,
) -> RouterSchedulingPolicy {
RouterSchedulingPolicy::new(config.router_queue_policy, args.block_size)
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use anyhow::{Result, bail};
use super::ReplayRouterMode;
use crate::common::protocols::{MockEngineArgs, WorkerType};
fn validate_replay_args(args: &MockEngineArgs, num_workers: usize, mode: &str) -> Result<()> {
if num_workers == 0 {
bail!("{mode} requires num_workers >= 1");
}
if args.worker_type != WorkerType::Aggregated {
bail!(
"{mode} only supports aggregated workers, got {:?}",
args.worker_type,
);
}
if args.dp_size != 1 {
bail!(
"{mode} only supports data_parallel_size=1, got {}",
args.dp_size,
);
}
Ok(())
}
fn validate_offline_router_mode(router_mode: ReplayRouterMode, num_workers: usize) -> Result<()> {
if router_mode != ReplayRouterMode::KvRouter {
return Ok(());
}
if num_workers > 1 {
return Ok(());
}
bail!("offline replay only supports router_mode=kv_router when num_workers > 1");
}
pub(super) fn validate_offline_replay_args(
args: &MockEngineArgs,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<()> {
validate_offline_router_mode(router_mode, num_workers)?;
validate_replay_args(args, num_workers, "trace replay")
}
pub(super) fn validate_offline_concurrency_args(
args: &MockEngineArgs,
num_workers: usize,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> Result<()> {
if max_in_flight == 0 {
bail!("concurrency replay requires max_in_flight >= 1");
}
validate_offline_router_mode(router_mode, num_workers)?;
validate_replay_args(args, num_workers, "concurrency replay")
}
pub(super) fn validate_online_replay_args(args: &MockEngineArgs, num_workers: usize) -> Result<()> {
validate_replay_args(args, num_workers, "online replay")
}
pub(super) fn validate_online_concurrency_args(
args: &MockEngineArgs,
num_workers: usize,
max_in_flight: usize,
) -> Result<()> {
if max_in_flight == 0 {
bail!("online concurrency replay requires max_in_flight >= 1");
}
validate_replay_args(args, num_workers, "online replay")
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::{Arc, Mutex};
use anyhow::Result;
use dynamo_kv_router::protocols::{KvCacheEvent, RouterEvent, WorkerId};
use crate::common::protocols::{KvCacheEventSink, KvEventPublishers, RawKvEvent, RawKvEventSink};
/// Captures router-ready events for offline replay and scheduler tests.
///
/// This path converts raw KV events into `RouterEvent`s immediately because the
/// caller only needs worker-tagged router events, not the original token-id
/// payloads used by the live publisher path.
#[derive(Clone, Default)]
pub(crate) struct CapturedRouterEventBuffer {
events: Arc<Mutex<Vec<RouterEvent>>>,
}
impl CapturedRouterEventBuffer {
pub(crate) fn push(&self, event: RouterEvent) {
self.events.lock().unwrap().push(event);
}
pub(crate) fn drain(&self) -> Vec<RouterEvent> {
std::mem::take(&mut *self.events.lock().unwrap())
}
}
/// Sink implementation that records `RouterEvent`s into
/// `CapturedRouterEventBuffer`.
#[derive(Clone)]
struct RouterEventCaptureSink {
worker_id: WorkerId,
buffer: CapturedRouterEventBuffer,
}
impl KvCacheEventSink for RouterEventCaptureSink {
fn publish(&self, event: KvCacheEvent) -> Result<()> {
self.buffer.push(RouterEvent::new(self.worker_id, event));
Ok(())
}
}
/// Returns the capture buffer plus a sink handle that can be passed into a
/// scheduler core for offline replay or tests.
pub(crate) fn capture_router_event_sink(
worker_id: WorkerId,
) -> (CapturedRouterEventBuffer, Arc<dyn KvCacheEventSink>) {
let buffer = CapturedRouterEventBuffer::default();
let sink: Arc<dyn KvCacheEventSink> = Arc::new(RouterEventCaptureSink {
worker_id,
buffer: buffer.clone(),
});
(buffer, sink)
}
/// Raw KV event payload buffered by the live scheduler so it can forward the
/// event to the real publisher sink at the correct pass phase.
#[derive(Debug, Clone)]
pub(crate) struct DeferredKvPublish {
pub(crate) event: KvCacheEvent,
pub(crate) block_token_ids: Option<Vec<Vec<u32>>>,
}
/// Captures raw KV publishes for the live `python -m dynamo.mocker` and online
/// replay paths.
///
/// Unlike `CapturedRouterEventBuffer`, this keeps `block_token_ids` so delayed
/// forwarding still works for sinks like ZMQ publishers that need the original
/// token-id payloads.
#[derive(Clone, Default)]
pub(crate) struct DeferredKvPublishBuffer {
events: Arc<Mutex<Vec<DeferredKvPublish>>>,
}
impl DeferredKvPublishBuffer {
pub(crate) fn push(&self, event: KvCacheEvent, block_token_ids: Option<Vec<Vec<u32>>>) {
self.events.lock().unwrap().push(DeferredKvPublish {
event,
block_token_ids,
});
}
pub(crate) fn drain(&self) -> Vec<DeferredKvPublish> {
std::mem::take(&mut *self.events.lock().unwrap())
}
}
/// Sink implementation that records raw KV publishes into
/// `DeferredKvPublishBuffer` instead of forwarding them immediately.
#[derive(Clone, Default)]
struct DeferredKvEventSink {
buffer: DeferredKvPublishBuffer,
}
impl KvCacheEventSink for DeferredKvEventSink {
fn publish(&self, event: KvCacheEvent) -> Result<()> {
self.buffer.push(event, None);
Ok(())
}
}
#[derive(Clone, Default)]
struct DeferredRawKvEventSink {
buffer: DeferredKvPublishBuffer,
}
impl RawKvEventSink for DeferredRawKvEventSink {
fn publish(&self, event: RawKvEvent) -> Result<()> {
let mut events = self.buffer.events.lock().unwrap();
if let Some(last) = events.last_mut()
&& last.event.event_id == event.event.event_id
&& last.event.dp_rank == event.event.dp_rank
{
last.block_token_ids = event.block_token_ids;
return Ok(());
}
events.push(DeferredKvPublish {
event: event.event,
block_token_ids: event.block_token_ids,
});
Ok(())
}
}
/// Returns the deferred-publish buffer plus a sink handle that can be passed
/// into the live scheduler core while `live.rs` retains control over when the
/// buffered events are forwarded to the real sink.
pub(crate) fn capture_deferred_kv_publish_sink(
capture_raw: bool,
) -> (DeferredKvPublishBuffer, KvEventPublishers) {
let buffer = DeferredKvPublishBuffer::default();
let event_sink: Arc<dyn KvCacheEventSink> = Arc::new(DeferredKvEventSink {
buffer: buffer.clone(),
});
let raw_sink = capture_raw.then(|| {
Arc::new(DeferredRawKvEventSink {
buffer: buffer.clone(),
}) as Arc<dyn RawKvEventSink>
});
(buffer, KvEventPublishers::new(Some(event_sink), raw_sink))
}
/// Forwards buffered live-scheduler KV events to the real sink once the pass
/// reaches the configured visibility point.
pub(crate) fn publish_deferred_kv_events(
sinks: &KvEventPublishers,
events: Vec<DeferredKvPublish>,
) {
for event in events {
if let Err(error) = sinks.publish(event.event, event.block_token_ids.as_deref()) {
tracing::warn!("Failed to forward buffered KV event: {error}");
}
}
}
......@@ -3,15 +3,157 @@
//! Engine-specific scheduling implementations.
mod kv_event_sink;
#[path = "sglang/mod.rs"]
pub mod sglang;
pub mod vllm;
use crate::common::protocols::DirectRequest;
use crate::common::protocols::{DirectRequest, KvEventPublishers, OutputSignal};
use dynamo_kv_router::protocols::RouterEvent;
pub(crate) use kv_event_sink::{
CapturedRouterEventBuffer, capture_deferred_kv_publish_sink, capture_router_event_sink,
publish_deferred_kv_events,
};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
pub(crate) use sglang::SglangCore;
pub use sglang::SglangScheduler;
pub(crate) use vllm::VllmCore;
pub use vllm::{MockerMetrics, Scheduler};
#[derive(Debug, Clone)]
pub(crate) struct AdmissionEvent {
pub(crate) uuid: Uuid,
pub(crate) reused_input_tokens: usize,
}
#[derive(Debug, Clone)]
pub(crate) struct EnginePassResult {
pub(crate) end_ms: f64,
pub(crate) completed_requests: usize,
pub(crate) output_signals: Vec<OutputSignal>,
pub(crate) admissions: Vec<AdmissionEvent>,
pub(crate) active_decode_blocks: u64,
/// Controls when replay/live schedulers should expose this pass's buffered
/// KV events to the real router or publisher sink.
pub(crate) router_event_visibility: RouterEventVisibility,
/// Router-visible KV events emitted during this pass.
pub(crate) kv_events: Vec<RouterEvent>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum RouterEventVisibility {
/// Expose buffered KV events when the pass starts, before the modeled sleep.
PassStart,
/// Expose buffered KV events when the pass finishes, before output flush.
PassEnd,
}
#[allow(clippy::large_enum_variant)]
pub(crate) enum EngineCore {
Vllm(VllmCore),
Sglang(SglangCore),
}
impl EngineCore {
pub(crate) fn receive(&mut self, request: DirectRequest) -> Uuid {
match self {
Self::Vllm(core) => core.receive(request),
Self::Sglang(core) => core.receive(request),
}
}
pub(crate) fn is_empty(&self) -> bool {
match self {
Self::Vllm(core) => core.is_empty(),
Self::Sglang(core) => core.is_empty(),
}
}
pub(crate) fn num_requests(&self) -> usize {
match self {
Self::Vllm(core) => core.num_requests(),
Self::Sglang(core) => core.num_requests(),
}
}
pub(crate) fn execute_pass(
&mut self,
collector: &mut crate::replay::TraceCollector,
now_ms: f64,
) -> EnginePassResult {
match self {
Self::Vllm(core) => core.execute_pass(collector, now_ms),
Self::Sglang(core) => core.execute_pass(collector, now_ms),
}
}
}
#[derive(Clone)]
pub(crate) enum EngineScheduler {
Vllm(Scheduler),
Sglang(SglangScheduler),
}
impl EngineScheduler {
pub(crate) fn new_with_admission(
args: crate::common::protocols::MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
) -> Self {
match args.engine_type {
crate::common::protocols::EngineType::Vllm => {
Self::Vllm(Scheduler::new_with_admission(
args,
dp_rank,
output_tx,
kv_event_publishers,
cancellation_token,
admission_tx,
))
}
crate::common::protocols::EngineType::Sglang => {
Self::Sglang(SglangScheduler::new_with_admission(
args,
dp_rank,
output_tx,
kv_event_publishers,
cancellation_token,
admission_tx,
))
}
}
}
}
impl SchedulerHandle for EngineScheduler {
fn receive(&self, request: DirectRequest) {
match self {
Self::Vllm(scheduler) => scheduler.receive(request),
Self::Sglang(scheduler) => scheduler.receive(request),
}
}
fn request_sender(&self) -> mpsc::UnboundedSender<DirectRequest> {
match self {
Self::Vllm(scheduler) => scheduler.request_sender(),
Self::Sglang(scheduler) => scheduler.request_sender(),
}
}
fn metrics_receiver(&self) -> tokio::sync::watch::Receiver<MockerMetrics> {
match self {
Self::Vllm(scheduler) => scheduler.metrics_receiver(),
Self::Sglang(scheduler) => scheduler.metrics_receiver(),
}
}
}
/// Engine-agnostic scheduler interface.
///
/// Both vLLM and SGLang schedulers implement this trait so that the engine
......@@ -29,87 +171,4 @@ pub trait SchedulerHandle: Send + Sync {
/// Shared test utilities for scheduler stress tests.
#[cfg(test)]
pub(crate) mod test_utils {
use super::*;
use crate::common::protocols::OutputSignal;
use tokio::time::Duration;
/// Send `num_requests` to a scheduler, collect all output signals, and assert
/// that the scheduler produces exactly `num_requests * max_output_tokens` signals
/// and returns to idle (0 active decode blocks).
///
/// When `use_shared_tokens` is true, the first half of each request shares a
/// common prefix to exercise prefix caching / radix tree reuse.
pub async fn assert_scheduler_completes_all(
scheduler: &dyn SchedulerHandle,
output_rx: &mut mpsc::UnboundedReceiver<OutputSignal>,
num_requests: usize,
input_len: usize,
max_output_tokens: usize,
use_shared_tokens: bool,
) {
let shared_tokens = if use_shared_tokens {
Some(
(0..input_len / 2)
.map(|_| rand::random::<u32>() % 50000)
.collect::<Vec<_>>(),
)
} else {
None
};
for _ in 0..num_requests {
let input_tokens = if let Some(ref shared) = shared_tokens {
let mut tokens = shared.clone();
tokens.extend((0..input_len / 2).map(|_| rand::random::<u32>() % 50000));
tokens
} else {
(0..input_len)
.map(|_| rand::random::<u32>() % 50000)
.collect::<Vec<_>>()
};
scheduler.receive(DirectRequest {
tokens: input_tokens,
max_output_tokens,
uuid: None,
dp_rank: 0,
arrival_timestamp_ms: None,
});
}
let expected_tokens = num_requests * max_output_tokens;
let mut received_tokens = 0;
let timeout = tokio::time::sleep(Duration::from_secs(2));
tokio::pin!(timeout);
loop {
tokio::select! {
biased;
Some(_) = output_rx.recv() => {
received_tokens += 1;
if received_tokens >= expected_tokens {
break;
}
timeout.set(tokio::time::sleep(Duration::from_secs(2)));
}
_ = &mut timeout => break,
}
}
assert_eq!(
received_tokens, expected_tokens,
"Expected {expected_tokens} output signals, got {received_tokens}"
);
// Verify scheduler returns to idle
tokio::time::sleep(Duration::from_millis(100)).await;
let metrics = scheduler.metrics_receiver().borrow().clone();
assert_eq!(
metrics.active_decode_blocks, 0,
"Scheduler should be idle after all requests complete, got {} active blocks",
metrics.active_decode_blocks
);
}
}
pub(crate) mod test_utils;
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! SGLang scheduler simulation with adaptive admission control.
//!
//! Reference: sglang/python/sglang/srt/managers/scheduler.py
use std::collections::VecDeque;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::mpsc;
use tokio::time::Duration;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use validator::Validate;
use crate::cache::radix_cache::NodeId;
use crate::common::perf_model::PerfModel;
use crate::common::protocols::{
DirectRequest, KvCacheEventSink, MockEngineArgs, OutputSignal, WorkerType,
};
use crate::common::utils::sleep_until_precise;
use crate::kv_manager::SglangKvManager;
use super::MockerMetrics;
// SGLang default constants
const DEFAULT_MAX_PREFILL_TOKENS: usize = 16384;
const DEFAULT_CHUNKED_PREFILL_SIZE: usize = 8192;
const DEFAULT_CLIP_MAX_NEW_TOKENS: usize = 4096;
const DEFAULT_INIT_NEW_TOKEN_RATIO: f64 = 0.7;
const DEFAULT_MIN_NEW_TOKEN_RATIO_FACTOR: f64 = 0.14;
const DEFAULT_NEW_TOKEN_RATIO_DECAY_STEPS: f64 = 600.0;
const LPM_FALLBACK_THRESHOLD: usize = 128;
/// Tracks a single request inside the SGLang scheduler.
struct SglangRequest {
uuid: Uuid,
token_ids: Vec<u64>,
max_output_tokens: usize,
output_len: usize,
/// Deepest matched node in radix tree.
last_node: Option<NodeId>,
/// Pool page indices for the full sequence.
kv_indices: Vec<usize>,
/// Number of input tokens already prefilled (for chunked prefill).
prefilled_tokens: usize,
}
impl SglangRequest {
fn total_tokens_needed(&self, clip_max_new_tokens: usize) -> usize {
let remaining_input = self.token_ids.len() - self.prefilled_tokens;
let clipped_output = self.max_output_tokens.min(clip_max_new_tokens);
remaining_input + clipped_output
}
fn extend_input_len(&self) -> usize {
self.token_ids.len() - self.prefilled_tokens
}
}
/// SGLang scheduler with adaptive admission control.
///
/// The scheduling loop mirrors SGLang's `Scheduler.event_loop_normal`:
/// `receive_requests → apply_schedule_policy → get_new_batch_prefill →
/// simulate_prefill → simulate_decode → decay_new_token_ratio`
pub struct SglangScheduler {
request_tx: mpsc::UnboundedSender<DirectRequest>,
metrics_rx: tokio::sync::watch::Receiver<MockerMetrics>,
_cancel_guard: Arc<CancelGuard>,
}
struct CancelGuard(CancellationToken);
impl Drop for CancelGuard {
fn drop(&mut self) {
self.0.cancel();
}
}
/// Scheduling policy for reordering the waiting queue.
#[derive(Clone, Copy, Debug, Default)]
pub enum SchedulePolicy {
/// Process in arrival order.
#[default]
Fifo,
/// Longest prefix match — prioritise requests with the most cached tokens.
/// Falls back to FIFO when `waiting.len() > 128` (prefix matching is expensive).
Lpm,
}
/// Configuration extracted from MockEngineArgs for SGLang-specific params.
struct SglangConfig {
schedule_policy: SchedulePolicy,
max_prefill_tokens: usize,
chunked_prefill_size: usize,
clip_max_new_tokens: usize,
init_new_token_ratio: f64,
min_new_token_ratio: f64,
new_token_ratio_decay_step: f64,
perf_model: Arc<PerfModel>,
speedup_ratio: f64,
worker_type: WorkerType,
page_size: usize,
}
impl SglangConfig {
fn from_args(args: &MockEngineArgs) -> Self {
let sglang = args.sglang.as_ref();
let schedule_conservativeness = sglang
.and_then(|s| s.schedule_conservativeness)
.unwrap_or(1.0);
let init_new_token_ratio = DEFAULT_INIT_NEW_TOKEN_RATIO * schedule_conservativeness;
let min_new_token_ratio = init_new_token_ratio * DEFAULT_MIN_NEW_TOKEN_RATIO_FACTOR;
let decay_steps = DEFAULT_NEW_TOKEN_RATIO_DECAY_STEPS;
let decay_step = (init_new_token_ratio - min_new_token_ratio) / decay_steps;
let policy_str = sglang.and_then(|s| s.schedule_policy.as_deref());
let schedule_policy = match policy_str {
Some("lpm") => SchedulePolicy::Lpm,
Some("fifo") | Some("fcfs") | None => SchedulePolicy::Fifo,
Some(other) => {
tracing::warn!(
"Unknown sglang schedule_policy '{}', falling back to FIFO",
other
);
SchedulePolicy::Fifo
}
};
Self {
schedule_policy,
max_prefill_tokens: sglang
.and_then(|s| s.max_prefill_tokens)
.unwrap_or(DEFAULT_MAX_PREFILL_TOKENS),
chunked_prefill_size: sglang
.and_then(|s| s.chunked_prefill_size)
.unwrap_or(DEFAULT_CHUNKED_PREFILL_SIZE),
clip_max_new_tokens: sglang
.and_then(|s| s.clip_max_new_tokens)
.unwrap_or(DEFAULT_CLIP_MAX_NEW_TOKENS),
init_new_token_ratio,
min_new_token_ratio,
new_token_ratio_decay_step: decay_step,
perf_model: args.perf_model.clone(),
speedup_ratio: args.speedup_ratio,
worker_type: args.worker_type,
page_size: sglang.and_then(|s| s.page_size).unwrap_or(1),
}
}
}
impl SglangScheduler {
pub fn new(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
kv_event_sink: Option<Arc<dyn KvCacheEventSink>>,
cancellation_token: Option<CancellationToken>,
) -> Self {
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<DirectRequest>();
let initial_metrics = MockerMetrics {
dp_rank,
active_decode_blocks: 0,
};
let (metrics_tx, metrics_rx) =
tokio::sync::watch::channel::<MockerMetrics>(initial_metrics);
let cancel_token = cancellation_token.unwrap_or_default();
let cancel_token_clone = cancel_token.clone();
let cancel_guard = Arc::new(CancelGuard(cancel_token));
args.validate().expect("invalid MockEngineArgs");
let config = SglangConfig::from_args(&args);
let total_tokens = args.num_gpu_blocks * args.block_size;
tokio::spawn(async move {
let mut kv_manager =
SglangKvManager::new(total_tokens, config.page_size, kv_event_sink, dp_rank);
let mut waiting: VecDeque<SglangRequest> = VecDeque::new();
let mut running: Vec<SglangRequest> = Vec::new();
let mut new_token_ratio = config.init_new_token_ratio;
loop {
// 1. Receive requests
if receive_requests(&mut waiting, &mut request_rx, &cancel_token_clone, &running)
.await
.is_none()
{
break;
}
// 2. Apply scheduling policy
apply_schedule_policy(&mut waiting, &kv_manager, &config);
// 3. Admit new requests for prefill
let admit = get_new_batch_prefill(
&mut waiting,
&mut kv_manager,
&config,
new_token_ratio,
&running,
);
if admit.oom {
new_token_ratio = config.init_new_token_ratio;
}
// 4. Simulate prefill
let batch_size = admit.can_run.len();
let mean_isl = if batch_size > 0 {
admit.total_isl / batch_size
} else {
0
};
let mean_prefix = if batch_size > 0 {
admit.total_prefix / batch_size
} else {
0
};
simulate_prefill(batch_size, mean_isl, mean_prefix, &config).await;
// Separate fully-prefilled from chunked requests
for mut req in admit.can_run {
if req.prefilled_tokens < req.token_ids.len() {
// Chunked prefill: cache partial sequence, put back in waiting
if let Some(last_node) = req.last_node {
let new_last = kv_manager.cache_unfinished_req(
&req.token_ids[..req.prefilled_tokens],
&req.kv_indices,
last_node,
);
req.last_node = Some(new_last);
}
waiting.push_front(req);
} else {
running.push(req);
}
}
// 5. Simulate decode (may retract requests under memory pressure)
let retracted = simulate_decode(
&mut running,
&mut kv_manager,
&output_tx,
&config,
dp_rank,
&metrics_tx,
)
.await;
if !retracted.is_empty() {
// Retracted requests go back to the front of the waiting queue
for req in retracted.into_iter().rev() {
waiting.push_front(req);
}
// Reset new_token_ratio like SGLang does after retraction
new_token_ratio = config.init_new_token_ratio;
}
// 6. Decay new_token_ratio
new_token_ratio = (new_token_ratio - config.new_token_ratio_decay_step)
.max(config.min_new_token_ratio);
}
});
Self {
request_tx,
metrics_rx,
_cancel_guard: cancel_guard,
}
}
}
impl super::SchedulerHandle for SglangScheduler {
fn receive(&self, request: DirectRequest) {
let _ = self.request_tx.send(request);
}
fn request_sender(&self) -> mpsc::UnboundedSender<DirectRequest> {
self.request_tx.clone()
}
fn metrics_receiver(&self) -> tokio::sync::watch::Receiver<MockerMetrics> {
self.metrics_rx.clone()
}
}
async fn receive_requests(
waiting: &mut VecDeque<SglangRequest>,
request_rx: &mut mpsc::UnboundedReceiver<DirectRequest>,
cancel_token: &CancellationToken,
running: &[SglangRequest],
) -> Option<()> {
if cancel_token.is_cancelled() {
return None;
}
if waiting.is_empty() && running.is_empty() {
// Fully idle — block until request or shutdown
tokio::select! {
biased;
_ = cancel_token.cancelled() => return None,
result = request_rx.recv() => {
let request = result?;
waiting.push_back(direct_to_sglang(request));
}
}
}
// Drain any pending requests without blocking
while let Ok(request) = request_rx.try_recv() {
waiting.push_back(direct_to_sglang(request));
}
Some(())
}
fn direct_to_sglang(req: DirectRequest) -> SglangRequest {
SglangRequest {
uuid: req.uuid.unwrap_or_else(Uuid::new_v4),
token_ids: req.tokens.iter().map(|&t| t as u64).collect(),
max_output_tokens: req.max_output_tokens,
output_len: 0,
last_node: None,
kv_indices: Vec::new(),
prefilled_tokens: 0,
}
}
/// Reorder waiting queue based on scheduling policy.
fn apply_schedule_policy(
waiting: &mut VecDeque<SglangRequest>,
kv_manager: &SglangKvManager,
config: &SglangConfig,
) {
match config.schedule_policy {
SchedulePolicy::Fifo => {} // already in arrival order
SchedulePolicy::Lpm => {
if waiting.len() > LPM_FALLBACK_THRESHOLD {
return; // too expensive, fall back to FIFO
}
// Score each request by prefix match length (read-only, no mutation)
let mut scored: Vec<(usize, SglangRequest)> = waiting
.drain(..)
.map(|req| {
let prefix_len = kv_manager.cache().prefix_match_len(&req.token_ids);
(prefix_len, req)
})
.collect();
// Sort descending by prefix match length (stable sort preserves FIFO for ties)
scored.sort_by(|a, b| b.0.cmp(&a.0));
for (_, req) in scored {
waiting.push_back(req);
}
}
}
}
struct AdmitResult {
can_run: Vec<SglangRequest>,
/// Sum of ISL values across admitted requests (for computing mean).
total_isl: usize,
/// Sum of prefix (cached tokens) across admitted requests (for computing mean).
total_prefix: usize,
oom: bool,
}
/// Admit requests from waiting queue within budget constraints.
fn get_new_batch_prefill(
waiting: &mut VecDeque<SglangRequest>,
kv_manager: &mut SglangKvManager,
config: &SglangConfig,
new_token_ratio: f64,
running: &[SglangRequest],
) -> AdmitResult {
let cache = kv_manager.cache();
let reserved: f64 = running
.iter()
.map(|req| {
let remaining_output =
(req.max_output_tokens - req.output_len).min(config.clip_max_new_tokens);
remaining_output as f64 * new_token_ratio
})
.sum();
let mut rem_total_tokens = (cache.available_tokens() + cache.evictable_size) as f64 - reserved;
let mut rem_input_tokens = config.max_prefill_tokens as f64;
let mut rem_chunk_tokens = config.chunked_prefill_size as f64;
let mut can_run = Vec::new();
let mut rejected = VecDeque::new();
let mut oom = false;
let mut total_isl: usize = 0;
let mut total_prefix: usize = 0;
while let Some(mut req) = waiting.pop_front() {
let extend_input = req.extend_input_len() as f64;
let total_needed = req.total_tokens_needed(config.clip_max_new_tokens) as f64;
// For chunked prefill: check against the chunk size, not the full input.
let effective_input = extend_input.min(config.chunked_prefill_size as f64);
if total_needed > rem_total_tokens || effective_input > rem_input_tokens {
rejected.push_back(req);
break;
}
// Keep previous chunk lock alive to protect cached prefix from eviction.
// Released after allocate_for_request secures its own lock.
let prev_node = req.last_node.take();
// Determine chunk boundary before allocation
let chunk_end = if extend_input > rem_chunk_tokens && rem_chunk_tokens > 0.0 {
let chunk = (rem_chunk_tokens as usize) / config.page_size * config.page_size;
if chunk > 0 {
req.prefilled_tokens + chunk
} else {
req.token_ids.len()
}
} else {
req.token_ids.len()
};
let alloc_tokens = &req.token_ids[..chunk_end];
let prefix_len = kv_manager.cache().prefix_match_len(alloc_tokens);
let needed_new = alloc_tokens.len() - prefix_len;
let available = kv_manager.cache().token_pool.available();
if available < needed_new {
kv_manager.evict(needed_new - available);
}
let alloc = kv_manager.allocate_for_request(alloc_tokens);
let Some(alloc) = alloc else {
// Restore lock on rejection so the cached prefix stays protected
req.last_node = prev_node;
rejected.push_back(req);
oom = true;
break;
};
// New allocation has its own lock; release the previous one
if let Some(node) = prev_node {
kv_manager.free_request(node);
}
req.last_node = Some(alloc.last_node);
req.kv_indices = alloc.kv_indices;
req.prefilled_tokens = chunk_end;
let actual_prefilled = (chunk_end - (req.token_ids.len() - extend_input as usize)) as f64;
total_isl += chunk_end;
total_prefix += alloc.prefix_len;
rem_total_tokens -= total_needed;
rem_input_tokens -= actual_prefilled;
rem_chunk_tokens -= actual_prefilled;
can_run.push(req);
if rem_chunk_tokens <= 0.0 {
break;
}
}
while let Some(req) = rejected.pop_back() {
waiting.push_front(req);
}
AdmitResult {
can_run,
total_isl,
total_prefix,
oom,
}
}
async fn simulate_prefill(
batch_size: usize,
mean_isl: usize,
mean_prefix: usize,
config: &SglangConfig,
) {
if batch_size == 0 || config.worker_type == WorkerType::Decode {
return;
}
let start = Instant::now();
let prefill_time = config
.perf_model
.predict_prefill_time(batch_size, mean_isl, mean_prefix);
let total_time = Duration::from_secs_f64(prefill_time / 1000.0);
if config.speedup_ratio > 0.0 && total_time > Duration::ZERO {
let sleep_duration =
Duration::from_secs_f64(total_time.as_secs_f64() / config.speedup_ratio);
sleep_until_precise(start + sleep_duration).await;
}
}
/// Check if the pool has enough tokens for one decode step of the entire batch.
/// Tries eviction first; if still short, retracts requests by output_len desc
/// (matching SGLang's retract_decode policy) until enough memory is available.
/// Returns retracted requests that should go back to the waiting queue.
fn check_decode_mem(
running: &mut Vec<SglangRequest>,
kv_manager: &mut SglangKvManager,
) -> Vec<SglangRequest> {
let needed = running.len();
let available = kv_manager.cache().token_pool.available();
let evictable = kv_manager.cache().evictable_size;
if available + evictable >= needed {
// Evict just enough to cover the deficit
if available < needed {
kv_manager.evict(needed - available);
}
return Vec::new();
}
// Not enough even after full eviction — retract requests.
// Sort indices by output_len descending (longest-running first, like SGLang).
let mut sorted_indices: Vec<usize> = (0..running.len()).collect();
sorted_indices.sort_by(|&a, &b| running[b].output_len.cmp(&running[a].output_len));
let mut freed = 0usize;
while available + evictable + freed < sorted_indices.len() {
if sorted_indices.len() <= 1 {
break; // always keep at least one request
}
let idx = sorted_indices.pop().unwrap();
let req = &running[idx];
// Free this request's KV indices and radix lock
let kv_len = req.kv_indices.len();
kv_manager.cache_mut().token_pool.free(&req.kv_indices);
if let Some(last_node) = req.last_node {
kv_manager.free_request(last_node);
}
freed += kv_len;
// Mark index for removal (we'll collect in a second pass)
sorted_indices.retain(|&i| i != idx);
}
// Remove retracted requests from running (those NOT in sorted_indices).
let remaining_set: std::collections::HashSet<usize> = sorted_indices.into_iter().collect();
let mut remove_indices: Vec<usize> = (0..running.len())
.filter(|i| !remaining_set.contains(i))
.collect();
remove_indices.sort_unstable_by(|a, b| b.cmp(a));
let mut retracted = Vec::with_capacity(remove_indices.len());
for idx in remove_indices {
let mut req = running.swap_remove(idx);
// Reset decode state so it re-enters as a fresh prefill
req.output_len = 0;
req.kv_indices.clear();
req.last_node = None;
req.prefilled_tokens = 0;
retracted.push(req);
}
// Now evict to cover remaining deficit
let available = kv_manager.cache().token_pool.available();
let needed = running.len();
if available < needed {
kv_manager.evict(needed - available);
}
if !retracted.is_empty() {
tracing::warn!(
num_retracted = retracted.len(),
remaining = running.len(),
"SGLang decode retract requests because KV pool is full"
);
}
retracted
}
async fn simulate_decode(
running: &mut Vec<SglangRequest>,
kv_manager: &mut SglangKvManager,
output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>,
config: &SglangConfig,
dp_rank: u32,
metrics_tx: &tokio::sync::watch::Sender<MockerMetrics>,
) -> Vec<SglangRequest> {
if running.is_empty() {
return Vec::new();
}
let start = Instant::now();
let total_context: usize = running
.iter()
.map(|r| r.token_ids.len() + r.output_len)
.sum();
let avg_context = total_context / running.len();
let decode_time =
config
.perf_model
.predict_decode_time(running.len(), total_context, avg_context);
let total_time = Duration::from_secs_f64(decode_time / 1000.0);
// Retract requests if not enough memory for one decode step
let retracted = check_decode_mem(running, kv_manager);
for req in running.iter_mut() {
if kv_manager.cache().token_pool.available() == 0 {
kv_manager.evict(1);
}
let last_idx = req.kv_indices.last().copied();
if let Some(new_idx) = kv_manager.allocate_decode_token(last_idx) {
req.kv_indices.push(new_idx);
req.output_len += 1;
} else {
tracing::warn!(uuid = %req.uuid, "Failed to allocate decode token, skipping output");
}
}
// Send output signals and handle completions
let mut completed_indices = Vec::new();
for (i, req) in running.iter_mut().enumerate() {
let is_complete = req.output_len >= req.max_output_tokens;
if let Some(tx) = output_tx {
let _ = tx.send(OutputSignal {
uuid: req.uuid,
completed: is_complete,
});
}
if is_complete {
let mut all_tokens = req.token_ids.clone();
for j in 0..req.output_len {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
req.uuid.hash(&mut hasher);
j.hash(&mut hasher);
all_tokens.push(hasher.finish());
}
// Page-align and cap by available indices.
let aligned_tokens = (all_tokens.len() / config.page_size) * config.page_size;
let tokens_to_cache = aligned_tokens.min(req.kv_indices.len());
all_tokens.truncate(tokens_to_cache);
// Free excess token indices not covered by the cached sequence.
if req.kv_indices.len() > tokens_to_cache {
let excess = req.kv_indices[tokens_to_cache..].to_vec();
kv_manager.cache_mut().token_pool.free(&excess);
}
if let Some(last_node) = req.last_node {
if tokens_to_cache > 0 {
kv_manager.cache_finished_req(
&all_tokens,
&req.kv_indices[..tokens_to_cache],
last_node,
);
} else {
kv_manager.free_request(last_node);
}
}
completed_indices.push(i);
}
}
// Remove completed requests in reverse order so swap_remove doesn't
// invalidate pending indices (completed_indices is built in ascending order).
for &i in completed_indices.iter().rev() {
running.swap_remove(i);
}
// Publish metrics: active blocks from running requests' total context
let remaining_context: usize = running
.iter()
.map(|r| r.token_ids.len() + r.output_len)
.sum();
let active_blocks = remaining_context / config.page_size;
let _ = metrics_tx.send(MockerMetrics {
dp_rank,
active_decode_blocks: active_blocks as u64,
});
if config.speedup_ratio > 0.0 && total_time > Duration::ZERO {
let sleep_duration =
Duration::from_secs_f64(total_time.as_secs_f64() / config.speedup_ratio);
sleep_until_precise(start + sleep_duration).await;
}
retracted
}
#[cfg(test)]
mod tests {
use super::*;
use crate::common::protocols::SglangArgs;
use crate::scheduler::SchedulerHandle;
use rstest::rstest;
#[tokio::test]
async fn test_sglang_scheduler_fifo_ordering() {
let args = MockEngineArgs::builder()
.num_gpu_blocks(100)
.block_size(64)
.speedup_ratio(100.0)
.build()
.unwrap();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let scheduler = SglangScheduler::new(args, 0, Some(output_tx), None, None);
let num_requests = 5;
let max_output = 3;
for i in 0..num_requests {
scheduler.receive(DirectRequest {
tokens: vec![i as u32; 10],
max_output_tokens: max_output,
uuid: None,
dp_rank: 0,
arrival_timestamp_ms: None,
});
}
// Collect all output signals
let expected_signals = num_requests * max_output;
let mut received = 0;
let timeout = tokio::time::sleep(Duration::from_secs(5));
tokio::pin!(timeout);
loop {
tokio::select! {
Some(_) = output_rx.recv() => {
received += 1;
if received >= expected_signals {
break;
}
timeout.set(tokio::time::sleep(Duration::from_secs(2)));
}
_ = &mut timeout => break,
}
}
assert_eq!(
received, expected_signals,
"Expected {expected_signals} signals, got {received}"
);
}
#[tokio::test]
async fn test_sglang_scheduler_admission_budget() {
// Small pool — only enough for a few requests
let args = MockEngineArgs::builder()
.num_gpu_blocks(2) // 2 * 64 = 128 tokens
.block_size(64)
.speedup_ratio(100.0)
.build()
.unwrap();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let scheduler = SglangScheduler::new(args, 0, Some(output_tx), None, None);
// Send requests that collectively exceed budget
for _ in 0..10 {
scheduler.receive(DirectRequest {
tokens: vec![1; 20],
max_output_tokens: 5,
uuid: None,
dp_rank: 0,
arrival_timestamp_ms: None,
});
}
// Should still complete all eventually (as earlier ones finish, budget frees up)
let mut received = 0;
let timeout = tokio::time::sleep(Duration::from_secs(10));
tokio::pin!(timeout);
loop {
tokio::select! {
Some(_) = output_rx.recv() => {
received += 1;
timeout.set(tokio::time::sleep(Duration::from_secs(2)));
}
_ = &mut timeout => break,
}
}
let expected = 10 * 5;
assert_eq!(
received, expected,
"Expected {expected} signals, got {received}"
);
}
#[test]
fn test_lpm_reorders_by_prefix_match() {
let mut kv_manager = SglangKvManager::new(1000, 1, None, 0);
// Seed cache with [1,2,3,4,5]
kv_manager
.cache_mut()
.insert(&[1, 2, 3, 4, 5], &[0, 1, 2, 3, 4]);
let config = SglangConfig {
schedule_policy: SchedulePolicy::Lpm,
..SglangConfig::from_args(
&MockEngineArgs::builder()
.speedup_ratio(1.0)
.build()
.unwrap(),
)
};
let no_match_uuid = Uuid::new_v4();
let match_uuid = Uuid::new_v4();
let mut waiting: VecDeque<SglangRequest> = VecDeque::new();
// no_match first in FIFO order
waiting.push_back(SglangRequest {
uuid: no_match_uuid,
token_ids: vec![9, 8, 7],
max_output_tokens: 1,
output_len: 0,
last_node: None,
kv_indices: Vec::new(),
prefilled_tokens: 0,
});
// match second in FIFO order
waiting.push_back(SglangRequest {
uuid: match_uuid,
token_ids: vec![1, 2, 3, 4, 5, 6, 7],
max_output_tokens: 1,
output_len: 0,
last_node: None,
kv_indices: Vec::new(),
prefilled_tokens: 0,
});
apply_schedule_policy(&mut waiting, &kv_manager, &config);
// LPM should reorder: match (5-token prefix) before no_match (0-token)
assert_eq!(waiting[0].uuid, match_uuid);
assert_eq!(waiting[1].uuid, no_match_uuid);
}
#[test]
fn test_chunked_prefill_budget() {
let config = SglangConfig {
chunked_prefill_size: 10,
..SglangConfig::from_args(
&MockEngineArgs::builder()
.speedup_ratio(1.0)
.build()
.unwrap(),
)
};
let mut kv_manager = SglangKvManager::new(10000, 1, None, 0);
let mut waiting: VecDeque<SglangRequest> = VecDeque::new();
waiting.push_back(SglangRequest {
uuid: Uuid::new_v4(),
token_ids: vec![1; 20], // 20 tokens > chunked_prefill_size=10
max_output_tokens: 3,
output_len: 0,
last_node: None,
kv_indices: Vec::new(),
prefilled_tokens: 0,
});
let admit = get_new_batch_prefill(&mut waiting, &mut kv_manager, &config, 0.7, &[]);
assert_eq!(admit.can_run.len(), 1);
// Should only prefill 10 tokens (chunked_prefill_size), not all 20
assert_eq!(admit.can_run[0].prefilled_tokens, 10);
assert!(admit.can_run[0].prefilled_tokens < admit.can_run[0].token_ids.len());
}
#[test]
fn test_new_token_ratio_decay_and_oom_reset() {
let config = SglangConfig::from_args(
&MockEngineArgs::builder()
.speedup_ratio(1.0)
.build()
.unwrap(),
);
let mut ratio = config.init_new_token_ratio;
for _ in 0..600 {
ratio = (ratio - config.new_token_ratio_decay_step).max(config.min_new_token_ratio);
}
// After 600 steps, ratio should be at or near minimum
assert!(
(ratio - config.min_new_token_ratio).abs() < 0.01,
"ratio={ratio}, min={}",
config.min_new_token_ratio
);
// Simulate OOM reset
ratio = config.init_new_token_ratio;
assert!((ratio - 0.7).abs() < 0.001);
}
/// Stress test mirroring vLLM's `test_scheduler_token_generation_patterns`.
/// Sends 200 requests × 1000 input × 100 output under heavy eviction pressure
/// and parametrises across `(shared_tokens, schedule_policy, page_size)`.
#[rstest]
#[case::case_1(false, "fifo", 1)]
#[case::case_2(true, "fifo", 1)]
#[case::case_3(false, "lpm", 1)]
#[case::case_4(true, "lpm", 1)]
#[case::case_5(false, "fifo", 4)]
#[case::case_6(true, "fifo", 4)]
#[case::case_7(false, "lpm", 4)]
#[case::case_8(true, "lpm", 4)]
#[tokio::test]
async fn test_sglang_scheduler_token_generation_patterns(
#[case] use_shared_tokens: bool,
#[case] schedule_policy: &str,
#[case] page_size: usize,
) {
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let args = MockEngineArgs::builder()
.num_gpu_blocks(500)
.block_size(64)
.speedup_ratio(10.0)
.sglang(Some(SglangArgs {
schedule_policy: Some(schedule_policy.to_string()),
page_size: Some(page_size),
..Default::default()
}))
.build()
.unwrap();
let scheduler = SglangScheduler::new(args, 0, Some(output_tx), None, None);
crate::scheduler::test_utils::assert_scheduler_completes_all(
&scheduler,
&mut output_rx,
200,
1000,
100,
use_shared_tokens,
)
.await;
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use crate::common::perf_model::PerfModel;
use crate::common::protocols::{MockEngineArgs, WorkerType};
const DEFAULT_MAX_PREFILL_TOKENS: usize = 16384;
const DEFAULT_CHUNKED_PREFILL_SIZE: usize = 8192;
const DEFAULT_CLIP_MAX_NEW_TOKENS: usize = 4096;
const DEFAULT_INIT_NEW_TOKEN_RATIO: f64 = 0.7;
const DEFAULT_MIN_NEW_TOKEN_RATIO_FACTOR: f64 = 0.14;
const DEFAULT_NEW_TOKEN_RATIO_DECAY_STEPS: f64 = 600.0;
pub(super) const LPM_FALLBACK_THRESHOLD: usize = 128;
pub(super) const IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD: usize = 32;
pub(super) const IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD: usize = 32;
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum SchedulePolicy {
#[default]
Fifo,
Lpm,
}
pub(super) struct SglangConfig {
pub(super) schedule_policy: SchedulePolicy,
pub(super) max_prefill_tokens: usize,
pub(super) chunked_prefill_size: usize,
pub(super) clip_max_new_tokens: usize,
pub(super) init_new_token_ratio: f64,
pub(super) min_new_token_ratio: f64,
pub(super) new_token_ratio_decay_step: f64,
pub(super) perf_model: Arc<PerfModel>,
pub(super) speedup_ratio: f64,
pub(super) decode_speedup_ratio: f64,
pub(super) worker_type: WorkerType,
pub(super) block_size: usize,
}
impl SglangConfig {
pub(super) fn from_args(args: &MockEngineArgs) -> Self {
let sglang = args.sglang.as_ref();
let schedule_conservativeness = sglang
.and_then(|s| s.schedule_conservativeness)
.unwrap_or(1.0);
let init_new_token_ratio = DEFAULT_INIT_NEW_TOKEN_RATIO * schedule_conservativeness;
let min_new_token_ratio = init_new_token_ratio * DEFAULT_MIN_NEW_TOKEN_RATIO_FACTOR;
let decay_steps = DEFAULT_NEW_TOKEN_RATIO_DECAY_STEPS;
let decay_step = (init_new_token_ratio - min_new_token_ratio) / decay_steps;
let policy_str = sglang.and_then(|s| s.schedule_policy.as_deref());
let schedule_policy = match policy_str {
Some("lpm") => SchedulePolicy::Lpm,
Some("fifo") | Some("fcfs") | None => SchedulePolicy::Fifo,
Some(other) => {
tracing::warn!(
"Unknown sglang schedule_policy '{}', falling back to FIFO",
other
);
SchedulePolicy::Fifo
}
};
Self {
schedule_policy,
max_prefill_tokens: sglang
.and_then(|s| s.max_prefill_tokens)
.unwrap_or(DEFAULT_MAX_PREFILL_TOKENS),
chunked_prefill_size: sglang
.and_then(|s| s.chunked_prefill_size)
.unwrap_or(DEFAULT_CHUNKED_PREFILL_SIZE),
clip_max_new_tokens: sglang
.and_then(|s| s.clip_max_new_tokens)
.unwrap_or(DEFAULT_CLIP_MAX_NEW_TOKENS),
init_new_token_ratio,
min_new_token_ratio,
new_token_ratio_decay_step: decay_step,
perf_model: args.perf_model.clone(),
speedup_ratio: args.speedup_ratio,
decode_speedup_ratio: args.decode_speedup_ratio,
worker_type: args.worker_type,
block_size: args.block_size,
}
}
}
pub(super) fn ceil_to_block(tokens: usize, block_size: usize) -> usize {
if tokens == 0 {
return 0;
}
tokens.div_ceil(block_size) * block_size
}
pub(super) fn floor_to_block(tokens: usize, block_size: usize) -> usize {
tokens / block_size * block_size
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::VecDeque;
use std::time::Duration;
use dynamo_kv_router::protocols::WorkerId;
use uuid::Uuid;
use crate::common::protocols::{DirectRequest, KvEventPublishers, MockEngineArgs, WorkerType};
use crate::kv_manager::SglangKvManager;
use crate::replay::TraceCollector;
use super::config::SglangConfig;
use super::decode::{cache_materialized_prefix, simulate_decode_step};
use super::policy::apply_schedule_policy;
use super::prefill::get_new_batch_prefill;
use super::request::{SglangRequest, direct_to_sglang};
use crate::scheduler::{
CapturedRouterEventBuffer, EnginePassResult, RouterEventVisibility, capture_router_event_sink,
};
pub(crate) struct SglangCore {
pub(super) config: SglangConfig,
pub(super) waiting: VecDeque<SglangRequest>,
pub(super) running: Vec<SglangRequest>,
pub(super) new_token_ratio: f64,
pub(super) kv_manager: SglangKvManager,
kv_event_buffer: Option<CapturedRouterEventBuffer>,
}
impl SglangCore {
pub(crate) fn new(args: MockEngineArgs) -> Self {
Self::new_internal(args, 0, None, KvEventPublishers::default())
}
pub(crate) fn new_with_kv_capture(args: MockEngineArgs, worker_id: WorkerId) -> Self {
let (buffer, sink) = capture_router_event_sink(worker_id);
Self::new_internal(
args,
worker_id as u32,
Some(buffer),
KvEventPublishers::new(Some(sink), None),
)
}
pub(super) fn new_with_sink(
args: MockEngineArgs,
dp_rank: u32,
kv_event_publishers: KvEventPublishers,
) -> Self {
Self::new_internal(args, dp_rank, None, kv_event_publishers)
}
fn new_internal(
args: MockEngineArgs,
dp_rank: u32,
kv_event_buffer: Option<CapturedRouterEventBuffer>,
kv_event_publishers: KvEventPublishers,
) -> Self {
let args = args.normalized().expect("invalid MockEngineArgs");
let config = SglangConfig::from_args(&args);
let total_tokens = args.num_gpu_blocks * args.block_size;
Self {
config,
waiting: VecDeque::new(),
running: Vec::new(),
new_token_ratio: SglangConfig::from_args(&args).init_new_token_ratio,
kv_manager: SglangKvManager::new(
total_tokens,
args.block_size,
kv_event_publishers,
dp_rank,
),
kv_event_buffer,
}
}
pub(crate) fn receive(&mut self, request: DirectRequest) -> Uuid {
let request = direct_to_sglang(request);
request.debug_assert_invariants(self.config.block_size);
let uuid = request.uuid;
self.waiting.push_back(request);
uuid
}
pub(crate) fn is_empty(&self) -> bool {
self.waiting.is_empty() && self.running.is_empty()
}
pub(crate) fn num_requests(&self) -> usize {
self.waiting.len() + self.running.len()
}
pub(crate) fn execute_pass(
&mut self,
collector: &mut TraceCollector,
now_ms: f64,
) -> EnginePassResult {
self.execute_pass_internal(Some(collector), now_ms)
}
pub(super) fn execute_pass_internal(
&mut self,
mut collector: Option<&mut TraceCollector>,
now_ms: f64,
) -> EnginePassResult {
apply_schedule_policy(&mut self.waiting, &self.kv_manager, &self.config);
let admit = get_new_batch_prefill(
&mut self.waiting,
&mut self.kv_manager,
&self.config,
self.new_token_ratio,
&self.running,
);
if admit.oom {
self.new_token_ratio = self.config.init_new_token_ratio;
}
for admission in &admit.admissions {
if let Some(collector) = collector.as_deref_mut() {
collector.on_admit(admission.uuid, now_ms, admission.reused_input_tokens);
}
}
let batch_size = admit.can_run.len();
let mean_isl = if batch_size > 0 {
admit.total_isl / batch_size
} else {
0
};
let mean_prefix = if batch_size > 0 {
admit.total_prefix / batch_size
} else {
0
};
let prefill_time =
simulate_prefill_duration(batch_size, mean_isl, mean_prefix, &self.config, true);
for mut req in admit.can_run {
if req.materialized_tokens < req.current_sequence_len() {
cache_materialized_prefix(&mut req, &mut self.kv_manager, &self.config);
self.waiting.push_front(req);
} else {
self.running.push(req);
}
}
let decode_start_ms = now_ms + prefill_time.as_secs_f64() * 1000.0;
let mut decode = simulate_decode_step(
&mut self.running,
&mut self.kv_manager,
&self.config,
decode_start_ms,
true,
);
if let Some(collector) = collector {
for signal in &decode.output_signals {
collector.on_token(signal.uuid, decode.end_ms);
}
}
for req in decode.requests.drain(..).rev() {
self.waiting.push_front(req);
}
if decode.retracted_any {
self.new_token_ratio = self.config.init_new_token_ratio;
}
self.new_token_ratio = (self.new_token_ratio - self.config.new_token_ratio_decay_step)
.max(self.config.min_new_token_ratio);
debug_assert_sglang_scheduler_state(&self.waiting, &self.running, self.config.block_size);
EnginePassResult {
end_ms: decode.end_ms,
completed_requests: decode
.output_signals
.iter()
.filter(|signal| signal.completed)
.count(),
output_signals: decode.output_signals,
admissions: admit.admissions,
active_decode_blocks: self.active_kv_blocks(),
router_event_visibility: RouterEventVisibility::PassEnd,
kv_events: self
.kv_event_buffer
.as_ref()
.map(CapturedRouterEventBuffer::drain)
.unwrap_or_default(),
}
}
fn active_kv_blocks(&self) -> u64 {
let active_reserved = self
.waiting
.iter()
.map(SglangRequest::extra_reserved_tokens)
.sum::<usize>()
+ self
.running
.iter()
.map(SglangRequest::extra_reserved_tokens)
.sum::<usize>();
let actual_used =
self.kv_manager.cache().total_tokens() - self.kv_manager.cache().available_tokens();
(actual_used + active_reserved).div_ceil(self.config.block_size) as u64
}
}
fn simulate_prefill_duration(
batch_size: usize,
mean_isl: usize,
mean_prefix: usize,
config: &SglangConfig,
apply_speedup: bool,
) -> Duration {
if batch_size == 0 || config.worker_type == WorkerType::Decode {
return Duration::ZERO;
}
let prefill_time = config
.perf_model
.predict_prefill_time(batch_size, mean_isl, mean_prefix);
let total_time = Duration::from_secs_f64(prefill_time / 1000.0);
if !apply_speedup || config.speedup_ratio <= 0.0 || total_time <= Duration::ZERO {
return total_time;
}
Duration::from_secs_f64(total_time.as_secs_f64() / config.speedup_ratio)
}
fn debug_assert_sglang_scheduler_state(
waiting: &VecDeque<SglangRequest>,
running: &[SglangRequest],
block_size: usize,
) {
#[cfg(debug_assertions)]
{
let mut seen = std::collections::HashSet::new();
for req in waiting {
debug_assert!(
seen.insert(req.uuid),
"request {} appears multiple times across waiting/running queues",
req.uuid
);
req.debug_assert_invariants(block_size);
}
for req in running {
debug_assert!(
seen.insert(req.uuid),
"request {} appears multiple times across waiting/running queues",
req.uuid
);
req.debug_assert_invariants(block_size);
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment