Unverified Commit 95a750f4 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore(replay): refactor offline components into cleaner lanes (#7866)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 210bbf5d
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::VecDeque;
use std::time::Duration;
use tokio::time::Instant;
use super::single::RequestId;
#[derive(Debug, Clone, Copy)]
pub(super) struct PrefillLoadState {
pub(super) initial_effective_prefill_tokens: usize,
pub(super) expected_prefill_duration: Option<Duration>,
}
#[derive(Debug, Default)]
pub(super) struct PrefillLoadTracker {
pub(super) prefill_order: VecDeque<RequestId>,
pub(super) prefill_full_tokens_sum: usize,
pub(super) anchored_prefill: Option<(RequestId, Instant)>,
}
impl PrefillLoadTracker {
pub(super) fn insert(
&mut self,
request_id: &RequestId,
prefill: PrefillLoadState,
decay_now: Instant,
) {
self.prefill_full_tokens_sum += prefill.initial_effective_prefill_tokens;
let should_anchor = self.anchored_prefill.is_none();
self.prefill_order.push_back(request_id.clone());
if should_anchor {
self.anchored_prefill = Some((request_id.clone(), decay_now));
}
}
pub(super) fn remove(
&mut self,
request_id: &RequestId,
prefill: PrefillLoadState,
decay_now: Instant,
) {
self.prefill_full_tokens_sum = self
.prefill_full_tokens_sum
.checked_sub(prefill.initial_effective_prefill_tokens)
.expect("prefill_full_tokens_sum underflow");
let removed_front = self.prefill_order.front() == Some(request_id);
if removed_front {
let removed = self.prefill_order.pop_front();
debug_assert_eq!(removed.as_ref(), Some(request_id));
} else {
self.prefill_order
.retain(|queued_request_id| queued_request_id != request_id);
}
if self
.anchored_prefill
.as_ref()
.is_some_and(|(anchored_request_id, _)| anchored_request_id == request_id)
{
self.set_anchor_to_front(decay_now);
}
}
pub(super) fn set_anchor_to_front(&mut self, now: Instant) {
self.anchored_prefill = self
.prefill_order
.front()
.cloned()
.map(|request_id| (request_id, now));
}
}
...@@ -26,6 +26,10 @@ use std::time::Duration; ...@@ -26,6 +26,10 @@ use std::time::Duration;
use tokio::time::Instant; use tokio::time::Instant;
use uuid::Uuid; use uuid::Uuid;
use super::block_tracker::BlockTracker;
use super::prefill_tracker::{PrefillLoadState, PrefillLoadTracker};
use crate::protocols::PrefillLoadHint;
/// Duration after which stale requests may be expired (5 minutes). /// Duration after which stale requests may be expired (5 minutes).
const EXPIRY_DURATION: Duration = Duration::from_secs(300); const EXPIRY_DURATION: Duration = Duration::from_secs(300);
...@@ -36,106 +40,183 @@ const CHECK_EXPIRY_FREQUENCY: Duration = Duration::from_secs(30); ...@@ -36,106 +40,183 @@ const CHECK_EXPIRY_FREQUENCY: Duration = Duration::from_secs(30);
// TODO: use the common request_id if it exists in the repo // TODO: use the common request_id if it exists in the repo
pub type RequestId = String; pub type RequestId = String;
#[derive(Debug)]
pub(super) struct RequestState {
blocks: Vec<(SequenceHash, Arc<()>)>,
started_at: Instant,
prefill: Option<PrefillLoadState>,
expected_output_tokens: Option<u32>,
}
/// A multi-request sequence manager that handles multiple active sequences with shared KV cache /// A multi-request sequence manager that handles multiple active sequences with shared KV cache
#[derive(Debug, Getters)] #[derive(Debug, Getters)]
pub struct ActiveSequences { pub struct ActiveSequences {
active_seqs: HashMap<RequestId, Vec<(SequenceHash, Arc<()>)>>, requests: HashMap<RequestId, RequestState>,
prefill: PrefillLoadTracker,
prefill_tokens: HashMap<RequestId, usize>, blocks: BlockTracker,
/// Expected output tokens per request (used for resource estimation)
expected_output_tokens: HashMap<RequestId, u32>,
unique_blocks: HashMap<SequenceHash, std::sync::Weak<()>>,
/// Fractional block counts for blocks that are partially cached
/// When a block is in both unique_blocks and fractional_blocks,
/// it contributes the fractional value instead of 1 to active_blocks()
fractional_blocks: HashMap<SequenceHash, f64>,
#[getter(copy)] #[getter(copy)]
block_size: usize, block_size: usize,
#[getter(copy)]
active_tokens: usize,
// Request timestamps, for expiration.
request_timestamps: HashMap<RequestId, Instant>,
last_expiry_check_time: Instant, last_expiry_check_time: Instant,
} }
impl ActiveSequences { impl ActiveSequences {
/// Create a new SharedSequenceManager instance /// Create a new SharedSequenceManager instance
pub fn new(block_size: usize) -> Self { pub fn new(block_size: usize) -> Self {
// TODO: make this not a hard req
assert!(block_size > 1, "block_size must be greater than 1"); assert!(block_size > 1, "block_size must be greater than 1");
Self { Self {
active_seqs: HashMap::new(), requests: HashMap::new(),
prefill_tokens: HashMap::new(), prefill: PrefillLoadTracker::default(),
expected_output_tokens: HashMap::new(), blocks: BlockTracker::default(),
unique_blocks: HashMap::new(),
fractional_blocks: HashMap::new(),
block_size, block_size,
active_tokens: 0,
request_timestamps: HashMap::new(),
last_expiry_check_time: Instant::now(), last_expiry_check_time: Instant::now(),
} }
} }
fn touch_block(&mut self, block: &SequenceHash) -> Arc<()> { #[cfg(any(test, debug_assertions))]
if let Some(weak) = self.unique_blocks.get(block) fn assert_consistent(&self) {
&& let Some(rc) = weak.upgrade() let active_prefills: HashSet<RequestId> = self
{ .requests
return rc; .iter()
.filter(|(_, state)| state.prefill.is_some())
.map(|(request_id, _)| request_id.clone())
.collect();
let ordered_prefills: HashSet<RequestId> =
self.prefill.prefill_order.iter().cloned().collect();
let recomputed_prefill_sum: usize = self
.requests
.values()
.filter_map(|state| state.prefill)
.map(|prefill| prefill.initial_effective_prefill_tokens)
.sum();
assert_eq!(
ordered_prefills.len(),
self.prefill.prefill_order.len(),
"prefill_order contains duplicate request ids",
);
assert_eq!(
ordered_prefills, active_prefills,
"prefill_order must match requests with active prefill load",
);
assert_eq!(
self.prefill.prefill_full_tokens_sum, recomputed_prefill_sum,
"prefill_full_tokens_sum drifted from request state",
);
if let Some(oldest_request_id) = self.prefill.prefill_order.front() {
let Some((anchored_request_id, _)) = self.prefill.anchored_prefill.as_ref() else {
panic!("anchored_prefill must exist when prefill_order is non-empty");
};
assert!(
self.requests
.get(oldest_request_id)
.is_some_and(|state| state.prefill.is_some()),
"prefill_order front must point to an active prefill request",
);
assert_eq!(
anchored_request_id, oldest_request_id,
"anchored_prefill must match prefill_order.front()",
);
} else {
assert!(
self.prefill.anchored_prefill.is_none(),
"anchored_prefill must be absent when no active prefills remain",
);
} }
assert!(
let rc = Arc::new(()); self.blocks
self.unique_blocks.insert(*block, Arc::downgrade(&rc)); .fractional_blocks
rc .keys()
.all(|hash| self.blocks.unique_blocks.contains_key(hash)),
"fractional_blocks cannot reference non-active blocks",
);
} }
fn try_remove_block(&mut self, block: &SequenceHash) { #[inline]
if let Some(weak) = self.unique_blocks.get(block) fn validate_state(&self) {
&& weak.strong_count() == 0 #[cfg(any(test, debug_assertions))]
{ self.assert_consistent();
self.unique_blocks.remove(block);
self.fractional_blocks.remove(block);
}
} }
pub fn active_blocks(&self) -> usize { pub fn active_blocks(&self) -> usize {
let mut count = self.unique_blocks.len() as f64; self.blocks.active_blocks()
for (hash, frac) in &self.fractional_blocks { }
if self.unique_blocks.contains_key(hash) {
// Subtract 1 (the full block) and add the fractional value fn insert_prefill_load(
count = count - 1.0 + frac; &mut self,
request_id: &RequestId,
prefill: PrefillLoadState,
decay_now: Instant,
) {
self.prefill.insert(request_id, prefill, decay_now);
}
fn remove_prefill_load(
&mut self,
request_id: &RequestId,
decay_now: Instant,
) -> Option<PrefillLoadState> {
let prefill = {
let state = self.requests.get_mut(request_id)?;
state.prefill.take()?
};
self.prefill.remove(request_id, prefill, decay_now);
Some(prefill)
}
fn active_prefill_tokens_at(&self, now: Instant) -> usize {
let Some((oldest_request_id, oldest_since)) = self.prefill.anchored_prefill.as_ref() else {
return 0;
};
let prefill = self
.requests
.get(oldest_request_id)
.and_then(|state| state.prefill)
.expect("prefill_order front missing prefill load");
let oldest_full = prefill.initial_effective_prefill_tokens;
let oldest_remaining = match prefill.expected_prefill_duration {
None => oldest_full,
Some(expected_prefill_duration) if expected_prefill_duration.is_zero() => 0,
Some(expected_prefill_duration) => {
let elapsed = now.saturating_duration_since(*oldest_since);
let remaining_fraction = (1.0
- (elapsed.as_secs_f64() / expected_prefill_duration.as_secs_f64()))
.clamp(0.0, 1.0);
((oldest_full as f64) * remaining_fraction).ceil() as usize
} }
} };
count.round() as usize
self.prefill
.prefill_full_tokens_sum
.checked_sub(oldest_full)
.expect("prefill_full_tokens_sum smaller than oldest load")
+ oldest_remaining
}
pub fn active_tokens(&self, decay_now: Instant) -> usize {
self.active_prefill_tokens_at(decay_now)
} }
/// Find all blocks in a request that have only a single strong reference (only used by this request) /// Find all blocks in a request that have only a single strong reference (only used by this request)
/// and insert them into fractional_blocks with the given fraction value. /// and insert them into fractional_blocks with the given fraction value.
pub fn set_single_ref_blocks_as_fractional(&mut self, request_id: &RequestId, fraction: f64) { pub fn set_single_ref_blocks_as_fractional(&mut self, request_id: &RequestId, fraction: f64) {
let Some(blocks) = self.active_seqs.get(request_id) else { let Some(request_state) = self.requests.get(request_id) else {
tracing::warn!( tracing::warn!(
"Request {request_id} not found for set_single_ref_blocks_as_fractional" "Request {request_id} not found for set_single_ref_blocks_as_fractional"
); );
return; return;
}; };
for (hash, rc) in blocks { for (hash, rc) in &request_state.blocks {
// A block with strong_count == 1 means only this request holds a reference
if Arc::strong_count(rc) == 1 { if Arc::strong_count(rc) == 1 {
self.fractional_blocks.insert(*hash, fraction); self.blocks.fractional_blocks.insert(*hash, fraction);
} }
} }
} }
/// Add a new request with its initial tokens /// Add a new request with its initial tokens.
/// Returns the set of expired request IDs that were removed during cleanup /// Returns the set of expired request IDs that were removed during cleanup.
pub fn add_request( pub fn add_request(
&mut self, &mut self,
request_id: RequestId, request_id: RequestId,
...@@ -143,6 +224,7 @@ impl ActiveSequences { ...@@ -143,6 +224,7 @@ impl ActiveSequences {
isl: usize, isl: usize,
overlap: u32, overlap: u32,
expected_output_tokens: Option<u32>, expected_output_tokens: Option<u32>,
decay_now: Instant,
) -> HashSet<RequestId> { ) -> HashSet<RequestId> {
self.add_request_with_prefill_tracking( self.add_request_with_prefill_tracking(
request_id, request_id,
...@@ -151,11 +233,14 @@ impl ActiveSequences { ...@@ -151,11 +233,14 @@ impl ActiveSequences {
overlap, overlap,
expected_output_tokens, expected_output_tokens,
true, true,
None,
decay_now,
) )
} }
/// Add a new request with optional prompt-token load accounting. /// Add a new request with optional prompt-token load accounting.
/// Returns the set of expired request IDs that were removed during cleanup. /// Returns the set of expired request IDs that were removed during cleanup.
#[allow(clippy::too_many_arguments)]
pub fn add_request_with_prefill_tracking( pub fn add_request_with_prefill_tracking(
&mut self, &mut self,
request_id: RequestId, request_id: RequestId,
...@@ -164,68 +249,76 @@ impl ActiveSequences { ...@@ -164,68 +249,76 @@ impl ActiveSequences {
overlap: u32, overlap: u32,
expected_output_tokens: Option<u32>, expected_output_tokens: Option<u32>,
track_prefill_tokens: bool, track_prefill_tokens: bool,
prefill_load_hint: Option<PrefillLoadHint>,
decay_now: Instant,
) -> HashSet<RequestId> { ) -> HashSet<RequestId> {
// Check for double-add and log error, returning early if self.requests.contains_key(&request_id) {
if self.active_seqs.contains_key(&request_id) {
tracing::error!("Request {request_id} is already active. Ignoring duplicate add."); tracing::error!("Request {request_id} is already active. Ignoring duplicate add.");
return HashSet::new(); return HashSet::new();
} }
// Lazily check and clean up expired requests, capturing removed IDs
let removed_requests = self.force_expiry(); let removed_requests = self.force_expiry();
let started_at = Instant::now();
let blocks = match token_sequence {
Some(sequence) => sequence
.into_iter()
.map(|block| {
let rc = self.blocks.touch_block(&block);
(block, rc)
})
.collect(),
None => Vec::new(),
};
let prefill_tokens = if track_prefill_tokens { let prefill = if track_prefill_tokens {
self.new_tokens(isl, overlap) let default_tokens = self.new_tokens(isl, overlap);
let hint = prefill_load_hint.unwrap_or(PrefillLoadHint {
initial_effective_prefill_tokens: default_tokens,
expected_prefill_duration: None,
});
(hint.initial_effective_prefill_tokens > 0).then_some(PrefillLoadState {
initial_effective_prefill_tokens: hint.initial_effective_prefill_tokens,
expected_prefill_duration: hint.expected_prefill_duration,
})
} else { } else {
0 None
}; };
self.prefill_tokens
.insert(request_id.clone(), prefill_tokens);
self.active_tokens += prefill_tokens;
// Store expected output tokens if provided
if let Some(tokens) = expected_output_tokens {
self.expected_output_tokens
.insert(request_id.clone(), tokens);
}
if let Some(sequence) = token_sequence { self.requests.insert(
let sequence_with_refs: Vec<(SequenceHash, Arc<()>)> = sequence request_id.clone(),
.iter() RequestState {
.map(|block| (*block, self.touch_block(block))) blocks,
.collect(); started_at,
self.active_seqs prefill,
.insert(request_id.clone(), sequence_with_refs); expected_output_tokens,
} else { },
// dummy empty sequence );
self.active_seqs.insert(request_id.clone(), Vec::new());
if let Some(prefill) = prefill {
self.insert_prefill_load(&request_id, prefill, decay_now);
} }
self.request_timestamps
.insert(request_id.clone(), Instant::now());
self.validate_state();
removed_requests removed_requests
} }
/// Mark prefill as completed for a request, removing it from prefill_tokens tracking /// Mark prefill as completed for a request, removing it from prompt-load tracking.
pub fn mark_prefill_completed(&mut self, request_id: &RequestId) { pub fn mark_prefill_completed(&mut self, request_id: &RequestId, decay_now: Instant) {
if let Some(tokens) = self.prefill_tokens.remove(request_id) { let _ = self.remove_prefill_load(request_id, decay_now);
self.active_tokens = self self.validate_state();
.active_tokens
.checked_sub(tokens)
.expect("active_tokens underflow");
}
} }
pub fn new_tokens(&self, isl: usize, overlap: u32) -> usize { pub fn new_tokens(&self, isl: usize, overlap: u32) -> usize {
let cached_tokens = (overlap as usize) * self.block_size; let cached_tokens = (overlap as usize) * self.block_size;
isl.checked_sub(cached_tokens) isl.checked_sub(cached_tokens).unwrap_or_else(|| {
.unwrap_or_else(|| { tracing::error!(
tracing::error!( "prefill_tokens < 0 with ISL {isl} < cached_tokens {cached_tokens} (overlap {overlap} * block_size {}), returning 0",
"prefill_tokens < 0 with ISL {isl} < cached_tokens {cached_tokens} (overlap {overlap} * block_size {}), returning 0", self.block_size
self.block_size );
); 0
0 })
})
} }
pub fn potential_blocks_and_tokens( pub fn potential_blocks_and_tokens(
...@@ -233,8 +326,15 @@ impl ActiveSequences { ...@@ -233,8 +326,15 @@ impl ActiveSequences {
token_sequence: Option<&[SequenceHash]>, token_sequence: Option<&[SequenceHash]>,
isl: usize, isl: usize,
overlap: u32, overlap: u32,
decay_now: Instant,
) -> (usize, usize) { ) -> (usize, usize) {
self.potential_blocks_and_tokens_with_prefill_tracking(token_sequence, isl, overlap, true) self.potential_blocks_and_tokens_with_prefill_tracking(
token_sequence,
isl,
overlap,
true,
decay_now,
)
} }
pub fn potential_blocks_and_tokens_with_prefill_tracking( pub fn potential_blocks_and_tokens_with_prefill_tracking(
...@@ -243,17 +343,20 @@ impl ActiveSequences { ...@@ -243,17 +343,20 @@ impl ActiveSequences {
isl: usize, isl: usize,
overlap: u32, overlap: u32,
track_prefill_tokens: bool, track_prefill_tokens: bool,
decay_now: Instant,
) -> (usize, usize) { ) -> (usize, usize) {
let potential_blocks = if let Some(token_seq) = token_sequence { let potential_blocks = if let Some(token_seq) = token_sequence {
self.new_blocks(token_seq) + self.active_blocks() self.new_blocks(token_seq) + self.active_blocks()
} else { } else {
self.active_blocks() self.active_blocks()
}; };
let active_tokens = self.active_tokens(decay_now);
let potential_tokens = if track_prefill_tokens { let potential_tokens = if track_prefill_tokens {
self.new_tokens(isl, overlap) + self.active_tokens self.new_tokens(isl, overlap) + active_tokens
} else { } else {
self.active_tokens active_tokens
}; };
(potential_blocks, potential_tokens) (potential_blocks, potential_tokens)
} }
...@@ -261,12 +364,11 @@ impl ActiveSequences { ...@@ -261,12 +364,11 @@ impl ActiveSequences {
pub fn new_blocks(&self, token_sequence: &[SequenceHash]) -> usize { pub fn new_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
token_sequence token_sequence
.iter() .iter()
.filter(|block| !self.unique_blocks.contains_key(block)) .filter(|block| !self.blocks.unique_blocks.contains_key(block))
.count() .count()
} }
/// Return the total number of blocks that would be used if the token sequence was added /// Return the total number of blocks that would be used if the token sequence was added.
/// This is the sum of new blocks that would be added plus the current active blocks
pub fn potential_blocks(&self, token_sequence: &[SequenceHash]) -> usize { pub fn potential_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
self.new_blocks(token_sequence) + self.active_blocks() self.new_blocks(token_sequence) + self.active_blocks()
} }
...@@ -275,96 +377,77 @@ impl ActiveSequences { ...@@ -275,96 +377,77 @@ impl ActiveSequences {
/// ///
/// This implicitly calls [`Self::mark_prefill_completed`] first, so callers do not need /// This implicitly calls [`Self::mark_prefill_completed`] first, so callers do not need
/// to invoke both when the request is finishing. /// to invoke both when the request is finishing.
pub fn free(&mut self, request_id: &RequestId) -> usize { pub fn free(&mut self, request_id: &RequestId, decay_now: Instant) -> usize {
self.mark_prefill_completed(request_id); self.mark_prefill_completed(request_id, decay_now);
// Remove expected output tokens tracking let Some(request_state) = self.requests.remove(request_id) else {
self.expected_output_tokens.remove(request_id); tracing::warn!("Trying to free non-existent request {request_id}");
return self.active_blocks();
// Remove from active_seqs and get the token sequence
self.request_timestamps.remove(request_id);
let token_seq = match self.active_seqs.remove(request_id) {
Some(seq) => seq,
None => {
tracing::warn!("Trying to free non-existent request {request_id}");
return self.active_blocks();
}
}; };
// Drop each Rc reference, then clean up the corresponding weak reference let _ = request_state.expected_output_tokens;
for (block_hash, rc) in token_seq { for (block_hash, rc) in request_state.blocks {
drop(rc); drop(rc);
self.try_remove_block(&block_hash); self.blocks.try_remove_block(&block_hash);
} }
self.validate_state();
self.active_blocks() self.active_blocks()
} }
/// Add an output block with a random hash and optional fractional decay weight. /// Add an output block with a random hash and optional fractional decay weight.
/// ///
/// This is used during generation to track output blocks as they are created. /// This is used during generation to track output blocks as they are created.
/// The decay_fraction (if provided) represents how "temporary" the block is:
/// - 1.0 means fully counted (early in generation)
/// - 0.0 means not counted (near end of expected output)
/// - Computed as: 1 - (current_osl / expected_output_tokens)
///
/// Returns true if the block was added, false if the request was not found.
pub fn add_output_block( pub fn add_output_block(
&mut self, &mut self,
request_id: &RequestId, request_id: &RequestId,
decay_fraction: Option<f64>, decay_fraction: Option<f64>,
) -> bool { ) -> bool {
// Check if request exists first (immutable borrow) if !self.requests.contains_key(request_id) {
if !self.active_seqs.contains_key(request_id) {
tracing::warn!("Request {request_id} not found for add_output_block"); tracing::warn!("Request {request_id} not found for add_output_block");
return false; return false;
} }
// Generate a random block hash using UUID
let random_hash: SequenceHash = Uuid::new_v4().as_u64_pair().0; let random_hash: SequenceHash = Uuid::new_v4().as_u64_pair().0;
let rc = self.blocks.touch_block(&random_hash);
// Touch the block (adds to unique_blocks) self.requests
let rc = self.touch_block(&random_hash);
// Now we can safely get_mut and push
self.active_seqs
.get_mut(request_id) .get_mut(request_id)
.unwrap() .expect("request existence was checked above")
.blocks
.push((random_hash, rc)); .push((random_hash, rc));
// Apply fractional decay to all single-ref blocks in this request if provided
if let Some(frac) = decay_fraction { if let Some(frac) = decay_fraction {
self.set_single_ref_blocks_as_fractional(request_id, frac); self.set_single_ref_blocks_as_fractional(request_id, frac);
} }
self.validate_state();
true true
} }
/// Force expiry of stale requests if the timer has elapsed /// Force expiry of stale requests if the timer has elapsed.
/// Returns the set of expired request IDs that were removed /// Returns the set of expired request IDs that were removed.
pub fn force_expiry(&mut self) -> HashSet<RequestId> { pub fn force_expiry(&mut self) -> HashSet<RequestId> {
let now = Instant::now(); let now = Instant::now();
// Early return if timer hasn't expired yet.
if now < self.last_expiry_check_time + CHECK_EXPIRY_FREQUENCY { if now < self.last_expiry_check_time + CHECK_EXPIRY_FREQUENCY {
return HashSet::new(); return HashSet::new();
} }
self.last_expiry_check_time = now; self.last_expiry_check_time = now;
let expired_requests_time = now - EXPIRY_DURATION; let expired_requests_time = now - EXPIRY_DURATION;
let expired_requests: HashSet<RequestId> = self
let mut expired_requests: HashSet<RequestId> = HashSet::new(); .requests
for (request_id, timestamp) in self.request_timestamps.iter() { .iter()
if *timestamp < expired_requests_time { .filter(|(_, state)| state.started_at < expired_requests_time)
expired_requests.insert(request_id.clone()); .map(|(request_id, _)| request_id.clone())
} .collect();
}
for request_id in &expired_requests { for request_id in &expired_requests {
tracing::warn!("Expiring stale request: {}", request_id); tracing::warn!("Expiring stale request: {}", request_id);
self.free(request_id); self.free(request_id, now);
} }
self.validate_state();
expired_requests expired_requests
} }
} }
...@@ -372,103 +455,131 @@ impl ActiveSequences { ...@@ -372,103 +455,131 @@ impl ActiveSequences {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use std::collections::VecDeque;
fn prefill_hint(tokens: usize, duration_secs: u64) -> PrefillLoadHint {
PrefillLoadHint {
initial_effective_prefill_tokens: tokens,
expected_prefill_duration: Some(Duration::from_secs(duration_secs)),
}
}
#[test] #[test]
fn test_active_sequences_shared_blocks() { fn test_active_sequences_shared_blocks() {
let block_size = 4; let block_size = 4;
let mut seq_manager = ActiveSequences::new(block_size); let mut seq_manager = ActiveSequences::new(block_size);
let decay_now = Instant::now();
seq_manager.add_request("request_1".to_string(), Some(vec![1, 2, 3]), 12, 0, None); seq_manager.add_request(
"request_1".to_string(),
Some(vec![1, 2, 3]),
12,
0,
None,
decay_now,
);
assert_eq!(seq_manager.active_blocks(), 3); assert_eq!(seq_manager.active_blocks(), 3);
assert_eq!(seq_manager.active_tokens(), 12); assert_eq!(seq_manager.active_tokens(decay_now), 12);
seq_manager.add_request("request_2".to_string(), Some(vec![4]), 4, 0, None); seq_manager.add_request(
"request_2".to_string(),
Some(vec![4]),
4,
0,
None,
decay_now,
);
assert_eq!(seq_manager.active_blocks(), 4); assert_eq!(seq_manager.active_blocks(), 4);
assert_eq!(seq_manager.active_tokens(), 16); assert_eq!(seq_manager.active_tokens(decay_now), 16);
seq_manager.add_request("request_3".to_string(), Some(vec![1, 2, 3, 4]), 16, 4, None); seq_manager.add_request(
"request_3".to_string(),
Some(vec![1, 2, 3, 4]),
16,
4,
None,
decay_now,
);
assert_eq!(seq_manager.active_blocks(), 4); assert_eq!(seq_manager.active_blocks(), 4);
assert_eq!(seq_manager.active_tokens(), 16); assert_eq!(seq_manager.active_tokens(decay_now), 16);
seq_manager.free(&"request_2".to_string()); seq_manager.free(&"request_2".to_string(), decay_now);
assert_eq!(seq_manager.active_blocks(), 4); assert_eq!(seq_manager.active_blocks(), 4);
assert_eq!(seq_manager.active_tokens(), 12); assert_eq!(seq_manager.active_tokens(decay_now), 12);
seq_manager.free(&"request_3".to_string()); seq_manager.free(&"request_3".to_string(), decay_now);
assert_eq!(seq_manager.active_blocks(), 3); assert_eq!(seq_manager.active_blocks(), 3);
assert_eq!(seq_manager.active_tokens(), 12); assert_eq!(seq_manager.active_tokens(decay_now), 12);
seq_manager.free(&"request_1".to_string()); seq_manager.free(&"request_1".to_string(), decay_now);
assert_eq!(seq_manager.active_blocks(), 0); assert_eq!(seq_manager.active_blocks(), 0);
assert_eq!(seq_manager.active_tokens(), 0); assert_eq!(seq_manager.active_tokens(decay_now), 0);
} }
#[test] #[test]
fn test_output_blocks_with_fractional_decay() { fn test_output_blocks_with_fractional_decay() {
let block_size = 4; let block_size = 4;
let mut seq_manager = ActiveSequences::new(block_size); let mut seq_manager = ActiveSequences::new(block_size);
let decay_now = Instant::now();
// Add request with 3 prefill blocks seq_manager.add_request(
seq_manager.add_request("r1".to_string(), Some(vec![1, 2, 3]), 12, 0, None); "r1".to_string(),
Some(vec![1, 2, 3]),
12,
0,
None,
decay_now,
);
assert_eq!(seq_manager.active_blocks(), 3); assert_eq!(seq_manager.active_blocks(), 3);
// Add output block with 0.5 decay fraction.
// This adds a random block and sets all single-ref blocks to 0.5.
assert!(seq_manager.add_output_block(&"r1".to_string(), Some(0.5))); assert!(seq_manager.add_output_block(&"r1".to_string(), Some(0.5)));
// 4 unique blocks, all single-ref → all fractional at 0.5
// active_blocks = 4 - 4 + 4*0.5 = 2
assert_eq!(seq_manager.active_blocks(), 2); assert_eq!(seq_manager.active_blocks(), 2);
// Add second request sharing prefix [1, 2] seq_manager.add_request("r2".to_string(), Some(vec![1, 2]), 8, 0, None, decay_now);
seq_manager.add_request("r2".to_string(), Some(vec![1, 2]), 8, 0, None);
// Blocks 1,2 now have strong_count=2 but still have fractional 0.5 from before
// No new unique blocks → active_blocks = 4 - 4 + 2.0 = 2
assert_eq!(seq_manager.active_blocks(), 2); assert_eq!(seq_manager.active_blocks(), 2);
// Add another output block with 0.0 decay for r1.
// set_single_ref_blocks_as_fractional updates only single-ref blocks:
// blocks 1,2: strong_count=2, NOT updated (remain 0.5)
// block 3, old output, new output: strong_count=1, set to 0.0
// active_blocks = 5 - 5 + (0.5+0.5+0.0+0.0+0.0) = 1
assert!(seq_manager.add_output_block(&"r1".to_string(), Some(0.0))); assert!(seq_manager.add_output_block(&"r1".to_string(), Some(0.0)));
assert_eq!(seq_manager.active_blocks(), 1); assert_eq!(seq_manager.active_blocks(), 1);
// Free both requests, verify clean state seq_manager.free(&"r2".to_string(), decay_now);
seq_manager.free(&"r2".to_string()); seq_manager.free(&"r1".to_string(), decay_now);
seq_manager.free(&"r1".to_string());
assert_eq!(seq_manager.active_blocks(), 0); assert_eq!(seq_manager.active_blocks(), 0);
assert_eq!(seq_manager.active_tokens(), 0); assert_eq!(seq_manager.active_tokens(decay_now), 0);
} }
#[test] #[test]
fn test_mark_prefill_completed() { fn test_mark_prefill_completed() {
let block_size = 4; let block_size = 4;
let mut seq_manager = ActiveSequences::new(block_size); let mut seq_manager = ActiveSequences::new(block_size);
let decay_now = Instant::now();
// Add request with isl=12, overlap=0 → active_tokens=12 seq_manager.add_request(
seq_manager.add_request("r1".to_string(), Some(vec![1, 2, 3]), 12, 0, None); "r1".to_string(),
assert_eq!(seq_manager.active_tokens(), 12); Some(vec![1, 2, 3]),
12,
0,
None,
decay_now,
);
assert_eq!(seq_manager.active_tokens(decay_now), 12);
// Mark prefill completed → active_tokens drops to 0 seq_manager.mark_prefill_completed(&"r1".to_string(), decay_now);
seq_manager.mark_prefill_completed(&"r1".to_string()); assert_eq!(seq_manager.active_tokens(decay_now), 0);
assert_eq!(seq_manager.active_tokens(), 0);
// Double-mark: no panic, still 0 seq_manager.mark_prefill_completed(&"r1".to_string(), decay_now);
seq_manager.mark_prefill_completed(&"r1".to_string()); assert_eq!(seq_manager.active_tokens(decay_now), 0);
assert_eq!(seq_manager.active_tokens(), 0);
// Add second request with isl=8 seq_manager.add_request("r2".to_string(), Some(vec![4, 5]), 8, 0, None, decay_now);
seq_manager.add_request("r2".to_string(), Some(vec![4, 5]), 8, 0, None); assert_eq!(seq_manager.active_tokens(decay_now), 8);
assert_eq!(seq_manager.active_tokens(), 8);
// Free it (internally calls mark_prefill_completed) → active_tokens=0 seq_manager.free(&"r2".to_string(), decay_now);
seq_manager.free(&"r2".to_string()); assert_eq!(seq_manager.active_tokens(decay_now), 0);
assert_eq!(seq_manager.active_tokens(), 0);
} }
#[test] #[test]
fn test_add_request_without_prefill_tracking_keeps_active_tokens_zero() { fn test_add_request_without_prefill_tracking_keeps_active_tokens_zero() {
let mut seq_manager = ActiveSequences::new(4); let mut seq_manager = ActiveSequences::new(4);
let decay_now = Instant::now();
seq_manager.add_request_with_prefill_tracking( seq_manager.add_request_with_prefill_tracking(
"r1".to_string(), "r1".to_string(),
...@@ -477,18 +588,24 @@ mod tests { ...@@ -477,18 +588,24 @@ mod tests {
0, 0,
None, None,
false, false,
None,
decay_now,
); );
assert_eq!(seq_manager.active_tokens(), 0); assert_eq!(seq_manager.active_tokens(decay_now), 0);
seq_manager.mark_prefill_completed(&"r1".to_string()); assert!(seq_manager.prefill.prefill_order.is_empty());
assert_eq!(seq_manager.active_tokens(), 0); assert_eq!(seq_manager.prefill.prefill_full_tokens_sum, 0);
seq_manager.free(&"r1".to_string());
seq_manager.mark_prefill_completed(&"r1".to_string(), decay_now);
assert_eq!(seq_manager.active_tokens(decay_now), 0);
seq_manager.free(&"r1".to_string(), decay_now);
assert_eq!(seq_manager.active_blocks(), 0); assert_eq!(seq_manager.active_blocks(), 0);
} }
#[test] #[test]
fn test_potential_blocks_and_tokens_without_prefill_tracking_ignores_prompt_load() { fn test_potential_blocks_and_tokens_without_prefill_tracking_ignores_prompt_load() {
let mut seq_manager = ActiveSequences::new(4); let mut seq_manager = ActiveSequences::new(4);
let decay_now = Instant::now();
seq_manager.add_request_with_prefill_tracking( seq_manager.add_request_with_prefill_tracking(
"r1".to_string(), "r1".to_string(),
Some(vec![1, 2, 3]), Some(vec![1, 2, 3]),
...@@ -496,6 +613,8 @@ mod tests { ...@@ -496,6 +613,8 @@ mod tests {
0, 0,
None, None,
false, false,
None,
decay_now,
); );
let (blocks, tokens) = seq_manager.potential_blocks_and_tokens_with_prefill_tracking( let (blocks, tokens) = seq_manager.potential_blocks_and_tokens_with_prefill_tracking(
...@@ -503,49 +622,331 @@ mod tests { ...@@ -503,49 +622,331 @@ mod tests {
16, 16,
0, 0,
false, false,
decay_now,
); );
assert_eq!(blocks, 4); assert_eq!(blocks, 4);
assert_eq!(tokens, 0); assert_eq!(tokens, 0);
} }
#[test]
fn test_prefill_decay_only_applies_to_oldest_request() {
let mut seq_manager = ActiveSequences::new(4);
let epoch = Instant::now();
seq_manager.add_request_with_prefill_tracking(
"r1".to_string(),
Some(vec![1]),
100,
0,
None,
true,
Some(prefill_hint(100, 10)),
epoch,
);
seq_manager.add_request_with_prefill_tracking(
"r2".to_string(),
Some(vec![2]),
60,
0,
None,
true,
Some(prefill_hint(60, 6)),
epoch + Duration::from_secs(2),
);
assert_eq!(
seq_manager.active_tokens(epoch + Duration::from_secs(2)),
140
);
let decayed = seq_manager.active_tokens(epoch + Duration::from_secs(5));
assert_eq!(decayed, 110);
assert!(decayed <= 160);
assert!(decayed >= 60);
}
#[test]
fn test_prefill_decay_hands_off_to_next_oldest_request() {
let mut seq_manager = ActiveSequences::new(4);
let epoch = Instant::now();
seq_manager.add_request_with_prefill_tracking(
"r1".to_string(),
Some(vec![1]),
100,
0,
None,
true,
Some(prefill_hint(100, 10)),
epoch,
);
seq_manager.add_request_with_prefill_tracking(
"r2".to_string(),
Some(vec![2]),
40,
0,
None,
true,
Some(prefill_hint(40, 8)),
epoch,
);
assert_eq!(
seq_manager.active_tokens(epoch + Duration::from_secs(3)),
110
);
seq_manager.mark_prefill_completed(&"r1".to_string(), epoch + Duration::from_secs(3));
assert_eq!(
seq_manager.active_tokens(epoch + Duration::from_secs(3)),
40
);
assert_eq!(
seq_manager.prefill.prefill_order,
VecDeque::from(vec!["r2".to_string()])
);
assert!(
seq_manager
.prefill
.anchored_prefill
.as_ref()
.is_some_and(|(request_id, _)| request_id == "r2")
);
assert_eq!(
seq_manager.active_tokens(epoch + Duration::from_secs(5)),
30
);
}
#[test]
fn test_prefill_decay_resets_when_request_becomes_oldest() {
let mut seq_manager = ActiveSequences::new(4);
let epoch = Instant::now();
seq_manager.add_request_with_prefill_tracking(
"r1".to_string(),
Some(vec![1]),
100,
0,
None,
true,
Some(prefill_hint(100, 10)),
epoch,
);
seq_manager.add_request_with_prefill_tracking(
"r2".to_string(),
Some(vec![2]),
80,
0,
None,
true,
Some(prefill_hint(80, 8)),
epoch + Duration::from_secs(4),
);
assert_eq!(
seq_manager.active_tokens(epoch + Duration::from_secs(8)),
100
);
seq_manager.mark_prefill_completed(&"r1".to_string(), epoch + Duration::from_secs(8));
assert_eq!(
seq_manager.active_tokens(epoch + Duration::from_secs(8)),
80
);
assert_eq!(
seq_manager.active_tokens(epoch + Duration::from_secs(10)),
60
);
}
#[test]
fn test_prefill_front_removal_reanchors_queue_front() {
let mut seq_manager = ActiveSequences::new(4);
let epoch = Instant::now();
seq_manager.add_request_with_prefill_tracking(
"r1".to_string(),
Some(vec![1]),
30,
0,
None,
true,
Some(prefill_hint(30, 6)),
epoch,
);
seq_manager.add_request_with_prefill_tracking(
"r2".to_string(),
Some(vec![2]),
20,
0,
None,
true,
Some(prefill_hint(20, 4)),
epoch,
);
seq_manager.mark_prefill_completed(&"r1".to_string(), epoch + Duration::from_secs(2));
assert!(
seq_manager
.prefill
.anchored_prefill
.as_ref()
.is_some_and(|(request_id, _)| request_id == "r2")
);
assert_eq!(
seq_manager.active_tokens(epoch + Duration::from_secs(2)),
20
);
}
#[test]
fn test_prefill_queue_and_sum_invariants_survive_idempotent_cleanup() {
let mut seq_manager = ActiveSequences::new(4);
let decay_now = Instant::now();
seq_manager.add_request_with_prefill_tracking(
"r1".to_string(),
Some(vec![1]),
50,
0,
None,
true,
Some(prefill_hint(50, 10)),
decay_now,
);
seq_manager.add_request_with_prefill_tracking(
"r2".to_string(),
Some(vec![2]),
30,
0,
None,
true,
Some(prefill_hint(30, 10)),
decay_now,
);
assert_eq!(seq_manager.prefill.prefill_full_tokens_sum, 80);
assert_eq!(
seq_manager.prefill.prefill_order,
VecDeque::from(vec!["r1".to_string(), "r2".to_string()])
);
seq_manager.mark_prefill_completed(&"r1".to_string(), decay_now);
seq_manager.mark_prefill_completed(&"r1".to_string(), decay_now);
assert_eq!(seq_manager.prefill.prefill_full_tokens_sum, 30);
assert_eq!(
seq_manager.prefill.prefill_order,
VecDeque::from(vec!["r2".to_string()])
);
seq_manager.free(&"r1".to_string(), decay_now);
seq_manager.free(&"r1".to_string(), decay_now);
assert_eq!(seq_manager.prefill.prefill_full_tokens_sum, 30);
assert_eq!(
seq_manager.prefill.prefill_order,
VecDeque::from(vec!["r2".to_string()])
);
seq_manager.free(&"r2".to_string(), decay_now);
assert_eq!(seq_manager.prefill.prefill_full_tokens_sum, 0);
assert!(seq_manager.prefill.prefill_order.is_empty());
assert!(seq_manager.requests.is_empty());
}
#[tokio::test(start_paused = true)] #[tokio::test(start_paused = true)]
async fn test_force_expiry() { async fn test_force_expiry() {
let block_size = 4; let block_size = 4;
let mut seq_manager = ActiveSequences::new(block_size); let mut seq_manager = ActiveSequences::new(block_size);
// Add two requests at time 0 (paused clock) seq_manager.add_request(
seq_manager.add_request("r1".to_string(), Some(vec![1, 2]), 8, 0, None); "r1".to_string(),
seq_manager.add_request("r2".to_string(), Some(vec![3, 4]), 8, 0, None); Some(vec![1, 2]),
8,
0,
None,
Instant::now(),
);
seq_manager.add_request(
"r2".to_string(),
Some(vec![3, 4]),
8,
0,
None,
Instant::now(),
);
assert_eq!(seq_manager.active_blocks(), 4); assert_eq!(seq_manager.active_blocks(), 4);
// Advance 20s: check interval (CHECK_EXPIRY_FREQUENCY = 30s) not reached,
// force_expiry returns without running the check.
tokio::time::advance(Duration::from_secs(20)).await; tokio::time::advance(Duration::from_secs(20)).await;
let expired = seq_manager.force_expiry(); let expired = seq_manager.force_expiry();
assert!(expired.is_empty(), "no check before CHECK_EXPIRY_FREQUENCY"); assert!(expired.is_empty(), "no check before CHECK_EXPIRY_FREQUENCY");
assert_eq!(seq_manager.active_blocks(), 4); assert_eq!(seq_manager.active_blocks(), 4);
// Advance to 31s: first time we pass the check interval. Requests are 31s old,
// still under EXPIRY_DURATION (300s), so none are expired.
tokio::time::advance(Duration::from_secs(11)).await; tokio::time::advance(Duration::from_secs(11)).await;
let expired = seq_manager.force_expiry(); let expired = seq_manager.force_expiry();
assert!(expired.is_empty(), "requests not old enough to expire"); assert!(expired.is_empty(), "requests not old enough to expire");
assert_eq!(seq_manager.active_blocks(), 4); assert_eq!(seq_manager.active_blocks(), 4);
seq_manager.assert_consistent();
// Advance to 301s: requests are now older than EXPIRY_DURATION.
// force_expiry runs and expires r1, r2.
tokio::time::advance(Duration::from_secs(270)).await; tokio::time::advance(Duration::from_secs(270)).await;
let expired = seq_manager.force_expiry(); let expired = seq_manager.force_expiry();
assert_eq!(expired, HashSet::from(["r1".to_string(), "r2".to_string()])); assert_eq!(expired, HashSet::from(["r1".to_string(), "r2".to_string()]));
assert_eq!(seq_manager.active_blocks(), 0); assert_eq!(seq_manager.active_blocks(), 0);
assert_eq!(seq_manager.active_tokens(), 0); assert_eq!(seq_manager.active_tokens(Instant::now()), 0);
seq_manager.assert_consistent();
// add_request calls force_expiry internally. Add r3; no old requests remain,
// so expired set is empty and only r3 is active.
tokio::time::advance(Duration::from_secs(31)).await; tokio::time::advance(Duration::from_secs(31)).await;
let expired = seq_manager.add_request("r3".to_string(), Some(vec![5]), 4, 0, None); let expired =
seq_manager.add_request("r3".to_string(), Some(vec![5]), 4, 0, None, Instant::now());
assert!(expired.is_empty()); assert!(expired.is_empty());
assert_eq!(seq_manager.active_blocks(), 1); assert_eq!(seq_manager.active_blocks(), 1);
assert_eq!(seq_manager.active_tokens(), 4); assert_eq!(seq_manager.active_tokens(Instant::now()), 4);
seq_manager.assert_consistent();
}
#[tokio::test(start_paused = true)]
async fn test_force_expiry_reanchors_new_oldest_request() {
let mut seq_manager = ActiveSequences::new(4);
let first_decay_now = Instant::now();
seq_manager.add_request_with_prefill_tracking(
"r1".to_string(),
Some(vec![1]),
40,
0,
None,
true,
Some(prefill_hint(40, 100)),
first_decay_now,
);
tokio::time::advance(Duration::from_secs(250)).await;
seq_manager.add_request_with_prefill_tracking(
"r2".to_string(),
Some(vec![2]),
30,
0,
None,
true,
Some(prefill_hint(30, 100)),
Instant::now(),
);
tokio::time::advance(Duration::from_secs(60)).await;
let expired = seq_manager.force_expiry();
assert_eq!(expired, HashSet::from(["r1".to_string()]));
assert_eq!(seq_manager.active_tokens(Instant::now()), 30);
assert!(
seq_manager
.prefill
.anchored_prefill
.as_ref()
.is_some_and(|(request_id, _)| request_id == "r2")
);
tokio::time::advance(Duration::from_secs(20)).await;
assert_eq!(seq_manager.active_tokens(Instant::now()), 24);
} }
} }
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
use std::{collections::HashSet, sync::Arc}; use std::{collections::HashSet, sync::Arc};
use dashmap::{DashMap, mapref::entry::Entry}; use dashmap::{DashMap, mapref::entry::Entry};
use dynamo_kv_router::{config::KvRouterConfig, protocols::WorkerId}; use dynamo_kv_router::{PrefillLoadEstimator, config::KvRouterConfig, protocols::WorkerId};
use tokio::sync::oneshot; use tokio::sync::oneshot;
use super::worker_monitor::LoadThresholdConfig; use super::worker_monitor::LoadThresholdConfig;
...@@ -568,6 +568,7 @@ impl ModelManager { ...@@ -568,6 +568,7 @@ impl ModelManager {
endpoint: &Endpoint, endpoint: &Endpoint,
kv_cache_block_size: u32, kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
worker_type: &'static str, worker_type: &'static str,
model_name: Option<String>, model_name: Option<String>,
is_eagle: bool, is_eagle: bool,
...@@ -604,6 +605,7 @@ impl ModelManager { ...@@ -604,6 +605,7 @@ impl ModelManager {
kv_cache_block_size, kv_cache_block_size,
selector, selector,
kv_router_config, kv_router_config,
prefill_load_estimator,
worker_type, worker_type,
model_name, model_name,
is_eagle, is_eagle,
......
...@@ -7,6 +7,7 @@ use tokio::sync::mpsc::Sender; ...@@ -7,6 +7,7 @@ use tokio::sync::mpsc::Sender;
use anyhow::Context as _; use anyhow::Context as _;
use dashmap::DashSet; use dashmap::DashSet;
use dynamo_kv_router::PrefillLoadEstimator;
use futures::StreamExt; use futures::StreamExt;
use dynamo_runtime::{ use dynamo_runtime::{
...@@ -74,6 +75,7 @@ pub struct ModelWatcher { ...@@ -74,6 +75,7 @@ pub struct ModelWatcher {
notify_on_model: Notify, notify_on_model: Notify,
model_update_tx: Option<Sender<ModelUpdate>>, model_update_tx: Option<Sender<ModelUpdate>>,
chat_engine_factory: Option<ChatEngineFactoryCallback>, chat_engine_factory: Option<ChatEngineFactoryCallback>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
metrics: Arc<Metrics>, metrics: Arc<Metrics>,
/// Guards against concurrent pipeline construction for the same (model, namespace). /// Guards against concurrent pipeline construction for the same (model, namespace).
registering_worker_sets: DashSet<String>, registering_worker_sets: DashSet<String>,
...@@ -118,6 +120,7 @@ impl ModelWatcher { ...@@ -118,6 +120,7 @@ impl ModelWatcher {
router_config: RouterConfig, router_config: RouterConfig,
migration_limit: u32, migration_limit: u32,
chat_engine_factory: Option<ChatEngineFactoryCallback>, chat_engine_factory: Option<ChatEngineFactoryCallback>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
metrics: Arc<Metrics>, metrics: Arc<Metrics>,
) -> ModelWatcher { ) -> ModelWatcher {
Self { Self {
...@@ -128,6 +131,7 @@ impl ModelWatcher { ...@@ -128,6 +131,7 @@ impl ModelWatcher {
notify_on_model: Notify::new(), notify_on_model: Notify::new(),
model_update_tx: None, model_update_tx: None,
chat_engine_factory, chat_engine_factory,
prefill_load_estimator,
metrics, metrics,
registering_worker_sets: DashSet::new(), registering_worker_sets: DashSet::new(),
} }
...@@ -465,6 +469,7 @@ impl ModelWatcher { ...@@ -465,6 +469,7 @@ impl ModelWatcher {
&endpoint, &endpoint,
card.kv_cache_block_size, card.kv_cache_block_size,
Some(self.router_config.kv_router_config.clone()), Some(self.router_config.kv_router_config.clone()),
self.prefill_load_estimator.clone(),
WORKER_TYPE_DECODE, // This is the decode router WORKER_TYPE_DECODE, // This is the decode router
Some(card.display_name.clone()), Some(card.display_name.clone()),
card.runtime_config.enable_eagle, card.runtime_config.enable_eagle,
...@@ -506,6 +511,7 @@ impl ModelWatcher { ...@@ -506,6 +511,7 @@ impl ModelWatcher {
self.router_config.router_mode, self.router_config.router_mode,
card.kv_cache_block_size, card.kv_cache_block_size,
Some(prefill_config), Some(prefill_config),
self.prefill_load_estimator.clone(),
self.router_config.enforce_disagg, self.router_config.enforce_disagg,
model_name.clone(), model_name.clone(),
namespace.clone(), namespace.clone(),
......
...@@ -12,7 +12,7 @@ use std::future::Future; ...@@ -12,7 +12,7 @@ use std::future::Future;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use dynamo_kv_router::config::KvRouterConfig; use dynamo_kv_router::{PrefillLoadEstimator, config::KvRouterConfig};
use dynamo_runtime::{discovery::ModelCardInstanceId, pipeline::RouterMode}; use dynamo_runtime::{discovery::ModelCardInstanceId, pipeline::RouterMode};
use crate::{ use crate::{
...@@ -68,6 +68,7 @@ pub enum EngineConfig { ...@@ -68,6 +68,7 @@ pub enum EngineConfig {
Dynamic { Dynamic {
model: Box<LocalModel>, model: Box<LocalModel>,
chat_engine_factory: Option<ChatEngineFactoryCallback>, chat_engine_factory: Option<ChatEngineFactoryCallback>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
}, },
/// A Text engine receives text, does it's own tokenization and prompt formatting. /// A Text engine receives text, does it's own tokenization and prompt formatting.
......
...@@ -94,7 +94,9 @@ pub async fn prepare_engine( ...@@ -94,7 +94,9 @@ pub async fn prepare_engine(
) -> anyhow::Result<PreparedEngine> { ) -> anyhow::Result<PreparedEngine> {
match engine_config { match engine_config {
EngineConfig::Dynamic { EngineConfig::Dynamic {
model: local_model, .. model: local_model,
prefill_load_estimator,
..
} => { } => {
let model_manager = Arc::new(ModelManager::new()); let model_manager = Arc::new(ModelManager::new());
// Create metrics for migration tracking (not exposed via /metrics in Dynamic engine mode) // Create metrics for migration tracking (not exposed via /metrics in Dynamic engine mode)
...@@ -105,6 +107,7 @@ pub async fn prepare_engine( ...@@ -105,6 +107,7 @@ pub async fn prepare_engine(
RouterConfig::default(), RouterConfig::default(),
local_model.migration_limit(), local_model.migration_limit(),
None, None,
prefill_load_estimator,
metrics, metrics,
)); ));
let discovery = distributed_runtime.discovery(); let discovery = distributed_runtime.discovery();
......
...@@ -33,7 +33,11 @@ pub async fn run( ...@@ -33,7 +33,11 @@ pub async fn run(
} }
let grpc_service = match engine_config { let grpc_service = match engine_config {
EngineConfig::Dynamic { ref model, .. } => { EngineConfig::Dynamic {
ref model,
ref prefill_load_estimator,
..
} => {
let grpc_service = grpc_service_builder.build()?; let grpc_service = grpc_service_builder.build()?;
let router_config = model.router_config(); let router_config = model.router_config();
let migration_limit = model.migration_limit(); let migration_limit = model.migration_limit();
...@@ -48,6 +52,7 @@ pub async fn run( ...@@ -48,6 +52,7 @@ pub async fn run(
router_config.clone(), router_config.clone(),
migration_limit, migration_limit,
namespace_filter, namespace_filter,
prefill_load_estimator.clone(),
) )
.await?; .await?;
grpc_service grpc_service
...@@ -111,6 +116,7 @@ async fn run_watcher( ...@@ -111,6 +116,7 @@ async fn run_watcher(
router_config: RouterConfig, router_config: RouterConfig,
migration_limit: u32, migration_limit: u32,
namespace_filter: NamespaceFilter, namespace_filter: NamespaceFilter,
prefill_load_estimator: Option<Arc<dyn dynamo_kv_router::PrefillLoadEstimator>>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
// Create metrics for migration tracking (not exposed via /metrics in gRPC mode) // Create metrics for migration tracking (not exposed via /metrics in gRPC mode)
let metrics = Arc::new(Metrics::new()); let metrics = Arc::new(Metrics::new());
...@@ -120,6 +126,7 @@ async fn run_watcher( ...@@ -120,6 +126,7 @@ async fn run_watcher(
router_config, router_config,
migration_limit, migration_limit,
None, None,
prefill_load_estimator,
metrics, metrics,
); );
tracing::debug!("Waiting for remote model"); tracing::debug!("Waiting for remote model");
......
...@@ -67,6 +67,7 @@ pub async fn run( ...@@ -67,6 +67,7 @@ pub async fn run(
EngineConfig::Dynamic { EngineConfig::Dynamic {
ref model, ref model,
ref chat_engine_factory, ref chat_engine_factory,
ref prefill_load_estimator,
} => { } => {
// Pass the discovery client so the /health endpoint can query active instances // Pass the discovery client so the /health endpoint can query active instances
http_service_builder = http_service_builder =
...@@ -90,6 +91,7 @@ pub async fn run( ...@@ -90,6 +91,7 @@ pub async fn run(
Arc::new(http_service.clone()), Arc::new(http_service.clone()),
http_service.state().metrics_clone(), http_service.state().metrics_clone(),
chat_engine_factory.clone(), chat_engine_factory.clone(),
prefill_load_estimator.clone(),
) )
.await?; .await?;
http_service http_service
...@@ -167,6 +169,7 @@ async fn run_watcher( ...@@ -167,6 +169,7 @@ async fn run_watcher(
http_service: Arc<HttpService>, http_service: Arc<HttpService>,
metrics: Arc<crate::http::service::metrics::Metrics>, metrics: Arc<crate::http::service::metrics::Metrics>,
chat_engine_factory: Option<ChatEngineFactoryCallback>, chat_engine_factory: Option<ChatEngineFactoryCallback>,
prefill_load_estimator: Option<Arc<dyn dynamo_kv_router::PrefillLoadEstimator>>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let mut watch_obj = ModelWatcher::new( let mut watch_obj = ModelWatcher::new(
runtime.clone(), runtime.clone(),
...@@ -174,6 +177,7 @@ async fn run_watcher( ...@@ -174,6 +177,7 @@ async fn run_watcher(
router_config, router_config,
migration_limit, migration_limit,
chat_engine_factory, chat_engine_factory,
prefill_load_estimator,
metrics.clone(), metrics.clone(),
); );
tracing::debug!("Waiting for remote model"); tracing::debug!("Waiting for remote model");
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use anyhow::Result; use anyhow::Result;
use dynamo_kv_router::{ use dynamo_kv_router::{
PrefillLoadEstimator,
config::{KvRouterConfig, RouterConfigOverride, min_initial_workers_from_env}, config::{KvRouterConfig, RouterConfigOverride, min_initial_workers_from_env},
indexer::KvRouterError, indexer::KvRouterError,
protocols::KV_EVENT_SUBJECT, protocols::KV_EVENT_SUBJECT,
protocols::{ protocols::{
BlockExtraInfo, BlockHashOptions, DpRank, RouterEvent, RouterRequest, RouterResponse, BlockExtraInfo, BlockHashOptions, DpRank, PrefillLoadHint, RouterEvent, RouterRequest,
TokensWithHashes, WorkerId, WorkerWithDpRank, compute_block_hash_for_seq, RouterResponse, TokensWithHashes, WorkerId, WorkerWithDpRank, compute_block_hash_for_seq,
}, },
}; };
use dynamo_runtime::{ use dynamo_runtime::{
...@@ -111,6 +113,7 @@ where ...@@ -111,6 +113,7 @@ where
scheduler: KvScheduler<Sel>, scheduler: KvScheduler<Sel>,
block_size: u32, block_size: u32,
kv_router_config: KvRouterConfig, kv_router_config: KvRouterConfig,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
cancellation_token: tokio_util::sync::CancellationToken, cancellation_token: tokio_util::sync::CancellationToken,
client: Client, client: Client,
is_eagle: bool, is_eagle: bool,
...@@ -128,6 +131,7 @@ where ...@@ -128,6 +131,7 @@ where
block_size: u32, block_size: u32,
selector: Sel, selector: Sel,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
worker_type: &'static str, worker_type: &'static str,
model_name: Option<String>, model_name: Option<String>,
is_eagle: bool, is_eagle: bool,
...@@ -159,6 +163,7 @@ where ...@@ -159,6 +163,7 @@ where
workers_with_configs.clone(), workers_with_configs.clone(),
selector, selector,
&kv_router_config, &kv_router_config,
prefill_load_estimator.clone(),
worker_type, worker_type,
) )
.await?; .await?;
...@@ -184,6 +189,7 @@ where ...@@ -184,6 +189,7 @@ where
scheduler, scheduler,
block_size, block_size,
kv_router_config, kv_router_config,
prefill_load_estimator,
cancellation_token, cancellation_token,
client, client,
is_eagle, is_eagle,
...@@ -345,6 +351,8 @@ where ...@@ -345,6 +351,8 @@ where
let track_prefill_tokens = self let track_prefill_tokens = self
.kv_router_config .kv_router_config
.track_prefill_tokens(router_config_override); .track_prefill_tokens(router_config_override);
let prefill_load_hint =
self.prefill_load_hint_for(isl_tokens, overlap_blocks, track_prefill_tokens);
if let Err(e) = self if let Err(e) = self
.scheduler .scheduler
...@@ -355,6 +363,7 @@ where ...@@ -355,6 +363,7 @@ where
overlap: overlap_blocks, overlap: overlap_blocks,
track_prefill_tokens, track_prefill_tokens,
expected_output_tokens, expected_output_tokens,
prefill_load_hint,
worker, worker,
lora_name, lora_name,
}) })
...@@ -377,6 +386,42 @@ where ...@@ -377,6 +386,42 @@ where
self.scheduler.pending_count() self.scheduler.pending_count()
} }
fn prefill_load_hint_for(
&self,
isl_tokens: usize,
overlap_blocks: u32,
track_prefill_tokens: bool,
) -> Option<PrefillLoadHint> {
if !track_prefill_tokens {
return None;
}
let prefix = (overlap_blocks as usize) * (self.block_size as usize);
let effective_isl = isl_tokens.saturating_sub(prefix);
if effective_isl == 0 {
return None;
}
let Some(estimator) = &self.prefill_load_estimator else {
return None;
};
match estimator.predict_prefill_duration(1, effective_isl, prefix) {
Ok(expected_prefill_duration) => Some(PrefillLoadHint {
initial_effective_prefill_tokens: effective_isl,
expected_prefill_duration: Some(expected_prefill_duration),
}),
Err(error) => {
tracing::warn!(
effective_isl,
prefix,
"failed to predict prefill duration for direct add_request path: {error}"
);
None
}
}
}
/// Get the worker type for this router ("prefill" or "decode"). /// Get the worker type for this router ("prefill" or "decode").
/// Used for Prometheus metric labeling. /// Used for Prometheus metric labeling.
pub fn worker_type(&self) -> &'static str { pub fn worker_type(&self) -> &'static str {
......
...@@ -6,7 +6,7 @@ use std::sync::Arc; ...@@ -6,7 +6,7 @@ use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use dynamo_kv_router::config::KvRouterConfig; use dynamo_kv_router::{PrefillLoadEstimator, config::KvRouterConfig};
use dynamo_runtime::{ use dynamo_runtime::{
component::{Client, Endpoint}, component::{Client, Endpoint},
pipeline::{PushRouter, RouterMode}, pipeline::{PushRouter, RouterMode},
...@@ -37,6 +37,7 @@ impl PrefillRouter { ...@@ -37,6 +37,7 @@ impl PrefillRouter {
cancel_token: tokio_util::sync::CancellationToken::new(), cancel_token: tokio_util::sync::CancellationToken::new(),
router_mode, router_mode,
enforce_disagg, enforce_disagg,
prefill_load_estimator: None,
model_name: String::new(), // Not used for disabled router model_name: String::new(), // Not used for disabled router
namespace: String::new(), // Not used for disabled router namespace: String::new(), // Not used for disabled router
is_eagle: false, is_eagle: false,
...@@ -50,6 +51,7 @@ impl PrefillRouter { ...@@ -50,6 +51,7 @@ impl PrefillRouter {
router_mode: RouterMode, router_mode: RouterMode,
kv_cache_block_size: u32, kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
enforce_disagg: bool, enforce_disagg: bool,
model_name: String, model_name: String,
namespace: String, namespace: String,
...@@ -65,6 +67,7 @@ impl PrefillRouter { ...@@ -65,6 +67,7 @@ impl PrefillRouter {
cancel_token: cancel_token.clone(), cancel_token: cancel_token.clone(),
router_mode, router_mode,
enforce_disagg, enforce_disagg,
prefill_load_estimator,
model_name, model_name,
namespace, namespace,
is_eagle, is_eagle,
...@@ -85,6 +88,7 @@ impl PrefillRouter { ...@@ -85,6 +88,7 @@ impl PrefillRouter {
model_manager, model_manager,
kv_cache_block_size, kv_cache_block_size,
kv_router_config, kv_router_config,
router_clone.prefill_load_estimator.clone(),
).await { ).await {
tracing::error!(error = %e, "Failed to activate prefill router"); tracing::error!(error = %e, "Failed to activate prefill router");
} }
...@@ -105,6 +109,7 @@ impl PrefillRouter { ...@@ -105,6 +109,7 @@ impl PrefillRouter {
model_manager: Arc<ModelManager>, model_manager: Arc<ModelManager>,
kv_cache_block_size: u32, kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
) -> Result<()> { ) -> Result<()> {
tracing::info!( tracing::info!(
router_mode = ?self.router_mode, router_mode = ?self.router_mode,
...@@ -127,6 +132,7 @@ impl PrefillRouter { ...@@ -127,6 +132,7 @@ impl PrefillRouter {
&endpoint, &endpoint,
kv_cache_block_size, kv_cache_block_size,
kv_router_config, kv_router_config,
prefill_load_estimator,
WORKER_TYPE_PREFILL, WORKER_TYPE_PREFILL,
Some(self.model_name.clone()), Some(self.model_name.clone()),
self.is_eagle, self.is_eagle,
......
...@@ -6,6 +6,7 @@ use std::sync::{Arc, OnceLock}; ...@@ -6,6 +6,7 @@ use std::sync::{Arc, OnceLock};
use anyhow::Result; use anyhow::Result;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use dynamo_kv_router::PrefillLoadEstimator;
use dynamo_runtime::{ use dynamo_runtime::{
pipeline::{ pipeline::{
AsyncEngineContextProvider, ManyOut, Operator, RouterMode, ServerStreamingEngine, SingleIn, AsyncEngineContextProvider, ManyOut, Operator, RouterMode, ServerStreamingEngine, SingleIn,
...@@ -47,6 +48,7 @@ pub struct PrefillRouter { ...@@ -47,6 +48,7 @@ pub struct PrefillRouter {
cancel_token: CancellationToken, cancel_token: CancellationToken,
router_mode: RouterMode, router_mode: RouterMode,
enforce_disagg: bool, enforce_disagg: bool,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
/// Model name used to look up the worker monitor for prefill client registration /// Model name used to look up the worker monitor for prefill client registration
model_name: String, model_name: String,
/// Namespace used to look up the correct WorkerSet's worker monitor /// Namespace used to look up the correct WorkerSet's worker monitor
......
...@@ -16,6 +16,7 @@ use crate::discovery::RuntimeConfigWatch; ...@@ -16,6 +16,7 @@ use crate::discovery::RuntimeConfigWatch;
use crate::local_model::runtime_config::ModelRuntimeConfig; use crate::local_model::runtime_config::ModelRuntimeConfig;
use anyhow::Result; use anyhow::Result;
use dynamo_kv_router::{ use dynamo_kv_router::{
PrefillLoadEstimator,
config::{KvRouterConfig, RouterConfigOverride}, config::{KvRouterConfig, RouterConfigOverride},
protocols::{OverlapScores, WorkerId}, protocols::{OverlapScores, WorkerId},
}; };
...@@ -45,6 +46,7 @@ where ...@@ -45,6 +46,7 @@ where
workers_with_configs: RuntimeConfigWatch, workers_with_configs: RuntimeConfigWatch,
selector: Sel, selector: Sel,
kv_router_config: &KvRouterConfig, kv_router_config: &KvRouterConfig,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
worker_type: &'static str, worker_type: &'static str,
) -> Result<Self, KvSchedulerError> { ) -> Result<Self, KvSchedulerError> {
let initial_workers: HashMap<WorkerId, ModelRuntimeConfig> = let initial_workers: HashMap<WorkerId, ModelRuntimeConfig> =
...@@ -81,6 +83,8 @@ where ...@@ -81,6 +83,8 @@ where
block_size, block_size,
selector, selector,
policy, policy,
prefill_load_estimator,
kv_router_config.router_queue_recheck_interval(),
kv_router_config.router_track_prefill_tokens, kv_router_config.router_track_prefill_tokens,
component.drt().child_token(), component.drt().child_token(),
worker_type, worker_type,
......
...@@ -143,35 +143,58 @@ pub async fn create_multi_worker_sequences( ...@@ -143,35 +143,58 @@ pub async fn create_multi_worker_sequences(
mod tests { mod tests {
use super::*; use super::*;
use dynamo_runtime::{DistributedRuntime, Runtime}; use dynamo_runtime::{DistributedRuntime, Runtime};
use tokio::time::Instant;
#[test] #[test]
fn test_active_sequences_shared_blocks() { fn test_active_sequences_shared_blocks() {
let block_size = 4; let block_size = 4;
let mut seq_manager = ActiveSequences::new(block_size); let mut seq_manager = ActiveSequences::new(block_size);
let decay_now = Instant::now();
seq_manager.add_request("request_1".to_string(), Some(vec![1, 2, 3]), 12, 0, None); seq_manager.add_request(
"request_1".to_string(),
Some(vec![1, 2, 3]),
12,
0,
None,
decay_now,
);
assert_eq!(seq_manager.active_blocks(), 3); assert_eq!(seq_manager.active_blocks(), 3);
assert_eq!(seq_manager.active_tokens(), 12); assert_eq!(seq_manager.active_tokens(decay_now), 12);
seq_manager.add_request("request_2".to_string(), Some(vec![4]), 4, 0, None); seq_manager.add_request(
"request_2".to_string(),
Some(vec![4]),
4,
0,
None,
decay_now,
);
assert_eq!(seq_manager.active_blocks(), 4); assert_eq!(seq_manager.active_blocks(), 4);
assert_eq!(seq_manager.active_tokens(), 16); assert_eq!(seq_manager.active_tokens(decay_now), 16);
seq_manager.add_request("request_3".to_string(), Some(vec![1, 2, 3, 4]), 16, 4, None); seq_manager.add_request(
"request_3".to_string(),
Some(vec![1, 2, 3, 4]),
16,
4,
None,
decay_now,
);
assert_eq!(seq_manager.active_blocks(), 4); assert_eq!(seq_manager.active_blocks(), 4);
assert_eq!(seq_manager.active_tokens(), 16); assert_eq!(seq_manager.active_tokens(decay_now), 16);
seq_manager.free(&"request_2".to_string()); seq_manager.free(&"request_2".to_string(), decay_now);
assert_eq!(seq_manager.active_blocks(), 4); assert_eq!(seq_manager.active_blocks(), 4);
assert_eq!(seq_manager.active_tokens(), 12); assert_eq!(seq_manager.active_tokens(decay_now), 12);
seq_manager.free(&"request_3".to_string()); seq_manager.free(&"request_3".to_string(), decay_now);
assert_eq!(seq_manager.active_blocks(), 3); assert_eq!(seq_manager.active_blocks(), 3);
assert_eq!(seq_manager.active_tokens(), 12); assert_eq!(seq_manager.active_tokens(decay_now), 12);
seq_manager.free(&"request_1".to_string()); seq_manager.free(&"request_1".to_string(), decay_now);
assert_eq!(seq_manager.active_blocks(), 0); assert_eq!(seq_manager.active_blocks(), 0);
assert_eq!(seq_manager.active_tokens(), 0); assert_eq!(seq_manager.active_tokens(decay_now), 0);
} }
#[tokio::test] #[tokio::test]
...@@ -217,43 +240,55 @@ mod tests { ...@@ -217,43 +240,55 @@ mod tests {
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
seq_manager_1.add_request(SequenceRequest { seq_manager_1.add_request(
request_id: "request_0".to_string(), SequenceRequest {
token_sequence: Some(vec![0, 1, 2]), request_id: "request_0".to_string(),
isl: 12, token_sequence: Some(vec![0, 1, 2]),
overlap: 0, isl: 12,
track_prefill_tokens: true, overlap: 0,
expected_output_tokens: None, track_prefill_tokens: true,
worker: WorkerWithDpRank::new(0, 0), expected_output_tokens: None,
lora_name: None, prefill_load_hint: None,
})?; worker: WorkerWithDpRank::new(0, 0),
lora_name: None,
seq_manager_1.add_request(SequenceRequest { },
request_id: "request_1".to_string(), Instant::now(),
token_sequence: Some(vec![3, 4]), )?;
isl: 8,
overlap: 0, seq_manager_1.add_request(
track_prefill_tokens: true, SequenceRequest {
expected_output_tokens: None, request_id: "request_1".to_string(),
worker: WorkerWithDpRank::new(0, 1), token_sequence: Some(vec![3, 4]),
lora_name: None, isl: 8,
})?; overlap: 0,
track_prefill_tokens: true,
seq_manager_2.add_request(SequenceRequest { expected_output_tokens: None,
request_id: "request_2".to_string(), prefill_load_hint: None,
token_sequence: Some(vec![0, 1, 2, 3]), worker: WorkerWithDpRank::new(0, 1),
isl: 16, lora_name: None,
overlap: 0, },
track_prefill_tokens: true, Instant::now(),
expected_output_tokens: None, )?;
worker: WorkerWithDpRank::new(1, 0),
lora_name: None, seq_manager_2.add_request(
})?; SequenceRequest {
request_id: "request_2".to_string(),
token_sequence: Some(vec![0, 1, 2, 3]),
isl: 16,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
prefill_load_hint: None,
worker: WorkerWithDpRank::new(1, 0),
lora_name: None,
},
Instant::now(),
)?;
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await; tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
let blocks_phase1 = seq_manager_1.active_blocks(); let blocks_phase1 = seq_manager_1.active_blocks();
let tokens_phase1 = seq_manager_1.active_tokens(); let tokens_phase1 = seq_manager_1.active_tokens(Instant::now());
let worker_0_dp0 = WorkerWithDpRank::new(0, 0); let worker_0_dp0 = WorkerWithDpRank::new(0, 0);
let worker_0_dp1 = WorkerWithDpRank::new(0, 1); let worker_0_dp1 = WorkerWithDpRank::new(0, 1);
...@@ -284,15 +319,15 @@ mod tests { ...@@ -284,15 +319,15 @@ mod tests {
"Worker 1 dp_rank 0 should have 16 active tokens (from request_2 added by seq_manager_2)" "Worker 1 dp_rank 0 should have 16 active tokens (from request_2 added by seq_manager_2)"
); );
seq_manager_1.free(&"request_2".to_string())?; seq_manager_1.free(&"request_2".to_string(), Instant::now())?;
seq_manager_2.free(&"request_0".to_string())?; seq_manager_2.free(&"request_0".to_string(), Instant::now())?;
seq_manager_2.free(&"request_1".to_string())?; seq_manager_2.free(&"request_1".to_string(), Instant::now())?;
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await; tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
let blocks_phase2 = seq_manager_2.active_blocks(); let blocks_phase2 = seq_manager_2.active_blocks();
let tokens_phase2 = seq_manager_2.active_tokens(); let tokens_phase2 = seq_manager_2.active_tokens(Instant::now());
let all_workers = vec![ let all_workers = vec![
WorkerWithDpRank::new(0, 0), WorkerWithDpRank::new(0, 0),
...@@ -364,42 +399,54 @@ mod tests { ...@@ -364,42 +399,54 @@ mod tests {
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
seq_manager_1.add_request(SequenceRequest { seq_manager_1.add_request(
request_id: "request_0".to_string(), SequenceRequest {
token_sequence: None, request_id: "request_0".to_string(),
isl: 12, token_sequence: None,
overlap: 0, isl: 12,
track_prefill_tokens: true, overlap: 0,
expected_output_tokens: None, track_prefill_tokens: true,
worker: WorkerWithDpRank::from_worker_id(0), expected_output_tokens: None,
lora_name: None, prefill_load_hint: None,
})?; worker: WorkerWithDpRank::from_worker_id(0),
lora_name: None,
seq_manager_1.add_request(SequenceRequest { },
request_id: "request_1".to_string(), Instant::now(),
token_sequence: None, )?;
isl: 8,
overlap: 0, seq_manager_1.add_request(
track_prefill_tokens: true, SequenceRequest {
expected_output_tokens: None, request_id: "request_1".to_string(),
worker: WorkerWithDpRank::from_worker_id(1), token_sequence: None,
lora_name: None, isl: 8,
})?; overlap: 0,
track_prefill_tokens: true,
seq_manager_2.add_request(SequenceRequest { expected_output_tokens: None,
request_id: "request_2".to_string(), prefill_load_hint: None,
token_sequence: None, worker: WorkerWithDpRank::from_worker_id(1),
isl: 16, lora_name: None,
overlap: 0, },
track_prefill_tokens: true, Instant::now(),
expected_output_tokens: None, )?;
worker: WorkerWithDpRank::from_worker_id(2),
lora_name: None, seq_manager_2.add_request(
})?; SequenceRequest {
request_id: "request_2".to_string(),
token_sequence: None,
isl: 16,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
prefill_load_hint: None,
worker: WorkerWithDpRank::from_worker_id(2),
lora_name: None,
},
Instant::now(),
)?;
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await; tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
let tokens_phase1 = seq_manager_1.active_tokens(); let tokens_phase1 = seq_manager_1.active_tokens(Instant::now());
let worker_0 = WorkerWithDpRank::from_worker_id(0); let worker_0 = WorkerWithDpRank::from_worker_id(0);
let worker_1 = WorkerWithDpRank::from_worker_id(1); let worker_1 = WorkerWithDpRank::from_worker_id(1);
...@@ -418,17 +465,17 @@ mod tests { ...@@ -418,17 +465,17 @@ mod tests {
"Worker 2 should have 16 active tokens (from request_2 added by seq_manager_2)" "Worker 2 should have 16 active tokens (from request_2 added by seq_manager_2)"
); );
seq_manager_1.mark_prefill_completed(&"request_2".to_string())?; seq_manager_1.mark_prefill_completed(&"request_2".to_string(), Instant::now())?;
seq_manager_1.free(&"request_2".to_string())?; seq_manager_1.free(&"request_2".to_string(), Instant::now())?;
seq_manager_2.mark_prefill_completed(&"request_0".to_string())?; seq_manager_2.mark_prefill_completed(&"request_0".to_string(), Instant::now())?;
seq_manager_2.mark_prefill_completed(&"request_1".to_string())?; seq_manager_2.mark_prefill_completed(&"request_1".to_string(), Instant::now())?;
seq_manager_2.free(&"request_0".to_string())?; seq_manager_2.free(&"request_0".to_string(), Instant::now())?;
seq_manager_2.free(&"request_1".to_string())?; seq_manager_2.free(&"request_1".to_string(), Instant::now())?;
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await; tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
let tokens_phase2 = seq_manager_2.active_tokens(); let tokens_phase2 = seq_manager_2.active_tokens(Instant::now());
for worker_id in 0..=2 { for worker_id in 0..=2 {
let worker = WorkerWithDpRank::from_worker_id(worker_id); let worker = WorkerWithDpRank::from_worker_id(worker_id);
......
...@@ -344,6 +344,7 @@ mod integration_tests { ...@@ -344,6 +344,7 @@ mod integration_tests {
dynamo_llm::entrypoint::RouterConfig::default(), dynamo_llm::entrypoint::RouterConfig::default(),
0, // migration_limit 0, // migration_limit
None, None,
None,
service.state().metrics_clone(), service.state().metrics_clone(),
); );
// Start watching for model registrations via discovery interface // Start watching for model registrations via discovery interface
......
...@@ -22,6 +22,7 @@ anyhow = { workspace = true } ...@@ -22,6 +22,7 @@ anyhow = { workspace = true }
dashmap = { workspace = true } dashmap = { workspace = true }
derive_builder = { workspace = true } derive_builder = { workspace = true }
derive-getters = { workspace = true } derive-getters = { workspace = true }
indicatif = "0.18"
rand = { workspace = true } rand = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }
......
...@@ -29,8 +29,8 @@ pub trait DecodeInterpolator: Send + Sync { ...@@ -29,8 +29,8 @@ pub trait DecodeInterpolator: Send + Sync {
/// Implementors call the Python AIC SDK via PyO3 GIL. /// Implementors call the Python AIC SDK via PyO3 GIL.
pub trait AicCallback: Send + Sync { pub trait AicCallback: Send + Sync {
/// Predict prefill latency in ms. /// Predict prefill latency in ms.
/// Parameters: (batch_size, isl, prefix, osl) /// Parameters: (batch_size, effective_isl, prefix)
fn predict_prefill(&self, batch_size: usize, isl: usize, prefix: usize, osl: usize) -> f64; fn predict_prefill(&self, batch_size: usize, effective_isl: usize, prefix: usize) -> f64;
/// Predict decode (generation) latency in ms. /// Predict decode (generation) latency in ms.
/// Parameters: (batch_size, isl, osl) /// Parameters: (batch_size, isl, osl)
...@@ -83,7 +83,7 @@ pub enum PerfModel { ...@@ -83,7 +83,7 @@ pub enum PerfModel {
decode_interp: Arc<dyn DecodeInterpolator>, decode_interp: Arc<dyn DecodeInterpolator>,
}, },
/// AI Configurator SDK calls via Python callback. /// AI Configurator SDK calls via Python callback.
/// Passes full parameters (batch_size, isl, prefix, osl) for maximum accuracy. /// Passes the reduced prefill inputs (batch_size, effective_isl, prefix).
Aiconfigurator { callback: Arc<dyn AicCallback> }, Aiconfigurator { callback: Arc<dyn AicCallback> },
} }
...@@ -217,7 +217,7 @@ impl PerfModel { ...@@ -217,7 +217,7 @@ impl PerfModel {
/// Callers always pass all parameters; each variant uses what it needs: /// Callers always pass all parameters; each variant uses what it needs:
/// - Polynomial/Interpolated: uses total new tokens across the batch /// - Polynomial/Interpolated: uses total new tokens across the batch
/// (`batch_size * (isl - prefix)`), modeling GPU processing total tokens in parallel /// (`batch_size * (isl - prefix)`), modeling GPU processing total tokens in parallel
/// - Aiconfigurator: passes (batch_size, isl, prefix) directly to the AIC SDK /// - Aiconfigurator: passes (batch_size, isl - prefix, prefix) to the AIC SDK
pub fn predict_prefill_time(&self, batch_size: usize, isl: usize, prefix: usize) -> f64 { pub fn predict_prefill_time(&self, batch_size: usize, isl: usize, prefix: usize) -> f64 {
let new_tokens_per_req = isl.saturating_sub(prefix); let new_tokens_per_req = isl.saturating_sub(prefix);
let time = match self { let time = match self {
...@@ -231,7 +231,7 @@ impl PerfModel { ...@@ -231,7 +231,7 @@ impl PerfModel {
prefill_interp.interp(tokens).unwrap_or(0.0) prefill_interp.interp(tokens).unwrap_or(0.0)
} }
PerfModel::Aiconfigurator { callback } => { PerfModel::Aiconfigurator { callback } => {
callback.predict_prefill(batch_size, isl, prefix, 1) callback.predict_prefill(batch_size, new_tokens_per_req, prefix)
} }
}; };
time.max(0.0) time.max(0.0)
......
...@@ -82,16 +82,16 @@ pub struct WorkloadDriver { ...@@ -82,16 +82,16 @@ pub struct WorkloadDriver {
} }
impl WorkloadDriver { impl WorkloadDriver {
pub(crate) fn new_trace(trace: Trace) -> Result<Self> { pub(crate) fn new_trace(trace: Trace, engine_block_size: usize) -> Result<Self> {
Self::new(trace, DriverMode::Trace) Self::new(trace, engine_block_size, DriverMode::Trace)
} }
pub(crate) fn new_concurrency(trace: Trace) -> Result<Self> { pub(crate) fn new_concurrency(trace: Trace, engine_block_size: usize) -> Result<Self> {
Self::new(trace, DriverMode::Concurrency) Self::new(trace, engine_block_size, DriverMode::Concurrency)
} }
fn new(trace: Trace, mode: DriverMode) -> Result<Self> { fn new(trace: Trace, engine_block_size: usize, mode: DriverMode) -> Result<Self> {
let block_size = trace.block_size; let trace_block_size = trace.block_size;
let sessions: Vec<SessionRuntime> = trace let sessions: Vec<SessionRuntime> = trace
.sessions .sessions
.into_iter() .into_iter()
...@@ -105,10 +105,11 @@ impl WorkloadDriver { ...@@ -105,10 +105,11 @@ impl WorkloadDriver {
.into_iter() .into_iter()
.map(|turn| -> Result<TurnRuntime> { .map(|turn| -> Result<TurnRuntime> {
Ok(TurnRuntime { Ok(TurnRuntime {
tokens: turn.synthesize_tokens(block_size)?, tokens: turn.synthesize_tokens(trace_block_size)?,
max_output_tokens: turn.max_output_tokens, max_output_tokens: turn.max_output_tokens,
delay_after_previous_ms: turn.delay_after_previous_ms, delay_after_previous_ms: turn.delay_after_previous_ms,
replay_hashes: turn.to_replay_hashes(block_size)?, replay_hashes: turn
.to_replay_hashes(trace_block_size, engine_block_size)?,
}) })
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
...@@ -260,4 +261,11 @@ impl WorkloadDriver { ...@@ -260,4 +261,11 @@ impl WorkloadDriver {
.iter() .iter()
.all(|session| session.next_turn_index >= session.turns.len()) .all(|session| session.next_turn_index >= session.turns.len())
} }
pub fn total_turns(&self) -> usize {
self.sessions
.iter()
.map(|session| session.turns.len())
.sum()
}
} }
...@@ -112,7 +112,7 @@ fn test_turn_replay_hashes_match_full_blocks_only() { ...@@ -112,7 +112,7 @@ fn test_turn_replay_hashes_match_full_blocks_only() {
let request = turn let request = turn
.to_direct_request(4, Uuid::from_u128(1), Some(5.0)) .to_direct_request(4, Uuid::from_u128(1), Some(5.0))
.unwrap(); .unwrap();
let replay_hashes = turn.to_replay_hashes(4).unwrap(); let replay_hashes = turn.to_replay_hashes(4, 4).unwrap();
let expected_local = let expected_local =
compute_block_hash_for_seq(&request.tokens, 4, BlockHashOptions::default()); compute_block_hash_for_seq(&request.tokens, 4, BlockHashOptions::default());
...@@ -124,6 +124,30 @@ fn test_turn_replay_hashes_match_full_blocks_only() { ...@@ -124,6 +124,30 @@ fn test_turn_replay_hashes_match_full_blocks_only() {
assert_eq!(replay_hashes.local_block_hashes.len(), 1); assert_eq!(replay_hashes.local_block_hashes.len(), 1);
} }
#[test]
fn test_turn_replay_hashes_support_distinct_trace_and_engine_block_sizes() {
let turn = TurnTrace {
input_length: 6,
max_output_tokens: 3,
hash_ids: vec![1, 2],
delay_after_previous_ms: 0.0,
};
let request = turn
.to_direct_request(4, Uuid::from_u128(2), Some(5.0))
.unwrap();
let replay_hashes = turn.to_replay_hashes(4, 2).unwrap();
let expected_local =
compute_block_hash_for_seq(&request.tokens, 2, BlockHashOptions::default());
assert_eq!(replay_hashes.local_block_hashes, expected_local);
assert_eq!(
replay_hashes.sequence_hashes,
compute_seq_hash_for_block(&expected_local)
);
assert_eq!(replay_hashes.local_block_hashes.len(), 3);
}
#[test] #[test]
fn test_partition_by_session_round_robin_keeps_sessions_intact() { fn test_partition_by_session_round_robin_keeps_sessions_intact() {
let trace = Trace::synthetic(SyntheticTraceSpec { let trace = Trace::synthetic(SyntheticTraceSpec {
...@@ -406,7 +430,7 @@ fn test_trace_driver_round_trips_turn_semantics_into_ready_requests() { ...@@ -406,7 +430,7 @@ fn test_trace_driver_round_trips_turn_semantics_into_ready_requests() {
first.replay_hashes.as_ref(), first.replay_hashes.as_ref(),
Some( Some(
&expected.sessions[0].turns[0] &expected.sessions[0].turns[0]
.to_replay_hashes(expected.block_size) .to_replay_hashes(expected.block_size, expected.block_size)
.unwrap() .unwrap()
) )
); );
...@@ -443,7 +467,7 @@ fn test_trace_driver_round_trips_turn_semantics_into_ready_requests() { ...@@ -443,7 +467,7 @@ fn test_trace_driver_round_trips_turn_semantics_into_ready_requests() {
second.replay_hashes.as_ref(), second.replay_hashes.as_ref(),
Some( Some(
&expected.sessions[1].turns[0] &expected.sessions[1].turns[0]
.to_replay_hashes(expected.block_size) .to_replay_hashes(expected.block_size, expected.block_size)
.unwrap() .unwrap()
) )
); );
...@@ -470,7 +494,7 @@ fn test_trace_driver_round_trips_turn_semantics_into_ready_requests() { ...@@ -470,7 +494,7 @@ fn test_trace_driver_round_trips_turn_semantics_into_ready_requests() {
third.replay_hashes.as_ref(), third.replay_hashes.as_ref(),
Some( Some(
&expected.sessions[0].turns[1] &expected.sessions[0].turns[1]
.to_replay_hashes(expected.block_size) .to_replay_hashes(expected.block_size, expected.block_size)
.unwrap() .unwrap()
) )
); );
...@@ -488,3 +512,39 @@ fn test_trace_driver_round_trips_turn_semantics_into_ready_requests() { ...@@ -488,3 +512,39 @@ fn test_trace_driver_round_trips_turn_semantics_into_ready_requests() {
expected_third_request.arrival_timestamp_ms expected_third_request.arrival_timestamp_ms
); );
} }
#[test]
fn test_trace_driver_rechunks_trace_blocks_into_engine_blocks() {
let trace = Trace {
block_size: 4,
sessions: vec![SessionTrace {
session_id: "session-a".to_string(),
first_arrival_timestamp_ms: Some(10.0),
turns: vec![TurnTrace {
input_length: 6,
max_output_tokens: 2,
hash_ids: vec![1, 2],
delay_after_previous_ms: 0.0,
}],
}],
};
let mut driver = trace.into_trace_driver_with_block_size(2).unwrap();
let ready = driver.pop_ready(10.0, usize::MAX);
assert_eq!(ready.len(), 1);
let ready = &ready[0];
assert_eq!(ready.request.tokens, vec![1, 1, 1, 1, 2, 2]);
assert_eq!(
ready.replay_hashes.as_ref(),
Some(
&TurnTrace {
input_length: 6,
max_output_tokens: 2,
hash_ids: vec![1, 2],
delay_after_previous_ms: 0.0,
}
.to_replay_hashes(4, 2)
.unwrap()
)
);
}
...@@ -9,7 +9,8 @@ use std::path::Path; ...@@ -9,7 +9,8 @@ use std::path::Path;
use anyhow::{Context, Result, anyhow, bail}; use anyhow::{Context, Result, anyhow, bail};
use dynamo_kv_router::LocalBlockHash; use dynamo_kv_router::LocalBlockHash;
use dynamo_kv_router::protocols::{ use dynamo_kv_router::protocols::{
ExternalSequenceBlockHash, WorkerId, XXH3_SEED, compute_seq_hash_for_block, BlockHashOptions, ExternalSequenceBlockHash, WorkerId, XXH3_SEED, compute_block_hash_for_seq,
compute_seq_hash_for_block,
}; };
use dynamo_tokens::compute_hash_v2; use dynamo_tokens::compute_hash_v2;
use rand::rngs::StdRng; use rand::rngs::StdRng;
...@@ -45,27 +46,27 @@ struct RawMooncakeRecord { ...@@ -45,27 +46,27 @@ struct RawMooncakeRecord {
} }
impl TurnTrace { impl TurnTrace {
fn validate_block_size_and_capacity(&self, block_size: usize) -> Result<()> { fn validate_block_size_and_capacity(&self, trace_block_size: usize) -> Result<()> {
if block_size == 0 { if trace_block_size == 0 {
bail!("block_size must be greater than 0"); bail!("trace_block_size must be greater than 0");
} }
if self.hash_ids.len() * block_size < self.input_length { if self.hash_ids.len() * trace_block_size < self.input_length {
bail!( bail!(
"input_length {} exceeds synthesized capacity {}", "input_length {} exceeds synthesized capacity {}",
self.input_length, self.input_length,
self.hash_ids.len() * block_size self.hash_ids.len() * trace_block_size
); );
} }
Ok(()) Ok(())
} }
pub(crate) fn synthesize_tokens(&self, block_size: usize) -> Result<Vec<u32>> { pub(crate) fn synthesize_tokens(&self, trace_block_size: usize) -> Result<Vec<u32>> {
self.validate_block_size_and_capacity(block_size)?; self.validate_block_size_and_capacity(trace_block_size)?;
let mut tokens = Vec::with_capacity(self.input_length); let mut tokens = Vec::with_capacity(self.input_length);
for &hash_id in &self.hash_ids { for &hash_id in &self.hash_ids {
let token_id = hash_id as u32; let token_id = hash_id as u32;
tokens.extend((0..block_size).map(|_| token_id)); tokens.extend((0..trace_block_size).map(|_| token_id));
if tokens.len() >= self.input_length { if tokens.len() >= self.input_length {
tokens.truncate(self.input_length); tokens.truncate(self.input_length);
break; break;
...@@ -85,11 +86,11 @@ impl TurnTrace { ...@@ -85,11 +86,11 @@ impl TurnTrace {
pub fn to_direct_request( pub fn to_direct_request(
&self, &self,
block_size: usize, trace_block_size: usize,
request_uuid: Uuid, request_uuid: Uuid,
arrival_timestamp_ms: Option<f64>, arrival_timestamp_ms: Option<f64>,
) -> Result<DirectRequest> { ) -> Result<DirectRequest> {
let tokens = self.synthesize_tokens(block_size)?; let tokens = self.synthesize_tokens(trace_block_size)?;
Ok(DirectRequest { Ok(DirectRequest {
tokens, tokens,
max_output_tokens: self.max_output_tokens, max_output_tokens: self.max_output_tokens,
...@@ -99,16 +100,20 @@ impl TurnTrace { ...@@ -99,16 +100,20 @@ impl TurnTrace {
}) })
} }
pub fn to_replay_hashes(&self, block_size: usize) -> Result<ReplayRequestHashes> { pub fn to_replay_hashes(
self.validate_block_size_and_capacity(block_size)?; &self,
trace_block_size: usize,
let num_full_blocks = self.input_length / block_size; engine_block_size: usize,
let local_block_hashes = self ) -> Result<ReplayRequestHashes> {
.hash_ids if engine_block_size == 0 {
.iter() bail!("engine_block_size must be greater than 0");
.take(num_full_blocks) }
.map(|&hash_id| local_block_hash_from_id(hash_id, block_size))
.collect::<Vec<_>>(); let tokens = self.synthesize_tokens(trace_block_size)?;
let engine_block_size =
u32::try_from(engine_block_size).context("engine_block_size does not fit in u32")?;
let local_block_hashes =
compute_block_hash_for_seq(&tokens, engine_block_size, BlockHashOptions::default());
let sequence_hashes = compute_seq_hash_for_block(&local_block_hashes); let sequence_hashes = compute_seq_hash_for_block(&local_block_hashes);
Ok(ReplayRequestHashes { Ok(ReplayRequestHashes {
...@@ -119,9 +124,9 @@ impl TurnTrace { ...@@ -119,9 +124,9 @@ impl TurnTrace {
} }
impl Trace { impl Trace {
pub fn from_mooncake(path: &Path, block_size: usize) -> Result<Self> { pub fn from_mooncake(path: &Path, trace_block_size: usize) -> Result<Self> {
if block_size == 0 { if trace_block_size == 0 {
bail!("block_size must be greater than 0"); bail!("trace_block_size must be greater than 0");
} }
let file = File::open(path) let file = File::open(path)
...@@ -157,7 +162,9 @@ impl Trace { ...@@ -157,7 +162,9 @@ impl Trace {
let hash_ids = raw let hash_ids = raw
.hash_ids .hash_ids
.ok_or_else(|| anyhow!("trace line {} is missing hash_ids", line_idx + 1))?; .ok_or_else(|| anyhow!("trace line {} is missing hash_ids", line_idx + 1))?;
let input_length = raw.input_length.unwrap_or(hash_ids.len() * block_size); let input_length = raw
.input_length
.unwrap_or(hash_ids.len() * trace_block_size);
let output_length = raw let output_length = raw
.output_length .output_length
.ok_or_else(|| anyhow!("trace line {} is missing output_length", line_idx + 1))?; .ok_or_else(|| anyhow!("trace line {} is missing output_length", line_idx + 1))?;
...@@ -214,12 +221,12 @@ impl Trace { ...@@ -214,12 +221,12 @@ impl Trace {
); );
} }
if hash_ids.len() * block_size < input_length { if hash_ids.len() * trace_block_size < input_length {
bail!( bail!(
"trace line {} input_length {} exceeds synthesized capacity {}", "trace line {} input_length {} exceeds synthesized capacity {}",
line_idx + 1, line_idx + 1,
input_length, input_length,
hash_ids.len() * block_size hash_ids.len() * trace_block_size
); );
} }
...@@ -239,7 +246,7 @@ impl Trace { ...@@ -239,7 +246,7 @@ impl Trace {
} }
Ok(Self { Ok(Self {
block_size, block_size: trace_block_size,
sessions, sessions,
}) })
} }
...@@ -598,12 +605,30 @@ impl Trace { ...@@ -598,12 +605,30 @@ impl Trace {
pub fn into_trace_driver(self) -> Result<WorkloadDriver> { pub fn into_trace_driver(self) -> Result<WorkloadDriver> {
self.validate_for_trace_mode()?; self.validate_for_trace_mode()?;
WorkloadDriver::new_trace(self) let engine_block_size = self.block_size;
WorkloadDriver::new_trace(self, engine_block_size)
} }
pub fn into_concurrency_driver(self) -> Result<WorkloadDriver> { pub fn into_concurrency_driver(self) -> Result<WorkloadDriver> {
self.validate_for_concurrency_mode()?; self.validate_for_concurrency_mode()?;
WorkloadDriver::new_concurrency(self) let engine_block_size = self.block_size;
WorkloadDriver::new_concurrency(self, engine_block_size)
}
pub fn into_trace_driver_with_block_size(
self,
engine_block_size: usize,
) -> Result<WorkloadDriver> {
self.validate_for_trace_mode()?;
WorkloadDriver::new_trace(self, engine_block_size)
}
pub fn into_concurrency_driver_with_block_size(
self,
engine_block_size: usize,
) -> Result<WorkloadDriver> {
self.validate_for_concurrency_mode()?;
WorkloadDriver::new_concurrency(self, engine_block_size)
} }
fn validate(&self, allow_missing_first_timestamp: bool) -> Result<()> { fn validate(&self, allow_missing_first_timestamp: bool) -> Result<()> {
......
...@@ -14,7 +14,8 @@ use super::validate::{ ...@@ -14,7 +14,8 @@ use super::validate::{
validate_online_concurrency_args, validate_online_replay_args, validate_online_concurrency_args, validate_online_replay_args,
}; };
use super::{ use super::{
OfflineDisaggReplayConfig, ReplayRouterMode, ReplayWorkerArtifacts, TraceSimulationReport, OfflineDisaggReplayConfig, ReplayPrefillLoadEstimator, ReplayRouterMode, ReplayWorkerArtifacts,
TraceSimulationReport,
}; };
use crate::common::protocols::{DirectRequest, MockEngineArgs}; use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::loadgen::Trace; use crate::loadgen::Trace;
...@@ -30,36 +31,43 @@ pub fn generate_trace_worker_artifacts_offline( ...@@ -30,36 +31,43 @@ pub fn generate_trace_worker_artifacts_offline(
pub fn simulate_trace_file( pub fn simulate_trace_file(
args: MockEngineArgs, args: MockEngineArgs,
trace_path: &Path, trace_path: &Path,
trace_block_size: usize,
num_workers: usize, num_workers: usize,
arrival_speedup_ratio: f64, arrival_speedup_ratio: f64,
) -> Result<TraceSimulationReport> { ) -> Result<TraceSimulationReport> {
simulate_trace_file_with_router_mode( simulate_trace_file_with_router_mode(
args, args,
None, None,
None,
trace_path, trace_path,
trace_block_size,
num_workers, num_workers,
arrival_speedup_ratio, arrival_speedup_ratio,
ReplayRouterMode::RoundRobin, ReplayRouterMode::RoundRobin,
) )
} }
#[allow(clippy::too_many_arguments)]
pub fn simulate_trace_file_with_router_mode( pub fn simulate_trace_file_with_router_mode(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace_path: &Path, trace_path: &Path,
trace_block_size: usize,
num_workers: usize, num_workers: usize,
arrival_speedup_ratio: f64, arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> { ) -> Result<TraceSimulationReport> {
let args = args.normalized()?; let args = args.normalized()?;
validate_offline_replay_args(&args, num_workers, router_mode)?; validate_offline_replay_args(&args, num_workers, router_mode)?;
let trace = Trace::from_mooncake(trace_path, args.block_size)? let trace = Trace::from_mooncake(trace_path, trace_block_size)?
.normalize_session_starts()? .normalize_session_starts()?
.speed_up_timing(arrival_speedup_ratio)?; .speed_up_timing(arrival_speedup_ratio)?;
let started_at = Instant::now(); let started_at = Instant::now();
let report = crate::replay::offline::simulate_trace_workload( let report = crate::replay::offline::simulate_trace_workload(
args, args,
router_config, router_config,
prefill_load_estimator,
trace, trace,
num_workers, num_workers,
router_mode, router_mode,
...@@ -70,19 +78,22 @@ pub fn simulate_trace_file_with_router_mode( ...@@ -70,19 +78,22 @@ pub fn simulate_trace_file_with_router_mode(
pub fn simulate_trace_file_disagg_with_router_mode( pub fn simulate_trace_file_disagg_with_router_mode(
config: OfflineDisaggReplayConfig, config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace_path: &Path, trace_path: &Path,
trace_block_size: usize,
arrival_speedup_ratio: f64, arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> { ) -> Result<TraceSimulationReport> {
let config = config.normalized()?; let config = config.normalized()?;
validate_offline_disagg_replay_args(&config, router_mode)?; validate_offline_disagg_replay_args(&config, router_mode)?;
let trace = Trace::from_mooncake(trace_path, config.prefill_args.block_size)? let trace = Trace::from_mooncake(trace_path, trace_block_size)?
.normalize_session_starts()? .normalize_session_starts()?
.speed_up_timing(arrival_speedup_ratio)?; .speed_up_timing(arrival_speedup_ratio)?;
let started_at = Instant::now(); let started_at = Instant::now();
let report = crate::replay::offline::simulate_trace_workload_disagg( let report = crate::replay::offline::simulate_trace_workload_disagg(
config, config,
router_config, router_config,
prefill_load_estimator,
trace, trace,
router_mode, router_mode,
)?; )?;
...@@ -92,33 +103,46 @@ pub fn simulate_trace_file_disagg_with_router_mode( ...@@ -92,33 +103,46 @@ pub fn simulate_trace_file_disagg_with_router_mode(
pub fn simulate_trace_live_file( pub fn simulate_trace_live_file(
args: MockEngineArgs, args: MockEngineArgs,
trace_path: &Path, trace_path: &Path,
trace_block_size: usize,
num_workers: usize, num_workers: usize,
arrival_speedup_ratio: f64, arrival_speedup_ratio: f64,
) -> Result<TraceSimulationReport> { ) -> Result<TraceSimulationReport> {
simulate_trace_live_file_with_router_mode( simulate_trace_live_file_with_router_mode(
args, args,
None, None,
None,
trace_path, trace_path,
trace_block_size,
num_workers, num_workers,
arrival_speedup_ratio, arrival_speedup_ratio,
ReplayRouterMode::RoundRobin, ReplayRouterMode::RoundRobin,
) )
} }
#[allow(clippy::too_many_arguments)]
pub fn simulate_trace_live_file_with_router_mode( pub fn simulate_trace_live_file_with_router_mode(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace_path: &Path, trace_path: &Path,
trace_block_size: usize,
num_workers: usize, num_workers: usize,
arrival_speedup_ratio: f64, arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> { ) -> Result<TraceSimulationReport> {
let args = args.normalized()?; let args = args.normalized()?;
validate_online_replay_args(&args, num_workers)?; validate_online_replay_args(&args, num_workers)?;
let trace = Trace::from_mooncake(trace_path, args.block_size)? let trace = Trace::from_mooncake(trace_path, trace_block_size)?
.normalize_session_starts()? .normalize_session_starts()?
.speed_up_timing(arrival_speedup_ratio)?; .speed_up_timing(arrival_speedup_ratio)?;
online::simulate_trace_workload(args, router_config, trace, num_workers, router_mode) online::simulate_trace_workload(
args,
router_config,
prefill_load_estimator,
trace,
num_workers,
router_mode,
)
} }
pub fn simulate_trace_requests( pub fn simulate_trace_requests(
...@@ -130,6 +154,7 @@ pub fn simulate_trace_requests( ...@@ -130,6 +154,7 @@ pub fn simulate_trace_requests(
simulate_trace_requests_with_router_mode( simulate_trace_requests_with_router_mode(
args, args,
None, None,
None,
requests, requests,
num_workers, num_workers,
arrival_speedup_ratio, arrival_speedup_ratio,
...@@ -140,6 +165,7 @@ pub fn simulate_trace_requests( ...@@ -140,6 +165,7 @@ pub fn simulate_trace_requests(
pub fn simulate_trace_requests_with_router_mode( pub fn simulate_trace_requests_with_router_mode(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>, requests: Vec<DirectRequest>,
num_workers: usize, num_workers: usize,
arrival_speedup_ratio: f64, arrival_speedup_ratio: f64,
...@@ -155,6 +181,7 @@ pub fn simulate_trace_requests_with_router_mode( ...@@ -155,6 +181,7 @@ pub fn simulate_trace_requests_with_router_mode(
let report = crate::replay::offline::simulate_trace( let report = crate::replay::offline::simulate_trace(
args, args,
router_config, router_config,
prefill_load_estimator,
requests, requests,
num_workers, num_workers,
arrival_speedup_ratio, arrival_speedup_ratio,
...@@ -166,6 +193,7 @@ pub fn simulate_trace_requests_with_router_mode( ...@@ -166,6 +193,7 @@ pub fn simulate_trace_requests_with_router_mode(
pub fn simulate_trace_requests_disagg_with_router_mode( pub fn simulate_trace_requests_disagg_with_router_mode(
config: OfflineDisaggReplayConfig, config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>, requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64, arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
...@@ -180,6 +208,7 @@ pub fn simulate_trace_requests_disagg_with_router_mode( ...@@ -180,6 +208,7 @@ pub fn simulate_trace_requests_disagg_with_router_mode(
let report = crate::replay::offline::simulate_trace_disagg( let report = crate::replay::offline::simulate_trace_disagg(
config, config,
router_config, router_config,
prefill_load_estimator,
requests, requests,
arrival_speedup_ratio, arrival_speedup_ratio,
router_mode, router_mode,
...@@ -196,6 +225,7 @@ pub fn simulate_trace_live_requests( ...@@ -196,6 +225,7 @@ pub fn simulate_trace_live_requests(
simulate_trace_live_requests_with_router_mode( simulate_trace_live_requests_with_router_mode(
args, args,
None, None,
None,
requests, requests,
num_workers, num_workers,
arrival_speedup_ratio, arrival_speedup_ratio,
...@@ -206,6 +236,7 @@ pub fn simulate_trace_live_requests( ...@@ -206,6 +236,7 @@ pub fn simulate_trace_live_requests(
pub fn simulate_trace_live_requests_with_router_mode( pub fn simulate_trace_live_requests_with_router_mode(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>, requests: Vec<DirectRequest>,
num_workers: usize, num_workers: usize,
arrival_speedup_ratio: f64, arrival_speedup_ratio: f64,
...@@ -220,6 +251,7 @@ pub fn simulate_trace_live_requests_with_router_mode( ...@@ -220,6 +251,7 @@ pub fn simulate_trace_live_requests_with_router_mode(
online::simulate_trace_requests( online::simulate_trace_requests(
args, args,
router_config, router_config,
prefill_load_estimator,
requests, requests,
num_workers, num_workers,
arrival_speedup_ratio, arrival_speedup_ratio,
...@@ -230,34 +262,41 @@ pub fn simulate_trace_live_requests_with_router_mode( ...@@ -230,34 +262,41 @@ pub fn simulate_trace_live_requests_with_router_mode(
pub fn simulate_concurrency_file( pub fn simulate_concurrency_file(
args: MockEngineArgs, args: MockEngineArgs,
trace_path: &Path, trace_path: &Path,
trace_block_size: usize,
max_in_flight: usize, max_in_flight: usize,
num_workers: usize, num_workers: usize,
) -> Result<TraceSimulationReport> { ) -> Result<TraceSimulationReport> {
simulate_concurrency_file_with_router_mode( simulate_concurrency_file_with_router_mode(
args, args,
None, None,
None,
trace_path, trace_path,
trace_block_size,
max_in_flight, max_in_flight,
num_workers, num_workers,
ReplayRouterMode::RoundRobin, ReplayRouterMode::RoundRobin,
) )
} }
#[allow(clippy::too_many_arguments)]
pub fn simulate_concurrency_file_with_router_mode( pub fn simulate_concurrency_file_with_router_mode(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace_path: &Path, trace_path: &Path,
trace_block_size: usize,
max_in_flight: usize, max_in_flight: usize,
num_workers: usize, num_workers: usize,
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> { ) -> Result<TraceSimulationReport> {
let args = args.normalized()?; let args = args.normalized()?;
validate_offline_concurrency_args(&args, num_workers, max_in_flight, router_mode)?; validate_offline_concurrency_args(&args, num_workers, max_in_flight, router_mode)?;
let trace = Trace::from_mooncake(trace_path, args.block_size)?; let trace = Trace::from_mooncake(trace_path, trace_block_size)?;
let started_at = Instant::now(); let started_at = Instant::now();
let report = simulate_concurrency_workload_with_router_mode( let report = simulate_concurrency_workload_with_router_mode(
args, args,
router_config, router_config,
prefill_load_estimator,
trace, trace,
max_in_flight, max_in_flight,
num_workers, num_workers,
...@@ -269,17 +308,20 @@ pub fn simulate_concurrency_file_with_router_mode( ...@@ -269,17 +308,20 @@ pub fn simulate_concurrency_file_with_router_mode(
pub fn simulate_concurrency_file_disagg_with_router_mode( pub fn simulate_concurrency_file_disagg_with_router_mode(
config: OfflineDisaggReplayConfig, config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace_path: &Path, trace_path: &Path,
trace_block_size: usize,
max_in_flight: usize, max_in_flight: usize,
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> { ) -> Result<TraceSimulationReport> {
let config = config.normalized()?; let config = config.normalized()?;
validate_offline_disagg_concurrency_args(&config, max_in_flight, router_mode)?; validate_offline_disagg_concurrency_args(&config, max_in_flight, router_mode)?;
let trace = Trace::from_mooncake(trace_path, config.prefill_args.block_size)?; let trace = Trace::from_mooncake(trace_path, trace_block_size)?;
let started_at = Instant::now(); let started_at = Instant::now();
let report = simulate_concurrency_workload_disagg_with_router_mode( let report = simulate_concurrency_workload_disagg_with_router_mode(
config, config,
router_config, router_config,
prefill_load_estimator,
trace, trace,
max_in_flight, max_in_flight,
router_mode, router_mode,
...@@ -290,33 +332,40 @@ pub fn simulate_concurrency_file_disagg_with_router_mode( ...@@ -290,33 +332,40 @@ pub fn simulate_concurrency_file_disagg_with_router_mode(
pub fn simulate_concurrency_live_file( pub fn simulate_concurrency_live_file(
args: MockEngineArgs, args: MockEngineArgs,
trace_path: &Path, trace_path: &Path,
trace_block_size: usize,
max_in_flight: usize, max_in_flight: usize,
num_workers: usize, num_workers: usize,
) -> Result<TraceSimulationReport> { ) -> Result<TraceSimulationReport> {
simulate_concurrency_live_file_with_router_mode( simulate_concurrency_live_file_with_router_mode(
args, args,
None, None,
None,
trace_path, trace_path,
trace_block_size,
max_in_flight, max_in_flight,
num_workers, num_workers,
ReplayRouterMode::RoundRobin, ReplayRouterMode::RoundRobin,
) )
} }
#[allow(clippy::too_many_arguments)]
pub fn simulate_concurrency_live_file_with_router_mode( pub fn simulate_concurrency_live_file_with_router_mode(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace_path: &Path, trace_path: &Path,
trace_block_size: usize,
max_in_flight: usize, max_in_flight: usize,
num_workers: usize, num_workers: usize,
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> { ) -> Result<TraceSimulationReport> {
let args = args.normalized()?; let args = args.normalized()?;
validate_online_concurrency_args(&args, num_workers, max_in_flight)?; validate_online_concurrency_args(&args, num_workers, max_in_flight)?;
let trace = Trace::from_mooncake(trace_path, args.block_size)?; let trace = Trace::from_mooncake(trace_path, trace_block_size)?;
online::simulate_concurrency_workload( online::simulate_concurrency_workload(
args, args,
router_config, router_config,
prefill_load_estimator,
trace, trace,
max_in_flight, max_in_flight,
num_workers, num_workers,
...@@ -333,6 +382,7 @@ pub fn simulate_concurrency_live_requests( ...@@ -333,6 +382,7 @@ pub fn simulate_concurrency_live_requests(
simulate_concurrency_live_requests_with_router_mode( simulate_concurrency_live_requests_with_router_mode(
args, args,
None, None,
None,
requests, requests,
max_in_flight, max_in_flight,
num_workers, num_workers,
...@@ -343,6 +393,7 @@ pub fn simulate_concurrency_live_requests( ...@@ -343,6 +393,7 @@ pub fn simulate_concurrency_live_requests(
pub fn simulate_concurrency_live_requests_with_router_mode( pub fn simulate_concurrency_live_requests_with_router_mode(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>, requests: Vec<DirectRequest>,
max_in_flight: usize, max_in_flight: usize,
num_workers: usize, num_workers: usize,
...@@ -357,6 +408,7 @@ pub fn simulate_concurrency_live_requests_with_router_mode( ...@@ -357,6 +408,7 @@ pub fn simulate_concurrency_live_requests_with_router_mode(
online::simulate_concurrency_requests( online::simulate_concurrency_requests(
args, args,
router_config, router_config,
prefill_load_estimator,
requests, requests,
max_in_flight, max_in_flight,
num_workers, num_workers,
...@@ -373,6 +425,7 @@ pub fn simulate_concurrency_requests( ...@@ -373,6 +425,7 @@ pub fn simulate_concurrency_requests(
simulate_concurrency_requests_with_router_mode( simulate_concurrency_requests_with_router_mode(
args, args,
None, None,
None,
requests, requests,
max_in_flight, max_in_flight,
num_workers, num_workers,
...@@ -383,6 +436,7 @@ pub fn simulate_concurrency_requests( ...@@ -383,6 +436,7 @@ pub fn simulate_concurrency_requests(
pub fn simulate_concurrency_requests_with_router_mode( pub fn simulate_concurrency_requests_with_router_mode(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>, requests: Vec<DirectRequest>,
max_in_flight: usize, max_in_flight: usize,
num_workers: usize, num_workers: usize,
...@@ -397,6 +451,7 @@ pub fn simulate_concurrency_requests_with_router_mode( ...@@ -397,6 +451,7 @@ pub fn simulate_concurrency_requests_with_router_mode(
crate::replay::offline::simulate_concurrency( crate::replay::offline::simulate_concurrency(
args, args,
router_config, router_config,
prefill_load_estimator,
requests, requests,
max_in_flight, max_in_flight,
num_workers, num_workers,
...@@ -407,6 +462,7 @@ pub fn simulate_concurrency_requests_with_router_mode( ...@@ -407,6 +462,7 @@ pub fn simulate_concurrency_requests_with_router_mode(
pub fn simulate_concurrency_requests_disagg_with_router_mode( pub fn simulate_concurrency_requests_disagg_with_router_mode(
config: OfflineDisaggReplayConfig, config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>, requests: Vec<DirectRequest>,
max_in_flight: usize, max_in_flight: usize,
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
...@@ -420,6 +476,7 @@ pub fn simulate_concurrency_requests_disagg_with_router_mode( ...@@ -420,6 +476,7 @@ pub fn simulate_concurrency_requests_disagg_with_router_mode(
crate::replay::offline::simulate_concurrency_disagg( crate::replay::offline::simulate_concurrency_disagg(
config, config,
router_config, router_config,
prefill_load_estimator,
requests, requests,
max_in_flight, max_in_flight,
router_mode, router_mode,
...@@ -434,6 +491,7 @@ pub fn simulate_trace_workload( ...@@ -434,6 +491,7 @@ pub fn simulate_trace_workload(
simulate_trace_workload_with_router_mode( simulate_trace_workload_with_router_mode(
args, args,
None, None,
None,
trace, trace,
num_workers, num_workers,
ReplayRouterMode::RoundRobin, ReplayRouterMode::RoundRobin,
...@@ -443,6 +501,7 @@ pub fn simulate_trace_workload( ...@@ -443,6 +501,7 @@ pub fn simulate_trace_workload(
pub fn simulate_trace_workload_with_router_mode( pub fn simulate_trace_workload_with_router_mode(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace, trace: Trace,
num_workers: usize, num_workers: usize,
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
...@@ -453,6 +512,7 @@ pub fn simulate_trace_workload_with_router_mode( ...@@ -453,6 +512,7 @@ pub fn simulate_trace_workload_with_router_mode(
let report = crate::replay::offline::simulate_trace_workload( let report = crate::replay::offline::simulate_trace_workload(
args, args,
router_config, router_config,
prefill_load_estimator,
trace, trace,
num_workers, num_workers,
router_mode, router_mode,
...@@ -463,6 +523,7 @@ pub fn simulate_trace_workload_with_router_mode( ...@@ -463,6 +523,7 @@ pub fn simulate_trace_workload_with_router_mode(
pub fn simulate_trace_workload_disagg_with_router_mode( pub fn simulate_trace_workload_disagg_with_router_mode(
config: OfflineDisaggReplayConfig, config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace, trace: Trace,
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> { ) -> Result<TraceSimulationReport> {
...@@ -472,6 +533,7 @@ pub fn simulate_trace_workload_disagg_with_router_mode( ...@@ -472,6 +533,7 @@ pub fn simulate_trace_workload_disagg_with_router_mode(
let report = crate::replay::offline::simulate_trace_workload_disagg( let report = crate::replay::offline::simulate_trace_workload_disagg(
config, config,
router_config, router_config,
prefill_load_estimator,
trace, trace,
router_mode, router_mode,
)?; )?;
...@@ -486,6 +548,7 @@ pub fn simulate_trace_live_workload( ...@@ -486,6 +548,7 @@ pub fn simulate_trace_live_workload(
simulate_trace_live_workload_with_router_mode( simulate_trace_live_workload_with_router_mode(
args, args,
None, None,
None,
trace, trace,
num_workers, num_workers,
ReplayRouterMode::RoundRobin, ReplayRouterMode::RoundRobin,
...@@ -495,13 +558,21 @@ pub fn simulate_trace_live_workload( ...@@ -495,13 +558,21 @@ pub fn simulate_trace_live_workload(
pub fn simulate_trace_live_workload_with_router_mode( pub fn simulate_trace_live_workload_with_router_mode(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace, trace: Trace,
num_workers: usize, num_workers: usize,
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> { ) -> Result<TraceSimulationReport> {
let args = args.normalized()?; let args = args.normalized()?;
validate_online_replay_args(&args, num_workers)?; validate_online_replay_args(&args, num_workers)?;
online::simulate_trace_workload(args, router_config, trace, num_workers, router_mode) online::simulate_trace_workload(
args,
router_config,
prefill_load_estimator,
trace,
num_workers,
router_mode,
)
} }
pub fn simulate_concurrency_workload( pub fn simulate_concurrency_workload(
...@@ -513,6 +584,7 @@ pub fn simulate_concurrency_workload( ...@@ -513,6 +584,7 @@ pub fn simulate_concurrency_workload(
simulate_concurrency_workload_with_router_mode( simulate_concurrency_workload_with_router_mode(
args, args,
None, None,
None,
trace, trace,
max_in_flight, max_in_flight,
num_workers, num_workers,
...@@ -523,6 +595,7 @@ pub fn simulate_concurrency_workload( ...@@ -523,6 +595,7 @@ pub fn simulate_concurrency_workload(
pub fn simulate_concurrency_workload_with_router_mode( pub fn simulate_concurrency_workload_with_router_mode(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace, trace: Trace,
max_in_flight: usize, max_in_flight: usize,
num_workers: usize, num_workers: usize,
...@@ -533,6 +606,7 @@ pub fn simulate_concurrency_workload_with_router_mode( ...@@ -533,6 +606,7 @@ pub fn simulate_concurrency_workload_with_router_mode(
crate::replay::offline::simulate_concurrency_workload( crate::replay::offline::simulate_concurrency_workload(
args, args,
router_config, router_config,
prefill_load_estimator,
trace, trace,
max_in_flight, max_in_flight,
num_workers, num_workers,
...@@ -543,6 +617,7 @@ pub fn simulate_concurrency_workload_with_router_mode( ...@@ -543,6 +617,7 @@ pub fn simulate_concurrency_workload_with_router_mode(
pub fn simulate_concurrency_workload_disagg_with_router_mode( pub fn simulate_concurrency_workload_disagg_with_router_mode(
config: OfflineDisaggReplayConfig, config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace, trace: Trace,
max_in_flight: usize, max_in_flight: usize,
router_mode: ReplayRouterMode, router_mode: ReplayRouterMode,
...@@ -552,6 +627,7 @@ pub fn simulate_concurrency_workload_disagg_with_router_mode( ...@@ -552,6 +627,7 @@ pub fn simulate_concurrency_workload_disagg_with_router_mode(
crate::replay::offline::simulate_concurrency_workload_disagg( crate::replay::offline::simulate_concurrency_workload_disagg(
config, config,
router_config, router_config,
prefill_load_estimator,
trace, trace,
max_in_flight, max_in_flight,
router_mode, router_mode,
...@@ -567,6 +643,7 @@ pub fn simulate_concurrency_live_workload( ...@@ -567,6 +643,7 @@ pub fn simulate_concurrency_live_workload(
simulate_concurrency_live_workload_with_router_mode( simulate_concurrency_live_workload_with_router_mode(
args, args,
None, None,
None,
trace, trace,
max_in_flight, max_in_flight,
num_workers, num_workers,
...@@ -577,6 +654,7 @@ pub fn simulate_concurrency_live_workload( ...@@ -577,6 +654,7 @@ pub fn simulate_concurrency_live_workload(
pub fn simulate_concurrency_live_workload_with_router_mode( pub fn simulate_concurrency_live_workload_with_router_mode(
args: MockEngineArgs, args: MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace, trace: Trace,
max_in_flight: usize, max_in_flight: usize,
num_workers: usize, num_workers: usize,
...@@ -587,6 +665,7 @@ pub fn simulate_concurrency_live_workload_with_router_mode( ...@@ -587,6 +665,7 @@ pub fn simulate_concurrency_live_workload_with_router_mode(
online::simulate_concurrency_workload( online::simulate_concurrency_workload(
args, args,
router_config, router_config,
prefill_load_estimator,
trace, trace,
max_in_flight, max_in_flight,
num_workers, num_workers,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment