Unverified Commit c3908a36 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(mocker): add offline trace replay mode [DYN-2502] (#7543)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 6e56bad6
...@@ -8,29 +8,31 @@ use dynamo_tokens::{TokenBlockSequence, Tokens}; ...@@ -8,29 +8,31 @@ use dynamo_tokens::{TokenBlockSequence, Tokens};
use rand::random; use rand::random;
use validator::Validate; use validator::Validate;
/// Create unique blocks from a TokenBlockSequence /// Create unique blocks and block hashes from a TokenBlockSequence.
fn create_unique_blocks_from_sequence( fn create_sequence_cache(
tokens: &TokenBlockSequence, tokens: &TokenBlockSequence,
block_size: usize, block_size: usize,
enable_prefix_caching: bool, enable_prefix_caching: bool,
) -> Vec<UniqueBlock> { ) -> (Vec<UniqueBlock>, Vec<u64>) {
let mut unique_blocks: Vec<UniqueBlock> = tokens let mut unique_blocks = Vec::with_capacity(tokens.blocks().len() + 1);
.blocks() let mut block_hashes = Vec::with_capacity(tokens.blocks().len());
.iter()
.map(|block| { for block in tokens.blocks() {
block_hashes.push(block.block_hash());
unique_blocks.push({
if enable_prefix_caching { if enable_prefix_caching {
UniqueBlock::FullBlock(block.sequence_hash()) UniqueBlock::FullBlock(block.sequence_hash())
} else { } else {
UniqueBlock::FullBlock(random::<u64>()) UniqueBlock::FullBlock(random::<u64>())
} }
}) });
.collect(); }
// Only push the partial block if tokens count isn't a multiple of block_size // Only push the partial block if tokens count isn't a multiple of block_size
if !tokens.total_tokens().is_multiple_of(block_size) { if !tokens.total_tokens().is_multiple_of(block_size) {
unique_blocks.push(UniqueBlock::default()); unique_blocks.push(UniqueBlock::default());
} }
unique_blocks (unique_blocks, block_hashes)
} }
/// A sequence that is actively being built, with the ability to add tokens and commit to hashes /// A sequence that is actively being built, with the ability to add tokens and commit to hashes
...@@ -38,6 +40,7 @@ fn create_unique_blocks_from_sequence( ...@@ -38,6 +40,7 @@ fn create_unique_blocks_from_sequence(
#[derive(Debug, Getters, Validate)] #[derive(Debug, Getters, Validate)]
pub struct ActiveSequence { pub struct ActiveSequence {
unique_blocks: Vec<UniqueBlock>, unique_blocks: Vec<UniqueBlock>,
block_hashes: Vec<u64>,
tokens: TokenBlockSequence, tokens: TokenBlockSequence,
...@@ -77,11 +80,12 @@ impl ActiveSequence { ...@@ -77,11 +80,12 @@ impl ActiveSequence {
let num_input_tokens = tokens.len(); let num_input_tokens = tokens.len();
let tokens = Tokens::from(tokens).into_sequence(block_size as u32, Some(1337)); let tokens = Tokens::from(tokens).into_sequence(block_size as u32, Some(1337));
let unique_blocks = let (unique_blocks, block_hashes) =
create_unique_blocks_from_sequence(&tokens, block_size, enable_prefix_caching); create_sequence_cache(&tokens, block_size, enable_prefix_caching);
let seq = Self { let seq = Self {
unique_blocks, unique_blocks,
block_hashes,
tokens, tokens,
block_size, block_size,
max_output_tokens, max_output_tokens,
...@@ -125,11 +129,9 @@ impl ActiveSequence { ...@@ -125,11 +129,9 @@ impl ActiveSequence {
let range = prev_blocks..target_blocks; let range = prev_blocks..target_blocks;
let blocks = self.unique_blocks[range.clone()].to_vec(); let blocks = self.unique_blocks[range.clone()].to_vec();
let all_hashes = self.block_hashes(); let hash_start = prev_blocks.min(self.block_hashes.len());
let num_full = all_hashes.len(); let hash_end = target_blocks.min(self.block_hashes.len());
let hash_start = prev_blocks.min(num_full); let hashes = self.block_hashes[hash_start..hash_end].to_vec();
let hash_end = target_blocks.min(num_full);
let hashes = all_hashes[hash_start..hash_end].to_vec();
let token_ids = if self.emit_token_ids && hash_start < hash_end { let token_ids = if self.emit_token_ids && hash_start < hash_end {
let all_token_ids: Vec<Vec<u32>> = self let all_token_ids: Vec<Vec<u32>> = self
...@@ -168,14 +170,6 @@ impl ActiveSequence { ...@@ -168,14 +170,6 @@ impl ActiveSequence {
self.allocate_blocks_for_chunk(self.len()) self.allocate_blocks_for_chunk(self.len())
} }
pub fn block_hashes(&self) -> Vec<u64> {
self.tokens
.blocks()
.iter()
.map(|block| block.block_hash())
.collect()
}
/// Create a new ActiveSequence instance and return the creation signal /// Create a new ActiveSequence instance and return the creation signal
pub fn new_with_signal( pub fn new_with_signal(
tokens: Vec<u32>, tokens: Vec<u32>,
...@@ -221,6 +215,7 @@ impl ActiveSequence { ...@@ -221,6 +215,7 @@ impl ActiveSequence {
} else { } else {
None None
}; };
self.block_hashes.push(last_block_hash);
self.unique_blocks.pop(); self.unique_blocks.pop();
// After pop, the last element is the parent block // After pop, the last element is the parent block
...@@ -310,7 +305,13 @@ impl ActiveSequence { ...@@ -310,7 +305,13 @@ impl ActiveSequence {
free_signal free_signal
} }
/// Pops last token in the sequence. /// Pops the last token in the sequence.
///
/// This is only used to undo a freshly generated decode token after a failed
/// allocation/preemption path. Under that invariant, the token being removed
/// must be in the current partial block, so we only need to drop the trailing
/// partial `UniqueBlock` when the sequence length returns to an exact block
/// boundary. Using this to unwind arbitrary prompt history would be incorrect.
pub fn pop(&mut self) { pub fn pop(&mut self) {
self.tokens.pop(); self.tokens.pop();
self.generated_tokens = self.generated_tokens.saturating_sub(1); self.generated_tokens = self.generated_tokens.saturating_sub(1);
...@@ -326,157 +327,193 @@ impl ActiveSequence { ...@@ -326,157 +327,193 @@ impl ActiveSequence {
mod tests { mod tests {
use super::*; use super::*;
#[test] fn block_hashes_from_tokens(seq: &ActiveSequence) -> Vec<u64> {
fn test_active_sequence_push() { seq.tokens
// Create a sequence with block size 16 initialized with tokens [0..15] .blocks()
let initial_tokens: Vec<u32> = (0..15).collect(); .iter()
let (mut seq1, signal1) = .map(|block| block.block_hash())
ActiveSequence::new_with_signal(initial_tokens, 100, Some(16), true); .collect()
assert_eq!(seq1.num_input_tokens(), 15); }
assert_eq!(seq1.len(), 15);
// Check that we got a Use signal
assert!(signal1.is_some());
match &signal1 {
Some(MoveBlock::Use(blocks, _hashes, ..)) => {
assert_eq!(blocks.len(), 1);
}
_ => panic!("Expected Use signal"),
}
// Push token 15 which should complete the block (no signals yet) fn assert_cached_hashes_match_promoted_blocks(seq: &ActiveSequence) {
let signal_15 = seq1.push(15); let num_full_unique_blocks = seq
assert!( .unique_blocks()
signal_15.is_none(), .iter()
"Completing a block should not trigger signals" .filter(|block| matches!(block, UniqueBlock::FullBlock(_)))
.count();
assert_eq!(
seq.block_hashes().as_slice(),
&block_hashes_from_tokens(seq)[..num_full_unique_blocks],
"cached block hashes should match the promoted full blocks"
); );
}
// Push token 16 which should trigger both Promote and Use signals fn assert_use_signal(
let signal_16 = seq1.push(16); signal: &MoveBlock,
assert!(signal_16.is_some()); expected_blocks: &[UniqueBlock],
let signal_16 = signal_16.unwrap(); expected_hashes: &[u64],
assert_eq!(signal_16.len(), 2); ) {
match signal {
MoveBlock::Use(blocks, hashes, ..) => {
assert_eq!(blocks, expected_blocks);
assert_eq!(hashes, expected_hashes);
}
_ => panic!("Expected MoveBlock::Use"),
}
}
fn assert_single_partial_use(signal: &MoveBlock) {
match signal {
MoveBlock::Use(blocks, hashes, ..) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
assert!(hashes.is_empty());
}
_ => panic!("Expected MoveBlock::Use with a single partial block"),
}
}
// First signal should be Promote for the previous block fn assert_promote_parent(signal: &MoveBlock, expected_parent: Option<u64>) {
match &signal_16[0] { match signal {
MoveBlock::Promote(_, _, parent_hash, _hash, ..) => { MoveBlock::Promote(_, _, parent_hash, _hash, ..) => {
assert_eq!(*parent_hash, None); assert_eq!(*parent_hash, expected_parent);
} }
_ => panic!("Expected Promote signal as second signal"), _ => panic!("Expected MoveBlock::Promote"),
} }
}
// Second signal should be Use for new partial block fn assert_destroy_partial(signal: &MoveBlock) {
match &signal_16[1] { match signal {
MoveBlock::Use(blocks, _hashes, ..) => { MoveBlock::Destroy(blocks) => {
assert_eq!(blocks.len(), 1); assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_))); assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
} }
_ => panic!("Expected Use signal as first signal"), _ => panic!("Expected MoveBlock::Destroy for partial block"),
}
}
fn assert_deref_full(signal: &MoveBlock) {
match signal {
MoveBlock::Deref(blocks) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::FullBlock(_)));
}
_ => panic!("Expected MoveBlock::Deref for full block"),
} }
}
#[test]
fn test_new_with_signal_creates_initial_partial_block() {
let initial_tokens: Vec<u32> = (0..15).collect();
let (seq, signal) = ActiveSequence::new_with_signal(initial_tokens, 100, Some(16), true);
assert_eq!(seq.num_input_tokens(), 15);
assert_eq!(seq.len(), 15);
assert_single_partial_use(signal.as_ref().expect("Expected initial Use signal"));
}
#[test]
fn test_push_across_block_boundary_promotes_and_allocates_partial() {
let initial_tokens: Vec<u32> = (0..15).collect();
let (mut seq, _) = ActiveSequence::new_with_signal(initial_tokens, 100, Some(16), true);
let signal_15 = seq.push(15);
assert!(
signal_15.is_none(),
"Completing a block should not trigger signals"
);
let signal_16 = seq.push(16).expect("Expected boundary crossing signals");
assert_eq!(signal_16.len(), 2);
assert_promote_parent(&signal_16[0], None);
assert_single_partial_use(&signal_16[1]);
assert_eq!(
seq.unique_blocks().len(),
2,
"sequence should have one full block and one partial block"
);
assert_eq!(
seq.len() % seq.block_size(),
1,
"sequence should have one token in the new partial block"
);
}
// Verify state after pushing tokens #[test]
assert_eq!(seq1.unique_blocks().len(), 2); // One full block and one partial block fn test_equivalent_histories_preserve_full_block_identity() {
assert_eq!(seq1.len(), 17); let initial_tokens: Vec<u32> = (0..15).collect();
assert_eq!(seq1.len() % seq1.block_size(), 1); let (mut seq1, _) = ActiveSequence::new_with_signal(initial_tokens, 100, Some(16), true);
seq1.push(15);
seq1.push(16);
// Create another sequence with block size 16 initialized with tokens [0..17]
let extended_tokens: Vec<u32> = (0..16).collect(); let extended_tokens: Vec<u32> = (0..16).collect();
let (mut seq2, _) = ActiveSequence::new_with_signal(extended_tokens, 100, Some(16), true); let (mut seq2, _) = ActiveSequence::new_with_signal(extended_tokens, 100, Some(16), true);
seq2.push(16); seq2.push(16);
seq2.pop(); seq2.pop();
seq2.push(16); seq2.push(16);
// Simplified assertions assert_eq!(seq1.unique_blocks()[0], seq2.unique_blocks()[0]);
assert_eq!( assert_ne!(seq1.unique_blocks()[1], seq2.unique_blocks()[1]);
seq1.unique_blocks()[0], }
seq2.unique_blocks()[0],
"First blocks should be the same"
);
assert_ne!( #[test]
seq1.unique_blocks()[1], fn test_promote_uses_previous_full_block_as_parent() {
seq2.unique_blocks()[1], let initial_tokens: Vec<u32> = (0..15).collect();
"Second blocks should be different" let (mut seq, _) = ActiveSequence::new_with_signal(initial_tokens, 100, Some(16), true);
); seq.push(15);
seq.push(16);
// Reset partial block on seq1 and push back token 16 seq.push(17);
seq1.push(17); seq.pop();
seq1.pop(); seq.pop();
seq1.pop(); seq.push(16);
seq1.push(16);
// Now push tokens 17..32 to both sequences let extended_tokens: Vec<u32> = (0..16).collect();
let (mut seq_equiv, _) =
ActiveSequence::new_with_signal(extended_tokens, 100, Some(16), true);
seq_equiv.push(16);
seq_equiv.pop();
seq_equiv.push(16);
for token in 17..33 { for token in 17..33 {
seq1.push(token); seq.push(token);
seq2.push(token); seq_equiv.push(token);
} }
// Both sequences should now have 2 blocks:
// 1. FullBlock for tokens 0-15
// 2. FullBlock for tokens 16-31
// 3. No partial block since there are no remaining tokens
assert_eq!(
seq1.unique_blocks().len(),
3,
"seq1 should have exactly 3 blocks"
);
assert_eq!(
seq2.unique_blocks().len(),
3,
"seq2 should have exactly 3 blocks"
);
assert_eq!( assert_eq!(
seq1.len() % seq1.block_size(), &seq.unique_blocks()[0..2],
1, &seq_equiv.unique_blocks()[0..2],
"seq1 should have 1 partial token" "first two full blocks should remain identical"
);
assert_eq!(
seq2.len() % seq2.block_size(),
1,
"seq2 should have 1 partial token"
); );
// Verify that both sequences have identical blocks up to the second position
assert_eq!(
&seq1.unique_blocks()[0..2],
&seq2.unique_blocks()[0..2],
"First two blocks should be identical"
);
// Push tokens 34..47 to seq1
for token in 33..48 { for token in 33..48 {
seq1.push(token); seq.push(token);
} }
// Push token 48 and get the signal - this completes the block and triggers signals let signal = seq
let signal = seq1.push(48); .push(48)
let signal = signal.unwrap(); .expect("Expected promote when opening next partial");
// Check that signal[0] is promote let UniqueBlock::FullBlock(expected_hash) = seq.unique_blocks()[1] else {
match &signal[0] { panic!("unique_blocks[1] should be a full block");
MoveBlock::Promote(_, _, parent_hash, _hash, ..) => { };
// Check that the parent_hash matches unique_blocks[1], which should be a full block assert_promote_parent(&signal[0], Some(expected_hash));
if let UniqueBlock::FullBlock(expected_hash) = seq1.unique_blocks()[1] { assert_single_partial_use(&signal[1]);
assert_eq!( }
*parent_hash,
Some(expected_hash),
"Parent hash should match unique_blocks[1]"
);
} else {
panic!("unique_blocks[1] should be a full block");
}
}
_ => panic!("Expected Promote signal as first signal"),
}
// Reset seq1 and check that it equals the original clone #[test]
let free_signals = seq1.reset_with_signal(); fn test_reset_with_signal_frees_blocks_and_resets_allocation() {
let initial_tokens: Vec<u32> = (0..15).collect();
let (mut seq, _) = ActiveSequence::new_with_signal(initial_tokens, 100, Some(16), true);
seq.push(15);
seq.push(16);
seq.commit_allocation(seq.len());
// 49 - 15 generated tokens let free_signals = seq.reset_with_signal();
assert_eq!(seq1.generated_tokens(), 34);
// Verify the reset signals include proper cleanup events
assert!(!free_signals.is_empty()); assert!(!free_signals.is_empty());
assert_eq!(seq.num_allocated_tokens(), 0);
assert_eq!(seq.generated_tokens(), 2);
} }
#[test] #[test]
...@@ -486,14 +523,7 @@ mod tests { ...@@ -486,14 +523,7 @@ mod tests {
let (mut seq, signal) = ActiveSequence::new_with_signal(initial_tokens, 5, Some(16), true); let (mut seq, signal) = ActiveSequence::new_with_signal(initial_tokens, 5, Some(16), true);
// Initial signal - should have received a Use signal for the partial block // Initial signal - should have received a Use signal for the partial block
assert!(signal.is_some()); assert_single_partial_use(signal.as_ref().expect("Expected initial Use signal"));
match signal {
Some(MoveBlock::Use(blocks, _hashes, ..)) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
}
_ => panic!("Expected Use signal for the initial partial block"),
}
// Generate first two tokens - should not trigger new signals // Generate first two tokens - should not trigger new signals
seq.generate(); seq.generate();
...@@ -505,21 +535,10 @@ mod tests { ...@@ -505,21 +535,10 @@ mod tests {
assert_eq!(signals_second.len(), 2); assert_eq!(signals_second.len(), 2);
// First signal should be Promote // First signal should be Promote
match &signals_second[0] { assert_promote_parent(&signals_second[0], None);
MoveBlock::Promote(_, _, parent_hash, _hash, ..) => {
assert_eq!(*parent_hash, None);
}
_ => panic!("Expected Promote signal as first signal after second token"),
}
// Second signal should be Use for new partial block // Second signal should be Use for new partial block
match &signals_second[1] { assert_single_partial_use(&signals_second[1]);
MoveBlock::Use(blocks, _hashes, ..) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
}
_ => panic!("Expected Use signal as second signal after second token"),
}
// Generate fourth token - should not trigger new signals as it's adding to partial block // Generate fourth token - should not trigger new signals as it's adding to partial block
let signals_third = seq.generate(); let signals_third = seq.generate();
...@@ -530,21 +549,70 @@ mod tests { ...@@ -530,21 +549,70 @@ mod tests {
assert_eq!(signals_last.len(), 2); assert_eq!(signals_last.len(), 2);
// First signal should be Destroy for the partial block // First signal should be Destroy for the partial block
match &signals_last[0] { assert_destroy_partial(&signals_last[0]);
MoveBlock::Destroy(blocks) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
}
_ => panic!("Expected Destroy signal for partial block after fourth token"),
}
// Second signal should be Deref for the full block // Second signal should be Deref for the full block
match &signals_last[1] { assert_deref_full(&signals_last[1]);
MoveBlock::Deref(blocks) => { }
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::FullBlock(_))); #[test]
} fn test_prepare_allocation_slices_full_and_partial_blocks() {
_ => panic!("Expected Deref signal for full block after fourth token"), let tokens: Vec<u32> = (0..10).collect();
} let seq = ActiveSequence::new(tokens, 4, Some(4), true, false);
let first = seq.prepare_allocation(4).unwrap();
assert_use_signal(
&first,
&seq.unique_blocks()[0..1],
&seq.block_hashes()[0..1],
);
let second = seq.prepare_allocation(8).unwrap();
assert_use_signal(
&second,
&seq.unique_blocks()[0..2],
&seq.block_hashes()[0..2],
);
let third = seq.prepare_allocation(10).unwrap();
assert_use_signal(
&third,
&seq.unique_blocks()[0..3],
&seq.block_hashes()[0..2],
);
}
#[test]
fn test_prepare_allocation_is_stable_until_commit() {
let tokens: Vec<u32> = (0..10).collect();
let mut seq = ActiveSequence::new(tokens, 4, Some(4), true, false);
let first = seq.prepare_allocation(4).unwrap();
let second = seq.prepare_allocation(4).unwrap();
assert_eq!(first, second);
seq.commit_allocation(4);
let next = seq.prepare_allocation(8).unwrap();
assert_use_signal(&next, &seq.unique_blocks()[1..2], &seq.block_hashes()[1..2]);
}
#[test]
fn test_block_hash_cache_stays_in_sync_after_promote_and_pop() {
let initial_tokens: Vec<u32> = (0..15).collect();
let (mut seq, _) = ActiveSequence::new_with_signal(initial_tokens, 4, Some(16), true);
assert_cached_hashes_match_promoted_blocks(&seq);
seq.push(15);
assert_cached_hashes_match_promoted_blocks(&seq);
let promote_signals = seq.push(16).unwrap();
assert_eq!(promote_signals.len(), 2);
assert_cached_hashes_match_promoted_blocks(&seq);
// `pop()` is only valid for undoing a freshly generated token from the
// current partial block; this is the replay/preemption path we rely on.
seq.pop();
assert_cached_hashes_match_promoted_blocks(&seq);
} }
} }
...@@ -12,3 +12,4 @@ pub mod common; ...@@ -12,3 +12,4 @@ pub mod common;
pub mod engine; pub mod engine;
pub mod kv_manager; pub mod kv_manager;
pub mod scheduler; pub mod scheduler;
pub mod simulation;
...@@ -74,6 +74,7 @@ pub(crate) mod test_utils { ...@@ -74,6 +74,7 @@ pub(crate) mod test_utils {
max_output_tokens, max_output_tokens,
uuid: None, uuid: None,
dp_rank: 0, dp_rank: 0,
arrival_timestamp_ms: None,
}); });
} }
......
...@@ -705,6 +705,7 @@ mod tests { ...@@ -705,6 +705,7 @@ mod tests {
max_output_tokens: max_output, max_output_tokens: max_output,
uuid: None, uuid: None,
dp_rank: 0, dp_rank: 0,
arrival_timestamp_ms: None,
}); });
} }
...@@ -753,6 +754,7 @@ mod tests { ...@@ -753,6 +754,7 @@ mod tests {
max_output_tokens: 5, max_output_tokens: 5,
uuid: None, uuid: None,
dp_rank: 0, dp_rank: 0,
arrival_timestamp_ms: None,
}); });
} }
......
...@@ -36,6 +36,7 @@ use crate::common::running_mean::RunningMean; ...@@ -36,6 +36,7 @@ use crate::common::running_mean::RunningMean;
use crate::common::sequence::ActiveSequence; use crate::common::sequence::ActiveSequence;
use crate::common::utils::sleep_until_precise; use crate::common::utils::sleep_until_precise;
use crate::kv_manager::KvManager; use crate::kv_manager::KvManager;
use crate::simulation::{TraceCollector, TraceSimulationReport};
use dynamo_kv_router::protocols::DpRank; use dynamo_kv_router::protocols::DpRank;
use dynamo_tokens::blocks::UniqueBlock; use dynamo_tokens::blocks::UniqueBlock;
use std::collections::{HashMap, VecDeque}; use std::collections::{HashMap, VecDeque};
...@@ -83,13 +84,11 @@ impl SchedulerState { ...@@ -83,13 +84,11 @@ impl SchedulerState {
/// Try to admit one request from waiting → prefill. /// Try to admit one request from waiting → prefill.
/// Converts DirectRequest → ActiveSequence if needed. PrefillCost is computed /// Converts DirectRequest → ActiveSequence if needed. PrefillCost is computed
/// later in simulate_prefill when the request reaches the front of the queue. /// later in simulate_prefill when the request reaches the front of the queue.
fn admit_one(&mut self, args: &MockEngineArgs) -> bool { fn admit_one(&mut self, args: &MockEngineArgs) -> Option<Uuid> {
let Some(&uuid) = self.waiting.front() else { let &uuid = self.waiting.front()?;
return false;
};
let num_active = self.prefill.len() + self.decode.len(); let num_active = self.prefill.len() + self.decode.len();
if args.max_num_seqs.is_some_and(|limit| num_active >= limit) { if args.max_num_seqs.is_some_and(|limit| num_active >= limit) {
return false; return None;
} }
self.waiting.pop_front(); self.waiting.pop_front();
...@@ -112,7 +111,7 @@ impl SchedulerState { ...@@ -112,7 +111,7 @@ impl SchedulerState {
} }
self.prefill.push_back(uuid); self.prefill.push_back(uuid);
true Some(uuid)
} }
fn run(&mut self, uuid: Uuid) -> Option<&mut ActiveSequence> { fn run(&mut self, uuid: Uuid) -> Option<&mut ActiveSequence> {
...@@ -308,6 +307,49 @@ async fn simulate_prefill( ...@@ -308,6 +307,49 @@ async fn simulate_prefill(
args: &MockEngineArgs, args: &MockEngineArgs,
) -> Duration { ) -> Duration {
let start_time = Instant::now(); let start_time = Instant::now();
let total_time = simulate_prefill_step(state, kv_manager, hit_rates, args, None, 0.0, false);
if args.speedup_ratio > 0.0 && total_time > Duration::ZERO {
let sleep_duration = Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio);
let deadline = start_time + sleep_duration;
sleep_until_precise(deadline).await;
}
total_time
}
/// Simulate decode phase for all active decode requests.
/// Returns the total decode compute time.
async fn simulate_decode(
state: &mut SchedulerState,
kv_manager: &mut KvManager,
output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>,
args: &MockEngineArgs,
) -> Duration {
let start_time = Instant::now();
let total_time = simulate_decode_step(state, kv_manager, output_tx, args, None, 0.0, false);
let effective_ratio = args.speedup_ratio * args.decode_speedup_ratio;
if effective_ratio > 0.0 && total_time > Duration::ZERO {
let sleep_duration = Duration::from_secs_f64(total_time.as_secs_f64() / effective_ratio);
let deadline = start_time + sleep_duration;
sleep_until_precise(deadline).await;
}
total_time
}
fn simulate_prefill_step(
state: &mut SchedulerState,
kv_manager: &mut KvManager,
hit_rates: &mut RunningMean<f32>,
args: &MockEngineArgs,
mut collector: Option<&mut TraceCollector>,
current_time_ms: f64,
apply_speedup: bool,
) -> Duration {
let mut total_time = Duration::ZERO; let mut total_time = Duration::ZERO;
let mut token_budget = args let mut token_budget = args
...@@ -315,9 +357,19 @@ async fn simulate_prefill( ...@@ -315,9 +357,19 @@ async fn simulate_prefill(
.map_or(usize::MAX, |t| t.saturating_sub(state.decode.len())); .map_or(usize::MAX, |t| t.saturating_sub(state.decode.len()));
'prefill: while token_budget > 0 { 'prefill: while token_budget > 0 {
// Drain prefill first, then pull from waiting one at a time // Drain prefill first, then pull from waiting one at a time.
if state.prefill.is_empty() && !state.admit_one(args) { if state.prefill.is_empty() {
break; let Some(admitted_uuid) = state.admit_one(args) else {
break;
};
if let Some(collector) = collector.as_deref_mut() {
let Some(Request::Active(seq)) = state.requests.get(&admitted_uuid) else {
panic!("Request does not exist.");
};
let prefill_cost = kv_manager.get_prefill_cost(seq);
let reused_input_tokens = seq.len().saturating_sub(prefill_cost.new_tokens);
collector.on_admit(admitted_uuid, current_time_ms, reused_input_tokens);
}
} }
let uuid = state.prefill[0]; let uuid = state.prefill[0];
...@@ -329,7 +381,7 @@ async fn simulate_prefill( ...@@ -329,7 +381,7 @@ async fn simulate_prefill(
let allocated_tokens = seq.num_allocated_tokens(); let allocated_tokens = seq.num_allocated_tokens();
let remaining = prefill_cost.new_tokens; let remaining = prefill_cost.new_tokens;
// Token budget check // Token budget check.
let tokens_left = sequence_len - allocated_tokens; let tokens_left = sequence_len - allocated_tokens;
if !args.enable_chunked_prefill && tokens_left > token_budget { if !args.enable_chunked_prefill && tokens_left > token_budget {
break; break;
...@@ -338,7 +390,7 @@ async fn simulate_prefill( ...@@ -338,7 +390,7 @@ async fn simulate_prefill(
let cumulative = allocated_tokens + chunk; let cumulative = allocated_tokens + chunk;
// Allocate blocks. process() returns the number of blocks committed. // Allocate blocks. process() returns the number of blocks committed.
// On partial success, preempt a decode request and retry the next // On partial success, preempt a decode request and retry; the next
// loop iteration re-prepares from the updated num_allocated_tokens. // loop iteration re-prepares from the updated num_allocated_tokens.
let Some(Request::Active(seq)) = state.requests.get_mut(&uuid) else { let Some(Request::Active(seq)) = state.requests.get_mut(&uuid) else {
panic!("Request does not exist."); panic!("Request does not exist.");
...@@ -349,11 +401,11 @@ async fn simulate_prefill( ...@@ -349,11 +401,11 @@ async fn simulate_prefill(
_ => unreachable!(), _ => unreachable!(),
}; };
let allocated = kv_manager.process(&signal); let allocated = kv_manager.process(&signal);
// Commit the blocks that were actually allocated // Commit the blocks that were actually allocated.
let committed_tokens = if allocated == expected { let committed_tokens = if allocated == expected {
cumulative cumulative
} else { } else {
// Partial: compute token boundary from block count // Partial success: compute token boundary from block count.
let prev_blocks = allocated_tokens let prev_blocks = allocated_tokens
.div_ceil(seq.block_size()) .div_ceil(seq.block_size())
.min(seq.unique_blocks().len()); .min(seq.unique_blocks().len());
...@@ -368,13 +420,13 @@ async fn simulate_prefill( ...@@ -368,13 +420,13 @@ async fn simulate_prefill(
for signal in state.preempt(args.preemption_mode) { for signal in state.preempt(args.preemption_mode) {
kv_manager.process(&signal); kv_manager.process(&signal);
} }
continue 'prefill; // retry with freed capacity continue 'prefill; // Retry with freed capacity.
} }
} else { } else {
seq.commit_allocation(cumulative); seq.commit_allocation(cumulative);
} }
// Accumulate prefill compute time (only for the new tokens in this chunk) // Accumulate prefill compute time only for the new tokens in this chunk.
let new_tokens_in_chunk = chunk.min(remaining); let new_tokens_in_chunk = chunk.min(remaining);
if args.worker_type != WorkerType::Decode && new_tokens_in_chunk > 0 { if args.worker_type != WorkerType::Decode && new_tokens_in_chunk > 0 {
total_time += Duration::from_secs_f64( total_time += Duration::from_secs_f64(
...@@ -383,7 +435,7 @@ async fn simulate_prefill( ...@@ -383,7 +435,7 @@ async fn simulate_prefill(
); );
} }
// Hit rate: fraction of tokens that were already cached // Hit rate: fraction of tokens that were already cached.
let hit_rate = if sequence_len > 0 { let hit_rate = if sequence_len > 0 {
1.0 - (remaining as f32 / sequence_len as f32) 1.0 - (remaining as f32 / sequence_len as f32)
} else { } else {
...@@ -394,65 +446,68 @@ async fn simulate_prefill( ...@@ -394,65 +446,68 @@ async fn simulate_prefill(
token_budget -= chunk; token_budget -= chunk;
if cumulative >= sequence_len { if cumulative >= sequence_len {
// Fully prefilled promote to decode queue // Fully prefilled: promote to decode queue.
state.prefill.pop_front(); state.prefill.pop_front();
state.decode.push_back(uuid); state.decode.push_back(uuid);
} else { } else {
// Partially prefilled resume next iteration with updated allocated_tokens // Partially prefilled: resume next iteration with updated allocation state.
break; break;
} }
} }
if args.speedup_ratio > 0.0 && total_time > Duration::ZERO { if !apply_speedup || args.speedup_ratio <= 0.0 || total_time <= Duration::ZERO {
let sleep_duration = Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio); return total_time;
let deadline = start_time + sleep_duration;
sleep_until_precise(deadline).await;
} }
total_time Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio)
} }
/// Simulate decode phase for all active decode requests. fn simulate_decode_step(
/// Returns the total decode compute time.
async fn simulate_decode(
state: &mut SchedulerState, state: &mut SchedulerState,
kv_manager: &mut KvManager, kv_manager: &mut KvManager,
output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>, output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>,
args: &MockEngineArgs, args: &MockEngineArgs,
mut collector: Option<&mut TraceCollector>,
current_time_ms: f64,
apply_speedup: bool,
) -> Duration { ) -> Duration {
let start_time = Instant::now(); if state.decode.is_empty() {
return Duration::ZERO;
}
// Compute decode timing let decode_start_ms = current_time_ms;
let active_kv_tokens = kv_manager.num_active_blocks() * args.block_size;
// Compute average context length across all active decode requests let decode_lengths = state
let total_length: usize = state
.decode .decode
.iter() .iter()
.map(|uuid| { .filter_map(|uuid| match state.requests.get(uuid).unwrap() {
if let Request::Active(seq) = state.requests.get(uuid).unwrap() { Request::Active(seq) => Some(seq.len()),
seq.len() Request::Direct(_) => None,
} else {
0
}
}) })
.sum(); .collect::<Vec<_>>();
let count = state.decode.len(); if decode_lengths.is_empty() {
return Duration::ZERO;
}
let context_length = if count > 0 { total_length / count } else { 0 }; let active_kv_tokens = kv_manager.num_active_blocks() * args.block_size;
let total_length: usize = decode_lengths.iter().sum();
let context_length = total_length / decode_lengths.len();
let decoding_time = args let decoding_time = args
.perf_model .perf_model
.predict_decode_time(active_kv_tokens, context_length); .predict_decode_time(active_kv_tokens, context_length);
let total_time = Duration::from_secs_f64(decoding_time / 1000.0); let unscaled_time = Duration::from_secs_f64(decoding_time / 1000.0);
let effective_ratio = args.speedup_ratio * args.decode_speedup_ratio;
// Process decoding let total_time = if apply_speedup && effective_ratio > 0.0 && unscaled_time > Duration::ZERO {
Duration::from_secs_f64(unscaled_time.as_secs_f64() / effective_ratio)
} else {
unscaled_time
};
let decode_end_ms = decode_start_ms + total_time.as_secs_f64() * 1000.0;
// Process decoding.
let uuids: Vec<Uuid> = state.decode.iter().copied().collect(); let uuids: Vec<Uuid> = state.decode.iter().copied().collect();
let mut emitted_any = false;
for uuid in uuids { for uuid in uuids {
// Try to generate; if allocation fails, preempt until it succeeds
// or nothing is left to preempt (matches vLLM v1 scheduler loop).
// Reborrow sequence each iteration so the mutable ref doesn't
// conflict with state.preempt().
let mut allocated = false; let mut allocated = false;
loop { loop {
let Some(sequence) = state.run(uuid) else { let Some(sequence) = state.run(uuid) else {
...@@ -487,8 +542,12 @@ async fn simulate_decode( ...@@ -487,8 +542,12 @@ async fn simulate_decode(
let Some(sequence) = state.run(uuid) else { let Some(sequence) = state.run(uuid) else {
continue; continue;
}; };
emitted_any = true;
if let Some(collector) = collector.as_deref_mut() {
collector.on_token(uuid, decode_end_ms);
}
// Check completion and send notification // Check completion and send notification.
let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens(); let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens();
let send_failed = output_tx.as_ref().is_some_and(|tx| { let send_failed = output_tx.as_ref().is_some_and(|tx| {
...@@ -510,17 +569,199 @@ async fn simulate_decode( ...@@ -510,17 +569,199 @@ async fn simulate_decode(
} }
} }
let effective_ratio = args.speedup_ratio * args.decode_speedup_ratio; if !emitted_any {
if effective_ratio > 0.0 && total_time > Duration::ZERO { return Duration::ZERO;
let sleep_duration = Duration::from_secs_f64(total_time.as_secs_f64() / effective_ratio);
let deadline = start_time + sleep_duration;
sleep_until_precise(deadline).await;
} }
total_time total_time
} }
pub fn simulate_trace(
args: MockEngineArgs,
mut requests: Vec<DirectRequest>,
) -> anyhow::Result<TraceSimulationReport> {
args.validate()?;
requests.sort_by(|left, right| {
let left_ts = left
.arrival_timestamp_ms
.expect("trace replay requests must have an arrival timestamp");
let right_ts = right
.arrival_timestamp_ms
.expect("trace replay requests must have an arrival timestamp");
left_ts.total_cmp(&right_ts)
});
let first_arrival_ms = requests
.first()
.and_then(|request| request.arrival_timestamp_ms)
.ok_or_else(|| anyhow::anyhow!("trace replay requires at least one timestamped request"))?;
let mut pending = VecDeque::from(
requests
.into_iter()
.map(|mut request| {
let arrival_timestamp_ms = request
.arrival_timestamp_ms
.expect("trace replay requests must have an arrival timestamp")
- first_arrival_ms;
request.arrival_timestamp_ms = Some(arrival_timestamp_ms);
request
})
.collect::<Vec<_>>(),
);
let mut state = SchedulerState::default();
let mut kv_manager = KvManager::new(args.num_gpu_blocks, args.block_size);
let mut hit_rates = RunningMean::new(1000);
let mut collector = TraceCollector::default();
let output_tx: Option<mpsc::UnboundedSender<OutputSignal>> = None;
let mut current_time_ms = 0.0;
while !pending.is_empty() || !state.is_empty() {
enqueue_trace_arrivals(&mut pending, &mut state, &mut collector, current_time_ms);
if state.is_empty() {
let Some(next_arrival_ms) = pending
.front()
.and_then(|request| request.arrival_timestamp_ms)
else {
break;
};
current_time_ms = next_arrival_ms;
enqueue_trace_arrivals(&mut pending, &mut state, &mut collector, current_time_ms);
continue;
}
let prefill_time = simulate_prefill_step(
&mut state,
&mut kv_manager,
&mut hit_rates,
&args,
Some(&mut collector),
current_time_ms,
true,
);
current_time_ms += prefill_time.as_secs_f64() * 1000.0;
enqueue_trace_arrivals(&mut pending, &mut state, &mut collector, current_time_ms);
let decode_time = simulate_decode_step(
&mut state,
&mut kv_manager,
&output_tx,
&args,
Some(&mut collector),
current_time_ms,
true,
);
current_time_ms += decode_time.as_secs_f64() * 1000.0;
}
Ok(collector.finish())
}
pub fn simulate_concurrency(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
) -> anyhow::Result<TraceSimulationReport> {
args.validate()?;
let mut pending = VecDeque::from(requests);
let mut state = SchedulerState::default();
let mut kv_manager = KvManager::new(args.num_gpu_blocks, args.block_size);
let mut hit_rates = RunningMean::new(1000);
let mut collector = TraceCollector::default();
let output_tx: Option<mpsc::UnboundedSender<OutputSignal>> = None;
let mut current_time_ms = 0.0;
while !pending.is_empty() || !state.is_empty() {
enqueue_concurrency_arrivals(
&mut pending,
&mut state,
&mut collector,
current_time_ms,
max_in_flight,
);
if state.is_empty() {
break;
}
let prefill_time = simulate_prefill_step(
&mut state,
&mut kv_manager,
&mut hit_rates,
&args,
Some(&mut collector),
current_time_ms,
true,
);
current_time_ms += prefill_time.as_secs_f64() * 1000.0;
let decode_time = simulate_decode_step(
&mut state,
&mut kv_manager,
&output_tx,
&args,
Some(&mut collector),
current_time_ms,
true,
);
current_time_ms += decode_time.as_secs_f64() * 1000.0;
}
Ok(collector.finish())
}
fn enqueue_trace_arrivals(
pending: &mut VecDeque<DirectRequest>,
state: &mut SchedulerState,
collector: &mut TraceCollector,
current_time_ms: f64,
) {
loop {
let Some(next_arrival_ms) = pending
.front()
.and_then(|request| request.arrival_timestamp_ms)
else {
break;
};
if next_arrival_ms > current_time_ms {
break;
}
let request = pending
.pop_front()
.expect("front request must exist when arrival is available");
let arrival_ms = request
.arrival_timestamp_ms
.expect("trace replay requests must have an arrival timestamp");
let input_length = request.tokens.len();
let output_length = request.max_output_tokens;
let uuid = state.receive(request);
collector.on_arrival(uuid, arrival_ms, input_length, output_length);
}
}
fn enqueue_concurrency_arrivals(
pending: &mut VecDeque<DirectRequest>,
state: &mut SchedulerState,
collector: &mut TraceCollector,
current_time_ms: f64,
max_in_flight: usize,
) {
while state.requests.len() < max_in_flight {
let Some(mut request) = pending.pop_front() else {
break;
};
request.arrival_timestamp_ms = Some(current_time_ms);
let input_length = request.tokens.len();
let output_length = request.max_output_tokens;
let uuid = state.receive(request);
collector.on_arrival(uuid, current_time_ms, input_length, output_length);
}
}
/// Processes MoveBlock signals with the KvManager. /// Processes MoveBlock signals with the KvManager.
/// ///
/// When a signal fails, this function verifies that the failure is for an expected case: /// When a signal fails, this function verifies that the failure is for an expected case:
...@@ -565,7 +806,9 @@ fn process_signals(kv_manager: &mut KvManager, signals: &[MoveBlock]) -> bool { ...@@ -565,7 +806,9 @@ fn process_signals(kv_manager: &mut KvManager, signals: &[MoveBlock]) -> bool {
mod tests { mod tests {
use super::*; use super::*;
use crate::scheduler::SchedulerHandle; use crate::scheduler::SchedulerHandle;
use crate::simulation::{TraceCollector, TraceRequestStatsSnapshot};
use rstest::rstest; use rstest::rstest;
use std::collections::HashMap;
use std::time::Duration; use std::time::Duration;
use tokio::time::interval; use tokio::time::interval;
...@@ -649,6 +892,7 @@ mod tests { ...@@ -649,6 +892,7 @@ mod tests {
max_output_tokens, max_output_tokens,
uuid: None, uuid: None,
dp_rank: 0, dp_rank: 0,
arrival_timestamp_ms: None,
}; };
scheduler.receive(request); scheduler.receive(request);
// Sleep for 0.1 second after each request // Sleep for 0.1 second after each request
...@@ -735,18 +979,21 @@ mod tests { ...@@ -735,18 +979,21 @@ mod tests {
max_output_tokens: 2, max_output_tokens: 2,
uuid: Some(r1_uuid), uuid: Some(r1_uuid),
dp_rank: 0, dp_rank: 0,
arrival_timestamp_ms: None,
}); });
state.receive(DirectRequest { state.receive(DirectRequest {
tokens: (100..108).collect(), tokens: (100..108).collect(),
max_output_tokens: 2, max_output_tokens: 2,
uuid: Some(r2_uuid), uuid: Some(r2_uuid),
dp_rank: 0, dp_rank: 0,
arrival_timestamp_ms: None,
}); });
state.receive(DirectRequest { state.receive(DirectRequest {
tokens: (200..212).collect(), tokens: (200..212).collect(),
max_output_tokens: 2, max_output_tokens: 2,
uuid: Some(r3_uuid), uuid: Some(r3_uuid),
dp_rank: 0, dp_rank: 0,
arrival_timestamp_ms: None,
}); });
assert_eq!(state.waiting.len(), 3); assert_eq!(state.waiting.len(), 3);
...@@ -881,6 +1128,7 @@ mod tests { ...@@ -881,6 +1128,7 @@ mod tests {
max_output_tokens, max_output_tokens,
uuid: None, uuid: None,
dp_rank: 0, dp_rank: 0,
arrival_timestamp_ms: None,
}; };
scheduler.receive(request); scheduler.receive(request);
...@@ -907,4 +1155,466 @@ mod tests { ...@@ -907,4 +1155,466 @@ mod tests {
assert_scheduler_idle(&metrics); assert_scheduler_idle(&metrics);
} }
#[derive(Debug)]
struct ManualReplayResult {
report: TraceSimulationReport,
snapshots: HashMap<Uuid, TraceRequestStatsSnapshot>,
idle_jump_ms: f64,
first_decode_end_ms: f64,
}
#[derive(Debug)]
struct ManualConcurrencyResult {
report: TraceSimulationReport,
snapshots: HashMap<Uuid, TraceRequestStatsSnapshot>,
}
fn replay_args(enable_prefix_caching: bool, enable_chunked_prefill: bool) -> MockEngineArgs {
MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(32)
.max_num_batched_tokens(Some(8))
.max_num_seqs(Some(2))
.enable_prefix_caching(enable_prefix_caching)
.enable_chunked_prefill(enable_chunked_prefill)
.speedup_ratio(0.0)
.build()
.unwrap()
}
fn replay_fixture() -> Vec<DirectRequest> {
vec![
DirectRequest {
tokens: vec![1, 1, 1, 1, 2, 2, 2, 2],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(11)),
dp_rank: 0,
arrival_timestamp_ms: Some(100.0),
},
DirectRequest {
tokens: vec![1, 1, 1, 1, 2, 2, 2, 2],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(22)),
dp_rank: 0,
arrival_timestamp_ms: Some(101.0),
},
DirectRequest {
tokens: vec![9, 9, 9, 9, 8, 8, 8, 8],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(33)),
dp_rank: 0,
arrival_timestamp_ms: Some(500.0),
},
]
}
fn run_trace_manually(
args: &MockEngineArgs,
requests: Vec<DirectRequest>,
) -> ManualReplayResult {
let mut requests = requests;
requests.sort_by(|left, right| {
let left_ts = left.arrival_timestamp_ms.unwrap();
let right_ts = right.arrival_timestamp_ms.unwrap();
left_ts.total_cmp(&right_ts)
});
let first_arrival_ms = requests.first().unwrap().arrival_timestamp_ms.unwrap();
let mut pending = VecDeque::from(
requests
.into_iter()
.map(|mut request| {
request.arrival_timestamp_ms =
Some(request.arrival_timestamp_ms.unwrap() - first_arrival_ms);
request
})
.collect::<Vec<_>>(),
);
let mut state = SchedulerState::default();
let mut kv_manager = KvManager::new(args.num_gpu_blocks, args.block_size);
let mut hit_rates = RunningMean::new(1000);
let mut collector = TraceCollector::default();
let output_tx: Option<mpsc::UnboundedSender<OutputSignal>> = None;
let mut current_time_ms = 0.0;
let mut idle_jump_ms = 0.0;
let mut first_decode_end_ms = 0.0;
while !pending.is_empty() || !state.is_empty() {
enqueue_trace_arrivals(&mut pending, &mut state, &mut collector, current_time_ms);
if state.is_empty() {
let next_arrival_ms = pending.front().unwrap().arrival_timestamp_ms.unwrap();
current_time_ms = next_arrival_ms;
if idle_jump_ms == 0.0 && current_time_ms > 0.0 {
idle_jump_ms = current_time_ms;
}
enqueue_trace_arrivals(&mut pending, &mut state, &mut collector, current_time_ms);
continue;
}
let prefill_time = simulate_prefill_step(
&mut state,
&mut kv_manager,
&mut hit_rates,
args,
Some(&mut collector),
current_time_ms,
true,
);
current_time_ms += prefill_time.as_secs_f64() * 1000.0;
enqueue_trace_arrivals(&mut pending, &mut state, &mut collector, current_time_ms);
let decode_time = simulate_decode_step(
&mut state,
&mut kv_manager,
&output_tx,
args,
Some(&mut collector),
current_time_ms,
true,
);
if first_decode_end_ms == 0.0 && decode_time > Duration::ZERO {
first_decode_end_ms = current_time_ms + decode_time.as_secs_f64() * 1000.0;
}
current_time_ms += decode_time.as_secs_f64() * 1000.0;
}
let snapshots = [
Uuid::from_u128(11),
Uuid::from_u128(22),
Uuid::from_u128(33),
]
.into_iter()
.map(|uuid| (uuid, collector.snapshot(uuid).unwrap()))
.collect();
ManualReplayResult {
report: collector.finish(),
snapshots,
idle_jump_ms,
first_decode_end_ms,
}
}
fn run_concurrency_manually(
args: &MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
) -> ManualConcurrencyResult {
let mut pending = VecDeque::from(requests);
let mut state = SchedulerState::default();
let mut kv_manager = KvManager::new(args.num_gpu_blocks, args.block_size);
let mut hit_rates = RunningMean::new(1000);
let mut collector = TraceCollector::default();
let output_tx: Option<mpsc::UnboundedSender<OutputSignal>> = None;
let mut current_time_ms = 0.0;
while !pending.is_empty() || !state.is_empty() {
enqueue_concurrency_arrivals(
&mut pending,
&mut state,
&mut collector,
current_time_ms,
max_in_flight,
);
if state.is_empty() {
break;
}
let prefill_time = simulate_prefill_step(
&mut state,
&mut kv_manager,
&mut hit_rates,
args,
Some(&mut collector),
current_time_ms,
true,
);
current_time_ms += prefill_time.as_secs_f64() * 1000.0;
let decode_time = simulate_decode_step(
&mut state,
&mut kv_manager,
&output_tx,
args,
Some(&mut collector),
current_time_ms,
true,
);
current_time_ms += decode_time.as_secs_f64() * 1000.0;
}
let snapshots = [
Uuid::from_u128(11),
Uuid::from_u128(22),
Uuid::from_u128(33),
]
.into_iter()
.map(|uuid| (uuid, collector.snapshot(uuid).unwrap()))
.collect();
ManualConcurrencyResult {
report: collector.finish(),
snapshots,
}
}
fn assert_report_close(left: &TraceSimulationReport, right: &TraceSimulationReport) {
let epsilon = 1e-9;
assert_eq!(
left.request_counts.num_requests,
right.request_counts.num_requests
);
assert_eq!(
left.request_counts.completed_requests,
right.request_counts.completed_requests
);
assert_eq!(
left.request_counts.total_input_tokens,
right.request_counts.total_input_tokens
);
assert_eq!(
left.request_counts.total_output_tokens,
right.request_counts.total_output_tokens
);
assert!((left.throughput.duration_ms - right.throughput.duration_ms).abs() <= epsilon);
assert!(
(left.throughput.request_throughput_rps - right.throughput.request_throughput_rps)
.abs()
<= epsilon
);
assert!(
(left.throughput.input_throughput_tok_s - right.throughput.input_throughput_tok_s)
.abs()
<= epsilon
);
assert!(
(left.throughput.output_throughput_tok_s - right.throughput.output_throughput_tok_s)
.abs()
<= epsilon
);
assert!(
(left.throughput.total_throughput_tok_s - right.throughput.total_throughput_tok_s)
.abs()
<= epsilon
);
assert!(
(left.prefix_cache_reused_ratio - right.prefix_cache_reused_ratio).abs() <= epsilon
);
assert!((left.latency.ttft.mean_ms - right.latency.ttft.mean_ms).abs() <= epsilon);
assert!((left.latency.ttft.min_ms - right.latency.ttft.min_ms).abs() <= epsilon);
assert!((left.latency.ttft.max_ms - right.latency.ttft.max_ms).abs() <= epsilon);
assert!((left.latency.ttft.median_ms - right.latency.ttft.median_ms).abs() <= epsilon);
assert!((left.latency.ttft.p75_ms - right.latency.ttft.p75_ms).abs() <= epsilon);
assert!((left.latency.ttft.p90_ms - right.latency.ttft.p90_ms).abs() <= epsilon);
assert!((left.latency.ttft.p95_ms - right.latency.ttft.p95_ms).abs() <= epsilon);
assert!((left.latency.ttft.p99_ms - right.latency.ttft.p99_ms).abs() <= epsilon);
assert!((left.latency.ttft.std_ms - right.latency.ttft.std_ms).abs() <= epsilon);
assert!((left.latency.ttst.mean_ms - right.latency.ttst.mean_ms).abs() <= epsilon);
assert!((left.latency.ttst.min_ms - right.latency.ttst.min_ms).abs() <= epsilon);
assert!((left.latency.ttst.max_ms - right.latency.ttst.max_ms).abs() <= epsilon);
assert!((left.latency.ttst.median_ms - right.latency.ttst.median_ms).abs() <= epsilon);
assert!((left.latency.ttst.p75_ms - right.latency.ttst.p75_ms).abs() <= epsilon);
assert!((left.latency.ttst.p90_ms - right.latency.ttst.p90_ms).abs() <= epsilon);
assert!((left.latency.ttst.p95_ms - right.latency.ttst.p95_ms).abs() <= epsilon);
assert!((left.latency.ttst.p99_ms - right.latency.ttst.p99_ms).abs() <= epsilon);
assert!((left.latency.ttst.std_ms - right.latency.ttst.std_ms).abs() <= epsilon);
assert!((left.latency.tpot.mean_ms - right.latency.tpot.mean_ms).abs() <= epsilon);
assert!((left.latency.tpot.min_ms - right.latency.tpot.min_ms).abs() <= epsilon);
assert!((left.latency.tpot.max_ms - right.latency.tpot.max_ms).abs() <= epsilon);
assert!((left.latency.tpot.median_ms - right.latency.tpot.median_ms).abs() <= epsilon);
assert!((left.latency.tpot.p75_ms - right.latency.tpot.p75_ms).abs() <= epsilon);
assert!((left.latency.tpot.p90_ms - right.latency.tpot.p90_ms).abs() <= epsilon);
assert!((left.latency.tpot.p95_ms - right.latency.tpot.p95_ms).abs() <= epsilon);
assert!((left.latency.tpot.p99_ms - right.latency.tpot.p99_ms).abs() <= epsilon);
assert!((left.latency.tpot.std_ms - right.latency.tpot.std_ms).abs() <= epsilon);
assert!(
(left.latency.itl.distribution.mean_ms - right.latency.itl.distribution.mean_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.min_ms - right.latency.itl.distribution.min_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.max_ms - right.latency.itl.distribution.max_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.median_ms - right.latency.itl.distribution.median_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.p75_ms - right.latency.itl.distribution.p75_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.p90_ms - right.latency.itl.distribution.p90_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.p95_ms - right.latency.itl.distribution.p95_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.p99_ms - right.latency.itl.distribution.p99_ms).abs()
<= epsilon
);
assert!(
(left.latency.itl.distribution.std_ms - right.latency.itl.distribution.std_ms).abs()
<= epsilon
);
assert!((left.latency.itl.max_ms - right.latency.itl.max_ms).abs() <= epsilon);
assert!((left.latency.e2e.mean_ms - right.latency.e2e.mean_ms).abs() <= epsilon);
assert!((left.latency.e2e.min_ms - right.latency.e2e.min_ms).abs() <= epsilon);
assert!((left.latency.e2e.max_ms - right.latency.e2e.max_ms).abs() <= epsilon);
assert!((left.latency.e2e.median_ms - right.latency.e2e.median_ms).abs() <= epsilon);
assert!((left.latency.e2e.p75_ms - right.latency.e2e.p75_ms).abs() <= epsilon);
assert!((left.latency.e2e.p90_ms - right.latency.e2e.p90_ms).abs() <= epsilon);
assert!((left.latency.e2e.p95_ms - right.latency.e2e.p95_ms).abs() <= epsilon);
assert!((left.latency.e2e.p99_ms - right.latency.e2e.p99_ms).abs() <= epsilon);
assert!((left.latency.e2e.std_ms - right.latency.e2e.std_ms).abs() <= epsilon);
assert!(
(left.latency.output_token_throughput_per_user.mean_ms
- right.latency.output_token_throughput_per_user.mean_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.min_ms
- right.latency.output_token_throughput_per_user.min_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.max_ms
- right.latency.output_token_throughput_per_user.max_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.median_ms
- right.latency.output_token_throughput_per_user.median_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.p75_ms
- right.latency.output_token_throughput_per_user.p75_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.p90_ms
- right.latency.output_token_throughput_per_user.p90_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.p95_ms
- right.latency.output_token_throughput_per_user.p95_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.p99_ms
- right.latency.output_token_throughput_per_user.p99_ms)
.abs()
<= epsilon
);
assert!(
(left.latency.output_token_throughput_per_user.std_ms
- right.latency.output_token_throughput_per_user.std_ms)
.abs()
<= epsilon
);
}
#[rstest]
#[case(false, false)]
#[case(false, true)]
#[case(true, false)]
#[case(true, true)]
fn test_trace_replay_matches_manual_steps(
#[case] enable_prefix_caching: bool,
#[case] enable_chunked_prefill: bool,
) {
let args = replay_args(enable_prefix_caching, enable_chunked_prefill);
let manual = run_trace_manually(&args, replay_fixture());
let replay_report = simulate_trace(args, replay_fixture()).unwrap();
let request_1 = manual.snapshots.get(&Uuid::from_u128(11)).unwrap();
let request_2 = manual.snapshots.get(&Uuid::from_u128(22)).unwrap();
let request_3 = manual.snapshots.get(&Uuid::from_u128(33)).unwrap();
assert_eq!(request_1.arrival_time_ms, 0.0);
assert_eq!(request_2.arrival_time_ms, 1.0);
assert_eq!(request_3.arrival_time_ms, 400.0);
assert_eq!(manual.idle_jump_ms, 400.0);
assert_eq!(
request_1.first_token_ms.unwrap(),
manual.first_decode_end_ms,
);
assert!(request_2.first_admit_ms.unwrap() >= request_2.arrival_time_ms);
assert!(request_3.first_admit_ms.unwrap() >= request_3.arrival_time_ms);
assert!(manual.report.latency.e2e.mean_ms >= manual.report.latency.ttft.mean_ms);
if enable_prefix_caching {
assert!(request_2.reused_input_tokens > 0);
assert!(manual.report.prefix_cache_reused_ratio > 0.0);
} else {
assert_eq!(request_2.reused_input_tokens, 0);
assert_eq!(manual.report.prefix_cache_reused_ratio, 0.0);
}
assert_report_close(&replay_report, &manual.report);
}
#[test]
fn test_concurrency_replay_matches_manual_steps() {
let args = replay_args(false, false);
let requests = vec![
DirectRequest {
tokens: vec![1, 2, 3, 4, 5, 6, 7, 8],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(11)),
dp_rank: 0,
arrival_timestamp_ms: Some(900.0),
},
DirectRequest {
tokens: vec![1, 2, 3, 4, 5, 9, 10, 11],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(22)),
dp_rank: 0,
arrival_timestamp_ms: Some(1000.0),
},
DirectRequest {
tokens: vec![12, 13, 14, 15, 16, 17, 18, 19],
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(33)),
dp_rank: 0,
arrival_timestamp_ms: Some(100.0),
},
];
let manual = run_concurrency_manually(&args, requests.clone(), 2);
let replay_report = simulate_concurrency(args, requests, 2).unwrap();
let request_1 = manual.snapshots.get(&Uuid::from_u128(11)).unwrap();
let request_2 = manual.snapshots.get(&Uuid::from_u128(22)).unwrap();
let request_3 = manual.snapshots.get(&Uuid::from_u128(33)).unwrap();
assert_eq!(request_1.arrival_time_ms, 0.0);
assert_eq!(request_2.arrival_time_ms, 0.0);
assert_eq!(request_3.arrival_time_ms, request_1.last_token_ms.unwrap());
assert!(request_3.arrival_time_ms < request_2.last_token_ms.unwrap());
assert_eq!(manual.report.request_counts.completed_requests, 3);
assert_eq!(manual.report.request_counts.total_input_tokens, 24);
assert_eq!(manual.report.request_counts.total_output_tokens, 6);
assert_report_close(&replay_report, &manual.report);
}
} }
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::Path;
use std::time::Instant;
use anyhow::{Context, Result, anyhow, bail};
use serde::ser::{SerializeMap, Serializer};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::common::protocols::{DirectRequest, EngineType, MockEngineArgs, WorkerType};
#[derive(Debug, Clone)]
pub struct TraceSimulationReport {
pub request_counts: TraceRequestCounts,
pub throughput: TraceThroughputStats,
pub prefix_cache_reused_ratio: f64,
pub latency: TraceLatencyStats,
}
#[derive(Debug, Clone)]
pub struct TraceRequestCounts {
pub num_requests: usize,
pub completed_requests: usize,
pub total_input_tokens: usize,
pub total_output_tokens: usize,
}
#[derive(Debug, Clone)]
pub struct TraceThroughputStats {
pub duration_ms: f64,
pub wall_time_ms: f64,
pub request_throughput_rps: f64,
pub input_throughput_tok_s: f64,
pub output_throughput_tok_s: f64,
pub total_throughput_tok_s: f64,
}
#[derive(Debug, Clone)]
pub struct TraceDistributionStats {
pub mean_ms: f64,
pub min_ms: f64,
pub max_ms: f64,
pub median_ms: f64,
pub p75_ms: f64,
pub p90_ms: f64,
pub p95_ms: f64,
pub p99_ms: f64,
pub std_ms: f64,
}
#[derive(Debug, Clone)]
pub struct TraceLatencyStats {
pub ttft: TraceDistributionStats,
pub ttst: TraceDistributionStats,
pub tpot: TraceDistributionStats,
pub itl: TraceInterTokenLatencyStats,
pub e2e: TraceDistributionStats,
pub output_token_throughput_per_user: TraceDistributionStats,
}
#[derive(Debug, Clone)]
pub struct TraceInterTokenLatencyStats {
pub distribution: TraceDistributionStats,
pub max_ms: f64,
}
impl TraceSimulationReport {
pub fn with_wall_time_ms(mut self, wall_time_ms: f64) -> Self {
self.throughput.wall_time_ms = wall_time_ms;
self
}
}
impl Serialize for TraceSimulationReport {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut map = serializer.serialize_map(Some(59))?;
map.serialize_entry("num_requests", &self.request_counts.num_requests)?;
map.serialize_entry(
"completed_requests",
&self.request_counts.completed_requests,
)?;
map.serialize_entry(
"total_input_tokens",
&self.request_counts.total_input_tokens,
)?;
map.serialize_entry(
"total_output_tokens",
&self.request_counts.total_output_tokens,
)?;
map.serialize_entry("duration_ms", &self.throughput.duration_ms)?;
map.serialize_entry("wall_time_ms", &self.throughput.wall_time_ms)?;
map.serialize_entry(
"request_throughput_rps",
&self.throughput.request_throughput_rps,
)?;
map.serialize_entry(
"input_throughput_tok_s",
&self.throughput.input_throughput_tok_s,
)?;
map.serialize_entry(
"output_throughput_tok_s",
&self.throughput.output_throughput_tok_s,
)?;
map.serialize_entry(
"total_throughput_tok_s",
&self.throughput.total_throughput_tok_s,
)?;
map.serialize_entry("prefix_cache_reused_ratio", &self.prefix_cache_reused_ratio)?;
serialize_distribution(&mut map, "ttft", &self.latency.ttft)?;
serialize_distribution(&mut map, "ttst", &self.latency.ttst)?;
serialize_distribution(&mut map, "tpot", &self.latency.tpot)?;
serialize_distribution(&mut map, "itl", &self.latency.itl.distribution)?;
map.serialize_entry("max_itl_ms", &self.latency.itl.max_ms)?;
serialize_distribution(&mut map, "e2e_latency", &self.latency.e2e)?;
serialize_rate_distribution(
&mut map,
"output_token_throughput_per_user",
&self.latency.output_token_throughput_per_user,
)?;
map.end()
}
}
fn serialize_distribution<S>(
map: &mut S,
prefix: &str,
stats: &TraceDistributionStats,
) -> Result<(), S::Error>
where
S: SerializeMap,
{
map.serialize_entry(&format!("mean_{prefix}_ms"), &stats.mean_ms)?;
map.serialize_entry(&format!("min_{prefix}_ms"), &stats.min_ms)?;
map.serialize_entry(&format!("max_{prefix}_ms"), &stats.max_ms)?;
map.serialize_entry(&format!("median_{prefix}_ms"), &stats.median_ms)?;
map.serialize_entry(&format!("p75_{prefix}_ms"), &stats.p75_ms)?;
map.serialize_entry(&format!("p90_{prefix}_ms"), &stats.p90_ms)?;
map.serialize_entry(&format!("p95_{prefix}_ms"), &stats.p95_ms)?;
map.serialize_entry(&format!("p99_{prefix}_ms"), &stats.p99_ms)?;
map.serialize_entry(&format!("std_{prefix}_ms"), &stats.std_ms)?;
Ok(())
}
fn serialize_rate_distribution<S>(
map: &mut S,
prefix: &str,
stats: &TraceDistributionStats,
) -> Result<(), S::Error>
where
S: SerializeMap,
{
map.serialize_entry(&format!("mean_{prefix}"), &stats.mean_ms)?;
map.serialize_entry(&format!("min_{prefix}"), &stats.min_ms)?;
map.serialize_entry(&format!("max_{prefix}"), &stats.max_ms)?;
map.serialize_entry(&format!("median_{prefix}"), &stats.median_ms)?;
map.serialize_entry(&format!("p75_{prefix}"), &stats.p75_ms)?;
map.serialize_entry(&format!("p90_{prefix}"), &stats.p90_ms)?;
map.serialize_entry(&format!("p95_{prefix}"), &stats.p95_ms)?;
map.serialize_entry(&format!("p99_{prefix}"), &stats.p99_ms)?;
map.serialize_entry(&format!("std_{prefix}"), &stats.std_ms)?;
Ok(())
}
#[derive(Debug)]
struct TraceRequestStats {
arrival_time_ms: f64,
first_admit_ms: Option<f64>,
token_times_ms: Vec<f64>,
input_length: usize,
output_length: usize,
reused_input_tokens: usize,
}
#[cfg(test)]
#[derive(Debug, Clone, PartialEq)]
pub(crate) struct TraceRequestStatsSnapshot {
pub arrival_time_ms: f64,
pub first_admit_ms: Option<f64>,
pub first_token_ms: Option<f64>,
pub last_token_ms: Option<f64>,
pub input_length: usize,
pub output_length: usize,
pub reused_input_tokens: usize,
}
#[derive(Debug, Default)]
pub(crate) struct TraceCollector {
requests: HashMap<Uuid, TraceRequestStats>,
}
impl TraceRequestStats {
fn first_token_ms(&self) -> Option<f64> {
self.token_times_ms.first().copied()
}
fn last_token_ms(&self) -> Option<f64> {
self.token_times_ms.last().copied()
}
fn mean_tpot_ms(&self) -> Option<f64> {
let num_gaps = self.token_times_ms.len().saturating_sub(1);
if num_gaps == 0 {
return None;
}
let first_token_ms = self.first_token_ms()?;
let last_token_ms = self.last_token_ms()?;
Some((last_token_ms - first_token_ms).max(0.0) / num_gaps as f64)
}
fn itls_ms(&self) -> impl Iterator<Item = f64> + '_ {
self.token_times_ms
.windows(2)
.map(|window| (window[1] - window[0]).max(0.0))
}
fn ttst_ms(&self) -> Option<f64> {
let [first_token_ms, second_token_ms, ..] = self.token_times_ms.as_slice() else {
return None;
};
Some((second_token_ms - first_token_ms).max(0.0))
}
}
impl TraceCollector {
pub(crate) fn on_arrival(
&mut self,
uuid: Uuid,
arrival_time_ms: f64,
input_length: usize,
output_length: usize,
) {
self.requests.insert(
uuid,
TraceRequestStats {
arrival_time_ms,
first_admit_ms: None,
token_times_ms: Vec::with_capacity(output_length),
input_length,
output_length,
reused_input_tokens: 0,
},
);
}
pub(crate) fn on_admit(&mut self, uuid: Uuid, admit_time_ms: f64, reused_input_tokens: usize) {
if let Some(stats) = self.requests.get_mut(&uuid) {
stats.first_admit_ms.get_or_insert(admit_time_ms);
stats.reused_input_tokens = stats.reused_input_tokens.max(reused_input_tokens);
}
}
pub(crate) fn on_token(&mut self, uuid: Uuid, token_time_ms: f64) {
if let Some(stats) = self.requests.get_mut(&uuid) {
stats.token_times_ms.push(token_time_ms);
}
}
pub(crate) fn finish(self) -> TraceSimulationReport {
let requests = self.requests;
let mut ttfts = Vec::new();
let mut ttsts = Vec::new();
let mut tpots = Vec::new();
let mut itls = Vec::new();
let mut e2e_latencies = Vec::new();
let mut output_token_throughput_per_user = Vec::new();
let mut duration_ms = 0.0_f64;
let mut total_input_tokens = 0usize;
let mut total_output_tokens = 0usize;
let mut completed_requests = 0usize;
let mut total_reused_tokens = 0usize;
for stats in requests.values() {
if stats.first_admit_ms.is_none() {
continue;
}
let Some(first_token_ms) = stats.first_token_ms() else {
continue;
};
let Some(last_token_ms) = stats.last_token_ms() else {
continue;
};
completed_requests += 1;
total_input_tokens += stats.input_length;
total_output_tokens += stats.output_length;
total_reused_tokens += stats.reused_input_tokens;
duration_ms = duration_ms.max(last_token_ms);
let ttft_ms = (first_token_ms - stats.arrival_time_ms).max(0.0);
let e2e_ms = (last_token_ms - stats.arrival_time_ms).max(0.0);
ttfts.push(ttft_ms);
e2e_latencies.push(e2e_ms);
if let Some(ttst_ms) = stats.ttst_ms() {
ttsts.push(ttst_ms);
}
if let Some(tpot_ms) = stats.mean_tpot_ms() {
tpots.push(tpot_ms);
for itl_ms in stats.itls_ms() {
if itl_ms > 0.0 {
output_token_throughput_per_user.push(1000.0 / itl_ms);
}
itls.push(itl_ms);
}
}
}
let duration_s = (duration_ms / 1000.0).max(1e-9);
TraceSimulationReport {
request_counts: TraceRequestCounts {
num_requests: requests.len(),
completed_requests,
total_input_tokens,
total_output_tokens,
},
throughput: TraceThroughputStats {
duration_ms,
wall_time_ms: 0.0,
request_throughput_rps: completed_requests as f64 / duration_s,
input_throughput_tok_s: total_input_tokens as f64 / duration_s,
output_throughput_tok_s: total_output_tokens as f64 / duration_s,
total_throughput_tok_s: (total_input_tokens + total_output_tokens) as f64
/ duration_s,
},
prefix_cache_reused_ratio: if total_input_tokens == 0 {
0.0
} else {
total_reused_tokens as f64 / total_input_tokens as f64
},
latency: TraceLatencyStats {
ttft: build_distribution_stats(&ttfts),
ttst: build_distribution_stats(&ttsts),
tpot: build_distribution_stats(&tpots),
itl: TraceInterTokenLatencyStats {
distribution: build_distribution_stats(&itls),
max_ms: max_value(&itls),
},
e2e: build_distribution_stats(&e2e_latencies),
output_token_throughput_per_user: build_distribution_stats(
&output_token_throughput_per_user,
),
},
}
}
#[cfg(test)]
pub(crate) fn snapshot(&self, uuid: Uuid) -> Option<TraceRequestStatsSnapshot> {
self.requests
.get(&uuid)
.map(|stats| TraceRequestStatsSnapshot {
arrival_time_ms: stats.arrival_time_ms,
first_admit_ms: stats.first_admit_ms,
first_token_ms: stats.first_token_ms(),
last_token_ms: stats.last_token_ms(),
input_length: stats.input_length,
output_length: stats.output_length,
reused_input_tokens: stats.reused_input_tokens,
})
}
}
#[derive(Debug, Deserialize)]
struct RawTraceRecord {
#[serde(default)]
timestamp: Option<f64>,
#[serde(default)]
created_time: Option<f64>,
#[serde(default, alias = "input_tokens")]
input_length: Option<usize>,
#[serde(default, alias = "output_tokens")]
output_length: Option<usize>,
#[serde(default)]
hash_ids: Option<Vec<u64>>,
}
pub fn simulate_trace_file(
args: MockEngineArgs,
trace_path: &Path,
num_workers: usize,
) -> Result<TraceSimulationReport> {
validate_offline_replay_args(&args, num_workers)?;
let requests = load_trace_requests(trace_path, args.block_size, true)?;
let started_at = Instant::now();
let report = crate::scheduler::vllm::simulate_trace(args, requests)?;
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
}
pub fn simulate_concurrency_file(
args: MockEngineArgs,
trace_path: &Path,
max_in_flight: usize,
num_workers: usize,
) -> Result<TraceSimulationReport> {
let requests = load_trace_requests(trace_path, args.block_size, false)?;
let started_at = Instant::now();
let report = simulate_concurrency_requests(args, requests, max_in_flight, num_workers)?;
Ok(report.with_wall_time_ms(started_at.elapsed().as_secs_f64() * 1000.0))
}
pub fn simulate_concurrency_requests(
args: MockEngineArgs,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
) -> Result<TraceSimulationReport> {
validate_offline_concurrency_args(&args, num_workers, max_in_flight)?;
if requests.is_empty() {
bail!("concurrency replay requires at least one request");
}
crate::scheduler::vllm::simulate_concurrency(args, requests, max_in_flight)
}
fn validate_offline_replay_args(args: &MockEngineArgs, num_workers: usize) -> Result<()> {
if num_workers != 1 {
bail!(
"trace replay only supports num_workers=1, got {}",
num_workers
);
}
if args.engine_type != EngineType::Vllm {
bail!(
"trace replay only supports engine_type=vllm, got {:?}",
args.engine_type
);
}
if args.worker_type != WorkerType::Aggregated {
bail!(
"trace replay only supports aggregated workers, got {:?}",
args.worker_type
);
}
if args.dp_size != 1 {
bail!(
"trace replay only supports data_parallel_size=1, got {}",
args.dp_size
);
}
Ok(())
}
fn validate_offline_concurrency_args(
args: &MockEngineArgs,
num_workers: usize,
max_in_flight: usize,
) -> Result<()> {
if max_in_flight == 0 {
bail!("concurrency replay requires max_in_flight >= 1");
}
validate_offline_replay_args(args, num_workers)
}
fn load_trace_requests(
trace_path: &Path,
trace_block_size: usize,
timestamps_required: bool,
) -> Result<Vec<DirectRequest>> {
let file = File::open(trace_path)
.with_context(|| format!("failed to open trace file {}", trace_path.display()))?;
let reader = BufReader::new(file);
let mut requests = Vec::new();
for (line_idx, line) in reader.lines().enumerate() {
let line = line.with_context(|| {
format!(
"failed to read line {} from {}",
line_idx + 1,
trace_path.display()
)
})?;
if line.trim().is_empty() {
continue;
}
let raw: RawTraceRecord = serde_json::from_str(&line).with_context(|| {
format!(
"failed to parse line {} from {} as JSON",
line_idx + 1,
trace_path.display()
)
})?;
let input_length = raw
.input_length
.ok_or_else(|| anyhow!("trace line {} is missing input_length", line_idx + 1))?;
let output_length = raw
.output_length
.ok_or_else(|| anyhow!("trace line {} is missing output_length", line_idx + 1))?;
let hash_ids = raw
.hash_ids
.ok_or_else(|| anyhow!("trace line {} is missing hash_ids", line_idx + 1))?;
let arrival_timestamp_ms = if timestamps_required {
match raw.timestamp.or(raw.created_time) {
Some(timestamp_ms) => Some(timestamp_ms),
None => return Err(anyhow!("trace line {} is missing timestamp", line_idx + 1)),
}
} else {
None
};
let tokens = synthesize_tokens_from_hash_ids(&hash_ids, input_length, trace_block_size)
.with_context(|| {
format!(
"failed to synthesize tokens from hash_ids on line {}",
line_idx + 1
)
})?;
requests.push(DirectRequest {
tokens,
max_output_tokens: output_length,
uuid: Some(Uuid::new_v4()),
dp_rank: 0,
arrival_timestamp_ms,
});
}
if requests.is_empty() {
bail!(
"trace file {} did not contain any requests",
trace_path.display()
);
}
Ok(requests)
}
fn synthesize_tokens_from_hash_ids(
hash_ids: &[u64],
input_length: usize,
trace_block_size: usize,
) -> Result<Vec<u32>> {
let mut tokens = Vec::with_capacity(input_length);
for &hash_id in hash_ids {
let token_id = u32::try_from(hash_id)
.map_err(|_| anyhow!("hash_id {hash_id} exceeds u32::MAX for token synthesis"))?;
// TODO: Replace this repeated-token expansion with a hash-native prompt representation.
tokens.extend((0..trace_block_size).map(|_| token_id));
if tokens.len() >= input_length {
tokens.truncate(input_length);
return Ok(tokens);
}
}
bail!(
"input_length {} exceeds synthesized capacity {} from {} hash_ids and block_size {}",
input_length,
hash_ids.len() * trace_block_size,
hash_ids.len(),
trace_block_size
);
}
fn mean(values: &[f64]) -> f64 {
if values.is_empty() {
0.0
} else {
values.iter().sum::<f64>() / values.len() as f64
}
}
fn max_value(values: &[f64]) -> f64 {
values.iter().copied().reduce(f64::max).unwrap_or(0.0)
}
fn build_distribution_stats(values: &[f64]) -> TraceDistributionStats {
TraceDistributionStats {
mean_ms: mean(values),
min_ms: min_value(values),
max_ms: max_value(values),
median_ms: percentile(values, 50.0),
p75_ms: percentile(values, 75.0),
p90_ms: percentile(values, 90.0),
p95_ms: percentile(values, 95.0),
p99_ms: percentile(values, 99.0),
std_ms: std_dev(values),
}
}
fn percentile(values: &[f64], percentile: f64) -> f64 {
if values.is_empty() {
return 0.0;
}
let mut sorted = values.to_vec();
sorted.sort_by(|left, right| left.total_cmp(right));
let rank = ((sorted.len() - 1) as f64 * percentile / 100.0).round() as usize;
sorted[rank.min(sorted.len() - 1)]
}
fn min_value(values: &[f64]) -> f64 {
values.iter().copied().reduce(f64::min).unwrap_or(0.0)
}
fn std_dev(values: &[f64]) -> f64 {
if values.is_empty() {
return 0.0;
}
let mean = mean(values);
let variance = values
.iter()
.map(|value| {
let centered = value - mean;
centered * centered
})
.sum::<f64>()
/ values.len() as f64;
variance.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_replay_itl_uses_per_token_gaps() {
let mut collector = TraceCollector::default();
let uuid = Uuid::from_u128(11);
collector.on_arrival(uuid, 0.0, 4, 4);
collector.on_admit(uuid, 0.0, 0);
collector.on_token(uuid, 10.0);
collector.on_token(uuid, 11.0);
collector.on_token(uuid, 12.0);
collector.on_token(uuid, 110.0);
let report = collector.finish();
assert!((report.latency.tpot.mean_ms - (100.0 / 3.0)).abs() < 1e-9);
assert!((report.latency.itl.distribution.mean_ms - (100.0 / 3.0)).abs() < 1e-9);
assert_eq!(report.latency.itl.distribution.median_ms, 1.0);
assert_eq!(report.latency.itl.distribution.p75_ms, 98.0);
assert_eq!(report.latency.itl.distribution.p90_ms, 98.0);
assert_eq!(report.latency.itl.distribution.p95_ms, 98.0);
assert_eq!(report.latency.itl.max_ms, 98.0);
assert_eq!(report.latency.ttst.min_ms, 1.0);
assert_eq!(report.latency.ttst.max_ms, 1.0);
assert_eq!(
report.latency.output_token_throughput_per_user.min_ms,
1000.0 / 98.0
);
assert_eq!(
report.latency.output_token_throughput_per_user.max_ms,
1000.0
);
}
}
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import os
import subprocess
import sys
from pathlib import Path
import pytest
from tests.utils.constants import ROUTER_MODEL_NAME
MODEL_NAME = ROUTER_MODEL_NAME
MOONCAKE_TRACE_BLOCK_SIZE = 512
MOONCAKE_TRACE_SAMPLE_LINES = [
'{"timestamp": 0, "input_length": 6755, "output_length": 500, "hash_ids": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]}',
'{"timestamp": 0, "input_length": 7319, "output_length": 490, "hash_ids": [0, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]}',
'{"timestamp": 0, "input_length": 7234, "output_length": 794, "hash_ids": [0, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41]}',
'{"timestamp": 0, "input_length": 2287, "output_length": 316, "hash_ids": [0, 42, 43, 44, 45]}',
'{"timestamp": 0, "input_length": 9013, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]}',
'{"timestamp": 0, "input_length": 6506, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 64]}',
'{"timestamp": 0, "input_length": 4824, "output_length": 173, "hash_ids": [0, 65, 66, 67, 68, 69, 70, 71, 72, 73]}',
'{"timestamp": 0, "input_length": 3119, "output_length": 20, "hash_ids": [74, 75, 76, 77, 78, 79, 80]}',
'{"timestamp": 0, "input_length": 23090, "output_length": 453, "hash_ids": [0, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125]}',
'{"timestamp": 0, "input_length": 3135, "output_length": 19, "hash_ids": [74, 75, 76, 77, 78, 126, 127]}',
'{"timestamp": 0, "input_length": 26874, "output_length": 458, "hash_ids": [0, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179]}',
'{"timestamp": 0, "input_length": 10487, "output_length": 402, "hash_ids": [0, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199]}',
'{"timestamp": 0, "input_length": 17448, "output_length": 610, "hash_ids": [0, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233]}',
'{"timestamp": 0, "input_length": 6253, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 234]}',
'{"timestamp": 0, "input_length": 6725, "output_length": 32, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 235, 236]}',
'{"timestamp": 3052, "input_length": 13538, "output_length": 71, "hash_ids": [0, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262]}',
'{"timestamp": 3052, "input_length": 87162, "output_length": 402, "hash_ids": [0, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432]}',
'{"timestamp": 3052, "input_length": 6166, "output_length": 24, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 433]}',
'{"timestamp": 3052, "input_length": 6320, "output_length": 548, "hash_ids": [0, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445]}',
'{"timestamp": 3052, "input_length": 2007, "output_length": 354, "hash_ids": [0, 446, 447, 448]}',
]
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.gpu_0,
pytest.mark.integration,
pytest.mark.parallel,
pytest.mark.router,
pytest.mark.model(MODEL_NAME),
]
@pytest.mark.timeout(120)
def test_mocker_trace_file_replay(tmp_path):
repo_root = Path.cwd()
trace_file = tmp_path / "mooncake_trace.jsonl"
trace_file.write_text(
"\n".join(MOONCAKE_TRACE_SAMPLE_LINES) + "\n", encoding="utf-8"
)
replay_report = trace_file.with_name(f"{trace_file.stem}.replay.json")
pythonpath_entries = [
str(repo_root / "components/src"),
str(repo_root / "lib/bindings/python/src"),
]
existing_pythonpath = os.environ.get("PYTHONPATH")
if existing_pythonpath:
pythonpath_entries.append(existing_pythonpath)
env = os.environ.copy()
env["PYTHONPATH"] = os.pathsep.join(pythonpath_entries)
result = subprocess.run(
[
sys.executable,
"-m",
"dynamo.mocker",
"--trace-file",
str(trace_file),
"--model-path",
MODEL_NAME,
"--num-workers",
"1",
"--block-size",
str(MOONCAKE_TRACE_BLOCK_SIZE),
"--speedup-ratio",
"0",
],
cwd=repo_root,
env=env,
capture_output=True,
text=True,
timeout=120,
check=False,
)
assert result.returncode == 0, (
f"dynamo.mocker trace replay failed with exit code {result.returncode}\n"
f"stdout:\n{result.stdout}\n"
f"stderr:\n{result.stderr}"
)
assert replay_report.exists(), (
"Expected default replay report next to the temp trace file, "
f"but {replay_report} was not created.\nstdout:\n{result.stdout}\n"
f"stderr:\n{result.stderr}"
)
assert "Replay Summary" in result.stdout
assert f"JSON report: {replay_report}" in result.stdout
report = json.loads(replay_report.read_text(encoding="utf-8"))
assert report["num_requests"] == len(MOONCAKE_TRACE_SAMPLE_LINES)
assert report["completed_requests"] == len(MOONCAKE_TRACE_SAMPLE_LINES)
assert report["total_input_tokens"] > 0
assert report["total_output_tokens"] > 0
assert report["duration_ms"] > 0
assert report["wall_time_ms"] >= 0
assert report["request_throughput_rps"] > 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment