Unverified Commit a2154ba5 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore(mocker): batch live output signal sends (#7647)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 06f17011
......@@ -16,7 +16,7 @@ use crate::scheduler::{Scheduler, SchedulerHandle, SglangScheduler};
pub fn create_engine(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
) -> Box<dyn SchedulerHandle> {
......
......@@ -18,7 +18,7 @@ Offline replay starts in `lib/mocker/src/replay/offline/mod.rs`.
`offline/mod.rs` chooses between three implementations:
- `lib/mocker/src/replay/offline/single.rs` for the special case `num_workers == 1` with the vLLM engine
- `lib/mocker/src/replay/offline/multi.rs` for everything else, including multi-worker replay and `kv_router` replay
- `lib/mocker/src/replay/offline/agg.rs` for everything else, including aggregated multi-worker replay and `kv_router` replay
- `lib/mocker/src/replay/offline/disagg.rs` for offline disaggregated prefill/decode replay
## File Map
......@@ -27,7 +27,7 @@ Offline replay starts in `lib/mocker/src/replay/offline/mod.rs`.
Chooses single-worker fast path vs multi-worker harness.
- `lib/mocker/src/replay/offline/single.rs`
Minimal replay loop for one vLLM worker.
- `lib/mocker/src/replay/offline/multi.rs`
- `lib/mocker/src/replay/offline/agg.rs`
General offline cluster simulator for multi-worker replay and KV-router replay.
- `lib/mocker/src/replay/offline/disagg.rs`
Offline two-stage replay harness with separate prefill and decode pools.
......@@ -75,7 +75,7 @@ Important details:
## Multi-Worker Harness
The general harness lives in `lib/mocker/src/replay/offline/multi.rs`. It models a cluster with:
The general aggregated harness lives in `lib/mocker/src/replay/offline/agg.rs`. It models a cluster with:
- a logical clock `now_ms`
- a pending request queue
......@@ -85,7 +85,7 @@ The general harness lives in `lib/mocker/src/replay/offline/multi.rs`. It models
### Main Loop
The harness is event-driven. It does not sleep. Instead, `OfflineRuntime` repeatedly:
The aggregated harness is event-driven. It does not sleep. Instead, `AggRuntime` repeatedly:
1. picks the next meaningful timestamp
2. advances `now_ms`
......@@ -164,7 +164,7 @@ flowchart LR
F -->|yes| G["dispatch to worker"]
F -->|no| H["store in router_pending"]
I["worker pass emits RouterEvent + OutputSignal"] --> J["OfflineRuntime::process_completed_pass"]
I["worker pass emits RouterEvent + OutputSignal"] --> J["AggRuntime::process_completed_pass"]
J --> K["apply router events to sync indexer"]
J --> L["mark_prefill_completed / free"]
L --> M["drain queued admissions"]
......
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::VecDeque;
use anyhow::Result;
use dynamo_kv_router::config::KvRouterConfig;
#[cfg(test)]
use super::agg::AggRuntimeStats;
use super::agg::{AggRuntime, ReplayMode as AggReplayMode};
#[cfg(test)]
use super::disagg::DisaggRuntimeStats;
use super::disagg::{DisaggRuntime, ReplayMode as DisaggReplayMode};
use super::normalize_trace_requests;
use super::single::{SingleReplayMode, SingleRuntime};
use crate::common::protocols::{DirectRequest, EngineType, MockEngineArgs};
use crate::loadgen::{Trace, WorkloadDriver};
use crate::replay::OfflineDisaggReplayConfig;
#[cfg(test)]
use crate::replay::TraceCollector;
use crate::replay::{ReplayRouterMode, TraceSimulationReport};
pub(crate) fn simulate_trace(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == EngineType::Vllm {
simulate_trace_single(args, requests, arrival_speedup_ratio)
} else {
simulate_trace_multi(
args,
router_config,
requests,
num_workers,
arrival_speedup_ratio,
router_mode,
)
}
}
pub(crate) fn simulate_concurrency(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == EngineType::Vllm {
simulate_concurrency_single(args, requests, max_in_flight)
} else {
simulate_concurrency_multi(
args,
router_config,
requests,
max_in_flight,
num_workers,
router_mode,
)
}
}
pub(crate) fn simulate_trace_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == EngineType::Vllm {
simulate_trace_workload_single(args, trace)
} else {
simulate_trace_workload_multi(args, router_config, trace, num_workers, router_mode)
}
}
pub(crate) fn simulate_concurrency_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == EngineType::Vllm {
simulate_concurrency_workload_single(args, trace, max_in_flight)
} else {
simulate_concurrency_workload_multi(
args,
router_config,
trace,
max_in_flight,
num_workers,
router_mode,
)
}
}
pub(crate) fn simulate_trace_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let (collector, _) = DisaggRuntime::new(
&config,
router_config,
pending,
DisaggReplayMode::Trace,
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let pending = VecDeque::from(requests);
let (collector, _) = DisaggRuntime::new(
&config,
router_config,
pending,
DisaggReplayMode::Concurrency { max_in_flight },
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_trace_workload_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let driver = WorkloadDriver::new_trace(trace)?;
let (collector, _) = DisaggRuntime::new_workload(
&config,
router_config,
driver,
DisaggReplayMode::Trace,
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_workload_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let driver = WorkloadDriver::new_concurrency(trace)?;
let (collector, _) = DisaggRuntime::new_workload(
&config,
router_config,
driver,
DisaggReplayMode::Concurrency { max_in_flight },
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_trace_single(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let collector = SingleRuntime::new(args, pending, SingleReplayMode::Trace).run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_single(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = VecDeque::from(requests);
let collector = SingleRuntime::new(
args,
pending,
SingleReplayMode::Concurrency { max_in_flight },
)
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_trace_workload_single(
args: MockEngineArgs,
trace: Trace,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let collector =
SingleRuntime::new_workload(args, trace.into_trace_driver()?, SingleReplayMode::Trace)
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_workload_single(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let collector = SingleRuntime::new_workload(
args,
trace.into_concurrency_driver()?,
SingleReplayMode::Concurrency { max_in_flight },
)
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_trace_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let (collector, _) = AggRuntime::new(
&args,
router_config,
pending,
num_workers,
AggReplayMode::Trace,
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = VecDeque::from(requests);
let (collector, _) = AggRuntime::new(
&args,
router_config,
pending,
num_workers,
AggReplayMode::Concurrency { max_in_flight },
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_trace_workload_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let (collector, _) = AggRuntime::new_workload(
&args,
router_config,
trace.into_trace_driver()?,
num_workers,
AggReplayMode::Trace,
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_workload_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let (collector, _) = AggRuntime::new_workload(
&args,
router_config,
trace.into_concurrency_driver()?,
num_workers,
AggReplayMode::Concurrency { max_in_flight },
router_mode,
)?
.run()?;
Ok(collector.finish())
}
#[cfg(test)]
pub(super) fn run_trace_single_collect(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
) -> TraceCollector {
let pending = normalize_trace_requests(requests, arrival_speedup_ratio).unwrap();
SingleRuntime::new(args, pending, SingleReplayMode::Trace)
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_concurrency_single_collect(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
) -> TraceCollector {
SingleRuntime::new(
args,
VecDeque::from(requests),
SingleReplayMode::Concurrency { max_in_flight },
)
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_trace_workload_single_collect(
args: MockEngineArgs,
trace: Trace,
) -> TraceCollector {
SingleRuntime::new_workload(
args,
trace.into_trace_driver().unwrap(),
SingleReplayMode::Trace,
)
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_concurrency_workload_single_collect(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
) -> TraceCollector {
SingleRuntime::new_workload(
args,
trace.into_concurrency_driver().unwrap(),
SingleReplayMode::Concurrency { max_in_flight },
)
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_trace_multi_collect_with_stats(
args: &MockEngineArgs,
requests: Vec<DirectRequest>,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, AggRuntimeStats) {
let pending = normalize_trace_requests(requests, 1.0).unwrap();
AggRuntime::new(
args,
None,
pending,
num_workers,
AggReplayMode::Trace,
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_concurrency_multi_collect_with_stats(
args: &MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, AggRuntimeStats) {
AggRuntime::new(
args,
None,
VecDeque::from(requests),
num_workers,
AggReplayMode::Concurrency { max_in_flight },
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_trace_workload_multi_collect_with_stats(
args: &MockEngineArgs,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, AggRuntimeStats) {
AggRuntime::new_workload(
args,
None,
trace.into_trace_driver().unwrap(),
num_workers,
AggReplayMode::Trace,
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_concurrency_workload_multi_collect_with_stats(
args: &MockEngineArgs,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, AggRuntimeStats) {
AggRuntime::new_workload(
args,
None,
trace.into_concurrency_driver().unwrap(),
num_workers,
AggReplayMode::Concurrency { max_in_flight },
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_trace_collect(
config: &OfflineDisaggReplayConfig,
requests: Vec<DirectRequest>,
router_config: Option<KvRouterConfig>,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> (TraceCollector, DisaggRuntimeStats) {
let pending = normalize_trace_requests(requests, arrival_speedup_ratio).unwrap();
DisaggRuntime::new(
config,
router_config,
pending,
DisaggReplayMode::Trace,
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_concurrency_collect(
config: &OfflineDisaggReplayConfig,
requests: Vec<DirectRequest>,
router_config: Option<KvRouterConfig>,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, DisaggRuntimeStats) {
DisaggRuntime::new(
config,
router_config,
VecDeque::from(requests),
DisaggReplayMode::Concurrency { max_in_flight },
router_mode,
)
.unwrap()
.run()
.unwrap()
}
......@@ -34,17 +34,9 @@ pub(crate) struct SimulationEvent {
pub(crate) kind: SimulationEventKind,
}
impl SimulationEvent {
fn kind_priority(&self) -> u8 {
0
}
}
impl PartialEq for SimulationEvent {
fn eq(&self, other: &Self) -> bool {
self.at_ms.to_bits() == other.at_ms.to_bits()
&& self.seq_no == other.seq_no
&& self.kind_priority() == other.kind_priority()
self.at_ms.to_bits() == other.at_ms.to_bits() && self.seq_no == other.seq_no
}
}
......@@ -62,7 +54,6 @@ impl Ord for SimulationEvent {
.at_ms
.partial_cmp(&self.at_ms)
.unwrap_or(Ordering::Equal)
.then_with(|| self.kind_priority().cmp(&other.kind_priority()))
.then_with(|| other.seq_no.cmp(&self.seq_no))
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::loadgen::Trace;
use crate::replay::OfflineDisaggReplayConfig;
pub(crate) use crate::replay::normalize_trace_requests;
use crate::replay::{ReplayRouterMode, TraceSimulationReport};
use dynamo_kv_router::config::KvRouterConfig;
pub(crate) mod agg;
pub(crate) mod core;
pub(crate) mod disagg;
mod entrypoints;
pub(crate) mod events;
pub(crate) mod multi;
pub(crate) mod runtime_utils;
pub(crate) mod single;
pub(crate) mod state;
pub(crate) fn simulate_trace(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == crate::common::protocols::EngineType::Vllm {
single::simulate_trace_single(args, requests, arrival_speedup_ratio)
} else {
multi::simulate_trace_multi(
args,
router_config,
requests,
num_workers,
arrival_speedup_ratio,
router_mode,
)
}
}
pub(crate) fn simulate_concurrency(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == crate::common::protocols::EngineType::Vllm {
single::simulate_concurrency_single(args, requests, max_in_flight)
} else {
multi::simulate_concurrency_multi(
args,
router_config,
requests,
max_in_flight,
num_workers,
router_mode,
)
}
}
pub(crate) fn simulate_trace_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == crate::common::protocols::EngineType::Vllm {
single::simulate_trace_workload_single(args, trace)
} else {
multi::simulate_trace_workload_multi(args, router_config, trace, num_workers, router_mode)
}
}
pub(crate) fn simulate_concurrency_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == crate::common::protocols::EngineType::Vllm {
single::simulate_concurrency_workload_single(args, trace, max_in_flight)
} else {
multi::simulate_concurrency_workload_multi(
args,
router_config,
trace,
max_in_flight,
num_workers,
router_mode,
)
}
}
pub(crate) fn simulate_trace_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
disagg::simulate_trace_disagg(
config,
router_config,
requests,
arrival_speedup_ratio,
router_mode,
)
}
pub(crate) fn simulate_concurrency_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
disagg::simulate_concurrency_disagg(config, router_config, requests, max_in_flight, router_mode)
}
pub(crate) fn simulate_trace_workload_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
disagg::simulate_trace_workload_disagg(config, router_config, trace, router_mode)
}
pub(crate) fn simulate_concurrency_workload_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
disagg::simulate_concurrency_workload_disagg(
config,
router_config,
trace,
max_in_flight,
router_mode,
)
}
pub(crate) use entrypoints::{
simulate_concurrency, simulate_concurrency_disagg, simulate_concurrency_workload,
simulate_concurrency_workload_disagg, simulate_trace, simulate_trace_disagg,
simulate_trace_workload, simulate_trace_workload_disagg,
};
......@@ -2,16 +2,15 @@
// SPDX-License-Identifier: Apache-2.0
use super::core::ReplayWorkerCore;
use super::normalize_trace_requests;
use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::loadgen::{Trace, WorkloadDriver};
use crate::replay::{TraceCollector, TraceSimulationReport};
use crate::loadgen::WorkloadDriver;
use crate::replay::TraceCollector;
use anyhow::bail;
use std::collections::VecDeque;
use uuid::Uuid;
#[derive(Debug, Clone, Copy)]
enum SingleReplayMode {
pub(super) enum SingleReplayMode {
Trace,
Concurrency { max_in_flight: usize },
}
......@@ -21,7 +20,7 @@ enum AdmissionSource {
Workload(WorkloadDriver),
}
struct SingleRuntime {
pub(super) struct SingleRuntime {
current_time_ms: f64,
admission: AdmissionSource,
worker: ReplayWorkerCore,
......@@ -30,11 +29,19 @@ struct SingleRuntime {
}
impl SingleRuntime {
fn new(args: MockEngineArgs, pending: VecDeque<DirectRequest>, mode: SingleReplayMode) -> Self {
pub(super) fn new(
args: MockEngineArgs,
pending: VecDeque<DirectRequest>,
mode: SingleReplayMode,
) -> Self {
Self::new_with_source(args, AdmissionSource::Requests(pending), mode)
}
fn new_workload(args: MockEngineArgs, driver: WorkloadDriver, mode: SingleReplayMode) -> Self {
pub(super) fn new_workload(
args: MockEngineArgs,
driver: WorkloadDriver,
mode: SingleReplayMode,
) -> Self {
Self::new_with_source(args, AdmissionSource::Workload(driver), mode)
}
......@@ -166,7 +173,7 @@ impl SingleRuntime {
}
}
fn run(mut self) -> anyhow::Result<TraceCollector> {
pub(super) fn run(mut self) -> anyhow::Result<TraceCollector> {
while !self.is_done() {
match self.mode {
SingleReplayMode::Trace => {
......@@ -196,119 +203,14 @@ impl SingleRuntime {
}
}
pub(crate) fn simulate_trace_single(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let collector = SingleRuntime::new(args, pending, SingleReplayMode::Trace).run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_single(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = VecDeque::from(requests);
let collector = SingleRuntime::new(
args,
pending,
SingleReplayMode::Concurrency { max_in_flight },
)
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_trace_workload_single(
args: MockEngineArgs,
trace: Trace,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let collector =
SingleRuntime::new_workload(args, trace.into_trace_driver()?, SingleReplayMode::Trace)
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_workload_single(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let collector = SingleRuntime::new_workload(
args,
trace.into_concurrency_driver()?,
SingleReplayMode::Concurrency { max_in_flight },
)
.run()?;
Ok(collector.finish())
}
#[cfg(test)]
pub(super) fn run_trace_single_collect(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
) -> TraceCollector {
let pending = normalize_trace_requests(requests, arrival_speedup_ratio).unwrap();
SingleRuntime::new(args, pending, SingleReplayMode::Trace)
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_concurrency_single_collect(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
) -> TraceCollector {
SingleRuntime::new(
args,
VecDeque::from(requests),
SingleReplayMode::Concurrency { max_in_flight },
)
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_trace_workload_single_collect(
args: MockEngineArgs,
trace: Trace,
) -> TraceCollector {
SingleRuntime::new_workload(
args,
trace.into_trace_driver().unwrap(),
SingleReplayMode::Trace,
)
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_concurrency_workload_single_collect(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
) -> TraceCollector {
SingleRuntime::new_workload(
args,
trace.into_concurrency_driver().unwrap(),
SingleReplayMode::Concurrency { max_in_flight },
)
.run()
.unwrap()
}
#[cfg(test)]
mod tests {
use super::super::entrypoints::{
run_concurrency_workload_single_collect, run_trace_workload_single_collect,
simulate_concurrency_single, simulate_trace_single,
};
use super::*;
use crate::loadgen::{SessionTrace, TurnTrace};
use crate::loadgen::{SessionTrace, Trace, TurnTrace};
use crate::replay::{TraceRequestStatsSnapshot, TraceSimulationReport};
use rstest::rstest;
use std::collections::{HashMap, VecDeque};
......
......@@ -13,11 +13,53 @@ use crate::scheduler::AdmissionEvent;
use super::state::{ArrivalEvent, RequestRegistry, SharedLiveRuntimeStats, now_ms};
async fn process_output_signal(
output: OutputSignal,
batch_time_ms: f64,
collector: &mut TraceCollector,
requests: &RequestRegistry,
router: &ReplayRouter,
stats: &SharedLiveRuntimeStats,
) {
collector.on_token(output.uuid, batch_time_ms);
let Some(state) = requests.get(&output.uuid) else {
return;
};
if state.mark_first_token_once() {
match router.on_first_token(output.uuid).await {
Ok(true) => stats.record_prefill_marked(),
Ok(false) => {}
Err(error) => tracing::warn!(
uuid = %output.uuid,
error = %error,
"online replay failed to mark prefill completed"
),
}
}
if !output.completed || !state.mark_completed_once() {
return;
}
match router.on_complete(output.uuid).await {
Ok(true) => stats.record_freed(),
Ok(false) => {}
Err(error) => tracing::warn!(
uuid = %output.uuid,
error = %error,
"online replay failed to free completed request"
),
}
state.notify_completion();
}
pub(super) async fn run_demux(
start: Instant,
mut arrival_rx: mpsc::UnboundedReceiver<ArrivalEvent>,
mut admission_rx: mpsc::UnboundedReceiver<AdmissionEvent>,
mut output_rx: mpsc::UnboundedReceiver<OutputSignal>,
mut output_rx: mpsc::UnboundedReceiver<Vec<OutputSignal>>,
requests: RequestRegistry,
router: Arc<ReplayRouter>,
stats: Arc<SharedLiveRuntimeStats>,
......@@ -55,33 +97,18 @@ pub(super) async fn run_demux(
}
output = output_rx.recv(), if outputs_open => {
match output {
Some(output) => {
collector.on_token(output.uuid, now_ms(start));
if let Some(state) = requests.get(&output.uuid) {
if state.mark_first_token_once() {
match router.on_first_token(output.uuid).await {
Ok(true) => stats.record_prefill_marked(),
Ok(false) => {}
Err(error) => tracing::warn!(
uuid = %output.uuid,
error = %error,
"online replay failed to mark prefill completed"
),
}
}
if output.completed && state.mark_completed_once() {
match router.on_complete(output.uuid).await {
Ok(true) => stats.record_freed(),
Ok(false) => {}
Err(error) => tracing::warn!(
uuid = %output.uuid,
error = %error,
"online replay failed to free completed request"
),
}
state.notify_completion();
}
Some(output_batch) => {
let batch_time_ms = now_ms(start);
for output in output_batch {
process_output_signal(
output,
batch_time_ms,
&mut collector,
&requests,
&router,
&stats,
)
.await;
}
}
None => outputs_open = false,
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::VecDeque;
use anyhow::{Result, anyhow, bail};
use dynamo_kv_router::config::KvRouterConfig;
use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::loadgen::{Trace, WorkloadDriver};
use crate::replay::{ReplayRouterMode, TraceSimulationReport, normalize_trace_requests};
use super::live_runtime::LiveRuntime;
use super::state::{LiveReplayMode, LiveRuntimeStats};
fn total_turns(trace: &Trace) -> usize {
trace
.sessions
.iter()
.map(|session| session.turns.len())
.sum()
}
fn run_live_runtime(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
pending: VecDeque<DirectRequest>,
num_workers: usize,
mode: LiveReplayMode,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.map_err(|e| anyhow!("failed to create online replay runtime: {e}"))?;
runtime.block_on(async move {
LiveRuntime::new(args, router_config, pending, num_workers, mode, router_mode)?
.run()
.await
})
}
fn run_live_workload_runtime(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
driver: WorkloadDriver,
total_turns: usize,
num_workers: usize,
mode: LiveReplayMode,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.map_err(|e| anyhow!("failed to create online replay runtime: {e}"))?;
runtime.block_on(async move {
LiveRuntime::new(
args,
router_config,
VecDeque::new(),
num_workers,
mode,
router_mode,
)?
.run_workload(driver, total_turns)
.await
})
}
pub(crate) fn simulate_trace_requests(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let (report, _) = run_live_runtime(
args,
router_config,
pending,
num_workers,
LiveReplayMode::Trace,
router_mode,
)?;
Ok(report)
}
pub(crate) fn simulate_concurrency_requests(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
if requests.is_empty() {
bail!("online concurrency replay requires at least one request");
}
let pending = VecDeque::from(requests);
let (report, _) = run_live_runtime(
args,
router_config,
pending,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)?;
Ok(report)
}
pub(crate) fn simulate_trace_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let total_turns = total_turns(&trace);
let (report, _) = run_live_workload_runtime(
args,
router_config,
trace.into_trace_driver()?,
total_turns,
num_workers,
LiveReplayMode::Trace,
router_mode,
)?;
Ok(report)
}
pub(crate) fn simulate_concurrency_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let total_turns = total_turns(&trace);
let (report, _) = run_live_workload_runtime(
args,
router_config,
trace.into_concurrency_driver()?,
total_turns,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)?;
Ok(report)
}
#[cfg(test)]
pub(super) fn simulate_trace_requests_with_stats(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
run_live_runtime(
args,
None,
pending,
num_workers,
LiveReplayMode::Trace,
router_mode,
)
}
#[cfg(test)]
pub(super) fn simulate_concurrency_requests_with_stats(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let pending = VecDeque::from(requests);
run_live_runtime(
args,
None,
pending,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)
}
#[cfg(test)]
pub(super) fn simulate_trace_workload_with_stats(
args: MockEngineArgs,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let total_turns = total_turns(&trace);
run_live_workload_runtime(
args,
None,
trace.into_trace_driver()?,
total_turns,
num_workers,
LiveReplayMode::Trace,
router_mode,
)
}
#[cfg(test)]
pub(super) fn simulate_concurrency_workload_with_stats(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let total_turns = total_turns(&trace);
run_live_workload_runtime(
args,
None,
trace.into_concurrency_driver()?,
total_turns,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::VecDeque;
use std::sync::Arc;
use anyhow::{Result, anyhow, bail};
use anyhow::{Result, anyhow};
use dashmap::DashMap;
use dynamo_kv_router::config::KvRouterConfig;
use tokio::sync::{Notify, Semaphore, mpsc};
......@@ -13,9 +12,9 @@ use tokio::time::Instant;
use tokio_util::sync::CancellationToken;
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal};
use crate::loadgen::{Trace, WorkloadDriver};
use crate::loadgen::WorkloadDriver;
use crate::replay::router::ReplayRouter;
use crate::replay::{ReplayRouterMode, TraceSimulationReport, normalize_trace_requests};
use crate::replay::{ReplayRouterMode, TraceSimulationReport};
use crate::scheduler::{AdmissionEvent, EngineScheduler, SchedulerHandle};
use super::demux::run_demux;
......@@ -25,11 +24,11 @@ use super::state::{
};
use super::task::{RequestTaskContext, run_request_task, wait_for_workload_progress};
struct LiveRuntime {
pending: VecDeque<DirectRequest>,
pub(super) struct LiveRuntime {
pending: std::collections::VecDeque<DirectRequest>,
senders: Arc<[mpsc::UnboundedSender<DirectRequest>]>,
schedulers: Vec<EngineScheduler>,
output_rx: mpsc::UnboundedReceiver<OutputSignal>,
output_rx: mpsc::UnboundedReceiver<Vec<OutputSignal>>,
admission_rx: mpsc::UnboundedReceiver<AdmissionEvent>,
cancel_token: CancellationToken,
start: Instant,
......@@ -38,16 +37,17 @@ struct LiveRuntime {
}
impl LiveRuntime {
fn new(
/// Build the shared router, worker schedulers, and demux inputs for one live replay run.
pub(super) fn new(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
pending: VecDeque<DirectRequest>,
pending: std::collections::VecDeque<DirectRequest>,
num_workers: usize,
mode: LiveReplayMode,
router_mode: ReplayRouterMode,
) -> Result<Self> {
let cancel_token = CancellationToken::new();
let (output_tx, output_rx) = mpsc::unbounded_channel();
let (output_tx, output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let (admission_tx, admission_rx) = mpsc::unbounded_channel();
let router = Arc::new(ReplayRouter::new(
router_mode,
......@@ -70,10 +70,6 @@ impl LiveRuntime {
senders.push(scheduler.request_sender());
schedulers.push(scheduler);
}
drop(output_tx);
drop(admission_tx);
Ok(Self {
pending,
senders: Arc::from(senders),
......@@ -87,7 +83,8 @@ impl LiveRuntime {
})
}
async fn run(mut self) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
/// Replay a finite queue of requests and return the final trace report plus debug stats.
pub(super) async fn run(mut self) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let requests = Arc::new(DashMap::with_capacity(self.pending.len()));
let stats = Arc::new(SharedLiveRuntimeStats::default());
let (arrival_tx, arrival_rx) = mpsc::unbounded_channel();
......@@ -160,7 +157,8 @@ impl LiveRuntime {
Ok((report, stats.snapshot()))
}
async fn run_workload(
/// Drive a multi-turn workload driver until it is drained and all spawned request tasks finish.
pub(super) async fn run_workload(
mut self,
driver: WorkloadDriver,
total_turns: usize,
......@@ -283,237 +281,3 @@ impl LiveRuntime {
Ok((report, stats.snapshot()))
}
}
fn run_live_runtime(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
pending: VecDeque<DirectRequest>,
num_workers: usize,
mode: LiveReplayMode,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.map_err(|e| anyhow!("failed to create online replay runtime: {e}"))?;
runtime.block_on(async move {
LiveRuntime::new(args, router_config, pending, num_workers, mode, router_mode)?
.run()
.await
})
}
fn run_live_workload_runtime(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
driver: WorkloadDriver,
total_turns: usize,
num_workers: usize,
mode: LiveReplayMode,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.map_err(|e| anyhow!("failed to create online replay runtime: {e}"))?;
runtime.block_on(async move {
LiveRuntime::new(
args,
router_config,
VecDeque::new(),
num_workers,
mode,
router_mode,
)?
.run_workload(driver, total_turns)
.await
})
}
pub(crate) fn simulate_trace_requests(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
let (report, _) = run_live_runtime(
args,
router_config,
pending,
num_workers,
LiveReplayMode::Trace,
router_mode,
)?;
Ok(report)
}
pub(crate) fn simulate_concurrency_requests(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
if requests.is_empty() {
bail!("online concurrency replay requires at least one request");
}
let pending = VecDeque::from(requests);
let (report, _) = run_live_runtime(
args,
router_config,
pending,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)?;
Ok(report)
}
pub(crate) fn simulate_trace_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let total_turns = trace
.sessions
.iter()
.map(|session| session.turns.len())
.sum();
let (report, _) = run_live_workload_runtime(
args,
router_config,
trace.into_trace_driver()?,
total_turns,
num_workers,
LiveReplayMode::Trace,
router_mode,
)?;
Ok(report)
}
pub(crate) fn simulate_concurrency_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let total_turns = trace
.sessions
.iter()
.map(|session| session.turns.len())
.sum();
let (report, _) = run_live_workload_runtime(
args,
router_config,
trace.into_concurrency_driver()?,
total_turns,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)?;
Ok(report)
}
#[cfg(test)]
pub(super) fn simulate_trace_requests_with_stats(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
run_live_runtime(
args,
None,
pending,
num_workers,
LiveReplayMode::Trace,
router_mode,
)
}
#[cfg(test)]
pub(super) fn simulate_concurrency_requests_with_stats(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let pending = VecDeque::from(requests);
run_live_runtime(
args,
None,
pending,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)
}
#[cfg(test)]
pub(super) fn simulate_trace_workload_with_stats(
args: MockEngineArgs,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let total_turns = trace
.sessions
.iter()
.map(|session| session.turns.len())
.sum();
run_live_workload_runtime(
args,
None,
trace.into_trace_driver()?,
total_turns,
num_workers,
LiveReplayMode::Trace,
router_mode,
)
}
#[cfg(test)]
pub(super) fn simulate_concurrency_workload_with_stats(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let total_turns = trace
.sessions
.iter()
.map(|session| session.turns.len())
.sum();
run_live_workload_runtime(
args,
None,
trace.into_concurrency_driver()?,
total_turns,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)
}
......@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
mod demux;
mod entrypoints;
mod live_runtime;
mod state;
mod task;
......@@ -9,7 +10,7 @@ mod task;
#[cfg(test)]
mod tests;
pub(crate) use live_runtime::{
pub(crate) use entrypoints::{
simulate_concurrency_requests, simulate_concurrency_workload, simulate_trace_requests,
simulate_trace_workload,
};
......@@ -16,7 +16,7 @@ use crate::loadgen::{SessionTrace, Trace, TurnTrace};
use crate::replay::ReplayRouterMode;
use crate::replay::router::ReplayRouter;
use super::live_runtime::{
use super::entrypoints::{
simulate_concurrency_requests_with_stats, simulate_concurrency_workload_with_stats,
simulate_trace_requests, simulate_trace_requests_with_stats,
simulate_trace_workload_with_stats,
......
......@@ -248,14 +248,14 @@ impl OfflineReplayRouter {
pub(crate) fn mark_prefill_completed(&mut self, uuid: Uuid) -> Result<Vec<(Uuid, usize)>> {
self.slots
.mark_prefill_completed_sync(&uuid.to_string())
.mark_prefill_completed(&uuid.to_string())
.map_err(anyhow::Error::from)?;
self.drain_pending()
}
pub(crate) fn free(&mut self, uuid: Uuid) -> Result<Vec<(Uuid, usize)>> {
self.slots
.free_sync(&uuid.to_string())
.free(&uuid.to_string())
.map_err(anyhow::Error::from)?;
self.drain_pending()
}
......@@ -316,8 +316,6 @@ impl OfflineReplayRouter {
}
}
pub(crate) fn shutdown(&mut self) {}
fn enqueue_key(&self, now_ms: f64, request: &PendingRequest) -> ReplayQueueKey {
let arrival_offset = Duration::from_secs_f64((now_ms.max(0.0)) / 1000.0);
self.policy.enqueue_key(
......@@ -400,7 +398,7 @@ impl OfflineReplayRouter {
let request_id = request.request_id();
self.slots
.add_request_sync(SequenceRequest {
.add_request(SequenceRequest {
request_id,
token_sequence: request.token_seq,
isl: request.isl_tokens,
......
......@@ -108,7 +108,7 @@ impl EngineScheduler {
pub(crate) fn new_with_admission(
args: crate::common::protocols::MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
......
......@@ -36,7 +36,7 @@ impl SglangScheduler {
pub fn new(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
) -> Self {
......@@ -53,7 +53,7 @@ impl SglangScheduler {
pub(crate) fn new_with_admission(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
......@@ -71,7 +71,7 @@ impl SglangScheduler {
fn new_internal(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
......@@ -121,11 +121,12 @@ impl SglangScheduler {
if pass.router_event_visibility == RouterEventVisibility::PassEnd {
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
}
flush_output_signals(&output_tx, &pass.output_signals);
let active_decode_blocks = pass.active_decode_blocks;
flush_output_signals(&output_tx, pass.output_signals);
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
let _ = metrics_tx.send(MockerMetrics::new(
dp_rank,
pass.active_decode_blocks,
active_decode_blocks,
total_blocks,
));
}
......@@ -182,14 +183,16 @@ async fn receive_requests(
}
fn flush_output_signals(
output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>,
output_signals: &[OutputSignal],
output_tx: &Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
output_signals: Vec<OutputSignal>,
) {
let Some(tx) = output_tx.as_ref() else {
return;
};
for signal in output_signals {
let _ = tx.send(signal.clone());
if output_signals.is_empty() {
return;
}
let _ = tx.send(output_signals);
}
......@@ -94,7 +94,7 @@ mod scheduling {
.build()
.unwrap();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let scheduler =
SglangScheduler::new(args, 0, Some(output_tx), KvEventPublishers::default(), None);
......@@ -117,8 +117,8 @@ mod scheduling {
loop {
tokio::select! {
Some(_) = output_rx.recv() => {
received += 1;
Some(output_batch) = output_rx.recv() => {
received += output_batch.len();
if received >= expected_signals {
break;
}
......@@ -535,7 +535,7 @@ mod core_behavior {
async fn assert_sglang_scheduler_completes_all(
scheduler: &SglangScheduler,
output_rx: &mut mpsc::UnboundedReceiver<OutputSignal>,
output_rx: &mut mpsc::UnboundedReceiver<Vec<OutputSignal>>,
num_requests: usize,
prompt_len: usize,
max_output_tokens: usize,
......@@ -567,8 +567,8 @@ async fn assert_sglang_scheduler_completes_all(
loop {
tokio::select! {
biased;
Some(_) = output_rx.recv() => {
received_tokens += 1;
Some(output_batch) = output_rx.recv() => {
received_tokens += output_batch.len();
if received_tokens >= expected_tokens {
break;
}
......@@ -604,7 +604,7 @@ mod router_events {
#[case] schedule_policy: &str,
#[case] page_size: usize,
) {
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let args = MockEngineArgs::builder()
.num_gpu_blocks(500)
.block_size(64)
......@@ -818,7 +818,7 @@ mod router_events {
let harness = RouterIndexerHarness::new(4, ROUTER_TEST_WORKER_ID);
let (sink, forward_task) = harness.spawn_forwarder();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let scheduler = SglangScheduler::new(
MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
......@@ -849,8 +849,8 @@ mod router_events {
loop {
tokio::select! {
Some(_) = output_rx.recv() => {
seen += 1;
Some(output_batch) = output_rx.recv() => {
seen += output_batch.len();
if seen == expected {
break;
}
......
......@@ -215,7 +215,7 @@ pub(crate) fn removed_event_count(events: &[RouterEvent]) -> usize {
/// common prefix to exercise prefix caching / radix tree reuse.
pub(crate) async fn assert_scheduler_completes_all(
scheduler: &dyn SchedulerHandle,
output_rx: &mut mpsc::UnboundedReceiver<OutputSignal>,
output_rx: &mut mpsc::UnboundedReceiver<Vec<OutputSignal>>,
num_requests: usize,
input_len: usize,
max_output_tokens: usize,
......@@ -260,8 +260,8 @@ pub(crate) async fn assert_scheduler_completes_all(
loop {
tokio::select! {
biased;
Some(_) = output_rx.recv() => {
received_tokens += 1;
Some(output_batch) = output_rx.recv() => {
received_tokens += output_batch.len();
if received_tokens >= expected_tokens {
break;
}
......
......@@ -63,7 +63,7 @@ impl Scheduler {
pub fn new(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
) -> Self {
......@@ -80,7 +80,7 @@ impl Scheduler {
pub(crate) fn new_with_admission(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
......@@ -98,7 +98,7 @@ impl Scheduler {
fn new_internal(
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
......@@ -137,7 +137,7 @@ impl Scheduler {
if pass.router_event_visibility == RouterEventVisibility::PassEnd {
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
}
flush_output_signals(&mut core, &output_tx, &pass.output_signals);
flush_output_signals(&mut core, &output_tx, pass.output_signals);
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
let _ = metrics_tx.send(MockerMetrics::new(
dp_rank,
......@@ -198,17 +198,20 @@ async fn receive_requests(
fn flush_output_signals(
core: &mut VllmCore,
output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>,
output_signals: &[OutputSignal],
output_tx: &Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
output_signals: Vec<OutputSignal>,
) {
let Some(tx) = output_tx.as_ref() else {
return;
};
for signal in output_signals {
if tx.send(signal.clone()).is_ok() {
continue;
if output_signals.is_empty() {
return;
}
if let Err(error) = tx.send(output_signals) {
for signal in error.0 {
core.drop_request(signal.uuid);
}
core.drop_request(signal.uuid);
}
}
......@@ -506,7 +506,7 @@ mod live_scheduler {
#[case] enable_prefix_caching: bool,
#[case] enable_chunked_prefill: bool,
) {
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let args = MockEngineArgs::builder()
.num_gpu_blocks(500)
......@@ -539,7 +539,7 @@ mod live_scheduler {
let num_requests = 10;
let token_length = 65;
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let args = MockEngineArgs::builder()
.num_gpu_blocks(100)
......@@ -576,8 +576,8 @@ mod live_scheduler {
let _metrics = metrics_rx.borrow().clone();
tracing::debug!("Forward Pass Metrics: {_metrics:#?}");
}
Some(_signal) = output_rx.recv() => {
received_tokens += 1;
Some(output_batch) = output_rx.recv() => {
received_tokens += output_batch.len();
timeout.set(tokio::time::sleep(Duration::from_millis(500)));
}
_ = &mut timeout => break,
......@@ -592,7 +592,7 @@ mod live_scheduler {
#[tokio::test]
async fn test_receiver_drop_cleans_up_resources() {
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let args = MockEngineArgs::builder()
.num_gpu_blocks(10)
.block_size(64)
......@@ -612,8 +612,8 @@ mod live_scheduler {
let mut received_count = 0;
while received_count < 129 {
if output_rx.recv().await.is_some() {
received_count += 1;
if let Some(output_batch) = output_rx.recv().await {
received_count += output_batch.len();
continue;
}
panic!("Channel closed before receiving 129 tokens");
......@@ -639,7 +639,7 @@ mod live_scheduler {
#[tokio::test]
async fn test_live_scheduler_forwards_buffered_kv_token_ids() {
let sink = Arc::new(CapturingKvSink::default());
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let args = MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(12)
......@@ -667,10 +667,14 @@ mod live_scheduler {
arrival_timestamp_ms: None,
});
let signal = tokio::time::timeout(Duration::from_secs(2), output_rx.recv())
let output_batch = tokio::time::timeout(Duration::from_secs(2), output_rx.recv())
.await
.expect("scheduler should emit output")
.expect("output channel should stay open");
let signal = output_batch
.into_iter()
.next()
.expect("live scheduler should emit one output signal");
assert!(signal.completed);
tokio::time::sleep(Duration::from_millis(50)).await;
......@@ -691,7 +695,7 @@ mod live_scheduler {
let harness = RouterIndexerHarness::new(4, ROUTER_TEST_WORKER_ID);
let (sink, forward_task) = harness.spawn_forwarder();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let scheduler = Scheduler::new(
MockEngineArgs::builder()
.block_size(4)
......@@ -726,8 +730,8 @@ mod live_scheduler {
loop {
tokio::select! {
Some(_) = output_rx.recv() => {
seen += 1;
Some(output_batch) = output_rx.recv() => {
seen += output_batch.len();
if seen == expected {
break;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment