Unverified Commit b2c59aa4 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(replay): add shared loadgen workload paths [DYN-2510] (#7593)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 2b36b175
......@@ -360,6 +360,22 @@ impl TraceCollector {
reused_input_tokens: stats.reused_input_tokens,
})
}
#[cfg(test)]
pub(crate) fn snapshots(&self) -> Vec<TraceRequestStatsSnapshot> {
self.requests
.values()
.map(|stats| TraceRequestStatsSnapshot {
arrival_time_ms: stats.arrival_time_ms,
first_admit_ms: stats.first_admit_ms,
first_token_ms: stats.first_token_ms(),
last_token_ms: stats.last_token_ms(),
input_length: stats.input_length,
output_length: stats.output_length,
reused_input_tokens: stats.reused_input_tokens,
})
.collect()
}
}
fn mean(values: &[f64]) -> f64 {
......
......@@ -7,7 +7,6 @@ use std::time::Instant;
use anyhow::{Result, bail};
use dynamo_kv_router::config::KvRouterConfig;
use super::loader::load_trace_requests;
use super::online;
use super::validate::{
validate_offline_concurrency_args, validate_offline_replay_args,
......@@ -15,6 +14,7 @@ use super::validate::{
};
use super::{ReplayRouterMode, TraceSimulationReport};
use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::loadgen::Trace;
pub fn simulate_trace_file(
args: MockEngineArgs,
......@@ -42,14 +42,15 @@ pub fn simulate_trace_file_with_router_mode(
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
validate_offline_replay_args(&args, num_workers, router_mode)?;
let requests = load_trace_requests(trace_path, args.block_size, true)?;
let trace = Trace::from_mooncake(trace_path, args.block_size)?
.normalize_session_starts()?
.speed_up_timing(arrival_speedup_ratio)?;
let started_at = Instant::now();
let report = crate::replay::offline::simulate_trace(
let report = crate::replay::offline::simulate_trace_workload(
args,
router_config,
requests,
trace,
num_workers,
arrival_speedup_ratio,
router_mode,
)?;
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
......@@ -81,15 +82,10 @@ pub fn simulate_trace_live_file_with_router_mode(
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
validate_online_replay_args(&args, num_workers)?;
let requests = load_trace_requests(trace_path, args.block_size, true)?;
online::simulate_trace_requests(
args,
router_config,
requests,
num_workers,
arrival_speedup_ratio,
router_mode,
)
let trace = Trace::from_mooncake(trace_path, args.block_size)?
.normalize_session_starts()?
.speed_up_timing(arrival_speedup_ratio)?;
online::simulate_trace_workload(args, router_config, trace, num_workers, router_mode)
}
pub fn simulate_trace_requests(
......@@ -199,12 +195,13 @@ pub fn simulate_concurrency_file_with_router_mode(
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let requests = load_trace_requests(trace_path, args.block_size, false)?;
validate_offline_concurrency_args(&args, num_workers, max_in_flight, router_mode)?;
let trace = Trace::from_mooncake(trace_path, args.block_size)?;
let started_at = Instant::now();
let report = simulate_concurrency_requests_with_router_mode(
let report = simulate_concurrency_workload_with_router_mode(
args,
router_config,
requests,
trace,
max_in_flight,
num_workers,
router_mode,
......@@ -238,11 +235,11 @@ pub fn simulate_concurrency_live_file_with_router_mode(
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
validate_online_concurrency_args(&args, num_workers, max_in_flight)?;
let requests = load_trace_requests(trace_path, args.block_size, false)?;
online::simulate_concurrency_requests(
let trace = Trace::from_mooncake(trace_path, args.block_size)?;
online::simulate_concurrency_workload(
args,
router_config,
requests,
trace,
max_in_flight,
num_workers,
router_mode,
......@@ -328,3 +325,135 @@ pub fn simulate_concurrency_requests_with_router_mode(
router_mode,
)
}
pub fn simulate_trace_workload(
args: MockEngineArgs,
trace: Trace,
num_workers: usize,
) -> Result<TraceSimulationReport> {
simulate_trace_workload_with_router_mode(
args,
None,
trace,
num_workers,
ReplayRouterMode::RoundRobin,
)
}
pub fn simulate_trace_workload_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
validate_offline_replay_args(&args, num_workers, router_mode)?;
let started_at = Instant::now();
let report = crate::replay::offline::simulate_trace_workload(
args,
router_config,
trace,
num_workers,
router_mode,
)?;
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
}
pub fn simulate_trace_live_workload(
args: MockEngineArgs,
trace: Trace,
num_workers: usize,
) -> Result<TraceSimulationReport> {
simulate_trace_live_workload_with_router_mode(
args,
None,
trace,
num_workers,
ReplayRouterMode::RoundRobin,
)
}
pub fn simulate_trace_live_workload_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
validate_online_replay_args(&args, num_workers)?;
online::simulate_trace_workload(args, router_config, trace, num_workers, router_mode)
}
pub fn simulate_concurrency_workload(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
) -> Result<TraceSimulationReport> {
simulate_concurrency_workload_with_router_mode(
args,
None,
trace,
max_in_flight,
num_workers,
ReplayRouterMode::RoundRobin,
)
}
pub fn simulate_concurrency_workload_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
validate_offline_concurrency_args(&args, num_workers, max_in_flight, router_mode)?;
crate::replay::offline::simulate_concurrency_workload(
args,
router_config,
trace,
max_in_flight,
num_workers,
router_mode,
)
}
pub fn simulate_concurrency_live_workload(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
) -> Result<TraceSimulationReport> {
simulate_concurrency_live_workload_with_router_mode(
args,
None,
trace,
max_in_flight,
num_workers,
ReplayRouterMode::RoundRobin,
)
}
pub fn simulate_concurrency_live_workload_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
validate_online_concurrency_args(&args, num_workers, max_in_flight)?;
online::simulate_concurrency_workload(
args,
router_config,
trace,
max_in_flight,
num_workers,
router_mode,
)
}
......@@ -3,7 +3,6 @@
mod collector;
mod entrypoints;
mod loader;
pub(crate) mod offline;
mod online;
mod router;
......@@ -30,11 +29,15 @@ pub use entrypoints::{
simulate_concurrency_file, simulate_concurrency_file_with_router_mode,
simulate_concurrency_live_file, simulate_concurrency_live_file_with_router_mode,
simulate_concurrency_live_requests, simulate_concurrency_live_requests_with_router_mode,
simulate_concurrency_live_workload, simulate_concurrency_live_workload_with_router_mode,
simulate_concurrency_requests, simulate_concurrency_requests_with_router_mode,
simulate_concurrency_workload, simulate_concurrency_workload_with_router_mode,
simulate_trace_file, simulate_trace_file_with_router_mode, simulate_trace_live_file,
simulate_trace_live_file_with_router_mode, simulate_trace_live_requests,
simulate_trace_live_requests_with_router_mode, simulate_trace_requests,
simulate_trace_requests_with_router_mode,
simulate_trace_live_requests_with_router_mode, simulate_trace_live_workload,
simulate_trace_live_workload_with_router_mode, simulate_trace_requests,
simulate_trace_requests_with_router_mode, simulate_trace_workload,
simulate_trace_workload_with_router_mode,
};
pub(crate) fn normalize_trace_requests(
......
......@@ -9,7 +9,7 @@ The goal is to simulate trace execution without spinning up async runtimes, netw
The public replay entrypoints live one level up in `lib/mocker/src/replay/entrypoints.rs`. They:
- normalize `MockEngineArgs`
- load or accept `DirectRequest`s
- load or accept `DirectRequest`s or `loadgen::Trace` workloads
- validate replay arguments
- dispatch to offline or online replay
......@@ -42,7 +42,10 @@ The single-worker path is intentionally simple and only used when:
- `num_workers == 1`
- engine type is `vllm`
That path avoids the cluster event queue and router machinery entirely.
That path avoids the cluster event queue and router machinery entirely, but it now supports both:
- flat request replay
- workload-driven replay through `WorkloadDriver` for multi-turn/session traces
```mermaid
flowchart TD
......@@ -63,6 +66,8 @@ Important details:
- Trace mode uses `normalize_trace_requests` in `lib/mocker/src/replay/mod.rs` so the first request starts at `0 ms`, then applies `arrival_speedup_ratio`.
- Concurrency mode ignores original arrival spacing and keeps the worker filled up to `max_in_flight`.
- Workload trace mode honors first-turn timestamps and inter-turn delays.
- Workload concurrency mode ignores first-turn timestamps but still enforces inter-turn delays after completion.
- The worker itself is still the real mocker engine core; only the scheduling loop is simplified.
## Multi-Worker Harness
......@@ -178,13 +183,15 @@ In round-robin mode, this capture is skipped because nothing consumes those even
Both single and multi harnesses support two admission modes:
- Trace mode
- respects input arrival timestamps
- timestamps are normalized so the first request starts at `0 ms`
- `arrival_speedup_ratio` compresses or stretches inter-arrival gaps
- for flat requests, respects input arrival timestamps
- for workloads, respects first-turn timestamps and inter-turn delays
- timestamps are normalized so the first request or first session starts at `0 ms`
- `arrival_speedup_ratio` compresses or stretches inter-arrival gaps and inter-turn delays
- Concurrency mode
- ignores original spacing
- ignores original first-turn spacing
- keeps up to `max_in_flight` requests resident in the cluster
- for workloads, still unlocks follow-up turns only after completion plus inter-turn delay
- stamps synthetic arrival times as requests are admitted
This split is why `lib/mocker/src/replay/offline/mod.rs` exposes both:
......
......@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::loadgen::Trace;
pub(crate) use crate::replay::normalize_trace_requests;
use crate::replay::{ReplayRouterMode, TraceSimulationReport};
use dynamo_kv_router::config::KvRouterConfig;
......@@ -55,3 +56,39 @@ pub(crate) fn simulate_concurrency(
)
}
}
pub(crate) fn simulate_trace_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == crate::common::protocols::EngineType::Vllm {
single::simulate_trace_workload_single(args, trace)
} else {
multi::simulate_trace_workload_multi(args, router_config, trace, num_workers, router_mode)
}
}
pub(crate) fn simulate_concurrency_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
if num_workers == 1 && args.engine_type == crate::common::protocols::EngineType::Vllm {
single::simulate_concurrency_workload_single(args, trace, max_in_flight)
} else {
multi::simulate_concurrency_workload_multi(
args,
router_config,
trace,
max_in_flight,
num_workers,
router_mode,
)
}
}
......@@ -3,9 +3,14 @@
use super::events::{SimulationEvent, SimulationEventKind};
use super::normalize_trace_requests;
#[cfg(test)]
use super::state::OfflineWorkerSnapshot;
use super::state::OfflineWorkerState;
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal};
use crate::loadgen::{ReplayRequestHashes, Trace, WorkloadDriver};
use crate::replay::router::OfflineReplayRouter;
#[cfg(test)]
use crate::replay::router::OfflineRouterSnapshot;
use crate::replay::{ReplayRouterMode, TraceCollector, TraceSimulationReport};
use crate::scheduler::RouterEventVisibility;
use anyhow::bail;
......@@ -20,6 +25,11 @@ enum ReplayMode {
Concurrency { max_in_flight: usize },
}
enum AdmissionSource {
Requests(VecDeque<DirectRequest>),
Workload(WorkloadDriver),
}
#[cfg(test)]
#[derive(Debug, Default, Clone, PartialEq, Eq)]
struct OfflineRuntimeStats {
......@@ -32,6 +42,17 @@ struct OfflineRuntimeStats {
max_router_pending: usize,
}
#[cfg(test)]
#[derive(Debug, Clone, PartialEq)]
struct OfflineRuntimeSnapshot {
now_ms: f64,
worker_active_requests: Vec<Vec<Uuid>>,
workers: Vec<OfflineWorkerSnapshot>,
router_pending_request_ids: Vec<Uuid>,
prefill_completed: Vec<Uuid>,
router: Option<OfflineRouterSnapshot>,
}
#[cfg(not(test))]
#[derive(Debug, Default, Clone, PartialEq, Eq)]
struct OfflineRuntimeStats;
......@@ -40,7 +61,7 @@ struct OfflineRuntime {
now_ms: f64,
next_worker_idx: usize,
next_event_seq: u64,
pending: VecDeque<DirectRequest>,
admission: AdmissionSource,
router_pending: HashMap<Uuid, DirectRequest>,
workers: Vec<OfflineWorkerState>,
collector: TraceCollector,
......@@ -49,6 +70,10 @@ struct OfflineRuntime {
router: Option<OfflineReplayRouter>,
prefill_completed: HashSet<Uuid>,
stats: OfflineRuntimeStats,
#[cfg(test)]
worker_active_requests: Vec<Vec<Uuid>>,
#[cfg(test)]
stepped: bool,
}
impl OfflineRuntime {
......@@ -59,6 +84,42 @@ impl OfflineRuntime {
num_workers: usize,
mode: ReplayMode,
router_mode: ReplayRouterMode,
) -> anyhow::Result<Self> {
Self::new_with_source(
args,
router_config,
AdmissionSource::Requests(pending),
num_workers,
mode,
router_mode,
)
}
fn new_workload(
args: &MockEngineArgs,
router_config: Option<KvRouterConfig>,
driver: WorkloadDriver,
num_workers: usize,
mode: ReplayMode,
router_mode: ReplayRouterMode,
) -> anyhow::Result<Self> {
Self::new_with_source(
args,
router_config,
AdmissionSource::Workload(driver),
num_workers,
mode,
router_mode,
)
}
fn new_with_source(
args: &MockEngineArgs,
router_config: Option<KvRouterConfig>,
admission: AdmissionSource,
num_workers: usize,
mode: ReplayMode,
router_mode: ReplayRouterMode,
) -> anyhow::Result<Self> {
let args = args.clone().normalized()?;
let router = match router_mode {
......@@ -73,7 +134,7 @@ impl OfflineRuntime {
now_ms: 0.0,
next_worker_idx: 0,
next_event_seq: 0,
pending,
admission,
router_pending: HashMap::new(),
workers: (0..num_workers)
.map(|worker_idx| {
......@@ -89,6 +150,10 @@ impl OfflineRuntime {
stats: OfflineRuntimeStats::default(),
#[cfg(not(test))]
stats: OfflineRuntimeStats,
#[cfg(test)]
worker_active_requests: vec![Vec::new(); num_workers],
#[cfg(test)]
stepped: false,
})
}
......@@ -148,6 +213,8 @@ impl OfflineRuntime {
self.validate_worker_idx(worker_idx)?;
self.workers[worker_idx].receive_request(request);
self.record_dispatch(uuid, worker_idx);
#[cfg(test)]
self.worker_active_requests[worker_idx].push(uuid);
Ok(())
}
......@@ -165,6 +232,7 @@ impl OfflineRuntime {
&mut self,
mut request: DirectRequest,
arrival_time_ms: f64,
replay_hashes: Option<ReplayRequestHashes>,
) -> anyhow::Result<Uuid> {
let uuid = request.uuid.unwrap_or_else(Uuid::new_v4);
request.uuid = Some(uuid);
......@@ -186,7 +254,8 @@ impl OfflineRuntime {
return Ok(uuid);
};
let maybe_worker_idx = router.submit_request(&request, self.now_ms)?;
let maybe_worker_idx =
router.submit_request_with_hashes(&request, replay_hashes, self.now_ms)?;
self.record_router_pending();
if let Some(worker_idx) = maybe_worker_idx {
self.dispatch_to_worker(request, uuid, worker_idx)?;
......@@ -199,20 +268,31 @@ impl OfflineRuntime {
}
fn is_done(&self) -> bool {
self.pending.is_empty()
&& self.events.is_empty()
self.events.is_empty()
&& self.cluster_in_flight() == 0
&& match &self.admission {
AdmissionSource::Requests(pending) => pending.is_empty(),
AdmissionSource::Workload(driver) => driver.is_drained(),
}
&& self.workers.iter().all(OfflineWorkerState::is_drained)
}
fn next_timestamp(&self) -> Option<f64> {
fn next_timestamp(&mut self) -> Option<f64> {
let next_event_ms = self.events.peek().map(|event| event.at_ms);
let next_arrival_ms = match self.mode {
ReplayMode::Trace => self
.pending
let cluster_in_flight = self.cluster_in_flight();
let next_arrival_ms = match (&self.mode, &mut self.admission) {
(ReplayMode::Trace, AdmissionSource::Requests(pending)) => pending
.front()
.and_then(|request| request.arrival_timestamp_ms),
ReplayMode::Concurrency { .. } => None,
(ReplayMode::Trace, AdmissionSource::Workload(driver)) => driver.next_ready_time_ms(),
(ReplayMode::Concurrency { max_in_flight }, AdmissionSource::Workload(driver)) => {
if cluster_in_flight < *max_in_flight {
driver.next_ready_time_ms()
} else {
None
}
}
(ReplayMode::Concurrency { .. }, AdmissionSource::Requests(_)) => None,
};
match (next_arrival_ms, next_event_ms) {
......@@ -249,6 +329,8 @@ impl OfflineRuntime {
fn process_output_signal(&mut self, signal: OutputSignal) -> anyhow::Result<()> {
let mut admissions = Vec::new();
if signal.completed {
#[cfg(test)]
self.remove_active_request(signal.uuid);
if let Some(router) = self.router.as_mut() {
admissions = router.free(signal.uuid)?;
#[cfg(test)]
......@@ -258,6 +340,9 @@ impl OfflineRuntime {
self.record_router_pending();
}
self.prefill_completed.remove(&signal.uuid);
if let AdmissionSource::Workload(driver) = &mut self.admission {
driver.on_complete(signal.uuid, self.now_ms)?;
}
self.dispatch_router_admissions(admissions)?;
return Ok(());
}
......@@ -279,6 +364,20 @@ impl OfflineRuntime {
Ok(())
}
#[cfg(test)]
fn remove_active_request(&mut self, uuid: Uuid) {
for active_requests in &mut self.worker_active_requests {
let Some(position) = active_requests
.iter()
.position(|candidate| *candidate == uuid)
else {
continue;
};
active_requests.remove(position);
return;
}
}
fn process_completed_pass(
&mut self,
worker_idx: usize,
......@@ -325,20 +424,39 @@ impl OfflineRuntime {
fn release_trace_arrivals(&mut self) -> anyhow::Result<bool> {
let mut released_any = false;
while self
.pending
.front()
.and_then(|request| request.arrival_timestamp_ms)
.is_some_and(|arrival_ms| arrival_ms <= self.now_ms)
{
let request = self
.pending
.pop_front()
.expect("front request must exist when arrival is ready");
let arrival_ms = request
.arrival_timestamp_ms
.expect("trace replay requests must have an arrival timestamp");
self.assign_request(request, arrival_ms)?;
let mut ready_requests = Vec::new();
match &mut self.admission {
AdmissionSource::Requests(pending) => {
while pending
.front()
.and_then(|request| request.arrival_timestamp_ms)
.is_some_and(|arrival_ms| arrival_ms <= self.now_ms)
{
let request = pending
.pop_front()
.expect("front request must exist when arrival is ready");
let arrival_ms = request
.arrival_timestamp_ms
.expect("trace replay requests must have an arrival timestamp");
ready_requests.push((request, arrival_ms, None));
}
}
AdmissionSource::Workload(driver) => {
ready_requests.extend(driver.pop_ready(self.now_ms, usize::MAX).into_iter().map(
|ready| {
(
ready.request,
ready.scheduled_ready_at_ms,
ready.replay_hashes,
)
},
));
}
}
for (request, arrival_ms, replay_hashes) in ready_requests {
self.assign_request(request, arrival_ms, replay_hashes)?;
released_any = true;
}
......@@ -347,11 +465,33 @@ impl OfflineRuntime {
fn top_off_concurrency(&mut self, max_in_flight: usize) -> anyhow::Result<bool> {
let mut released_any = false;
while self.cluster_in_flight() < max_in_flight {
let Some(request) = self.pending.pop_front() else {
break;
};
self.assign_request(request, self.now_ms)?;
let available = max_in_flight.saturating_sub(self.cluster_in_flight());
if available == 0 {
return Ok(false);
}
let mut ready_requests = Vec::new();
match &mut self.admission {
AdmissionSource::Requests(pending) => {
for _ in 0..available {
let Some(request) = pending.pop_front() else {
break;
};
ready_requests.push((request, None));
}
}
AdmissionSource::Workload(driver) => {
ready_requests.extend(
driver
.pop_ready(self.now_ms, available)
.into_iter()
.map(|ready| (ready.request, ready.replay_hashes)),
);
}
}
for (request, replay_hashes) in ready_requests {
self.assign_request(request, self.now_ms, replay_hashes)?;
released_any = true;
}
......@@ -449,6 +589,55 @@ impl OfflineRuntime {
Ok((self.collector, self.stats))
}
#[cfg(test)]
fn advance_one_timestamp(&mut self) -> anyhow::Result<bool> {
if self.is_done() {
return Ok(false);
}
if !self.stepped {
self.stepped = true;
self.drain_current_timestamp()?;
return Ok(true);
}
let Some(next_timestamp_ms) = self.next_timestamp() else {
bail!(
"offline replay reached a dead end with {} in-flight requests remaining",
self.cluster_in_flight()
);
};
self.now_ms = next_timestamp_ms;
self.drain_current_timestamp()?;
Ok(true)
}
#[cfg(test)]
fn debug_snapshot(&self) -> OfflineRuntimeSnapshot {
let mut router_pending_request_ids =
self.router_pending.keys().copied().collect::<Vec<_>>();
router_pending_request_ids.sort_unstable();
let mut prefill_completed = self.prefill_completed.iter().copied().collect::<Vec<_>>();
prefill_completed.sort_unstable();
OfflineRuntimeSnapshot {
now_ms: self.now_ms,
worker_active_requests: self.worker_active_requests.clone(),
workers: self
.workers
.iter()
.map(OfflineWorkerState::debug_snapshot)
.collect(),
router_pending_request_ids,
prefill_completed,
router: self
.router
.as_ref()
.map(OfflineReplayRouter::debug_snapshot),
}
}
}
pub(crate) fn simulate_trace_multi(
......@@ -495,6 +684,49 @@ pub(crate) fn simulate_concurrency_multi(
Ok(collector.finish())
}
pub(crate) fn simulate_trace_workload_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let driver = trace.into_trace_driver()?;
let (collector, _) = OfflineRuntime::new_workload(
&args,
router_config,
driver,
num_workers,
ReplayMode::Trace,
router_mode,
)?
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_workload_multi(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let driver = trace.into_concurrency_driver()?;
let (collector, _) = OfflineRuntime::new_workload(
&args,
router_config,
driver,
num_workers,
ReplayMode::Concurrency { max_in_flight },
router_mode,
)?
.run()?;
Ok(collector.finish())
}
#[cfg(test)]
fn run_trace_multi_collect_with_stats(
args: &MockEngineArgs,
......@@ -537,11 +769,53 @@ fn run_concurrency_multi_collect_with_stats(
.unwrap()
}
#[cfg(test)]
fn run_trace_workload_multi_collect_with_stats(
args: &MockEngineArgs,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, OfflineRuntimeStats) {
OfflineRuntime::new_workload(
args,
None,
trace.into_trace_driver().unwrap(),
num_workers,
ReplayMode::Trace,
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
fn run_concurrency_workload_multi_collect_with_stats(
args: &MockEngineArgs,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> (TraceCollector, OfflineRuntimeStats) {
OfflineRuntime::new_workload(
args,
None,
trace.into_concurrency_driver().unwrap(),
num_workers,
ReplayMode::Concurrency { max_in_flight },
router_mode,
)
.unwrap()
.run()
.unwrap()
}
#[cfg(test)]
mod tests {
use super::super::single::{run_concurrency_single_collect, run_trace_single_collect};
use super::*;
use crate::common::protocols::{EngineType, SglangArgs};
use crate::loadgen::{SessionTrace, TurnTrace};
use dynamo_kv_router::config::RouterQueuePolicy;
fn replay_args(enable_prefix_caching: bool, enable_chunked_prefill: bool) -> MockEngineArgs {
......@@ -597,6 +871,286 @@ mod tests {
.unwrap()
}
fn multiturn_trace() -> Trace {
Trace {
block_size: 64,
sessions: vec![
SessionTrace {
session_id: "session-a".to_string(),
first_arrival_timestamp_ms: Some(0.0),
turns: vec![
TurnTrace {
input_length: 64,
max_output_tokens: 2,
hash_ids: vec![11],
delay_after_previous_ms: 0.0,
},
TurnTrace {
input_length: 192,
max_output_tokens: 2,
hash_ids: vec![21, 22, 23],
delay_after_previous_ms: 10.0,
},
],
},
SessionTrace {
session_id: "session-b".to_string(),
first_arrival_timestamp_ms: Some(5.0),
turns: vec![TurnTrace {
input_length: 128,
max_output_tokens: 2,
hash_ids: vec![31, 32],
delay_after_previous_ms: 0.0,
}],
},
],
}
}
#[test]
fn test_trace_workload_follow_up_turn_arrives_after_completion_plus_delay() {
let args = fast_router_args();
let (collector, stats) = run_trace_workload_multi_collect_with_stats(
&args,
multiturn_trace(),
2,
ReplayRouterMode::RoundRobin,
);
let first_turn_uuid = *stats
.dispatch_order
.iter()
.find(|uuid| {
collector
.snapshot(**uuid)
.is_some_and(|stats| stats.input_length == 64)
})
.unwrap();
let second_turn_uuid = *stats
.dispatch_order
.iter()
.find(|uuid| {
collector
.snapshot(**uuid)
.is_some_and(|stats| stats.input_length == 192)
})
.unwrap();
let session_b_uuid = *stats
.dispatch_order
.iter()
.find(|uuid| {
collector
.snapshot(**uuid)
.is_some_and(|stats| stats.input_length == 128)
})
.unwrap();
let first_turn = collector.snapshot(first_turn_uuid).unwrap();
let second_turn = collector.snapshot(second_turn_uuid).unwrap();
let session_b = collector.snapshot(session_b_uuid).unwrap();
assert_eq!(first_turn.arrival_time_ms, 0.0);
assert_eq!(session_b.arrival_time_ms, 5.0);
assert!(
second_turn.arrival_time_ms >= first_turn.last_token_ms.unwrap() + 10.0,
"follow-up turn should unlock after completion plus delay"
);
}
#[test]
fn test_concurrency_workload_delayed_follow_up_does_not_bypass_other_ready_sessions() {
let args = fast_router_args();
let (collector, stats) = run_concurrency_workload_multi_collect_with_stats(
&args,
multiturn_trace(),
1,
2,
ReplayRouterMode::RoundRobin,
);
assert_eq!(stats.max_in_flight_seen, 1);
let dispatch_input_lengths = stats
.dispatch_order
.iter()
.map(|uuid| collector.snapshot(*uuid).unwrap().input_length)
.collect::<Vec<_>>();
assert_eq!(dispatch_input_lengths, vec![64, 128, 192]);
}
#[test]
fn test_trace_workload_kv_router_precomputed_hashes_match_request_fallback() {
let args = fast_router_args();
let requests = vec![
DirectRequest {
tokens: [vec![11; 64], vec![21; 32]].concat(),
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(111)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
},
DirectRequest {
tokens: [vec![11; 64], vec![22; 32]].concat(),
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(222)),
dp_rank: 0,
arrival_timestamp_ms: Some(500.0),
},
];
let workload = Trace {
block_size: 64,
sessions: vec![
SessionTrace {
session_id: "session-a".to_string(),
first_arrival_timestamp_ms: Some(0.0),
turns: vec![TurnTrace {
input_length: 96,
max_output_tokens: 2,
hash_ids: vec![11, 21],
delay_after_previous_ms: 0.0,
}],
},
SessionTrace {
session_id: "session-b".to_string(),
first_arrival_timestamp_ms: Some(500.0),
turns: vec![TurnTrace {
input_length: 96,
max_output_tokens: 2,
hash_ids: vec![11, 22],
delay_after_previous_ms: 0.0,
}],
},
],
};
let (request_collector, request_stats) =
run_trace_multi_collect_with_stats(&args, requests, 2, ReplayRouterMode::KvRouter);
let (workload_collector, workload_stats) = run_trace_workload_multi_collect_with_stats(
&args,
workload,
2,
ReplayRouterMode::KvRouter,
);
let request_report = request_collector.finish();
let workload_report = workload_collector.finish();
assert_eq!(request_stats.dispatch_history.len(), 2);
assert_eq!(workload_stats.dispatch_history.len(), 2);
assert_eq!(
request_stats.dispatch_history[0],
request_stats.dispatch_history[1]
);
assert_eq!(
workload_stats.dispatch_history[0],
workload_stats.dispatch_history[1]
);
assert_eq!(
request_report.request_counts.completed_requests,
workload_report.request_counts.completed_requests
);
assert_eq!(
request_report.request_counts.total_input_tokens,
workload_report.request_counts.total_input_tokens
);
assert_eq!(
request_report.request_counts.total_output_tokens,
workload_report.request_counts.total_output_tokens
);
assert_eq!(
request_report.prefix_cache_reused_ratio,
workload_report.prefix_cache_reused_ratio
);
}
#[test]
fn test_multi_worker_trace_kv_router_debug_snapshot_tracks_queue_and_cached_dispatch() {
let args = queueing_router_args(RouterQueuePolicy::Fcfs);
let mut runtime = OfflineRuntime::new(
&args,
None,
normalize_trace_requests(
vec![
DirectRequest {
tokens: vec![11; 64],
max_output_tokens: 8,
uuid: Some(Uuid::from_u128(11)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
},
DirectRequest {
tokens: vec![22; 64],
max_output_tokens: 8,
uuid: Some(Uuid::from_u128(22)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
},
DirectRequest {
tokens: vec![11; 64],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(33)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.1),
},
],
1.0,
)
.unwrap(),
2,
ReplayMode::Trace,
ReplayRouterMode::KvRouter,
)
.unwrap();
assert!(runtime.advance_one_timestamp().unwrap());
let initial = runtime.debug_snapshot();
let initial_router = initial.router.as_ref().unwrap();
assert_eq!(initial.now_ms, 0.0);
assert!(initial.router_pending_request_ids.is_empty());
assert!(initial_router.pending.is_empty());
assert_eq!(
initial
.worker_active_requests
.iter()
.map(Vec::len)
.collect::<Vec<_>>(),
vec![1, 1]
);
assert!(initial_router.indexer.total_cached_blocks > 0);
assert!(runtime.advance_one_timestamp().unwrap());
let queued = runtime.debug_snapshot();
let queued_router = queued.router.as_ref().unwrap();
assert_eq!(queued.now_ms, 0.1);
assert_eq!(queued.router_pending_request_ids, vec![Uuid::from_u128(33)]);
assert_eq!(queued_router.pending.len(), 1);
assert_eq!(queued_router.pending[0].uuid, Uuid::from_u128(33));
let cached_workers = queued_router.pending[0]
.overlap_blocks_by_worker
.iter()
.filter(|(_, overlap)| *overlap > 0)
.map(|(worker_idx, _)| *worker_idx)
.collect::<Vec<_>>();
assert_eq!(cached_workers.len(), 1);
let cached_worker = cached_workers[0];
while !runtime
.stats
.assigned_worker_by_uuid
.contains_key(&Uuid::from_u128(33))
{
assert!(runtime.advance_one_timestamp().unwrap());
}
let dispatched = runtime.debug_snapshot();
assert!(dispatched.router_pending_request_ids.is_empty());
assert_eq!(
runtime.stats.assigned_worker_by_uuid[&Uuid::from_u128(33)],
cached_worker
);
}
#[test]
fn test_multi_worker_trace_round_robin_assigns_same_timestamp_requests_deterministically() {
let args = replay_args(false, true);
......
......@@ -4,6 +4,7 @@
use super::core::ReplayWorkerCore;
use super::normalize_trace_requests;
use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::loadgen::{Trace, WorkloadDriver};
use crate::replay::{TraceCollector, TraceSimulationReport};
use anyhow::bail;
use std::collections::VecDeque;
......@@ -15,9 +16,14 @@ enum SingleReplayMode {
Concurrency { max_in_flight: usize },
}
enum AdmissionSource {
Requests(VecDeque<DirectRequest>),
Workload(WorkloadDriver),
}
struct SingleRuntime {
current_time_ms: f64,
pending: VecDeque<DirectRequest>,
admission: AdmissionSource,
worker: ReplayWorkerCore,
collector: TraceCollector,
mode: SingleReplayMode,
......@@ -25,9 +31,21 @@ struct SingleRuntime {
impl SingleRuntime {
fn new(args: MockEngineArgs, pending: VecDeque<DirectRequest>, mode: SingleReplayMode) -> Self {
Self::new_with_source(args, AdmissionSource::Requests(pending), mode)
}
fn new_workload(args: MockEngineArgs, driver: WorkloadDriver, mode: SingleReplayMode) -> Self {
Self::new_with_source(args, AdmissionSource::Workload(driver), mode)
}
fn new_with_source(
args: MockEngineArgs,
admission: AdmissionSource,
mode: SingleReplayMode,
) -> Self {
Self {
current_time_ms: 0.0,
pending,
admission,
worker: ReplayWorkerCore::new(args),
collector: TraceCollector::default(),
mode,
......@@ -35,36 +53,67 @@ impl SingleRuntime {
}
fn enqueue_trace_arrivals(&mut self) {
loop {
let Some(next_arrival_ms) = self
.pending
.front()
.and_then(|request| request.arrival_timestamp_ms)
else {
break;
};
if next_arrival_ms > self.current_time_ms {
break;
let mut ready_requests = Vec::new();
match &mut self.admission {
AdmissionSource::Requests(pending) => loop {
let Some(next_arrival_ms) = pending
.front()
.and_then(|request| request.arrival_timestamp_ms)
else {
break;
};
if next_arrival_ms > self.current_time_ms {
break;
}
let request = pending
.pop_front()
.expect("front request must exist when arrival is available");
let arrival_ms = request
.arrival_timestamp_ms
.expect("trace replay requests must have an arrival timestamp");
ready_requests.push((request, arrival_ms));
},
AdmissionSource::Workload(driver) => {
ready_requests.extend(
driver
.pop_ready(self.current_time_ms, usize::MAX)
.into_iter()
.map(|ready| (ready.request, ready.scheduled_ready_at_ms)),
);
}
}
let request = self
.pending
.pop_front()
.expect("front request must exist when arrival is available");
let arrival_ms = request
.arrival_timestamp_ms
.expect("trace replay requests must have an arrival timestamp");
for (request, arrival_ms) in ready_requests {
self.record_arrival(request, arrival_ms);
}
}
fn enqueue_concurrency_arrivals(&mut self, max_in_flight: usize) {
while self.worker.num_requests() < max_in_flight {
let Some(mut request) = self.pending.pop_front() else {
break;
};
let available = max_in_flight.saturating_sub(self.worker.num_requests());
let mut ready_requests = Vec::new();
request.arrival_timestamp_ms = Some(self.current_time_ms);
match &mut self.admission {
AdmissionSource::Requests(pending) => {
for _ in 0..available {
let Some(mut request) = pending.pop_front() else {
break;
};
request.arrival_timestamp_ms = Some(self.current_time_ms);
ready_requests.push(request);
}
}
AdmissionSource::Workload(driver) => {
ready_requests.extend(
driver
.pop_ready(self.current_time_ms, available)
.into_iter()
.map(|ready| ready.request),
);
}
}
for request in ready_requests {
self.record_arrival(request, self.current_time_ms);
}
}
......@@ -79,15 +128,21 @@ impl SingleRuntime {
}
fn is_done(&self) -> bool {
self.pending.is_empty() && self.worker.is_empty()
self.worker.is_empty()
&& match &self.admission {
AdmissionSource::Requests(pending) => pending.is_empty(),
AdmissionSource::Workload(driver) => driver.is_drained(),
}
}
fn advance_to_next_trace_arrival(&mut self) -> anyhow::Result<()> {
let Some(next_arrival_ms) = self
.pending
.front()
.and_then(|request| request.arrival_timestamp_ms)
else {
let next_arrival_ms = match &mut self.admission {
AdmissionSource::Requests(pending) => pending
.front()
.and_then(|request| request.arrival_timestamp_ms),
AdmissionSource::Workload(driver) => driver.next_ready_time_ms(),
};
let Some(next_arrival_ms) = next_arrival_ms else {
bail!("trace replay reached an idle state without a pending arrival");
};
self.current_time_ms = next_arrival_ms;
......@@ -99,6 +154,13 @@ impl SingleRuntime {
.worker
.execute_pass(&mut self.collector, self.current_time_ms);
self.current_time_ms = pass.end_ms;
if let AdmissionSource::Workload(driver) = &mut self.admission {
for signal in pass.output_signals.iter().filter(|signal| signal.completed) {
driver
.on_complete(signal.uuid, self.current_time_ms)
.expect("completed workload request must belong to a session");
}
}
if admit_arrivals_between_steps {
self.enqueue_trace_arrivals();
}
......@@ -119,7 +181,11 @@ impl SingleRuntime {
SingleReplayMode::Concurrency { max_in_flight } => {
self.enqueue_concurrency_arrivals(max_in_flight);
if self.worker.is_empty() {
break;
if self.is_done() {
break;
}
self.advance_to_next_trace_arrival()?;
continue;
}
self.drive_worker(false);
}
......@@ -157,6 +223,32 @@ pub(crate) fn simulate_concurrency_single(
Ok(collector.finish())
}
pub(crate) fn simulate_trace_workload_single(
args: MockEngineArgs,
trace: Trace,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let collector =
SingleRuntime::new_workload(args, trace.into_trace_driver()?, SingleReplayMode::Trace)
.run()?;
Ok(collector.finish())
}
pub(crate) fn simulate_concurrency_workload_single(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
) -> anyhow::Result<TraceSimulationReport> {
let args = args.normalized()?;
let collector = SingleRuntime::new_workload(
args,
trace.into_concurrency_driver()?,
SingleReplayMode::Concurrency { max_in_flight },
)
.run()?;
Ok(collector.finish())
}
#[cfg(test)]
pub(super) fn run_trace_single_collect(
args: MockEngineArgs,
......@@ -184,9 +276,39 @@ pub(super) fn run_concurrency_single_collect(
.unwrap()
}
#[cfg(test)]
pub(super) fn run_trace_workload_single_collect(
args: MockEngineArgs,
trace: Trace,
) -> TraceCollector {
SingleRuntime::new_workload(
args,
trace.into_trace_driver().unwrap(),
SingleReplayMode::Trace,
)
.run()
.unwrap()
}
#[cfg(test)]
pub(super) fn run_concurrency_workload_single_collect(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
) -> TraceCollector {
SingleRuntime::new_workload(
args,
trace.into_concurrency_driver().unwrap(),
SingleReplayMode::Concurrency { max_in_flight },
)
.run()
.unwrap()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::loadgen::{SessionTrace, TurnTrace};
use crate::replay::{TraceRequestStatsSnapshot, TraceSimulationReport};
use rstest::rstest;
use std::collections::{HashMap, VecDeque};
......@@ -295,6 +417,42 @@ mod tests {
]
}
fn multiturn_trace_fixture() -> Trace {
Trace {
block_size: 1,
sessions: vec![
SessionTrace {
session_id: "session-a".to_string(),
first_arrival_timestamp_ms: Some(0.0),
turns: vec![
TurnTrace {
input_length: 3,
max_output_tokens: 2,
hash_ids: vec![1, 2, 3],
delay_after_previous_ms: 0.0,
},
TurnTrace {
input_length: 5,
max_output_tokens: 2,
hash_ids: vec![4, 5, 6, 7, 8],
delay_after_previous_ms: 5.0,
},
],
},
SessionTrace {
session_id: "session-b".to_string(),
first_arrival_timestamp_ms: Some(1.0),
turns: vec![TurnTrace {
input_length: 4,
max_output_tokens: 2,
hash_ids: vec![9, 10, 11, 12],
delay_after_previous_ms: 0.0,
}],
},
],
}
}
fn run_trace_manually(
args: &MockEngineArgs,
requests: Vec<DirectRequest>,
......@@ -674,4 +832,44 @@ mod tests {
assert_report_close(&replay_report, &manual.report);
}
#[test]
fn test_trace_workload_single_unlocks_follow_up_turn_after_completion() {
let args = replay_args(false, true);
let collector = run_trace_workload_single_collect(args, multiturn_trace_fixture());
let snapshots = collector.snapshots();
let first = snapshots
.iter()
.find(|stats| stats.input_length == 3)
.unwrap();
let second = snapshots
.iter()
.find(|stats| stats.input_length == 5)
.unwrap();
let other = snapshots
.iter()
.find(|stats| stats.input_length == 4)
.unwrap();
assert_eq!(first.arrival_time_ms, 0.0);
assert_eq!(other.arrival_time_ms, 1.0);
assert!(second.arrival_time_ms >= first.last_token_ms.unwrap() + 5.0);
}
#[test]
fn test_concurrency_workload_single_ignores_first_turn_timestamps_but_keeps_delay() {
let args = replay_args(false, true);
let collector = run_concurrency_workload_single_collect(args, multiturn_trace_fixture(), 1);
let arrival_times = collector
.snapshots()
.into_iter()
.map(|stats| stats.arrival_time_ms)
.collect::<Vec<_>>();
let report = collector.finish();
assert!(arrival_times.contains(&0.0));
assert!(arrival_times.iter().all(|arrival| *arrival >= 0.0));
assert_eq!(report.request_counts.completed_requests, 3);
}
}
......@@ -12,6 +12,15 @@ pub(crate) struct OfflineWorkerState {
in_flight: usize,
}
#[cfg(test)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct OfflineWorkerSnapshot {
pub(crate) busy: bool,
pub(crate) in_flight: usize,
pub(crate) ready: bool,
pub(crate) drained: bool,
}
impl OfflineWorkerState {
pub(crate) fn new(worker_idx: usize, args: MockEngineArgs, capture_kv_events: bool) -> Self {
let core = match args.engine_type {
......@@ -81,4 +90,14 @@ impl OfflineWorkerState {
) -> EnginePassResult {
self.core.execute_pass(collector, now_ms)
}
#[cfg(test)]
pub(crate) fn debug_snapshot(&self) -> OfflineWorkerSnapshot {
OfflineWorkerSnapshot {
busy: self.busy,
in_flight: self.in_flight,
ready: self.is_ready(),
drained: self.is_drained(),
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::time::Instant;
use crate::common::protocols::OutputSignal;
use crate::replay::router::ReplayRouter;
use crate::replay::{TraceCollector, TraceSimulationReport};
use crate::scheduler::AdmissionEvent;
use super::state::{ArrivalEvent, RequestRegistry, SharedLiveRuntimeStats, now_ms};
pub(super) async fn run_demux(
start: Instant,
mut arrival_rx: mpsc::UnboundedReceiver<ArrivalEvent>,
mut admission_rx: mpsc::UnboundedReceiver<AdmissionEvent>,
mut output_rx: mpsc::UnboundedReceiver<OutputSignal>,
requests: RequestRegistry,
router: Arc<ReplayRouter>,
stats: Arc<SharedLiveRuntimeStats>,
) -> TraceSimulationReport {
let mut collector = TraceCollector::default();
let mut arrivals_open = true;
let mut admissions_open = true;
let mut outputs_open = true;
loop {
if !arrivals_open && !admissions_open && !outputs_open {
break;
}
tokio::select! {
biased;
arrival = arrival_rx.recv(), if arrivals_open => {
match arrival {
Some(arrival) => collector.on_arrival(
arrival.uuid,
arrival.at_ms,
arrival.input_tokens,
arrival.output_tokens,
),
None => arrivals_open = false,
}
}
admission = admission_rx.recv(), if admissions_open => {
match admission {
Some(admission) => {
collector.on_admit(admission.uuid, now_ms(start), admission.reused_input_tokens);
}
None => admissions_open = false,
}
}
output = output_rx.recv(), if outputs_open => {
match output {
Some(output) => {
collector.on_token(output.uuid, now_ms(start));
if let Some(state) = requests.get(&output.uuid) {
if state.mark_first_token_once() {
match router.on_first_token(output.uuid).await {
Ok(true) => stats.record_prefill_marked(),
Ok(false) => {}
Err(error) => tracing::warn!(
uuid = %output.uuid,
error = %error,
"online replay failed to mark prefill completed"
),
}
}
if output.completed && state.mark_completed_once() {
match router.on_complete(output.uuid).await {
Ok(true) => stats.record_freed(),
Ok(false) => {}
Err(error) => tracing::warn!(
uuid = %output.uuid,
error = %error,
"online replay failed to free completed request"
),
}
state.notify_completion();
}
}
}
None => outputs_open = false,
}
}
}
}
collector.finish().with_wall_time_ms(now_ms(start))
}
......@@ -3,200 +3,27 @@
use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use anyhow::{Result, anyhow, bail};
use dashmap::DashMap;
use dynamo_kv_router::config::KvRouterConfig;
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, mpsc};
use tokio::sync::{Notify, Semaphore, mpsc};
use tokio::task::JoinSet;
use tokio::time::Instant;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal};
use crate::loadgen::{Trace, WorkloadDriver};
use crate::replay::router::ReplayRouter;
use crate::replay::{
ReplayRouterMode, TraceCollector, TraceSimulationReport, normalize_trace_requests,
};
use crate::replay::{ReplayRouterMode, TraceSimulationReport, normalize_trace_requests};
use crate::scheduler::{AdmissionEvent, EngineScheduler, SchedulerHandle};
#[derive(Clone, Copy, Debug)]
enum LiveReplayMode {
Trace,
Concurrency { max_in_flight: usize },
}
#[derive(Debug, Default, PartialEq, Eq)]
pub(super) struct LiveRuntimeStats {
pub(super) dispatch_history: Vec<usize>,
pub(super) max_in_flight_seen: usize,
pub(super) prefill_marked_count: usize,
pub(super) freed_count: usize,
}
#[derive(Default)]
struct SharedLiveRuntimeStats {
dispatch_history: Mutex<Vec<usize>>,
current_in_flight: AtomicUsize,
max_in_flight_seen: AtomicUsize,
prefill_marked_count: AtomicUsize,
freed_count: AtomicUsize,
}
impl SharedLiveRuntimeStats {
fn record_dispatch(&self, worker_idx: usize) {
self.dispatch_history.lock().unwrap().push(worker_idx);
let current = self.current_in_flight.fetch_add(1, Ordering::AcqRel) + 1;
self.max_in_flight_seen.fetch_max(current, Ordering::AcqRel);
}
fn record_completion(&self) {
self.current_in_flight.fetch_sub(1, Ordering::AcqRel);
}
fn record_prefill_marked(&self) {
self.prefill_marked_count.fetch_add(1, Ordering::AcqRel);
}
fn record_freed(&self) {
self.freed_count.fetch_add(1, Ordering::AcqRel);
}
fn snapshot(&self) -> LiveRuntimeStats {
LiveRuntimeStats {
dispatch_history: self.dispatch_history.lock().unwrap().clone(),
max_in_flight_seen: self.max_in_flight_seen.load(Ordering::Acquire),
prefill_marked_count: self.prefill_marked_count.load(Ordering::Acquire),
freed_count: self.freed_count.load(Ordering::Acquire),
}
}
}
#[derive(Default)]
struct RequestState {
first_token_seen: AtomicBool,
completed_seen: AtomicBool,
completion_notify: Notify,
}
impl RequestState {
fn mark_first_token_once(&self) -> bool {
!self.first_token_seen.swap(true, Ordering::AcqRel)
}
fn mark_completed_once(&self) -> bool {
!self.completed_seen.swap(true, Ordering::AcqRel)
}
fn notify_completion(&self) {
self.completion_notify.notify_waiters();
}
async fn wait_for_completion(&self) {
loop {
let notified = self.completion_notify.notified();
if self.completed_seen.load(Ordering::Acquire) {
return;
}
notified.await;
}
}
}
#[derive(Clone, Copy)]
struct ArrivalEvent {
uuid: Uuid,
at_ms: f64,
input_tokens: usize,
output_tokens: usize,
}
type RequestRegistry = Arc<DashMap<Uuid, Arc<RequestState>>>;
async fn run_demux(
start: Instant,
mut arrival_rx: mpsc::UnboundedReceiver<ArrivalEvent>,
mut admission_rx: mpsc::UnboundedReceiver<AdmissionEvent>,
mut output_rx: mpsc::UnboundedReceiver<OutputSignal>,
requests: RequestRegistry,
router: Arc<ReplayRouter>,
stats: Arc<SharedLiveRuntimeStats>,
) -> TraceSimulationReport {
let mut collector = TraceCollector::default();
let mut arrivals_open = true;
let mut admissions_open = true;
let mut outputs_open = true;
loop {
if !arrivals_open && !admissions_open && !outputs_open {
break;
}
tokio::select! {
biased;
arrival = arrival_rx.recv(), if arrivals_open => {
match arrival {
Some(arrival) => collector.on_arrival(
arrival.uuid,
arrival.at_ms,
arrival.input_tokens,
arrival.output_tokens,
),
None => arrivals_open = false,
}
}
admission = admission_rx.recv(), if admissions_open => {
match admission {
Some(admission) => {
let now_ms = start.elapsed().as_secs_f64() * 1000.0;
collector.on_admit(admission.uuid, now_ms, admission.reused_input_tokens);
}
None => admissions_open = false,
}
}
output = output_rx.recv(), if outputs_open => {
match output {
Some(output) => {
let now_ms = start.elapsed().as_secs_f64() * 1000.0;
collector.on_token(output.uuid, now_ms);
if let Some(state) = requests.get(&output.uuid) {
if state.mark_first_token_once() {
match router.on_first_token(output.uuid).await {
Ok(true) => stats.record_prefill_marked(),
Ok(false) => {}
Err(error) => tracing::warn!(
uuid = %output.uuid,
error = %error,
"online replay failed to mark prefill completed"
),
}
}
if output.completed && state.mark_completed_once() {
match router.on_complete(output.uuid).await {
Ok(true) => stats.record_freed(),
Ok(false) => {}
Err(error) => tracing::warn!(
uuid = %output.uuid,
error = %error,
"online replay failed to free completed request"
),
}
state.notify_completion();
}
}
}
None => outputs_open = false,
}
}
}
}
let wall_time_ms = start.elapsed().as_secs_f64() * 1000.0;
collector.finish().with_wall_time_ms(wall_time_ms)
}
use super::demux::run_demux;
use super::state::{
LiveReplayMode, LiveRuntimeStats, SharedLiveRuntimeStats, WorkloadDispatchState, now_ms,
record_arrival,
};
use super::task::{RequestTaskContext, run_request_task, wait_for_workload_progress};
struct LiveRuntime {
pending: VecDeque<DirectRequest>,
......@@ -210,75 +37,6 @@ struct LiveRuntime {
router: Arc<ReplayRouter>,
}
fn now_ms(start: Instant) -> f64 {
start.elapsed().as_secs_f64() * 1000.0
}
fn request_uuid(request: &DirectRequest) -> Result<Uuid> {
request
.uuid
.ok_or_else(|| anyhow!("online replay requires requests to have stable UUIDs"))
}
fn record_arrival(
arrival_tx: &mpsc::UnboundedSender<ArrivalEvent>,
request: &DirectRequest,
arrival_at_ms: f64,
) -> Result<Uuid> {
let uuid = request_uuid(request)?;
let input_tokens = request.tokens.len();
let output_tokens = request.max_output_tokens;
arrival_tx
.send(ArrivalEvent {
uuid,
at_ms: arrival_at_ms,
input_tokens,
output_tokens,
})
.map_err(|_| anyhow!("online replay arrival channel closed"))?;
Ok(uuid)
}
#[derive(Clone)]
struct RequestTaskContext {
senders: Arc<[mpsc::UnboundedSender<DirectRequest>]>,
router: Arc<ReplayRouter>,
requests: RequestRegistry,
stats: Arc<SharedLiveRuntimeStats>,
}
async fn run_request_task(
ctx: RequestTaskContext,
request: DirectRequest,
permit: Option<OwnedSemaphorePermit>,
) -> Result<()> {
let uuid = request_uuid(&request)?;
let worker_idx = ctx
.router
.select_worker(&request, ctx.senders.len())
.await?;
if worker_idx >= ctx.senders.len() {
bail!("online replay selected unknown worker index {worker_idx}");
}
let state = Arc::new(RequestState::default());
ctx.requests.insert(uuid, Arc::clone(&state));
if let Err(error) = ctx.senders[worker_idx].send(request) {
ctx.requests.remove(&uuid);
return Err(anyhow!(
"online replay failed to dispatch request to worker {worker_idx}: {error}"
));
}
ctx.stats.record_dispatch(worker_idx);
state.wait_for_completion().await;
ctx.stats.record_completion();
ctx.requests.remove(&uuid);
drop(permit);
Ok(())
}
impl LiveRuntime {
fn new(
args: MockEngineArgs,
......@@ -288,10 +46,6 @@ impl LiveRuntime {
mode: LiveReplayMode,
router_mode: ReplayRouterMode,
) -> Result<Self> {
if pending.is_empty() {
bail!("online replay requires at least one request");
}
let cancel_token = CancellationToken::new();
let (output_tx, output_rx) = mpsc::unbounded_channel();
let (admission_tx, admission_rx) = mpsc::unbounded_channel();
......@@ -363,6 +117,7 @@ impl LiveRuntime {
router: Arc::clone(&self.router),
requests: Arc::clone(&requests),
stats: Arc::clone(&stats),
workload: None,
};
match self.mode {
......@@ -404,6 +159,129 @@ impl LiveRuntime {
router.shutdown().await?;
Ok((report, stats.snapshot()))
}
async fn run_workload(
mut self,
driver: WorkloadDriver,
total_turns: usize,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let requests = Arc::new(DashMap::with_capacity(total_turns.max(1)));
let stats = Arc::new(SharedLiveRuntimeStats::default());
let (arrival_tx, arrival_rx) = mpsc::unbounded_channel();
let demux_requests = Arc::clone(&requests);
let start = self.start;
let router = Arc::clone(&self.router);
let senders = Arc::clone(&self.senders);
let output_rx = self.output_rx;
let admission_rx = self.admission_rx;
let demux_stats = Arc::clone(&stats);
let demux_router = Arc::clone(&router);
let demux_task = tokio::spawn(async move {
run_demux(
start,
arrival_rx,
admission_rx,
output_rx,
demux_requests,
demux_router,
demux_stats,
)
.await
});
let workload = Arc::new(WorkloadDispatchState {
driver: std::sync::Mutex::new(driver),
wakeup: Notify::new(),
start,
});
let mut tasks = JoinSet::new();
let task_ctx = RequestTaskContext {
senders,
router: Arc::clone(&self.router),
requests: Arc::clone(&requests),
stats: Arc::clone(&stats),
workload: Some(Arc::clone(&workload)),
};
let semaphore = match self.mode {
LiveReplayMode::Trace => None,
LiveReplayMode::Concurrency { max_in_flight } => {
Some(Arc::new(Semaphore::new(max_in_flight)))
}
};
loop {
let now = now_ms(start);
let dispatch_limit = match &semaphore {
Some(semaphore) => semaphore.available_permits(),
None => usize::MAX,
};
if dispatch_limit > 0 {
let ready_turns = workload
.driver
.lock()
.unwrap()
.pop_ready(now, dispatch_limit);
if !ready_turns.is_empty() {
for ready_turn in ready_turns {
let permit = match &semaphore {
Some(semaphore) => {
Some(semaphore.clone().try_acquire_owned().map_err(|_| {
anyhow!(
"online replay concurrency semaphore unexpectedly closed"
)
})?)
}
None => None,
};
let arrival_at_ms = match self.mode {
LiveReplayMode::Trace => ready_turn.scheduled_ready_at_ms,
LiveReplayMode::Concurrency { .. } => now_ms(start),
};
record_arrival(&arrival_tx, &ready_turn.request, arrival_at_ms)?;
tasks.spawn(run_request_task(
task_ctx.clone(),
ready_turn.request,
permit,
));
}
continue;
}
}
let wake = workload.wakeup.notified();
tokio::pin!(wake);
let (is_drained, next_ready_ms) = {
let mut driver = workload.driver.lock().unwrap();
(driver.is_drained(), driver.next_ready_time_ms())
};
if is_drained {
break;
}
wait_for_workload_progress(
self.mode,
semaphore.as_deref(),
next_ready_ms,
start,
wake.as_mut(),
)
.await;
}
while let Some(result) = tasks.join_next().await {
result.map_err(|e| anyhow!("online replay request task failed: {e}"))??;
}
drop(arrival_tx);
self.cancel_token.cancel();
self.schedulers.clear();
let report = demux_task
.await
.map_err(|e| anyhow!("online replay demux task failed: {e}"))?;
router.shutdown().await?;
Ok((report, stats.snapshot()))
}
}
fn run_live_runtime(
......@@ -426,6 +304,34 @@ fn run_live_runtime(
})
}
fn run_live_workload_runtime(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
driver: WorkloadDriver,
total_turns: usize,
num_workers: usize,
mode: LiveReplayMode,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.map_err(|e| anyhow!("failed to create online replay runtime: {e}"))?;
runtime.block_on(async move {
LiveRuntime::new(
args,
router_config,
VecDeque::new(),
num_workers,
mode,
router_mode,
)?
.run_workload(driver, total_turns)
.await
})
}
pub(crate) fn simulate_trace_requests(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
......@@ -472,8 +378,59 @@ pub(crate) fn simulate_concurrency_requests(
Ok(report)
}
pub(crate) fn simulate_trace_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let total_turns = trace
.sessions
.iter()
.map(|session| session.turns.len())
.sum();
let (report, _) = run_live_workload_runtime(
args,
router_config,
trace.into_trace_driver()?,
total_turns,
num_workers,
LiveReplayMode::Trace,
router_mode,
)?;
Ok(report)
}
pub(crate) fn simulate_concurrency_workload(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
let total_turns = trace
.sessions
.iter()
.map(|session| session.turns.len())
.sum();
let (report, _) = run_live_workload_runtime(
args,
router_config,
trace.into_concurrency_driver()?,
total_turns,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)?;
Ok(report)
}
#[cfg(test)]
fn simulate_trace_requests_with_stats(
pub(super) fn simulate_trace_requests_with_stats(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
num_workers: usize,
......@@ -493,7 +450,7 @@ fn simulate_trace_requests_with_stats(
}
#[cfg(test)]
fn simulate_concurrency_requests_with_stats(
pub(super) fn simulate_concurrency_requests_with_stats(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
......@@ -513,275 +470,50 @@ fn simulate_concurrency_requests_with_stats(
}
#[cfg(test)]
mod tests {
use super::*;
use crate::common::protocols::{DirectRequest, EngineType, SglangArgs};
fn replay_args() -> MockEngineArgs {
MockEngineArgs::builder()
.speedup_ratio(1000.0)
.block_size(64)
.build()
.unwrap()
}
fn sglang_replay_args() -> MockEngineArgs {
MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.num_gpu_blocks(512)
.speedup_ratio(1000.0)
.sglang(Some(SglangArgs {
page_size: Some(2),
..Default::default()
}))
.build()
.unwrap()
}
fn request(uuid: u128, token: u32, arrival_timestamp_ms: Option<f64>) -> DirectRequest {
DirectRequest {
tokens: vec![token; 64],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(uuid)),
dp_rank: 0,
arrival_timestamp_ms,
}
}
#[test]
fn test_online_trace_replay_single_worker_completes() {
let args = replay_args();
let requests = vec![request(1, 11, Some(0.0)), request(2, 22, Some(1.0))];
let report =
simulate_trace_requests(args, None, requests, 1, 1.0, ReplayRouterMode::RoundRobin)
.unwrap();
assert_eq!(report.request_counts.num_requests, 2);
assert_eq!(report.request_counts.completed_requests, 2);
assert_eq!(report.request_counts.total_output_tokens, 4);
assert!(report.throughput.wall_time_ms >= 0.0);
}
#[tokio::test]
async fn test_record_arrival_uses_caller_arrival_timestamp() {
let (arrival_tx, mut arrival_rx) = mpsc::unbounded_channel();
let uuid = Uuid::from_u128(999);
let arrival_at_ms = 123.0;
let request = request(999, 42, Some(arrival_at_ms));
let recorded_uuid = record_arrival(&arrival_tx, &request, arrival_at_ms).unwrap();
let arrival = arrival_rx.recv().await.unwrap();
assert_eq!(recorded_uuid, uuid);
assert_eq!(arrival.uuid, uuid);
assert_eq!(arrival.at_ms, arrival_at_ms);
}
#[tokio::test]
async fn test_trace_arrivals_are_not_blocked_by_queued_router_selection() {
let args = MockEngineArgs::builder()
.speedup_ratio(1000.0)
.block_size(64)
.max_num_seqs(Some(1))
.max_num_batched_tokens(Some(8))
.build()
.unwrap();
let start = Instant::now();
let router = Arc::new(ReplayRouter::new(
ReplayRouterMode::KvRouter,
&args,
None,
1,
));
let senders: Arc<[mpsc::UnboundedSender<DirectRequest>]> =
Arc::from(vec![mpsc::unbounded_channel::<DirectRequest>().0]);
let requests = Arc::new(DashMap::new());
let stats = Arc::new(SharedLiveRuntimeStats::default());
let (arrival_tx, mut arrival_rx) = mpsc::unbounded_channel();
let task_ctx = RequestTaskContext {
senders,
router: Arc::clone(&router),
requests,
stats,
};
let mut tasks = JoinSet::new();
let mut pending = VecDeque::from(vec![
request(1, 11, Some(0.0)),
request(2, 22, Some(1.0)),
request(3, 33, Some(2.0)),
]);
while let Some(request) = pending.pop_front() {
let arrival_ms = request.arrival_timestamp_ms.unwrap_or(0.0);
let deadline = start + tokio::time::Duration::from_secs_f64(arrival_ms / 1000.0);
tokio::time::sleep_until(deadline).await;
record_arrival(&arrival_tx, &request, arrival_ms).unwrap();
tasks.spawn(run_request_task(task_ctx.clone(), request, None));
}
let first = tokio::time::timeout(tokio::time::Duration::from_millis(50), arrival_rx.recv())
.await
.unwrap()
.unwrap();
let second =
tokio::time::timeout(tokio::time::Duration::from_millis(50), arrival_rx.recv())
.await
.unwrap()
.unwrap();
let third = tokio::time::timeout(tokio::time::Duration::from_millis(50), arrival_rx.recv())
.await
.unwrap()
.unwrap();
assert_eq!(first.uuid, Uuid::from_u128(1));
assert_eq!(second.uuid, Uuid::from_u128(2));
assert_eq!(third.uuid, Uuid::from_u128(3));
assert_eq!(first.at_ms, 0.0);
assert_eq!(second.at_ms, 1.0);
assert_eq!(third.at_ms, 2.0);
tasks.abort_all();
router.shutdown().await.unwrap();
}
#[test]
fn test_online_trace_replay_uses_round_robin_dispatch() {
let args = replay_args();
let requests = vec![
request(1, 1, Some(0.0)),
request(2, 2, Some(100.0)),
request(3, 3, Some(200.0)),
request(4, 4, Some(300.0)),
request(5, 5, Some(400.0)),
];
let (_, stats) = simulate_trace_requests_with_stats(
args,
requests,
3,
1.0,
ReplayRouterMode::RoundRobin,
)
.unwrap();
assert_eq!(stats.dispatch_history, vec![0, 1, 2, 0, 1]);
}
#[test]
fn test_online_concurrency_replay_respects_max_in_flight() {
let args = replay_args();
let requests = vec![
request(1, 10, None),
request(2, 20, None),
request(3, 30, None),
request(4, 40, None),
];
let (report, stats) = simulate_concurrency_requests_with_stats(
args,
requests,
2,
2,
ReplayRouterMode::RoundRobin,
)
.unwrap();
assert_eq!(report.request_counts.completed_requests, 4);
assert_eq!(stats.max_in_flight_seen, 2);
}
#[test]
fn test_online_trace_replay_populates_admit_reuse_stats() {
let args = replay_args();
let requests = vec![request(1, 77, Some(0.0)), request(2, 77, Some(5.0))];
let report =
simulate_trace_requests(args, None, requests, 1, 1.0, ReplayRouterMode::RoundRobin)
.unwrap();
assert_eq!(report.request_counts.completed_requests, 2);
assert!(report.prefix_cache_reused_ratio > 0.0);
}
#[test]
fn test_online_trace_replay_kv_router_prefers_cached_worker() {
let args = replay_args();
let requests = vec![request(1, 88, Some(0.0)), request(2, 88, Some(500.0))];
let (_, stats) =
simulate_trace_requests_with_stats(args, requests, 2, 1.0, ReplayRouterMode::KvRouter)
.unwrap();
assert_eq!(stats.dispatch_history.len(), 2);
assert_eq!(stats.dispatch_history[0], stats.dispatch_history[1]);
}
#[test]
fn test_online_trace_replay_sglang_single_worker_completes() {
let args = sglang_replay_args();
let requests = vec![request(101, 7, Some(0.0)), request(102, 8, Some(1.0))];
let report =
simulate_trace_requests(args, None, requests, 1, 1.0, ReplayRouterMode::RoundRobin)
.unwrap();
assert_eq!(report.request_counts.completed_requests, 2);
assert_eq!(report.request_counts.total_output_tokens, 4);
}
#[test]
fn test_online_trace_replay_sglang_kv_router_smoke() {
let args = sglang_replay_args();
let requests = vec![request(111, 9, Some(0.0)), request(112, 9, Some(500.0))];
let (report, stats) =
simulate_trace_requests_with_stats(args, requests, 2, 1.0, ReplayRouterMode::KvRouter)
.unwrap();
assert_eq!(report.request_counts.completed_requests, 2);
assert_eq!(stats.dispatch_history.len(), 2);
}
#[test]
fn test_online_concurrency_replay_kv_router_respects_max_in_flight() {
let args = replay_args();
let requests = vec![
request(1, 10, None),
request(2, 20, None),
request(3, 10, None),
request(4, 20, None),
];
let (report, stats) = simulate_concurrency_requests_with_stats(
args,
requests,
2,
2,
ReplayRouterMode::KvRouter,
)
.unwrap();
assert_eq!(report.request_counts.completed_requests, 4);
assert_eq!(stats.max_in_flight_seen, 2);
}
pub(super) fn simulate_trace_workload_with_stats(
args: MockEngineArgs,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let total_turns = trace
.sessions
.iter()
.map(|session| session.turns.len())
.sum();
run_live_workload_runtime(
args,
None,
trace.into_trace_driver()?,
total_turns,
num_workers,
LiveReplayMode::Trace,
router_mode,
)
}
#[test]
fn test_online_trace_replay_kv_router_marks_prefill_and_free_once() {
let args = replay_args();
let requests = vec![DirectRequest {
tokens: vec![9; 64],
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(9)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
}];
let (_, stats) =
simulate_trace_requests_with_stats(args, requests, 1, 1.0, ReplayRouterMode::KvRouter)
.unwrap();
assert_eq!(stats.prefill_marked_count, 1);
assert_eq!(stats.freed_count, 1);
}
#[cfg(test)]
pub(super) fn simulate_concurrency_workload_with_stats(
args: MockEngineArgs,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
let args = args.normalized()?;
let total_turns = trace
.sessions
.iter()
.map(|session| session.turns.len())
.sum();
run_live_workload_runtime(
args,
None,
trace.into_concurrency_driver()?,
total_turns,
num_workers,
LiveReplayMode::Concurrency { max_in_flight },
router_mode,
)
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
mod runtime;
mod demux;
mod live_runtime;
mod state;
mod task;
pub(crate) use runtime::{simulate_concurrency_requests, simulate_trace_requests};
#[cfg(test)]
mod tests;
pub(crate) use live_runtime::{
simulate_concurrency_requests, simulate_concurrency_workload, simulate_trace_requests,
simulate_trace_workload,
};
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use anyhow::{Result, anyhow};
use dashmap::DashMap;
use tokio::sync::{Notify, mpsc};
use tokio::time::Instant;
use uuid::Uuid;
use crate::common::protocols::DirectRequest;
use crate::loadgen::WorkloadDriver;
#[derive(Clone, Copy, Debug)]
pub(super) enum LiveReplayMode {
Trace,
Concurrency { max_in_flight: usize },
}
#[derive(Debug, Default, PartialEq, Eq)]
pub(super) struct LiveRuntimeStats {
pub(super) dispatch_history: Vec<usize>,
pub(super) max_in_flight_seen: usize,
pub(super) prefill_marked_count: usize,
pub(super) freed_count: usize,
}
#[derive(Default)]
pub(super) struct SharedLiveRuntimeStats {
dispatch_history: Mutex<Vec<usize>>,
current_in_flight: AtomicUsize,
max_in_flight_seen: AtomicUsize,
prefill_marked_count: AtomicUsize,
freed_count: AtomicUsize,
}
impl SharedLiveRuntimeStats {
pub(super) fn record_dispatch(&self, worker_idx: usize) {
self.dispatch_history.lock().unwrap().push(worker_idx);
let current = self.current_in_flight.fetch_add(1, Ordering::AcqRel) + 1;
self.max_in_flight_seen.fetch_max(current, Ordering::AcqRel);
}
pub(super) fn record_completion(&self) {
self.current_in_flight.fetch_sub(1, Ordering::AcqRel);
}
pub(super) fn record_prefill_marked(&self) {
self.prefill_marked_count.fetch_add(1, Ordering::AcqRel);
}
pub(super) fn record_freed(&self) {
self.freed_count.fetch_add(1, Ordering::AcqRel);
}
pub(super) fn snapshot(&self) -> LiveRuntimeStats {
LiveRuntimeStats {
dispatch_history: self.dispatch_history.lock().unwrap().clone(),
max_in_flight_seen: self.max_in_flight_seen.load(Ordering::Acquire),
prefill_marked_count: self.prefill_marked_count.load(Ordering::Acquire),
freed_count: self.freed_count.load(Ordering::Acquire),
}
}
}
#[derive(Default)]
pub(super) struct RequestState {
first_token_seen: AtomicBool,
completed_seen: AtomicBool,
completion_notify: Notify,
}
impl RequestState {
pub(super) fn mark_first_token_once(&self) -> bool {
!self.first_token_seen.swap(true, Ordering::AcqRel)
}
pub(super) fn mark_completed_once(&self) -> bool {
!self.completed_seen.swap(true, Ordering::AcqRel)
}
pub(super) fn notify_completion(&self) {
self.completion_notify.notify_waiters();
}
pub(super) async fn wait_for_completion(&self) {
loop {
let notified = self.completion_notify.notified();
if self.completed_seen.load(Ordering::Acquire) {
return;
}
notified.await;
}
}
}
#[derive(Clone, Copy)]
pub(super) struct ArrivalEvent {
pub(super) uuid: Uuid,
pub(super) at_ms: f64,
pub(super) input_tokens: usize,
pub(super) output_tokens: usize,
}
pub(super) type RequestRegistry = Arc<DashMap<Uuid, Arc<RequestState>>>;
pub(super) struct WorkloadDispatchState {
pub(super) driver: Mutex<WorkloadDriver>,
pub(super) wakeup: Notify,
pub(super) start: Instant,
}
pub(super) fn now_ms(start: Instant) -> f64 {
start.elapsed().as_secs_f64() * 1000.0
}
pub(super) fn request_uuid(request: &DirectRequest) -> Result<Uuid> {
request
.uuid
.ok_or_else(|| anyhow!("online replay requires requests to have stable UUIDs"))
}
pub(super) fn record_arrival(
arrival_tx: &mpsc::UnboundedSender<ArrivalEvent>,
request: &DirectRequest,
arrival_at_ms: f64,
) -> Result<Uuid> {
let uuid = request_uuid(request)?;
let input_tokens = request.tokens.len();
let output_tokens = request.max_output_tokens;
arrival_tx
.send(ArrivalEvent {
uuid,
at_ms: arrival_at_ms,
input_tokens,
output_tokens,
})
.map_err(|_| anyhow!("online replay arrival channel closed"))?;
Ok(uuid)
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use anyhow::{Result, anyhow, bail};
use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc};
use tokio::time::Instant;
use crate::common::protocols::DirectRequest;
use crate::replay::router::ReplayRouter;
use super::state::{
LiveReplayMode, RequestRegistry, RequestState, SharedLiveRuntimeStats, WorkloadDispatchState,
now_ms, request_uuid,
};
#[derive(Clone)]
pub(super) struct RequestTaskContext {
pub(super) senders: Arc<[mpsc::UnboundedSender<DirectRequest>]>,
pub(super) router: Arc<ReplayRouter>,
pub(super) requests: RequestRegistry,
pub(super) stats: Arc<SharedLiveRuntimeStats>,
pub(super) workload: Option<Arc<WorkloadDispatchState>>,
}
pub(super) async fn wait_for_workload_progress<F>(
mode: LiveReplayMode,
semaphore: Option<&Semaphore>,
next_ready_ms: Option<f64>,
start: Instant,
mut wake: Pin<&mut F>,
) where
F: Future<Output = ()>,
{
match (mode, semaphore, next_ready_ms) {
(LiveReplayMode::Trace, _, Some(next_ready_ms)) => {
let deadline = start + tokio::time::Duration::from_secs_f64(next_ready_ms / 1000.0);
tokio::select! {
_ = tokio::time::sleep_until(deadline) => {}
_ = wake.as_mut() => {}
}
}
(LiveReplayMode::Trace, _, None) => {
wake.as_mut().await;
}
(LiveReplayMode::Concurrency { .. }, Some(semaphore), Some(next_ready_ms)) => {
if semaphore.available_permits() == 0 {
wake.as_mut().await;
} else {
let deadline = start + tokio::time::Duration::from_secs_f64(next_ready_ms / 1000.0);
tokio::select! {
_ = tokio::time::sleep_until(deadline) => {}
_ = wake.as_mut() => {}
}
}
}
(LiveReplayMode::Concurrency { .. }, Some(_semaphore), None) => {
wake.as_mut().await;
}
(LiveReplayMode::Concurrency { .. }, None, _) => {
unreachable!("concurrency mode must have a semaphore");
}
}
}
pub(super) async fn run_request_task(
ctx: RequestTaskContext,
request: DirectRequest,
permit: Option<OwnedSemaphorePermit>,
) -> Result<()> {
let uuid = request_uuid(&request)?;
let worker_idx = ctx
.router
.select_worker(&request, ctx.senders.len())
.await?;
if worker_idx >= ctx.senders.len() {
bail!("online replay selected unknown worker index {worker_idx}");
}
let state = Arc::new(RequestState::default());
ctx.requests.insert(uuid, Arc::clone(&state));
if let Err(error) = ctx.senders[worker_idx].send(request) {
ctx.requests.remove(&uuid);
return Err(anyhow!(
"online replay failed to dispatch request to worker {worker_idx}: {error}"
));
}
ctx.stats.record_dispatch(worker_idx);
state.wait_for_completion().await;
ctx.stats.record_completion();
ctx.requests.remove(&uuid);
if let Some(workload) = ctx.workload.as_ref() {
let completion_ms = now_ms(workload.start);
workload
.driver
.lock()
.unwrap()
.on_complete(uuid, completion_ms)?;
workload.wakeup.notify_waiters();
}
drop(permit);
Ok(())
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::Mutex;
use dashmap::DashMap;
use tokio::sync::{Notify, Semaphore, mpsc};
use tokio::task::JoinSet;
use tokio::time::Instant;
use uuid::Uuid;
use crate::common::protocols::{DirectRequest, EngineType, MockEngineArgs, SglangArgs};
use crate::loadgen::{SessionTrace, Trace, TurnTrace};
use crate::replay::ReplayRouterMode;
use crate::replay::router::ReplayRouter;
use super::live_runtime::{
simulate_concurrency_requests_with_stats, simulate_concurrency_workload_with_stats,
simulate_trace_requests, simulate_trace_requests_with_stats,
simulate_trace_workload_with_stats,
};
use super::state::{LiveReplayMode, SharedLiveRuntimeStats, WorkloadDispatchState, record_arrival};
use super::task::{RequestTaskContext, run_request_task, wait_for_workload_progress};
fn replay_args() -> MockEngineArgs {
MockEngineArgs::builder()
.speedup_ratio(1000.0)
.block_size(64)
.build()
.unwrap()
}
fn sglang_replay_args() -> MockEngineArgs {
MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.num_gpu_blocks(512)
.speedup_ratio(1000.0)
.sglang(Some(SglangArgs {
page_size: Some(2),
..Default::default()
}))
.build()
.unwrap()
}
fn request(uuid: u128, token: u32, arrival_timestamp_ms: Option<f64>) -> DirectRequest {
DirectRequest {
tokens: vec![token; 64],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(uuid)),
dp_rank: 0,
arrival_timestamp_ms,
}
}
fn multiturn_trace() -> Trace {
Trace {
block_size: 1,
sessions: vec![
SessionTrace {
session_id: "session-a".to_string(),
first_arrival_timestamp_ms: Some(0.0),
turns: vec![
TurnTrace {
input_length: 4,
max_output_tokens: 2,
hash_ids: vec![11, 12, 13, 14],
delay_after_previous_ms: 0.0,
},
TurnTrace {
input_length: 6,
max_output_tokens: 2,
hash_ids: vec![21, 22, 23, 24, 25, 26],
delay_after_previous_ms: 5.0,
},
],
},
SessionTrace {
session_id: "session-b".to_string(),
first_arrival_timestamp_ms: Some(1.0),
turns: vec![TurnTrace {
input_length: 5,
max_output_tokens: 2,
hash_ids: vec![31, 32, 33, 34, 35],
delay_after_previous_ms: 0.0,
}],
},
],
}
}
#[test]
fn test_online_trace_replay_single_worker_completes() {
let args = replay_args();
let requests = vec![request(1, 11, Some(0.0)), request(2, 22, Some(1.0))];
let report =
simulate_trace_requests(args, None, requests, 1, 1.0, ReplayRouterMode::RoundRobin)
.unwrap();
assert_eq!(report.request_counts.num_requests, 2);
assert_eq!(report.request_counts.completed_requests, 2);
assert_eq!(report.request_counts.total_output_tokens, 4);
assert!(report.throughput.wall_time_ms >= 0.0);
}
#[test]
fn test_online_trace_workload_completes_multiturn_sessions() {
let args = replay_args();
let (report, _) =
simulate_trace_workload_with_stats(args, multiturn_trace(), 2, ReplayRouterMode::KvRouter)
.unwrap();
assert_eq!(report.request_counts.num_requests, 3);
assert_eq!(report.request_counts.completed_requests, 3);
assert_eq!(report.request_counts.total_input_tokens, 15);
assert_eq!(report.request_counts.total_output_tokens, 6);
}
#[test]
fn test_online_concurrency_workload_respects_global_cap() {
let args = replay_args();
let (report, stats) = simulate_concurrency_workload_with_stats(
args,
multiturn_trace(),
1,
2,
ReplayRouterMode::KvRouter,
)
.unwrap();
assert_eq!(report.request_counts.num_requests, 3);
assert_eq!(report.request_counts.completed_requests, 3);
assert_eq!(stats.max_in_flight_seen, 1);
}
#[tokio::test]
async fn test_record_arrival_uses_caller_arrival_timestamp() {
let (arrival_tx, mut arrival_rx) = mpsc::unbounded_channel();
let uuid = Uuid::from_u128(999);
let arrival_at_ms = 123.0;
let request = request(999, 42, Some(arrival_at_ms));
let recorded_uuid = record_arrival(&arrival_tx, &request, arrival_at_ms).unwrap();
let arrival = arrival_rx.recv().await.unwrap();
assert_eq!(recorded_uuid, uuid);
assert_eq!(arrival.uuid, uuid);
assert_eq!(arrival.at_ms, arrival_at_ms);
}
#[tokio::test]
async fn test_trace_arrivals_are_not_blocked_by_queued_router_selection() {
let args = MockEngineArgs::builder()
.speedup_ratio(1000.0)
.block_size(64)
.max_num_seqs(Some(1))
.max_num_batched_tokens(Some(8))
.build()
.unwrap();
let start = Instant::now();
let router = Arc::new(ReplayRouter::new(
ReplayRouterMode::KvRouter,
&args,
None,
1,
));
let senders: Arc<[mpsc::UnboundedSender<DirectRequest>]> =
Arc::from(vec![mpsc::unbounded_channel::<DirectRequest>().0]);
let requests = Arc::new(DashMap::new());
let stats = Arc::new(SharedLiveRuntimeStats::default());
let (arrival_tx, mut arrival_rx) = mpsc::unbounded_channel();
let task_ctx = RequestTaskContext {
senders,
router: Arc::clone(&router),
requests,
stats,
workload: None,
};
let mut tasks = JoinSet::new();
let mut pending = VecDeque::from(vec![
request(1, 11, Some(0.0)),
request(2, 22, Some(1.0)),
request(3, 33, Some(2.0)),
]);
while let Some(request) = pending.pop_front() {
let arrival_ms = request.arrival_timestamp_ms.unwrap_or(0.0);
let deadline = start + tokio::time::Duration::from_secs_f64(arrival_ms / 1000.0);
tokio::time::sleep_until(deadline).await;
record_arrival(&arrival_tx, &request, arrival_ms).unwrap();
tasks.spawn(run_request_task(task_ctx.clone(), request, None));
}
let first = tokio::time::timeout(tokio::time::Duration::from_millis(50), arrival_rx.recv())
.await
.unwrap()
.unwrap();
let second = tokio::time::timeout(tokio::time::Duration::from_millis(50), arrival_rx.recv())
.await
.unwrap()
.unwrap();
let third = tokio::time::timeout(tokio::time::Duration::from_millis(50), arrival_rx.recv())
.await
.unwrap()
.unwrap();
assert_eq!(first.uuid, Uuid::from_u128(1));
assert_eq!(second.uuid, Uuid::from_u128(2));
assert_eq!(third.uuid, Uuid::from_u128(3));
assert_eq!(first.at_ms, 0.0);
assert_eq!(second.at_ms, 1.0);
assert_eq!(third.at_ms, 2.0);
tasks.abort_all();
router.shutdown().await.unwrap();
}
#[tokio::test]
async fn test_workload_wakeup_is_not_lost_when_completion_happens_before_await() {
let mut driver = Trace {
block_size: 1,
sessions: vec![SessionTrace {
session_id: "session-a".to_string(),
first_arrival_timestamp_ms: Some(0.0),
turns: vec![
TurnTrace {
input_length: 4,
max_output_tokens: 1,
hash_ids: vec![1, 2, 3, 4],
delay_after_previous_ms: 0.0,
},
TurnTrace {
input_length: 4,
max_output_tokens: 1,
hash_ids: vec![5, 6, 7, 8],
delay_after_previous_ms: 5.0,
},
],
}],
}
.into_trace_driver()
.unwrap();
let first = driver.pop_ready(0.0, 1);
assert_eq!(first.len(), 1);
let workload = WorkloadDispatchState {
driver: Mutex::new(driver),
wakeup: Notify::new(),
start: Instant::now(),
};
let wake = workload.wakeup.notified();
tokio::pin!(wake);
let (is_drained, next_ready_ms) = {
let mut driver = workload.driver.lock().unwrap();
(driver.is_drained(), driver.next_ready_time_ms())
};
assert!(!is_drained);
assert_eq!(next_ready_ms, None);
{
let mut driver = workload.driver.lock().unwrap();
driver.on_complete(first[0].request_uuid, 5.0).unwrap();
}
workload.wakeup.notify_waiters();
tokio::time::timeout(tokio::time::Duration::from_millis(50), &mut wake)
.await
.unwrap();
assert_eq!(
workload.driver.lock().unwrap().next_ready_time_ms(),
Some(10.0)
);
}
#[tokio::test]
async fn test_concurrency_workload_waits_for_wakeup_when_next_turn_is_completion_gated() {
let semaphore = Arc::new(Semaphore::new(1));
let notify = Arc::new(Notify::new());
let wake = notify.notified();
tokio::pin!(wake);
assert!(
tokio::time::timeout(
tokio::time::Duration::from_millis(20),
wait_for_workload_progress(
LiveReplayMode::Concurrency { max_in_flight: 1 },
Some(semaphore.as_ref()),
None,
Instant::now(),
wake.as_mut(),
),
)
.await
.is_err(),
"concurrency workload should wait for wakeup when no turn is time-ready"
);
let wake = notify.notified();
tokio::pin!(wake);
let wait = wait_for_workload_progress(
LiveReplayMode::Concurrency { max_in_flight: 1 },
Some(semaphore.as_ref()),
None,
Instant::now(),
wake.as_mut(),
);
let notify_task = {
let notify = Arc::clone(&notify);
tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
notify.notify_waiters();
})
};
tokio::time::timeout(tokio::time::Duration::from_millis(50), wait)
.await
.unwrap();
notify_task.await.unwrap();
}
#[test]
fn test_online_trace_replay_uses_round_robin_dispatch() {
let args = replay_args();
let requests = vec![
request(1, 1, Some(0.0)),
request(2, 2, Some(100.0)),
request(3, 3, Some(200.0)),
request(4, 4, Some(300.0)),
request(5, 5, Some(400.0)),
];
let (_, stats) =
simulate_trace_requests_with_stats(args, requests, 3, 1.0, ReplayRouterMode::RoundRobin)
.unwrap();
assert_eq!(stats.dispatch_history, vec![0, 1, 2, 0, 1]);
}
#[test]
fn test_online_concurrency_replay_respects_max_in_flight() {
let args = replay_args();
let requests = vec![
request(1, 10, None),
request(2, 20, None),
request(3, 30, None),
request(4, 40, None),
];
let (report, stats) = simulate_concurrency_requests_with_stats(
args,
requests,
2,
2,
ReplayRouterMode::RoundRobin,
)
.unwrap();
assert_eq!(report.request_counts.completed_requests, 4);
assert_eq!(stats.max_in_flight_seen, 2);
}
#[test]
fn test_online_trace_replay_populates_admit_reuse_stats() {
let args = replay_args();
let requests = vec![request(1, 77, Some(0.0)), request(2, 77, Some(5.0))];
let report =
simulate_trace_requests(args, None, requests, 1, 1.0, ReplayRouterMode::RoundRobin)
.unwrap();
assert_eq!(report.request_counts.completed_requests, 2);
assert!(report.prefix_cache_reused_ratio > 0.0);
}
#[test]
fn test_online_trace_replay_kv_router_prefers_cached_worker() {
let args = replay_args();
let requests = vec![request(1, 88, Some(0.0)), request(2, 88, Some(500.0))];
let (_, stats) =
simulate_trace_requests_with_stats(args, requests, 2, 1.0, ReplayRouterMode::KvRouter)
.unwrap();
assert_eq!(stats.dispatch_history.len(), 2);
assert_eq!(stats.dispatch_history[0], stats.dispatch_history[1]);
}
#[test]
fn test_online_trace_replay_sglang_single_worker_completes() {
let args = sglang_replay_args();
let requests = vec![request(101, 7, Some(0.0)), request(102, 8, Some(1.0))];
let report =
simulate_trace_requests(args, None, requests, 1, 1.0, ReplayRouterMode::RoundRobin)
.unwrap();
assert_eq!(report.request_counts.completed_requests, 2);
assert_eq!(report.request_counts.total_output_tokens, 4);
}
#[test]
fn test_online_trace_replay_sglang_kv_router_smoke() {
let args = sglang_replay_args();
let requests = vec![request(111, 9, Some(0.0)), request(112, 9, Some(500.0))];
let (report, stats) =
simulate_trace_requests_with_stats(args, requests, 2, 1.0, ReplayRouterMode::KvRouter)
.unwrap();
assert_eq!(report.request_counts.completed_requests, 2);
assert_eq!(stats.dispatch_history.len(), 2);
}
#[test]
fn test_online_concurrency_replay_kv_router_respects_max_in_flight() {
let args = replay_args();
let requests = vec![
request(1, 10, None),
request(2, 20, None),
request(3, 10, None),
request(4, 20, None),
];
let (report, stats) =
simulate_concurrency_requests_with_stats(args, requests, 2, 2, ReplayRouterMode::KvRouter)
.unwrap();
assert_eq!(report.request_counts.completed_requests, 4);
assert_eq!(stats.max_in_flight_seen, 2);
}
#[test]
fn test_online_trace_replay_kv_router_marks_prefill_and_free_once() {
let args = replay_args();
let requests = vec![DirectRequest {
tokens: vec![9; 64],
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(9)),
dp_rank: 0,
arrival_timestamp_ms: Some(0.0),
}];
let (_, stats) =
simulate_trace_requests_with_stats(args, requests, 1, 1.0, ReplayRouterMode::KvRouter)
.unwrap();
assert_eq!(stats.prefill_marked_count, 1);
assert_eq!(stats.freed_count, 1);
}
......@@ -6,4 +6,6 @@ mod online;
mod shared;
pub(crate) use offline::OfflineReplayRouter;
#[cfg(test)]
pub(crate) use offline::OfflineRouterSnapshot;
pub(crate) use online::ReplayRouter;
......@@ -7,6 +7,7 @@ use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, Result, anyhow};
use dynamo_kv_router::LocalBlockHash;
use dynamo_kv_router::config::KvRouterConfig;
use dynamo_kv_router::protocols::{
OverlapScores, RouterEvent, WorkerConfigLike, WorkerId, WorkerWithDpRank,
......@@ -26,9 +27,33 @@ use super::shared::{
};
use crate::common::protocols::DirectRequest;
use crate::common::protocols::MockEngineArgs;
use crate::loadgen::ReplayRequestHashes;
type ReplayQueueKey = <RouterSchedulingPolicy as SchedulingPolicy>::Key;
#[cfg(test)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct OfflinePendingRequestSnapshot {
pub(crate) uuid: Uuid,
pub(crate) overlap_blocks_by_worker: Vec<(usize, u32)>,
}
#[cfg(test)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct OfflineIndexerSnapshot {
pub(crate) total_cached_blocks: usize,
pub(crate) cached_blocks_by_worker: Vec<(usize, usize)>,
}
#[cfg(test)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct OfflineRouterSnapshot {
pub(crate) pending: Vec<OfflinePendingRequestSnapshot>,
pub(crate) active_blocks_by_worker: Vec<(usize, usize)>,
pub(crate) active_tokens_by_worker: Vec<(usize, usize)>,
pub(crate) indexer: OfflineIndexerSnapshot,
}
struct SyncReplayIndexer {
block_size: u32,
tree: RadixTree,
......@@ -47,9 +72,30 @@ impl SyncReplayIndexer {
self.tree.find_matches(sequence, false)
}
fn find_matches_for_hashes(&self, local_block_hashes: Vec<LocalBlockHash>) -> OverlapScores {
self.tree.find_matches(local_block_hashes, false)
}
fn apply_event(&mut self, event: RouterEvent) -> Result<()> {
self.tree.apply_event(event).map_err(Into::into)
}
#[cfg(test)]
fn debug_snapshot(&self) -> OfflineIndexerSnapshot {
let mut blocks_by_worker = HashMap::<usize, usize>::new();
for event in self.tree.dump_tree_as_events() {
*blocks_by_worker
.entry(event.worker_id as usize)
.or_default() += 1;
}
let mut cached_blocks_by_worker = blocks_by_worker.into_iter().collect::<Vec<_>>();
cached_blocks_by_worker.sort_unstable_by_key(|(worker_id, _)| *worker_id);
OfflineIndexerSnapshot {
total_cached_blocks: self.tree.current_size(),
cached_blocks_by_worker,
}
}
}
struct PendingRequest {
......@@ -120,7 +166,6 @@ impl PartialOrd for QueueEntry {
pub(crate) struct OfflineReplayRouter {
config: KvRouterConfig,
block_size: u32,
runtime: tokio::runtime::Runtime,
queue_threshold: Option<f64>,
workers_with_configs: HashMap<WorkerId, ReplayWorkerConfig>,
slots: Arc<ActiveSequencesMultiWorker<ReplayNoopPublisher>>,
......@@ -142,10 +187,6 @@ impl OfflineReplayRouter {
let slots = replay_slots(args, &workers_with_configs);
let selector = replay_selector(&config);
let policy = replay_policy(&config, args);
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| anyhow!("failed to create offline replay router runtime: {e}"))?;
let queue_threshold = if num_workers > 1 {
config.router_queue_threshold
} else {
......@@ -155,7 +196,6 @@ impl OfflineReplayRouter {
Ok(Self {
config,
block_size: args.block_size as u32,
runtime,
queue_threshold,
workers_with_configs,
slots,
......@@ -167,12 +207,13 @@ impl OfflineReplayRouter {
})
}
pub(crate) fn submit_request(
pub(crate) fn submit_request_with_hashes(
&mut self,
request: &DirectRequest,
replay_hashes: Option<ReplayRequestHashes>,
now_ms: f64,
) -> Result<Option<usize>> {
let pending = self.build_pending_request(request)?;
let pending = self.build_pending_request(request, replay_hashes)?;
let should_queue = self
.queue_threshold
.is_some_and(|threshold| self.all_workers_busy(threshold));
......@@ -197,15 +238,15 @@ impl OfflineReplayRouter {
}
pub(crate) fn mark_prefill_completed(&mut self, uuid: Uuid) -> Result<Vec<(Uuid, usize)>> {
self.runtime
.block_on(self.slots.mark_prefill_completed(&uuid.to_string()))
self.slots
.mark_prefill_completed_sync(&uuid.to_string())
.map_err(anyhow::Error::from)?;
self.drain_pending()
}
pub(crate) fn free(&mut self, uuid: Uuid) -> Result<Vec<(Uuid, usize)>> {
self.runtime
.block_on(self.slots.free(&uuid.to_string()))
self.slots
.free_sync(&uuid.to_string())
.map_err(anyhow::Error::from)?;
self.drain_pending()
}
......@@ -215,6 +256,58 @@ impl OfflineReplayRouter {
self.pending.len()
}
#[cfg(test)]
pub(crate) fn debug_snapshot(&self) -> OfflineRouterSnapshot {
let mut pending = self
.pending
.iter()
.map(|entry| {
let mut overlap_blocks_by_worker = entry
.request
.overlaps
.scores
.iter()
.map(|(worker, overlap)| (worker.worker_id as usize, *overlap))
.collect::<Vec<_>>();
overlap_blocks_by_worker.sort_unstable_by_key(|(worker_id, _)| *worker_id);
(
entry,
OfflinePendingRequestSnapshot {
uuid: entry.request.uuid,
overlap_blocks_by_worker,
},
)
})
.collect::<Vec<_>>();
pending.sort_unstable_by(|(left_entry, _), (right_entry, _)| {
left_entry.cmp(right_entry).reverse()
});
let mut active_blocks_by_worker = self
.slots
.active_blocks()
.into_iter()
.map(|(worker, blocks)| (worker.worker_id as usize, blocks))
.collect::<Vec<_>>();
active_blocks_by_worker.sort_unstable_by_key(|(worker_id, _)| *worker_id);
let mut active_tokens_by_worker = self
.slots
.active_tokens()
.into_iter()
.map(|(worker, tokens)| (worker.worker_id as usize, tokens))
.collect::<Vec<_>>();
active_tokens_by_worker.sort_unstable_by_key(|(worker_id, _)| *worker_id);
OfflineRouterSnapshot {
pending: pending.into_iter().map(|(_, snapshot)| snapshot).collect(),
active_blocks_by_worker,
active_tokens_by_worker,
indexer: self.indexer.debug_snapshot(),
}
}
pub(crate) fn shutdown(&mut self) {}
fn enqueue_key(&self, now_ms: f64, request: &PendingRequest) -> ReplayQueueKey {
......@@ -225,17 +318,44 @@ impl OfflineReplayRouter {
)
}
fn build_pending_request(&self, request: &DirectRequest) -> Result<PendingRequest> {
fn build_pending_request(
&self,
request: &DirectRequest,
replay_hashes: Option<ReplayRequestHashes>,
) -> Result<PendingRequest> {
let uuid = request
.uuid
.ok_or_else(|| anyhow!("offline replay requires requests to have stable UUIDs"))?;
let overlaps = self.indexer.find_matches_for_request(&request.tokens, None);
let token_seq = self.config.compute_seq_hashes_for_tracking(
&request.tokens,
self.block_size,
None,
None,
);
let (overlaps, token_seq) = match replay_hashes {
Some(replay_hashes) => {
let overlaps = self
.indexer
.find_matches_for_hashes(replay_hashes.local_block_hashes);
let token_seq = if !self.config.router_track_active_blocks {
None
} else if self.config.router_assume_kv_reuse {
Some(replay_hashes.sequence_hashes)
} else {
self.config.compute_seq_hashes_for_tracking(
&request.tokens,
self.block_size,
None,
None,
)
};
(overlaps, token_seq)
}
None => {
let overlaps = self.indexer.find_matches_for_request(&request.tokens, None);
let token_seq = self.config.compute_seq_hashes_for_tracking(
&request.tokens,
self.block_size,
None,
None,
);
(overlaps, token_seq)
}
};
Ok(PendingRequest {
uuid,
......@@ -265,8 +385,8 @@ impl OfflineReplayRouter {
.map_err(|_| anyhow!("selected worker id does not fit into usize"))?;
let request_id = request.request_id();
self.runtime
.block_on(self.slots.add_request(SequenceRequest {
self.slots
.add_request_sync(SequenceRequest {
request_id,
token_sequence: request.token_seq,
isl: request.isl_tokens,
......@@ -274,7 +394,7 @@ impl OfflineReplayRouter {
expected_output_tokens: request.expected_output_tokens,
worker: selection.worker,
lora_name: None,
}))
})
.map_err(anyhow::Error::from)?;
Ok(worker_idx)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment