Unverified Commit 40768e9c authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore(mocker): speed up replay hash hot paths (#7698)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 4abafffe
......@@ -2088,6 +2088,7 @@ dependencies = [
"ndarray-npy",
"rand 0.9.2",
"rstest 0.18.2",
"rustc-hash 2.1.1",
"serde",
"serde_json",
"slotmap",
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Rust-native offline replay benchmark entrypoint.
//!
//! Useful for profiling replay itself without the Python CLI wrapper. This keeps
//! the default mocker perf model unless CLI overrides are provided.
use std::fs::File;
use std::path::PathBuf;
use std::time::Instant;
use anyhow::{Context, Result};
use clap::{Parser, ValueEnum};
use dynamo_mocker::common::protocols::MockEngineArgs;
use dynamo_mocker::replay::{ReplayRouterMode, simulate_trace_file_with_router_mode};
#[derive(Copy, Clone, Debug, Eq, PartialEq, ValueEnum)]
enum RouterModeArg {
RoundRobin,
KvRouter,
}
impl From<RouterModeArg> for ReplayRouterMode {
fn from(value: RouterModeArg) -> Self {
match value {
RouterModeArg::RoundRobin => ReplayRouterMode::RoundRobin,
RouterModeArg::KvRouter => ReplayRouterMode::KvRouter,
}
}
}
#[derive(Parser, Debug)]
#[command(name = "offline_replay_bench")]
#[command(about = "Run offline replay directly in Rust for benchmarking and profiling")]
struct Args {
/// Mooncake trace JSONL file
trace_file: PathBuf,
/// Number of aggregated workers
#[arg(long, default_value_t = 4)]
num_workers: usize,
/// Router mode for multi-worker replay
#[arg(long, value_enum, default_value_t = RouterModeArg::KvRouter)]
router_mode: RouterModeArg,
/// Compress trace arrival timestamps by this factor
#[arg(long, default_value_t = 4.0)]
arrival_speedup_ratio: f64,
/// Mocker block size; defaults to 512 for Mooncake traces
#[arg(long, default_value_t = 512)]
block_size: usize,
/// Override max running requests per worker
#[arg(long)]
max_num_seqs: Option<usize>,
/// Override batched token budget per worker pass
#[arg(long)]
max_num_batched_tokens: Option<usize>,
/// Global speedup multiplier for the default perf model
#[arg(long)]
speedup_ratio: Option<f64>,
/// Additional decode-only speedup multiplier
#[arg(long)]
decode_speedup_ratio: Option<f64>,
/// Explicit planner profile NPZ to use for perf-model timing
#[arg(long)]
planner_profile_data: Option<PathBuf>,
/// Optional path to write the full replay report as pretty JSON
#[arg(long)]
report_json: Option<PathBuf>,
/// Number of times to rerun the same replay in-process
#[arg(long, default_value_t = 1)]
iterations: usize,
}
fn build_engine_args(args: &Args) -> Result<MockEngineArgs> {
let mut builder = MockEngineArgs::builder();
builder = builder.block_size(args.block_size);
if let Some(max_num_seqs) = args.max_num_seqs {
builder = builder.max_num_seqs(Some(max_num_seqs));
}
if let Some(max_num_batched_tokens) = args.max_num_batched_tokens {
builder = builder.max_num_batched_tokens(Some(max_num_batched_tokens));
}
if let Some(speedup_ratio) = args.speedup_ratio {
builder = builder.speedup_ratio(speedup_ratio);
}
if let Some(decode_speedup_ratio) = args.decode_speedup_ratio {
builder = builder.decode_speedup_ratio(decode_speedup_ratio);
}
if let Some(planner_profile_data) = args.planner_profile_data.as_ref() {
builder = builder.planner_profile_data(Some(planner_profile_data.clone()));
}
builder
.build()
.context("failed to build replay engine args")?
.normalized()
}
fn main() -> Result<()> {
let args = Args::parse();
let engine_args = build_engine_args(&args)?;
let started_at = Instant::now();
let mut last_report = None;
for _ in 0..args.iterations {
last_report = Some(simulate_trace_file_with_router_mode(
engine_args.clone(),
None,
&args.trace_file,
args.num_workers,
args.arrival_speedup_ratio,
args.router_mode.into(),
)?);
}
let report = last_report.expect("iterations must be at least 1");
let process_wall_time_ms = started_at.elapsed().as_secs_f64() * 1000.0;
if let Some(report_path) = args.report_json.as_ref() {
let file = File::create(report_path)
.with_context(|| format!("failed to create report file at {:?}", report_path))?;
serde_json::to_writer_pretty(file, &report)
.with_context(|| format!("failed to write report JSON to {:?}", report_path))?;
println!("Saved report to {}", report_path.display());
}
println!("Offline replay report");
println!(
" completed_requests: {}",
report.request_counts.completed_requests
);
println!(
" request_throughput_rps: {:.6}",
report.throughput.request_throughput_rps
);
println!(
" output_throughput_tok_s: {:.6}",
report.throughput.output_throughput_tok_s
);
println!(" mean_ttft_ms: {:.6}", report.latency.ttft.mean_ms);
println!(" mean_e2e_latency_ms: {:.6}", report.latency.e2e.mean_ms);
println!(
" prefix_cache_reused_ratio: {:.6}",
report.prefix_cache_reused_ratio
);
println!(" wall_time_ms: {:.6}", report.throughput.wall_time_ms);
println!(" process_wall_time_ms: {:.6}", process_wall_time_ms);
Ok(())
}
......@@ -1712,6 +1712,7 @@ dependencies = [
"ndarray-interp",
"ndarray-npy",
"rand 0.9.2",
"rustc-hash 2.1.1",
"serde",
"serde_json",
"slotmap",
......
......@@ -1728,6 +1728,7 @@ dependencies = [
"ndarray-interp",
"ndarray-npy",
"rand 0.9.2",
"rustc-hash 2.1.1",
"serde",
"serde_json",
"slotmap",
......
......@@ -36,6 +36,7 @@ ndarray = "0.16"
slotmap = "1"
ndarray-npy = "0.9"
ndarray-interp = "0.5"
rustc-hash = "2"
[target.'cfg(target_os = "linux")'.dependencies]
tokio-timerfd = "0.2"
......
......@@ -3,12 +3,12 @@
use crate::common::evictor::LRUEvictor;
use dynamo_tokens::blocks::UniqueBlock;
use std::collections::HashMap;
use rustc_hash::FxHashMap;
/// Hash-based KV cache with O(1) block lookups, maintaining active (ref-counted) and
/// inactive (LRU-evictable) pools.
pub struct HashCache {
active_blocks: HashMap<UniqueBlock, usize>,
active_blocks: FxHashMap<UniqueBlock, usize>,
inactive_blocks: LRUEvictor<UniqueBlock>,
max_capacity: usize,
}
......@@ -17,7 +17,7 @@ impl HashCache {
/// Create a new HashCache with the given maximum block capacity.
pub fn new(max_capacity: usize) -> Self {
Self {
active_blocks: HashMap::new(),
active_blocks: FxHashMap::default(),
inactive_blocks: LRUEvictor::default(),
max_capacity,
}
......@@ -148,7 +148,7 @@ impl HashCache {
}
/// Direct access to active blocks map (for tests that check ref counts).
pub fn active_blocks(&self) -> &HashMap<UniqueBlock, usize> {
pub fn active_blocks(&self) -> &FxHashMap<UniqueBlock, usize> {
&self.active_blocks
}
}
......@@ -2,9 +2,11 @@
// SPDX-License-Identifier: Apache-2.0
use std::cmp::{Eq, Ordering};
use std::collections::{BTreeSet, HashMap};
use std::collections::BTreeSet;
use std::hash::Hash;
use rustc_hash::FxHashMap;
/// A wrapper for (T, counter) that implements Ord based only on counter
#[derive(Debug, Clone, Eq, PartialEq)]
struct PriorityItem<T> {
......@@ -28,7 +30,7 @@ impl<T: Eq> PartialOrd for PriorityItem<T> {
/// priority counter. Lower counter values are evicted first.
#[derive(Debug)]
pub struct LRUEvictor<T: Clone + Eq + Hash> {
free_table: HashMap<T, i64>,
free_table: FxHashMap<T, i64>,
priority_queue: BTreeSet<PriorityItem<T>>,
positive_counter: i64,
negative_counter: i64,
......@@ -37,7 +39,7 @@ pub struct LRUEvictor<T: Clone + Eq + Hash> {
impl<T: Clone + Eq + Hash> Default for LRUEvictor<T> {
fn default() -> Self {
Self {
free_table: HashMap::new(),
free_table: FxHashMap::default(),
priority_queue: BTreeSet::new(),
positive_counter: 0,
negative_counter: 0,
......
......@@ -44,7 +44,7 @@ use dynamo_kv_router::protocols::{
};
use dynamo_tokens::blocks::UniqueBlock;
use dynamo_tokens::{BlockHash, SequenceHash};
use std::collections::HashMap;
use rustc_hash::FxHashMap;
pub struct KvManager {
cache: HashCache,
......@@ -351,7 +351,7 @@ impl KvManager {
}
/// Direct access to active blocks map (for tests).
pub fn active_blocks(&self) -> &HashMap<UniqueBlock, usize> {
pub fn active_blocks(&self) -> &FxHashMap<UniqueBlock, usize> {
self.cache.active_blocks()
}
......
......@@ -2,9 +2,10 @@
// SPDX-License-Identifier: Apache-2.0
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap};
use std::collections::BinaryHeap;
use anyhow::{Result, anyhow, bail};
use rustc_hash::FxHashMap;
use uuid::Uuid;
use super::types::{ReadyTurn, ReplayRequestHashes, Trace};
......@@ -76,7 +77,7 @@ impl PartialOrd for ReadySession {
pub struct WorkloadDriver {
mode: DriverMode,
sessions: Vec<SessionRuntime>,
in_flight: HashMap<Uuid, InFlightTurn>,
in_flight: FxHashMap<Uuid, InFlightTurn>,
ready_sessions: BinaryHeap<ReadySession>,
}
......@@ -136,7 +137,7 @@ impl WorkloadDriver {
Ok(Self {
mode,
sessions,
in_flight: HashMap::new(),
in_flight: FxHashMap::default(),
ready_sessions,
})
}
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::HashMap;
use rustc_hash::FxHashMap;
use serde::Serialize;
use serde::ser::{SerializeMap, Serializer};
use uuid::Uuid;
......@@ -186,7 +185,7 @@ pub(crate) struct TraceRequestStatsSnapshot {
#[derive(Debug, Default)]
pub(crate) struct TraceCollector {
requests: HashMap<Uuid, TraceRequestStats>,
requests: FxHashMap<Uuid, TraceRequestStats>,
}
impl TraceRequestStats {
......@@ -259,11 +258,12 @@ impl TraceCollector {
pub(crate) fn finish(self) -> TraceSimulationReport {
let requests = self.requests;
let mut ttfts = Vec::new();
let mut ttsts = Vec::new();
let mut tpots = Vec::new();
let request_count = requests.len();
let mut ttfts = Vec::with_capacity(request_count);
let mut ttsts = Vec::with_capacity(request_count);
let mut tpots = Vec::with_capacity(request_count);
let mut itls = Vec::new();
let mut e2e_latencies = Vec::new();
let mut e2e_latencies = Vec::with_capacity(request_count);
let mut output_token_throughput_per_user = Vec::new();
let mut duration_ms = 0.0_f64;
let mut total_input_tokens = 0usize;
......@@ -309,10 +309,10 @@ impl TraceCollector {
}
let duration_s = (duration_ms / 1000.0).max(1e-9);
let itl_distribution = build_distribution_stats(&itls);
let itl_distribution = build_distribution_stats(itls);
TraceSimulationReport {
request_counts: TraceRequestCounts {
num_requests: requests.len(),
num_requests: request_count,
completed_requests,
total_input_tokens,
total_output_tokens,
......@@ -332,16 +332,16 @@ impl TraceCollector {
total_reused_tokens as f64 / total_input_tokens as f64
},
latency: TraceLatencyStats {
ttft: build_distribution_stats(&ttfts),
ttst: build_distribution_stats(&ttsts),
tpot: build_distribution_stats(&tpots),
ttft: build_distribution_stats(ttfts),
ttst: build_distribution_stats(ttsts),
tpot: build_distribution_stats(tpots),
itl: TraceInterTokenLatencyStats {
max_ms: itl_distribution.max_ms,
distribution: itl_distribution,
},
e2e: build_distribution_stats(&e2e_latencies),
e2e: build_distribution_stats(e2e_latencies),
output_token_throughput_per_user: build_distribution_stats(
&output_token_throughput_per_user,
output_token_throughput_per_user,
),
},
}
......@@ -387,7 +387,7 @@ fn mean(values: &[f64]) -> f64 {
}
}
fn build_distribution_stats(values: &[f64]) -> TraceDistributionStats {
fn build_distribution_stats(mut values: Vec<f64>) -> TraceDistributionStats {
if values.is_empty() {
return TraceDistributionStats {
mean_ms: 0.0,
......@@ -402,24 +402,39 @@ fn build_distribution_stats(values: &[f64]) -> TraceDistributionStats {
};
}
let mut sorted = values.to_vec();
sorted.sort_by(|left, right| left.total_cmp(right));
let min_ms = values
.iter()
.copied()
.min_by(|left, right| left.total_cmp(right))
.expect("non-empty values must have a minimum");
let max_ms = values
.iter()
.copied()
.max_by(|left, right| left.total_cmp(right))
.expect("non-empty values must have a maximum");
TraceDistributionStats {
mean_ms: mean(values),
min_ms: sorted[0],
max_ms: *sorted.last().expect("sorted values must be non-empty"),
median_ms: percentile_sorted(&sorted, 50.0),
p75_ms: percentile_sorted(&sorted, 75.0),
p90_ms: percentile_sorted(&sorted, 90.0),
p95_ms: percentile_sorted(&sorted, 95.0),
p99_ms: percentile_sorted(&sorted, 99.0),
std_ms: std_dev(values),
mean_ms: mean(&values),
min_ms,
max_ms,
median_ms: percentile_in_place(&mut values, 50.0),
p75_ms: percentile_in_place(&mut values, 75.0),
p90_ms: percentile_in_place(&mut values, 90.0),
p95_ms: percentile_in_place(&mut values, 95.0),
p99_ms: percentile_in_place(&mut values, 99.0),
std_ms: std_dev(&values),
}
}
fn percentile_sorted(sorted: &[f64], percentile: f64) -> f64 {
let rank = ((sorted.len() - 1) as f64 * percentile / 100.0).round() as usize;
sorted[rank.min(sorted.len() - 1)]
fn percentile_in_place(values: &mut [f64], percentile: f64) -> f64 {
let rank = percentile_rank(values.len(), percentile);
let (_, selected, _) = values.select_nth_unstable_by(rank, |left, right| left.total_cmp(right));
*selected
}
fn percentile_rank(len: usize, percentile: f64) -> usize {
let rank = ((len - 1) as f64 * percentile / 100.0).round() as usize;
rank.min(len - 1)
}
fn std_dev(values: &[f64]) -> f64 {
......@@ -438,3 +453,58 @@ fn std_dev(values: &[f64]) -> f64 {
/ values.len() as f64;
variance.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
fn build_distribution_stats_sorted(values: &[f64]) -> TraceDistributionStats {
if values.is_empty() {
return TraceDistributionStats {
mean_ms: 0.0,
min_ms: 0.0,
max_ms: 0.0,
median_ms: 0.0,
p75_ms: 0.0,
p90_ms: 0.0,
p95_ms: 0.0,
p99_ms: 0.0,
std_ms: 0.0,
};
}
let mut sorted = values.to_vec();
sorted.sort_by(|left, right| left.total_cmp(right));
TraceDistributionStats {
mean_ms: mean(values),
min_ms: sorted[0],
max_ms: *sorted.last().expect("sorted values must be non-empty"),
median_ms: sorted[percentile_rank(sorted.len(), 50.0)],
p75_ms: sorted[percentile_rank(sorted.len(), 75.0)],
p90_ms: sorted[percentile_rank(sorted.len(), 90.0)],
p95_ms: sorted[percentile_rank(sorted.len(), 95.0)],
p99_ms: sorted[percentile_rank(sorted.len(), 99.0)],
std_ms: std_dev(values),
}
}
#[test]
fn build_distribution_stats_matches_sorted_baseline() {
let values = vec![
0.0, 1.0, 1.0, 2.5, 4.0, 4.0, 7.25, 9.5, 15.0, 22.0, 22.0, 100.0,
];
let expected = build_distribution_stats_sorted(&values);
let actual = build_distribution_stats(values);
assert_eq!(actual.mean_ms, expected.mean_ms);
assert_eq!(actual.min_ms, expected.min_ms);
assert_eq!(actual.max_ms, expected.max_ms);
assert_eq!(actual.median_ms, expected.median_ms);
assert_eq!(actual.p75_ms, expected.p75_ms);
assert_eq!(actual.p90_ms, expected.p90_ms);
assert_eq!(actual.p95_ms, expected.p95_ms);
assert_eq!(actual.p99_ms, expected.p99_ms);
assert_eq!(actual.std_ms, expected.std_ms);
}
}
......@@ -19,7 +19,10 @@ use crate::scheduler::RouterEventVisibility;
use anyhow::bail;
use dynamo_kv_router::config::KvRouterConfig;
use dynamo_kv_router::protocols::RouterEvent;
use std::collections::{BinaryHeap, HashMap, VecDeque};
use rustc_hash::FxHashMap;
#[cfg(test)]
use std::collections::HashMap;
use std::collections::{BinaryHeap, VecDeque};
use uuid::Uuid;
#[derive(Debug, Clone, Copy)]
......@@ -65,7 +68,7 @@ pub(super) struct AggRuntime {
next_worker_idx: usize,
next_event_seq: u64,
admission: AdmissionSource,
requests: HashMap<Uuid, AggRequestState>,
requests: FxHashMap<Uuid, AggRequestState>,
workers: Vec<OfflineWorkerState>,
collector: TraceCollector,
events: BinaryHeap<SimulationEvent>,
......@@ -140,7 +143,7 @@ impl AggRuntime {
next_worker_idx: 0,
next_event_seq: 0,
admission,
requests: HashMap::new(),
requests: FxHashMap::default(),
workers: (0..num_workers)
.map(|worker_idx| {
OfflineWorkerState::new(worker_idx, args.clone(), capture_kv_events)
......
......@@ -5,6 +5,7 @@ use std::collections::HashMap;
use std::future;
use std::sync::Arc;
use crate::common::protocols::MockEngineArgs;
use dynamo_kv_router::config::KvRouterConfig;
use dynamo_kv_router::protocols::{
ActiveLoad, ActiveSequenceEvent, WorkerConfigLike, WorkerId, WorkerWithDpRank,
......@@ -15,8 +16,6 @@ use dynamo_kv_router::{
SequencePublisher,
};
use crate::common::protocols::MockEngineArgs;
#[derive(Clone, Copy, Debug, Default)]
pub(super) struct ReplayNoopPublisher;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::{HashMap, HashSet, VecDeque};
use std::collections::VecDeque;
use std::time::Duration;
use dynamo_kv_router::protocols::WorkerId;
use dynamo_tokens::blocks::UniqueBlock;
use rustc_hash::{FxHashMap, FxHashSet};
use tokio::sync::mpsc;
use uuid::Uuid;
......@@ -39,10 +40,10 @@ pub(crate) struct VllmRequestState {
#[derive(Default)]
pub(crate) struct SchedulerState {
pub(crate) waiting: VecDeque<Uuid>,
waiting_members: HashSet<Uuid>,
waiting_members: FxHashSet<Uuid>,
pub(crate) running: VecDeque<Uuid>,
running_members: HashSet<Uuid>,
pub(crate) requests: HashMap<Uuid, VllmRequestState>,
running_members: FxHashSet<Uuid>,
pub(crate) requests: FxHashMap<Uuid, VllmRequestState>,
}
struct PreemptedRequest {
......@@ -292,7 +293,7 @@ impl VllmCore {
let requests_before = self.state.requests.len();
self.state.compact_running();
let mut token_budget = self.args.max_num_batched_tokens.unwrap_or(usize::MAX);
let mut scheduled = HashMap::new();
let mut scheduled = FxHashMap::default();
let mut batch_count = 0usize;
let mut batch_total_isl = 0usize;
let mut batch_total_prefix = 0usize;
......@@ -411,7 +412,7 @@ impl VllmCore {
uuid: Uuid,
from_waiting: bool,
token_budget: &mut usize,
scheduled: &mut HashMap<Uuid, ScheduledWork>,
scheduled: &mut FxHashMap<Uuid, ScheduledWork>,
batch_count: &mut usize,
batch_total_isl: &mut usize,
batch_total_prefix: &mut usize,
......@@ -679,7 +680,7 @@ impl VllmCore {
}
}
fn request_sequence_len(requests: &HashMap<Uuid, VllmRequestState>, uuid: Uuid) -> usize {
fn request_sequence_len(requests: &FxHashMap<Uuid, VllmRequestState>, uuid: Uuid) -> usize {
requests
.get(&uuid)
.map(|request| request.sequence.len())
......@@ -716,7 +717,7 @@ fn debug_assert_vllm_request_progress(uuid: Uuid, request: &VllmRequestState) {
}
}
fn debug_assert_vllm_ready_to_decode(requests: &HashMap<Uuid, VllmRequestState>, uuid: Uuid) {
fn debug_assert_vllm_ready_to_decode(requests: &FxHashMap<Uuid, VllmRequestState>, uuid: Uuid) {
#[cfg(debug_assertions)]
{
let Some(request) = requests.get(&uuid) else {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment