Unverified Commit 40768e9c authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore(mocker): speed up replay hash hot paths (#7698)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 4abafffe
...@@ -2088,6 +2088,7 @@ dependencies = [ ...@@ -2088,6 +2088,7 @@ dependencies = [
"ndarray-npy", "ndarray-npy",
"rand 0.9.2", "rand 0.9.2",
"rstest 0.18.2", "rstest 0.18.2",
"rustc-hash 2.1.1",
"serde", "serde",
"serde_json", "serde_json",
"slotmap", "slotmap",
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Rust-native offline replay benchmark entrypoint.
//!
//! Useful for profiling replay itself without the Python CLI wrapper. This keeps
//! the default mocker perf model unless CLI overrides are provided.
use std::fs::File;
use std::path::PathBuf;
use std::time::Instant;
use anyhow::{Context, Result};
use clap::{Parser, ValueEnum};
use dynamo_mocker::common::protocols::MockEngineArgs;
use dynamo_mocker::replay::{ReplayRouterMode, simulate_trace_file_with_router_mode};
#[derive(Copy, Clone, Debug, Eq, PartialEq, ValueEnum)]
enum RouterModeArg {
RoundRobin,
KvRouter,
}
impl From<RouterModeArg> for ReplayRouterMode {
fn from(value: RouterModeArg) -> Self {
match value {
RouterModeArg::RoundRobin => ReplayRouterMode::RoundRobin,
RouterModeArg::KvRouter => ReplayRouterMode::KvRouter,
}
}
}
#[derive(Parser, Debug)]
#[command(name = "offline_replay_bench")]
#[command(about = "Run offline replay directly in Rust for benchmarking and profiling")]
struct Args {
/// Mooncake trace JSONL file
trace_file: PathBuf,
/// Number of aggregated workers
#[arg(long, default_value_t = 4)]
num_workers: usize,
/// Router mode for multi-worker replay
#[arg(long, value_enum, default_value_t = RouterModeArg::KvRouter)]
router_mode: RouterModeArg,
/// Compress trace arrival timestamps by this factor
#[arg(long, default_value_t = 4.0)]
arrival_speedup_ratio: f64,
/// Mocker block size; defaults to 512 for Mooncake traces
#[arg(long, default_value_t = 512)]
block_size: usize,
/// Override max running requests per worker
#[arg(long)]
max_num_seqs: Option<usize>,
/// Override batched token budget per worker pass
#[arg(long)]
max_num_batched_tokens: Option<usize>,
/// Global speedup multiplier for the default perf model
#[arg(long)]
speedup_ratio: Option<f64>,
/// Additional decode-only speedup multiplier
#[arg(long)]
decode_speedup_ratio: Option<f64>,
/// Explicit planner profile NPZ to use for perf-model timing
#[arg(long)]
planner_profile_data: Option<PathBuf>,
/// Optional path to write the full replay report as pretty JSON
#[arg(long)]
report_json: Option<PathBuf>,
/// Number of times to rerun the same replay in-process
#[arg(long, default_value_t = 1)]
iterations: usize,
}
fn build_engine_args(args: &Args) -> Result<MockEngineArgs> {
let mut builder = MockEngineArgs::builder();
builder = builder.block_size(args.block_size);
if let Some(max_num_seqs) = args.max_num_seqs {
builder = builder.max_num_seqs(Some(max_num_seqs));
}
if let Some(max_num_batched_tokens) = args.max_num_batched_tokens {
builder = builder.max_num_batched_tokens(Some(max_num_batched_tokens));
}
if let Some(speedup_ratio) = args.speedup_ratio {
builder = builder.speedup_ratio(speedup_ratio);
}
if let Some(decode_speedup_ratio) = args.decode_speedup_ratio {
builder = builder.decode_speedup_ratio(decode_speedup_ratio);
}
if let Some(planner_profile_data) = args.planner_profile_data.as_ref() {
builder = builder.planner_profile_data(Some(planner_profile_data.clone()));
}
builder
.build()
.context("failed to build replay engine args")?
.normalized()
}
fn main() -> Result<()> {
let args = Args::parse();
let engine_args = build_engine_args(&args)?;
let started_at = Instant::now();
let mut last_report = None;
for _ in 0..args.iterations {
last_report = Some(simulate_trace_file_with_router_mode(
engine_args.clone(),
None,
&args.trace_file,
args.num_workers,
args.arrival_speedup_ratio,
args.router_mode.into(),
)?);
}
let report = last_report.expect("iterations must be at least 1");
let process_wall_time_ms = started_at.elapsed().as_secs_f64() * 1000.0;
if let Some(report_path) = args.report_json.as_ref() {
let file = File::create(report_path)
.with_context(|| format!("failed to create report file at {:?}", report_path))?;
serde_json::to_writer_pretty(file, &report)
.with_context(|| format!("failed to write report JSON to {:?}", report_path))?;
println!("Saved report to {}", report_path.display());
}
println!("Offline replay report");
println!(
" completed_requests: {}",
report.request_counts.completed_requests
);
println!(
" request_throughput_rps: {:.6}",
report.throughput.request_throughput_rps
);
println!(
" output_throughput_tok_s: {:.6}",
report.throughput.output_throughput_tok_s
);
println!(" mean_ttft_ms: {:.6}", report.latency.ttft.mean_ms);
println!(" mean_e2e_latency_ms: {:.6}", report.latency.e2e.mean_ms);
println!(
" prefix_cache_reused_ratio: {:.6}",
report.prefix_cache_reused_ratio
);
println!(" wall_time_ms: {:.6}", report.throughput.wall_time_ms);
println!(" process_wall_time_ms: {:.6}", process_wall_time_ms);
Ok(())
}
...@@ -1712,6 +1712,7 @@ dependencies = [ ...@@ -1712,6 +1712,7 @@ dependencies = [
"ndarray-interp", "ndarray-interp",
"ndarray-npy", "ndarray-npy",
"rand 0.9.2", "rand 0.9.2",
"rustc-hash 2.1.1",
"serde", "serde",
"serde_json", "serde_json",
"slotmap", "slotmap",
......
...@@ -1728,6 +1728,7 @@ dependencies = [ ...@@ -1728,6 +1728,7 @@ dependencies = [
"ndarray-interp", "ndarray-interp",
"ndarray-npy", "ndarray-npy",
"rand 0.9.2", "rand 0.9.2",
"rustc-hash 2.1.1",
"serde", "serde",
"serde_json", "serde_json",
"slotmap", "slotmap",
......
...@@ -36,6 +36,7 @@ ndarray = "0.16" ...@@ -36,6 +36,7 @@ ndarray = "0.16"
slotmap = "1" slotmap = "1"
ndarray-npy = "0.9" ndarray-npy = "0.9"
ndarray-interp = "0.5" ndarray-interp = "0.5"
rustc-hash = "2"
[target.'cfg(target_os = "linux")'.dependencies] [target.'cfg(target_os = "linux")'.dependencies]
tokio-timerfd = "0.2" tokio-timerfd = "0.2"
......
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
use crate::common::evictor::LRUEvictor; use crate::common::evictor::LRUEvictor;
use dynamo_tokens::blocks::UniqueBlock; use dynamo_tokens::blocks::UniqueBlock;
use std::collections::HashMap; use rustc_hash::FxHashMap;
/// Hash-based KV cache with O(1) block lookups, maintaining active (ref-counted) and /// Hash-based KV cache with O(1) block lookups, maintaining active (ref-counted) and
/// inactive (LRU-evictable) pools. /// inactive (LRU-evictable) pools.
pub struct HashCache { pub struct HashCache {
active_blocks: HashMap<UniqueBlock, usize>, active_blocks: FxHashMap<UniqueBlock, usize>,
inactive_blocks: LRUEvictor<UniqueBlock>, inactive_blocks: LRUEvictor<UniqueBlock>,
max_capacity: usize, max_capacity: usize,
} }
...@@ -17,7 +17,7 @@ impl HashCache { ...@@ -17,7 +17,7 @@ impl HashCache {
/// Create a new HashCache with the given maximum block capacity. /// Create a new HashCache with the given maximum block capacity.
pub fn new(max_capacity: usize) -> Self { pub fn new(max_capacity: usize) -> Self {
Self { Self {
active_blocks: HashMap::new(), active_blocks: FxHashMap::default(),
inactive_blocks: LRUEvictor::default(), inactive_blocks: LRUEvictor::default(),
max_capacity, max_capacity,
} }
...@@ -148,7 +148,7 @@ impl HashCache { ...@@ -148,7 +148,7 @@ impl HashCache {
} }
/// Direct access to active blocks map (for tests that check ref counts). /// Direct access to active blocks map (for tests that check ref counts).
pub fn active_blocks(&self) -> &HashMap<UniqueBlock, usize> { pub fn active_blocks(&self) -> &FxHashMap<UniqueBlock, usize> {
&self.active_blocks &self.active_blocks
} }
} }
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::cmp::{Eq, Ordering}; use std::cmp::{Eq, Ordering};
use std::collections::{BTreeSet, HashMap}; use std::collections::BTreeSet;
use std::hash::Hash; use std::hash::Hash;
use rustc_hash::FxHashMap;
/// A wrapper for (T, counter) that implements Ord based only on counter /// A wrapper for (T, counter) that implements Ord based only on counter
#[derive(Debug, Clone, Eq, PartialEq)] #[derive(Debug, Clone, Eq, PartialEq)]
struct PriorityItem<T> { struct PriorityItem<T> {
...@@ -28,7 +30,7 @@ impl<T: Eq> PartialOrd for PriorityItem<T> { ...@@ -28,7 +30,7 @@ impl<T: Eq> PartialOrd for PriorityItem<T> {
/// priority counter. Lower counter values are evicted first. /// priority counter. Lower counter values are evicted first.
#[derive(Debug)] #[derive(Debug)]
pub struct LRUEvictor<T: Clone + Eq + Hash> { pub struct LRUEvictor<T: Clone + Eq + Hash> {
free_table: HashMap<T, i64>, free_table: FxHashMap<T, i64>,
priority_queue: BTreeSet<PriorityItem<T>>, priority_queue: BTreeSet<PriorityItem<T>>,
positive_counter: i64, positive_counter: i64,
negative_counter: i64, negative_counter: i64,
...@@ -37,7 +39,7 @@ pub struct LRUEvictor<T: Clone + Eq + Hash> { ...@@ -37,7 +39,7 @@ pub struct LRUEvictor<T: Clone + Eq + Hash> {
impl<T: Clone + Eq + Hash> Default for LRUEvictor<T> { impl<T: Clone + Eq + Hash> Default for LRUEvictor<T> {
fn default() -> Self { fn default() -> Self {
Self { Self {
free_table: HashMap::new(), free_table: FxHashMap::default(),
priority_queue: BTreeSet::new(), priority_queue: BTreeSet::new(),
positive_counter: 0, positive_counter: 0,
negative_counter: 0, negative_counter: 0,
......
...@@ -44,7 +44,7 @@ use dynamo_kv_router::protocols::{ ...@@ -44,7 +44,7 @@ use dynamo_kv_router::protocols::{
}; };
use dynamo_tokens::blocks::UniqueBlock; use dynamo_tokens::blocks::UniqueBlock;
use dynamo_tokens::{BlockHash, SequenceHash}; use dynamo_tokens::{BlockHash, SequenceHash};
use std::collections::HashMap; use rustc_hash::FxHashMap;
pub struct KvManager { pub struct KvManager {
cache: HashCache, cache: HashCache,
...@@ -351,7 +351,7 @@ impl KvManager { ...@@ -351,7 +351,7 @@ impl KvManager {
} }
/// Direct access to active blocks map (for tests). /// Direct access to active blocks map (for tests).
pub fn active_blocks(&self) -> &HashMap<UniqueBlock, usize> { pub fn active_blocks(&self) -> &FxHashMap<UniqueBlock, usize> {
self.cache.active_blocks() self.cache.active_blocks()
} }
......
...@@ -2,9 +2,10 @@ ...@@ -2,9 +2,10 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::cmp::Ordering; use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap}; use std::collections::BinaryHeap;
use anyhow::{Result, anyhow, bail}; use anyhow::{Result, anyhow, bail};
use rustc_hash::FxHashMap;
use uuid::Uuid; use uuid::Uuid;
use super::types::{ReadyTurn, ReplayRequestHashes, Trace}; use super::types::{ReadyTurn, ReplayRequestHashes, Trace};
...@@ -76,7 +77,7 @@ impl PartialOrd for ReadySession { ...@@ -76,7 +77,7 @@ impl PartialOrd for ReadySession {
pub struct WorkloadDriver { pub struct WorkloadDriver {
mode: DriverMode, mode: DriverMode,
sessions: Vec<SessionRuntime>, sessions: Vec<SessionRuntime>,
in_flight: HashMap<Uuid, InFlightTurn>, in_flight: FxHashMap<Uuid, InFlightTurn>,
ready_sessions: BinaryHeap<ReadySession>, ready_sessions: BinaryHeap<ReadySession>,
} }
...@@ -136,7 +137,7 @@ impl WorkloadDriver { ...@@ -136,7 +137,7 @@ impl WorkloadDriver {
Ok(Self { Ok(Self {
mode, mode,
sessions, sessions,
in_flight: HashMap::new(), in_flight: FxHashMap::default(),
ready_sessions, ready_sessions,
}) })
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::collections::HashMap; use rustc_hash::FxHashMap;
use serde::Serialize; use serde::Serialize;
use serde::ser::{SerializeMap, Serializer}; use serde::ser::{SerializeMap, Serializer};
use uuid::Uuid; use uuid::Uuid;
...@@ -186,7 +185,7 @@ pub(crate) struct TraceRequestStatsSnapshot { ...@@ -186,7 +185,7 @@ pub(crate) struct TraceRequestStatsSnapshot {
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub(crate) struct TraceCollector { pub(crate) struct TraceCollector {
requests: HashMap<Uuid, TraceRequestStats>, requests: FxHashMap<Uuid, TraceRequestStats>,
} }
impl TraceRequestStats { impl TraceRequestStats {
...@@ -259,11 +258,12 @@ impl TraceCollector { ...@@ -259,11 +258,12 @@ impl TraceCollector {
pub(crate) fn finish(self) -> TraceSimulationReport { pub(crate) fn finish(self) -> TraceSimulationReport {
let requests = self.requests; let requests = self.requests;
let mut ttfts = Vec::new(); let request_count = requests.len();
let mut ttsts = Vec::new(); let mut ttfts = Vec::with_capacity(request_count);
let mut tpots = Vec::new(); let mut ttsts = Vec::with_capacity(request_count);
let mut tpots = Vec::with_capacity(request_count);
let mut itls = Vec::new(); let mut itls = Vec::new();
let mut e2e_latencies = Vec::new(); let mut e2e_latencies = Vec::with_capacity(request_count);
let mut output_token_throughput_per_user = Vec::new(); let mut output_token_throughput_per_user = Vec::new();
let mut duration_ms = 0.0_f64; let mut duration_ms = 0.0_f64;
let mut total_input_tokens = 0usize; let mut total_input_tokens = 0usize;
...@@ -309,10 +309,10 @@ impl TraceCollector { ...@@ -309,10 +309,10 @@ impl TraceCollector {
} }
let duration_s = (duration_ms / 1000.0).max(1e-9); let duration_s = (duration_ms / 1000.0).max(1e-9);
let itl_distribution = build_distribution_stats(&itls); let itl_distribution = build_distribution_stats(itls);
TraceSimulationReport { TraceSimulationReport {
request_counts: TraceRequestCounts { request_counts: TraceRequestCounts {
num_requests: requests.len(), num_requests: request_count,
completed_requests, completed_requests,
total_input_tokens, total_input_tokens,
total_output_tokens, total_output_tokens,
...@@ -332,16 +332,16 @@ impl TraceCollector { ...@@ -332,16 +332,16 @@ impl TraceCollector {
total_reused_tokens as f64 / total_input_tokens as f64 total_reused_tokens as f64 / total_input_tokens as f64
}, },
latency: TraceLatencyStats { latency: TraceLatencyStats {
ttft: build_distribution_stats(&ttfts), ttft: build_distribution_stats(ttfts),
ttst: build_distribution_stats(&ttsts), ttst: build_distribution_stats(ttsts),
tpot: build_distribution_stats(&tpots), tpot: build_distribution_stats(tpots),
itl: TraceInterTokenLatencyStats { itl: TraceInterTokenLatencyStats {
max_ms: itl_distribution.max_ms, max_ms: itl_distribution.max_ms,
distribution: itl_distribution, distribution: itl_distribution,
}, },
e2e: build_distribution_stats(&e2e_latencies), e2e: build_distribution_stats(e2e_latencies),
output_token_throughput_per_user: build_distribution_stats( output_token_throughput_per_user: build_distribution_stats(
&output_token_throughput_per_user, output_token_throughput_per_user,
), ),
}, },
} }
...@@ -387,7 +387,7 @@ fn mean(values: &[f64]) -> f64 { ...@@ -387,7 +387,7 @@ fn mean(values: &[f64]) -> f64 {
} }
} }
fn build_distribution_stats(values: &[f64]) -> TraceDistributionStats { fn build_distribution_stats(mut values: Vec<f64>) -> TraceDistributionStats {
if values.is_empty() { if values.is_empty() {
return TraceDistributionStats { return TraceDistributionStats {
mean_ms: 0.0, mean_ms: 0.0,
...@@ -402,24 +402,39 @@ fn build_distribution_stats(values: &[f64]) -> TraceDistributionStats { ...@@ -402,24 +402,39 @@ fn build_distribution_stats(values: &[f64]) -> TraceDistributionStats {
}; };
} }
let mut sorted = values.to_vec(); let min_ms = values
sorted.sort_by(|left, right| left.total_cmp(right)); .iter()
.copied()
.min_by(|left, right| left.total_cmp(right))
.expect("non-empty values must have a minimum");
let max_ms = values
.iter()
.copied()
.max_by(|left, right| left.total_cmp(right))
.expect("non-empty values must have a maximum");
TraceDistributionStats { TraceDistributionStats {
mean_ms: mean(values), mean_ms: mean(&values),
min_ms: sorted[0], min_ms,
max_ms: *sorted.last().expect("sorted values must be non-empty"), max_ms,
median_ms: percentile_sorted(&sorted, 50.0), median_ms: percentile_in_place(&mut values, 50.0),
p75_ms: percentile_sorted(&sorted, 75.0), p75_ms: percentile_in_place(&mut values, 75.0),
p90_ms: percentile_sorted(&sorted, 90.0), p90_ms: percentile_in_place(&mut values, 90.0),
p95_ms: percentile_sorted(&sorted, 95.0), p95_ms: percentile_in_place(&mut values, 95.0),
p99_ms: percentile_sorted(&sorted, 99.0), p99_ms: percentile_in_place(&mut values, 99.0),
std_ms: std_dev(values), std_ms: std_dev(&values),
} }
} }
fn percentile_sorted(sorted: &[f64], percentile: f64) -> f64 { fn percentile_in_place(values: &mut [f64], percentile: f64) -> f64 {
let rank = ((sorted.len() - 1) as f64 * percentile / 100.0).round() as usize; let rank = percentile_rank(values.len(), percentile);
sorted[rank.min(sorted.len() - 1)] let (_, selected, _) = values.select_nth_unstable_by(rank, |left, right| left.total_cmp(right));
*selected
}
fn percentile_rank(len: usize, percentile: f64) -> usize {
let rank = ((len - 1) as f64 * percentile / 100.0).round() as usize;
rank.min(len - 1)
} }
fn std_dev(values: &[f64]) -> f64 { fn std_dev(values: &[f64]) -> f64 {
...@@ -438,3 +453,58 @@ fn std_dev(values: &[f64]) -> f64 { ...@@ -438,3 +453,58 @@ fn std_dev(values: &[f64]) -> f64 {
/ values.len() as f64; / values.len() as f64;
variance.sqrt() variance.sqrt()
} }
#[cfg(test)]
mod tests {
use super::*;
fn build_distribution_stats_sorted(values: &[f64]) -> TraceDistributionStats {
if values.is_empty() {
return TraceDistributionStats {
mean_ms: 0.0,
min_ms: 0.0,
max_ms: 0.0,
median_ms: 0.0,
p75_ms: 0.0,
p90_ms: 0.0,
p95_ms: 0.0,
p99_ms: 0.0,
std_ms: 0.0,
};
}
let mut sorted = values.to_vec();
sorted.sort_by(|left, right| left.total_cmp(right));
TraceDistributionStats {
mean_ms: mean(values),
min_ms: sorted[0],
max_ms: *sorted.last().expect("sorted values must be non-empty"),
median_ms: sorted[percentile_rank(sorted.len(), 50.0)],
p75_ms: sorted[percentile_rank(sorted.len(), 75.0)],
p90_ms: sorted[percentile_rank(sorted.len(), 90.0)],
p95_ms: sorted[percentile_rank(sorted.len(), 95.0)],
p99_ms: sorted[percentile_rank(sorted.len(), 99.0)],
std_ms: std_dev(values),
}
}
#[test]
fn build_distribution_stats_matches_sorted_baseline() {
let values = vec![
0.0, 1.0, 1.0, 2.5, 4.0, 4.0, 7.25, 9.5, 15.0, 22.0, 22.0, 100.0,
];
let expected = build_distribution_stats_sorted(&values);
let actual = build_distribution_stats(values);
assert_eq!(actual.mean_ms, expected.mean_ms);
assert_eq!(actual.min_ms, expected.min_ms);
assert_eq!(actual.max_ms, expected.max_ms);
assert_eq!(actual.median_ms, expected.median_ms);
assert_eq!(actual.p75_ms, expected.p75_ms);
assert_eq!(actual.p90_ms, expected.p90_ms);
assert_eq!(actual.p95_ms, expected.p95_ms);
assert_eq!(actual.p99_ms, expected.p99_ms);
assert_eq!(actual.std_ms, expected.std_ms);
}
}
...@@ -19,7 +19,10 @@ use crate::scheduler::RouterEventVisibility; ...@@ -19,7 +19,10 @@ use crate::scheduler::RouterEventVisibility;
use anyhow::bail; use anyhow::bail;
use dynamo_kv_router::config::KvRouterConfig; use dynamo_kv_router::config::KvRouterConfig;
use dynamo_kv_router::protocols::RouterEvent; use dynamo_kv_router::protocols::RouterEvent;
use std::collections::{BinaryHeap, HashMap, VecDeque}; use rustc_hash::FxHashMap;
#[cfg(test)]
use std::collections::HashMap;
use std::collections::{BinaryHeap, VecDeque};
use uuid::Uuid; use uuid::Uuid;
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
...@@ -65,7 +68,7 @@ pub(super) struct AggRuntime { ...@@ -65,7 +68,7 @@ pub(super) struct AggRuntime {
next_worker_idx: usize, next_worker_idx: usize,
next_event_seq: u64, next_event_seq: u64,
admission: AdmissionSource, admission: AdmissionSource,
requests: HashMap<Uuid, AggRequestState>, requests: FxHashMap<Uuid, AggRequestState>,
workers: Vec<OfflineWorkerState>, workers: Vec<OfflineWorkerState>,
collector: TraceCollector, collector: TraceCollector,
events: BinaryHeap<SimulationEvent>, events: BinaryHeap<SimulationEvent>,
...@@ -140,7 +143,7 @@ impl AggRuntime { ...@@ -140,7 +143,7 @@ impl AggRuntime {
next_worker_idx: 0, next_worker_idx: 0,
next_event_seq: 0, next_event_seq: 0,
admission, admission,
requests: HashMap::new(), requests: FxHashMap::default(),
workers: (0..num_workers) workers: (0..num_workers)
.map(|worker_idx| { .map(|worker_idx| {
OfflineWorkerState::new(worker_idx, args.clone(), capture_kv_events) OfflineWorkerState::new(worker_idx, args.clone(), capture_kv_events)
......
...@@ -5,6 +5,7 @@ use std::collections::HashMap; ...@@ -5,6 +5,7 @@ use std::collections::HashMap;
use std::future; use std::future;
use std::sync::Arc; use std::sync::Arc;
use crate::common::protocols::MockEngineArgs;
use dynamo_kv_router::config::KvRouterConfig; use dynamo_kv_router::config::KvRouterConfig;
use dynamo_kv_router::protocols::{ use dynamo_kv_router::protocols::{
ActiveLoad, ActiveSequenceEvent, WorkerConfigLike, WorkerId, WorkerWithDpRank, ActiveLoad, ActiveSequenceEvent, WorkerConfigLike, WorkerId, WorkerWithDpRank,
...@@ -15,8 +16,6 @@ use dynamo_kv_router::{ ...@@ -15,8 +16,6 @@ use dynamo_kv_router::{
SequencePublisher, SequencePublisher,
}; };
use crate::common::protocols::MockEngineArgs;
#[derive(Clone, Copy, Debug, Default)] #[derive(Clone, Copy, Debug, Default)]
pub(super) struct ReplayNoopPublisher; pub(super) struct ReplayNoopPublisher;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::collections::{HashMap, HashSet, VecDeque}; use std::collections::VecDeque;
use std::time::Duration; use std::time::Duration;
use dynamo_kv_router::protocols::WorkerId; use dynamo_kv_router::protocols::WorkerId;
use dynamo_tokens::blocks::UniqueBlock; use dynamo_tokens::blocks::UniqueBlock;
use rustc_hash::{FxHashMap, FxHashSet};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use uuid::Uuid; use uuid::Uuid;
...@@ -39,10 +40,10 @@ pub(crate) struct VllmRequestState { ...@@ -39,10 +40,10 @@ pub(crate) struct VllmRequestState {
#[derive(Default)] #[derive(Default)]
pub(crate) struct SchedulerState { pub(crate) struct SchedulerState {
pub(crate) waiting: VecDeque<Uuid>, pub(crate) waiting: VecDeque<Uuid>,
waiting_members: HashSet<Uuid>, waiting_members: FxHashSet<Uuid>,
pub(crate) running: VecDeque<Uuid>, pub(crate) running: VecDeque<Uuid>,
running_members: HashSet<Uuid>, running_members: FxHashSet<Uuid>,
pub(crate) requests: HashMap<Uuid, VllmRequestState>, pub(crate) requests: FxHashMap<Uuid, VllmRequestState>,
} }
struct PreemptedRequest { struct PreemptedRequest {
...@@ -292,7 +293,7 @@ impl VllmCore { ...@@ -292,7 +293,7 @@ impl VllmCore {
let requests_before = self.state.requests.len(); let requests_before = self.state.requests.len();
self.state.compact_running(); self.state.compact_running();
let mut token_budget = self.args.max_num_batched_tokens.unwrap_or(usize::MAX); let mut token_budget = self.args.max_num_batched_tokens.unwrap_or(usize::MAX);
let mut scheduled = HashMap::new(); let mut scheduled = FxHashMap::default();
let mut batch_count = 0usize; let mut batch_count = 0usize;
let mut batch_total_isl = 0usize; let mut batch_total_isl = 0usize;
let mut batch_total_prefix = 0usize; let mut batch_total_prefix = 0usize;
...@@ -411,7 +412,7 @@ impl VllmCore { ...@@ -411,7 +412,7 @@ impl VllmCore {
uuid: Uuid, uuid: Uuid,
from_waiting: bool, from_waiting: bool,
token_budget: &mut usize, token_budget: &mut usize,
scheduled: &mut HashMap<Uuid, ScheduledWork>, scheduled: &mut FxHashMap<Uuid, ScheduledWork>,
batch_count: &mut usize, batch_count: &mut usize,
batch_total_isl: &mut usize, batch_total_isl: &mut usize,
batch_total_prefix: &mut usize, batch_total_prefix: &mut usize,
...@@ -679,7 +680,7 @@ impl VllmCore { ...@@ -679,7 +680,7 @@ impl VllmCore {
} }
} }
fn request_sequence_len(requests: &HashMap<Uuid, VllmRequestState>, uuid: Uuid) -> usize { fn request_sequence_len(requests: &FxHashMap<Uuid, VllmRequestState>, uuid: Uuid) -> usize {
requests requests
.get(&uuid) .get(&uuid)
.map(|request| request.sequence.len()) .map(|request| request.sequence.len())
...@@ -716,7 +717,7 @@ fn debug_assert_vllm_request_progress(uuid: Uuid, request: &VllmRequestState) { ...@@ -716,7 +717,7 @@ fn debug_assert_vllm_request_progress(uuid: Uuid, request: &VllmRequestState) {
} }
} }
fn debug_assert_vllm_ready_to_decode(requests: &HashMap<Uuid, VllmRequestState>, uuid: Uuid) { fn debug_assert_vllm_ready_to_decode(requests: &FxHashMap<Uuid, VllmRequestState>, uuid: Uuid) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
let Some(request) = requests.get(&uuid) else { let Some(request) = requests.get(&uuid) else {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment