"docs/vscode:/vscode.git/clone" did not exist on "353ba5dbb5194e28270dc59680cb3fa0d71c0bbe"
Unverified Commit d22ca523 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Various Mocker Perf improvements + fixes (#5808)


Signed-off-by: default avatarjthomson04 <jothomson@nvidia.com>
parent e55ebec5
...@@ -22,7 +22,7 @@ The mocker engine now supports a vLLM-style CLI interface with individual argume ...@@ -22,7 +22,7 @@ The mocker engine now supports a vLLM-style CLI interface with individual argume
- `--enable-prefix-caching` / `--no-enable-prefix-caching`: Enable/disable automatic prefix caching (default: True) - `--enable-prefix-caching` / `--no-enable-prefix-caching`: Enable/disable automatic prefix caching (default: True)
- `--enable-chunked-prefill` / `--no-enable-chunked-prefill`: Enable/disable chunked prefill (default: True) - `--enable-chunked-prefill` / `--no-enable-chunked-prefill`: Enable/disable chunked prefill (default: True)
- `--watermark`: KV cache watermark threshold as a fraction (default: 0.01) - `--watermark`: KV cache watermark threshold as a fraction (default: 0.01)
- `--speedup-ratio`: Speed multiplier for token generation (default: 1.0). Higher values make the simulation engines run faster - `--speedup-ratio`: Speed multiplier for token generation (default: 1.0). Higher values make the simulation engines run faster. Use `0` for infinite speedup (no simulation delays)
- `--data-parallel-size`: Number of data parallel workers to simulate (default: 1) - `--data-parallel-size`: Number of data parallel workers to simulate (default: 1)
- `--num-workers`: Number of mocker workers to launch in the same process (default: 1). All workers share the same tokio runtime and thread pool - `--num-workers`: Number of mocker workers to launch in the same process (default: 1). All workers share the same tokio runtime and thread pool
- `--stagger-delay`: Delay in seconds between launching each worker to avoid overwhelming etcd/NATS/frontend. Set to 0 to disable staggering. Use -1 for auto mode (stagger dependent on number of workers). Default: -1 (auto) - `--stagger-delay`: Delay in seconds between launching each worker to avoid overwhelming etcd/NATS/frontend. Set to 0 to disable staggering. Use -1 for auto mode (stagger dependent on number of workers). Default: -1 (auto)
......
...@@ -249,7 +249,7 @@ def parse_args(): ...@@ -249,7 +249,7 @@ def parse_args():
"--speedup-ratio", "--speedup-ratio",
type=float, type=float,
default=None, default=None,
help="Speedup ratio for mock execution (default: 1.0)", help="Speedup ratio for mock execution (default: 1.0). Use 0 for infinite speedup (no simulation delays).",
) )
parser.add_argument( parser.add_argument(
"--data-parallel-size", "--data-parallel-size",
......
...@@ -45,7 +45,7 @@ use derive_getters::Getters; ...@@ -45,7 +45,7 @@ use derive_getters::Getters;
use dynamo_runtime::component::Component; use dynamo_runtime::component::Component;
use dynamo_tokens::blocks::UniqueBlock; use dynamo_tokens::blocks::UniqueBlock;
use dynamo_tokens::{BlockHash, SequenceHash}; use dynamo_tokens::{BlockHash, SequenceHash};
use std::collections::{HashMap, HashSet}; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
#[derive(Getters)] #[derive(Getters)]
...@@ -60,8 +60,6 @@ pub struct KvManager { ...@@ -60,8 +60,6 @@ pub struct KvManager {
inactive_blocks: LRUEvictor<UniqueBlock>, inactive_blocks: LRUEvictor<UniqueBlock>,
all_blocks: HashSet<UniqueBlock>,
kv_event_publisher: Option<Arc<KvEventPublisher>>, kv_event_publisher: Option<Arc<KvEventPublisher>>,
#[getter(copy)] #[getter(copy)]
...@@ -84,7 +82,6 @@ impl KvManager { ...@@ -84,7 +82,6 @@ impl KvManager {
) -> Self { ) -> Self {
let active_blocks = HashMap::new(); let active_blocks = HashMap::new();
let inactive_blocks = LRUEvictor::default(); let inactive_blocks = LRUEvictor::default();
let all_blocks = HashSet::new();
let kv_event_publisher = component.map(|comp| { let kv_event_publisher = component.map(|comp| {
tracing::info!( tracing::info!(
...@@ -101,7 +98,6 @@ impl KvManager { ...@@ -101,7 +98,6 @@ impl KvManager {
block_size, block_size,
active_blocks, active_blocks,
inactive_blocks, inactive_blocks,
all_blocks,
kv_event_publisher, kv_event_publisher,
dp_rank, dp_rank,
next_event_id: 0, next_event_id: 0,
...@@ -204,7 +200,6 @@ impl KvManager { ...@@ -204,7 +200,6 @@ impl KvManager {
"Evicting block from inactive pool: {evicted:?}, dp_rank={}", "Evicting block from inactive pool: {evicted:?}, dp_rank={}",
self.dp_rank self.dp_rank
); );
self.all_blocks.remove(&evicted);
if let UniqueBlock::FullBlock(evicted_full_block) = evicted { if let UniqueBlock::FullBlock(evicted_full_block) = evicted {
self.publish_kv_event(vec![evicted_full_block], &[], None, false); self.publish_kv_event(vec![evicted_full_block], &[], None, false);
} }
...@@ -212,7 +207,6 @@ impl KvManager { ...@@ -212,7 +207,6 @@ impl KvManager {
// Now insert the new block in active blocks with reference count 1 // Now insert the new block in active blocks with reference count 1
self.active_blocks.insert(hash.clone(), 1); self.active_blocks.insert(hash.clone(), 1);
self.all_blocks.insert(hash.clone());
if self.kv_event_publisher.is_some() if self.kv_event_publisher.is_some()
&& let UniqueBlock::FullBlock(stored_full_block) = hash && let UniqueBlock::FullBlock(stored_full_block) = hash
{ {
...@@ -234,8 +228,6 @@ impl KvManager { ...@@ -234,8 +228,6 @@ impl KvManager {
// Process blocks in order (already reversed by caller if needed) // Process blocks in order (already reversed by caller if needed)
for hash in hashes.iter() { for hash in hashes.iter() {
self.active_blocks.remove(hash).unwrap(); self.active_blocks.remove(hash).unwrap();
// Remove from all_blocks when destroyed
assert!(self.all_blocks.remove(hash));
// Track blocks for batch sending // Track blocks for batch sending
if self.kv_event_publisher.is_some() if self.kv_event_publisher.is_some()
...@@ -289,9 +281,6 @@ impl KvManager { ...@@ -289,9 +281,6 @@ impl KvManager {
self.active_blocks self.active_blocks
.insert(hash_block.clone(), hash_ref_count + 1); .insert(hash_block.clone(), hash_ref_count + 1);
assert!(self.all_blocks.remove(&uuid_block));
self.all_blocks.insert(hash_block);
} }
} }
...@@ -299,12 +288,13 @@ impl KvManager { ...@@ -299,12 +288,13 @@ impl KvManager {
true true
} }
/// Get the count of blocks in the input list that aren't in all_blocks /// Get the count of blocks that aren't in active or inactive pools
pub fn probe_new_blocks(&self, blocks: &[UniqueBlock]) -> usize { pub fn probe_new_blocks(&self, blocks: &[UniqueBlock]) -> usize {
blocks blocks
.iter() .iter()
// .filter(|&block| !self.active_blocks.contains_key(block)) .filter(|&block| {
.filter(|&block| !self.all_blocks.contains(block)) !self.active_blocks.contains_key(block) && !self.inactive_blocks.contains(block)
})
.count() .count()
} }
...@@ -349,9 +339,22 @@ impl KvManager { ...@@ -349,9 +339,22 @@ impl KvManager {
/// Check if a sequence can be scheduled and calculate cost if possible /// Check if a sequence can be scheduled and calculate cost if possible
pub fn get_prefill_cost(&self, sequence: &ActiveSequence) -> PrefillCost { pub fn get_prefill_cost(&self, sequence: &ActiveSequence) -> PrefillCost {
let seq_blocks = sequence.unique_blocks(); let seq_blocks = sequence.unique_blocks();
let new_blocks = self.probe_new_blocks(seq_blocks);
let overlap_blocks = seq_blocks.len() - new_blocks; // Find the longest prefix that exists in cache
let new_tokens = sequence.num_input_tokens() - overlap_blocks * self.block_size; // We must stop at the first cache miss since KV states are computed sequentially
let mut overlap_blocks = 0;
for block in seq_blocks {
if !self.active_blocks.contains_key(block) && !self.inactive_blocks.contains(block) {
// First cache miss - can't use anything after this point
break;
}
overlap_blocks += 1;
}
let new_blocks = seq_blocks.len() - overlap_blocks;
// Clamp cached_tokens to handle partial blocks (last block may have < block_size tokens)
let cached_tokens = (overlap_blocks * self.block_size).min(sequence.num_input_tokens());
let new_tokens = sequence.num_input_tokens() - cached_tokens;
PrefillCost { PrefillCost {
new_blocks, new_blocks,
......
...@@ -257,10 +257,10 @@ impl Scheduler { ...@@ -257,10 +257,10 @@ impl Scheduler {
component: Option<dynamo_runtime::component::Component>, component: Option<dynamo_runtime::component::Component>,
cancellation_token: Option<CancellationToken>, cancellation_token: Option<CancellationToken>,
) -> Self { ) -> Self {
// Assert speedup_ratio is greater than 0 // Assert speedup_ratio is non-negative (0 means infinite speedup)
assert!( assert!(
args.speedup_ratio > 0.0, args.speedup_ratio >= 0.0,
"speedup_ratio must be greater than 0, got: {}", "speedup_ratio must be >= 0 (0 means infinite speedup), got: {}",
args.speedup_ratio args.speedup_ratio
); );
...@@ -414,9 +414,11 @@ async fn simulate_prefill( ...@@ -414,9 +414,11 @@ async fn simulate_prefill(
break; break;
} }
} }
if speedup_ratio > 0.0 {
let deadline = start_time + Duration::from_secs_f64(total_time.as_secs_f64() / speedup_ratio); let deadline =
start_time + Duration::from_secs_f64(total_time.as_secs_f64() / speedup_ratio);
tokio::time::sleep_until(deadline).await; tokio::time::sleep_until(deadline).await;
}
total_time total_time
} }
...@@ -434,18 +436,21 @@ async fn simulate_decode( ...@@ -434,18 +436,21 @@ async fn simulate_decode(
let start_time = tokio::time::Instant::now(); let start_time = tokio::time::Instant::now();
// Compute decode timing // Compute decode timing
let active_kv_tokens = kv_manager.num_active_blocks() * block_size; let active_kv_tokens = kv_manager.num_active_blocks() * block_size;
// Compute average context length across all active decode requests // Compute average context length across all active decode requests
let (total_length, count) = state let total_length: usize = state
.decode .decode
.keys() .keys()
.filter_map(|uuid| state.requests.get(uuid)) .map(|uuid| {
.fold((0usize, 0usize), |(sum, cnt), req| { if let Request::Active(seq) = state.requests.get(uuid).unwrap() {
if let Request::Active(seq) = req { seq.len()
(sum + seq.len(), cnt + 1)
} else { } else {
(sum, cnt) 0
} }
}); })
.sum();
let count = state.decode.len();
let context_length = if count > 0 { total_length / count } else { 0 }; let context_length = if count > 0 { total_length / count } else { 0 };
let decoding_time = perf_model.predict_decode_time(active_kv_tokens, context_length); let decoding_time = perf_model.predict_decode_time(active_kv_tokens, context_length);
let total_time = Duration::from_secs_f64(decoding_time / 1000.0); let total_time = Duration::from_secs_f64(decoding_time / 1000.0);
...@@ -472,10 +477,8 @@ async fn simulate_decode( ...@@ -472,10 +477,8 @@ async fn simulate_decode(
// Check completion and send notification // Check completion and send notification
let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens(); let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens();
let should_output = sequence.generated_tokens() > sequence.already_generated_tokens();
let send_failed = should_output let send_failed = output_tx.as_ref().is_some_and(|tx| {
&& output_tx.as_ref().is_some_and(|tx| {
tx.send(OutputSignal { tx.send(OutputSignal {
uuid, uuid,
completed: is_complete, completed: is_complete,
...@@ -493,9 +496,11 @@ async fn simulate_decode( ...@@ -493,9 +496,11 @@ async fn simulate_decode(
state.complete(&uuid); state.complete(&uuid);
} }
} }
if speedup_ratio > 0.0 {
let deadline = start_time + Duration::from_secs_f64(total_time.as_secs_f64() / speedup_ratio); let deadline =
start_time + Duration::from_secs_f64(total_time.as_secs_f64() / speedup_ratio);
tokio::time::sleep_until(deadline).await; tokio::time::sleep_until(deadline).await;
}
total_time total_time
} }
......
...@@ -49,9 +49,6 @@ pub struct ActiveSequence { ...@@ -49,9 +49,6 @@ pub struct ActiveSequence {
#[getter(copy)] #[getter(copy)]
generated_tokens: usize, generated_tokens: usize,
#[getter(copy)]
already_generated_tokens: usize,
#[getter(copy)] #[getter(copy)]
num_input_tokens: usize, num_input_tokens: usize,
...@@ -85,7 +82,6 @@ impl ActiveSequence { ...@@ -85,7 +82,6 @@ impl ActiveSequence {
block_size, block_size,
max_output_tokens, max_output_tokens,
generated_tokens: 0, generated_tokens: 0,
already_generated_tokens: 0,
num_input_tokens, num_input_tokens,
creation_signal, creation_signal,
enable_prefix_caching, enable_prefix_caching,
...@@ -228,18 +224,13 @@ impl ActiveSequence { ...@@ -228,18 +224,13 @@ impl ActiveSequence {
.collect() .collect()
} }
/// Reset the sequence to its initial state and return the free signals from freeing current blocks /// Move the request to a preempted state and return the free signals from freeing current blocks
/// Upon preemption, the sequence retains the tokens generated during the decode phase (if any).
pub fn reset_with_signal(&mut self) -> Vec<MoveBlock> { pub fn reset_with_signal(&mut self) -> Vec<MoveBlock> {
let free_signal = self.free_signal(); let free_signal = self.free_signal();
self.tokens.truncate(self.num_input_tokens).unwrap(); // Don't reset generated_tokens since we're keeping the tokens in the sequence
self.unique_blocks = create_unique_blocks_from_sequence(
&self.tokens,
self.block_size,
self.enable_prefix_caching,
);
self.already_generated_tokens = self.generated_tokens.max(self.already_generated_tokens);
self.generated_tokens = 0;
self.creation_signal = Some(MoveBlock::Use( self.creation_signal = Some(MoveBlock::Use(
self.unique_blocks.clone(), self.unique_blocks.clone(),
self.block_hashes(), self.block_hashes(),
...@@ -411,7 +402,7 @@ mod tests { ...@@ -411,7 +402,7 @@ mod tests {
let free_signals = seq1.reset_with_signal(); let free_signals = seq1.reset_with_signal();
// 49 - 15 generated tokens // 49 - 15 generated tokens
assert_eq!(seq1.already_generated_tokens, 34); assert_eq!(seq1.generated_tokens(), 34);
// Verify the reset signals include proper cleanup events // Verify the reset signals include proper cleanup events
assert!(!free_signals.is_empty()); assert!(!free_signals.is_empty());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment