Unverified Commit 95a750f4 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore(replay): refactor offline components into cleaner lanes (#7866)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 210bbf5d
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::VecDeque;
use std::time::Duration;
use tokio::time::Instant;
use super::single::RequestId;
#[derive(Debug, Clone, Copy)]
pub(super) struct PrefillLoadState {
pub(super) initial_effective_prefill_tokens: usize,
pub(super) expected_prefill_duration: Option<Duration>,
}
#[derive(Debug, Default)]
pub(super) struct PrefillLoadTracker {
pub(super) prefill_order: VecDeque<RequestId>,
pub(super) prefill_full_tokens_sum: usize,
pub(super) anchored_prefill: Option<(RequestId, Instant)>,
}
impl PrefillLoadTracker {
pub(super) fn insert(
&mut self,
request_id: &RequestId,
prefill: PrefillLoadState,
decay_now: Instant,
) {
self.prefill_full_tokens_sum += prefill.initial_effective_prefill_tokens;
let should_anchor = self.anchored_prefill.is_none();
self.prefill_order.push_back(request_id.clone());
if should_anchor {
self.anchored_prefill = Some((request_id.clone(), decay_now));
}
}
pub(super) fn remove(
&mut self,
request_id: &RequestId,
prefill: PrefillLoadState,
decay_now: Instant,
) {
self.prefill_full_tokens_sum = self
.prefill_full_tokens_sum
.checked_sub(prefill.initial_effective_prefill_tokens)
.expect("prefill_full_tokens_sum underflow");
let removed_front = self.prefill_order.front() == Some(request_id);
if removed_front {
let removed = self.prefill_order.pop_front();
debug_assert_eq!(removed.as_ref(), Some(request_id));
} else {
self.prefill_order
.retain(|queued_request_id| queued_request_id != request_id);
}
if self
.anchored_prefill
.as_ref()
.is_some_and(|(anchored_request_id, _)| anchored_request_id == request_id)
{
self.set_anchor_to_front(decay_now);
}
}
pub(super) fn set_anchor_to_front(&mut self, now: Instant) {
self.anchored_prefill = self
.prefill_order
.front()
.cloned()
.map(|request_id| (request_id, now));
}
}
......@@ -26,6 +26,10 @@ use std::time::Duration;
use tokio::time::Instant;
use uuid::Uuid;
use super::block_tracker::BlockTracker;
use super::prefill_tracker::{PrefillLoadState, PrefillLoadTracker};
use crate::protocols::PrefillLoadHint;
/// Duration after which stale requests may be expired (5 minutes).
const EXPIRY_DURATION: Duration = Duration::from_secs(300);
......@@ -36,106 +40,183 @@ const CHECK_EXPIRY_FREQUENCY: Duration = Duration::from_secs(30);
// TODO: use the common request_id if it exists in the repo
pub type RequestId = String;
#[derive(Debug)]
pub(super) struct RequestState {
blocks: Vec<(SequenceHash, Arc<()>)>,
started_at: Instant,
prefill: Option<PrefillLoadState>,
expected_output_tokens: Option<u32>,
}
/// A multi-request sequence manager that handles multiple active sequences with shared KV cache
#[derive(Debug, Getters)]
pub struct ActiveSequences {
active_seqs: HashMap<RequestId, Vec<(SequenceHash, Arc<()>)>>,
prefill_tokens: HashMap<RequestId, usize>,
/// Expected output tokens per request (used for resource estimation)
expected_output_tokens: HashMap<RequestId, u32>,
unique_blocks: HashMap<SequenceHash, std::sync::Weak<()>>,
/// Fractional block counts for blocks that are partially cached
/// When a block is in both unique_blocks and fractional_blocks,
/// it contributes the fractional value instead of 1 to active_blocks()
fractional_blocks: HashMap<SequenceHash, f64>,
requests: HashMap<RequestId, RequestState>,
prefill: PrefillLoadTracker,
blocks: BlockTracker,
#[getter(copy)]
block_size: usize,
#[getter(copy)]
active_tokens: usize,
// Request timestamps, for expiration.
request_timestamps: HashMap<RequestId, Instant>,
last_expiry_check_time: Instant,
}
impl ActiveSequences {
/// Create a new SharedSequenceManager instance
pub fn new(block_size: usize) -> Self {
// TODO: make this not a hard req
assert!(block_size > 1, "block_size must be greater than 1");
Self {
active_seqs: HashMap::new(),
prefill_tokens: HashMap::new(),
expected_output_tokens: HashMap::new(),
unique_blocks: HashMap::new(),
fractional_blocks: HashMap::new(),
requests: HashMap::new(),
prefill: PrefillLoadTracker::default(),
blocks: BlockTracker::default(),
block_size,
active_tokens: 0,
request_timestamps: HashMap::new(),
last_expiry_check_time: Instant::now(),
}
}
fn touch_block(&mut self, block: &SequenceHash) -> Arc<()> {
if let Some(weak) = self.unique_blocks.get(block)
&& let Some(rc) = weak.upgrade()
{
return rc;
#[cfg(any(test, debug_assertions))]
fn assert_consistent(&self) {
let active_prefills: HashSet<RequestId> = self
.requests
.iter()
.filter(|(_, state)| state.prefill.is_some())
.map(|(request_id, _)| request_id.clone())
.collect();
let ordered_prefills: HashSet<RequestId> =
self.prefill.prefill_order.iter().cloned().collect();
let recomputed_prefill_sum: usize = self
.requests
.values()
.filter_map(|state| state.prefill)
.map(|prefill| prefill.initial_effective_prefill_tokens)
.sum();
assert_eq!(
ordered_prefills.len(),
self.prefill.prefill_order.len(),
"prefill_order contains duplicate request ids",
);
assert_eq!(
ordered_prefills, active_prefills,
"prefill_order must match requests with active prefill load",
);
assert_eq!(
self.prefill.prefill_full_tokens_sum, recomputed_prefill_sum,
"prefill_full_tokens_sum drifted from request state",
);
if let Some(oldest_request_id) = self.prefill.prefill_order.front() {
let Some((anchored_request_id, _)) = self.prefill.anchored_prefill.as_ref() else {
panic!("anchored_prefill must exist when prefill_order is non-empty");
};
assert!(
self.requests
.get(oldest_request_id)
.is_some_and(|state| state.prefill.is_some()),
"prefill_order front must point to an active prefill request",
);
assert_eq!(
anchored_request_id, oldest_request_id,
"anchored_prefill must match prefill_order.front()",
);
} else {
assert!(
self.prefill.anchored_prefill.is_none(),
"anchored_prefill must be absent when no active prefills remain",
);
}
let rc = Arc::new(());
self.unique_blocks.insert(*block, Arc::downgrade(&rc));
rc
assert!(
self.blocks
.fractional_blocks
.keys()
.all(|hash| self.blocks.unique_blocks.contains_key(hash)),
"fractional_blocks cannot reference non-active blocks",
);
}
fn try_remove_block(&mut self, block: &SequenceHash) {
if let Some(weak) = self.unique_blocks.get(block)
&& weak.strong_count() == 0
{
self.unique_blocks.remove(block);
self.fractional_blocks.remove(block);
}
#[inline]
fn validate_state(&self) {
#[cfg(any(test, debug_assertions))]
self.assert_consistent();
}
pub fn active_blocks(&self) -> usize {
let mut count = self.unique_blocks.len() as f64;
for (hash, frac) in &self.fractional_blocks {
if self.unique_blocks.contains_key(hash) {
// Subtract 1 (the full block) and add the fractional value
count = count - 1.0 + frac;
self.blocks.active_blocks()
}
fn insert_prefill_load(
&mut self,
request_id: &RequestId,
prefill: PrefillLoadState,
decay_now: Instant,
) {
self.prefill.insert(request_id, prefill, decay_now);
}
fn remove_prefill_load(
&mut self,
request_id: &RequestId,
decay_now: Instant,
) -> Option<PrefillLoadState> {
let prefill = {
let state = self.requests.get_mut(request_id)?;
state.prefill.take()?
};
self.prefill.remove(request_id, prefill, decay_now);
Some(prefill)
}
fn active_prefill_tokens_at(&self, now: Instant) -> usize {
let Some((oldest_request_id, oldest_since)) = self.prefill.anchored_prefill.as_ref() else {
return 0;
};
let prefill = self
.requests
.get(oldest_request_id)
.and_then(|state| state.prefill)
.expect("prefill_order front missing prefill load");
let oldest_full = prefill.initial_effective_prefill_tokens;
let oldest_remaining = match prefill.expected_prefill_duration {
None => oldest_full,
Some(expected_prefill_duration) if expected_prefill_duration.is_zero() => 0,
Some(expected_prefill_duration) => {
let elapsed = now.saturating_duration_since(*oldest_since);
let remaining_fraction = (1.0
- (elapsed.as_secs_f64() / expected_prefill_duration.as_secs_f64()))
.clamp(0.0, 1.0);
((oldest_full as f64) * remaining_fraction).ceil() as usize
}
}
count.round() as usize
};
self.prefill
.prefill_full_tokens_sum
.checked_sub(oldest_full)
.expect("prefill_full_tokens_sum smaller than oldest load")
+ oldest_remaining
}
pub fn active_tokens(&self, decay_now: Instant) -> usize {
self.active_prefill_tokens_at(decay_now)
}
/// Find all blocks in a request that have only a single strong reference (only used by this request)
/// and insert them into fractional_blocks with the given fraction value.
pub fn set_single_ref_blocks_as_fractional(&mut self, request_id: &RequestId, fraction: f64) {
let Some(blocks) = self.active_seqs.get(request_id) else {
let Some(request_state) = self.requests.get(request_id) else {
tracing::warn!(
"Request {request_id} not found for set_single_ref_blocks_as_fractional"
);
return;
};
for (hash, rc) in blocks {
// A block with strong_count == 1 means only this request holds a reference
for (hash, rc) in &request_state.blocks {
if Arc::strong_count(rc) == 1 {
self.fractional_blocks.insert(*hash, fraction);
self.blocks.fractional_blocks.insert(*hash, fraction);
}
}
}
/// Add a new request with its initial tokens
/// Returns the set of expired request IDs that were removed during cleanup
/// Add a new request with its initial tokens.
/// Returns the set of expired request IDs that were removed during cleanup.
pub fn add_request(
&mut self,
request_id: RequestId,
......@@ -143,6 +224,7 @@ impl ActiveSequences {
isl: usize,
overlap: u32,
expected_output_tokens: Option<u32>,
decay_now: Instant,
) -> HashSet<RequestId> {
self.add_request_with_prefill_tracking(
request_id,
......@@ -151,11 +233,14 @@ impl ActiveSequences {
overlap,
expected_output_tokens,
true,
None,
decay_now,
)
}
/// Add a new request with optional prompt-token load accounting.
/// Returns the set of expired request IDs that were removed during cleanup.
#[allow(clippy::too_many_arguments)]
pub fn add_request_with_prefill_tracking(
&mut self,
request_id: RequestId,
......@@ -164,68 +249,76 @@ impl ActiveSequences {
overlap: u32,
expected_output_tokens: Option<u32>,
track_prefill_tokens: bool,
prefill_load_hint: Option<PrefillLoadHint>,
decay_now: Instant,
) -> HashSet<RequestId> {
// Check for double-add and log error, returning early
if self.active_seqs.contains_key(&request_id) {
if self.requests.contains_key(&request_id) {
tracing::error!("Request {request_id} is already active. Ignoring duplicate add.");
return HashSet::new();
}
// Lazily check and clean up expired requests, capturing removed IDs
let removed_requests = self.force_expiry();
let started_at = Instant::now();
let blocks = match token_sequence {
Some(sequence) => sequence
.into_iter()
.map(|block| {
let rc = self.blocks.touch_block(&block);
(block, rc)
})
.collect(),
None => Vec::new(),
};
let prefill_tokens = if track_prefill_tokens {
self.new_tokens(isl, overlap)
let prefill = if track_prefill_tokens {
let default_tokens = self.new_tokens(isl, overlap);
let hint = prefill_load_hint.unwrap_or(PrefillLoadHint {
initial_effective_prefill_tokens: default_tokens,
expected_prefill_duration: None,
});
(hint.initial_effective_prefill_tokens > 0).then_some(PrefillLoadState {
initial_effective_prefill_tokens: hint.initial_effective_prefill_tokens,
expected_prefill_duration: hint.expected_prefill_duration,
})
} else {
0
None
};
self.prefill_tokens
.insert(request_id.clone(), prefill_tokens);
self.active_tokens += prefill_tokens;
// Store expected output tokens if provided
if let Some(tokens) = expected_output_tokens {
self.expected_output_tokens
.insert(request_id.clone(), tokens);
}
if let Some(sequence) = token_sequence {
let sequence_with_refs: Vec<(SequenceHash, Arc<()>)> = sequence
.iter()
.map(|block| (*block, self.touch_block(block)))
.collect();
self.active_seqs
.insert(request_id.clone(), sequence_with_refs);
} else {
// dummy empty sequence
self.active_seqs.insert(request_id.clone(), Vec::new());
self.requests.insert(
request_id.clone(),
RequestState {
blocks,
started_at,
prefill,
expected_output_tokens,
},
);
if let Some(prefill) = prefill {
self.insert_prefill_load(&request_id, prefill, decay_now);
}
self.request_timestamps
.insert(request_id.clone(), Instant::now());
self.validate_state();
removed_requests
}
/// Mark prefill as completed for a request, removing it from prefill_tokens tracking
pub fn mark_prefill_completed(&mut self, request_id: &RequestId) {
if let Some(tokens) = self.prefill_tokens.remove(request_id) {
self.active_tokens = self
.active_tokens
.checked_sub(tokens)
.expect("active_tokens underflow");
}
/// Mark prefill as completed for a request, removing it from prompt-load tracking.
pub fn mark_prefill_completed(&mut self, request_id: &RequestId, decay_now: Instant) {
let _ = self.remove_prefill_load(request_id, decay_now);
self.validate_state();
}
pub fn new_tokens(&self, isl: usize, overlap: u32) -> usize {
let cached_tokens = (overlap as usize) * self.block_size;
isl.checked_sub(cached_tokens)
.unwrap_or_else(|| {
tracing::error!(
"prefill_tokens < 0 with ISL {isl} < cached_tokens {cached_tokens} (overlap {overlap} * block_size {}), returning 0",
self.block_size
);
0
})
isl.checked_sub(cached_tokens).unwrap_or_else(|| {
tracing::error!(
"prefill_tokens < 0 with ISL {isl} < cached_tokens {cached_tokens} (overlap {overlap} * block_size {}), returning 0",
self.block_size
);
0
})
}
pub fn potential_blocks_and_tokens(
......@@ -233,8 +326,15 @@ impl ActiveSequences {
token_sequence: Option<&[SequenceHash]>,
isl: usize,
overlap: u32,
decay_now: Instant,
) -> (usize, usize) {
self.potential_blocks_and_tokens_with_prefill_tracking(token_sequence, isl, overlap, true)
self.potential_blocks_and_tokens_with_prefill_tracking(
token_sequence,
isl,
overlap,
true,
decay_now,
)
}
pub fn potential_blocks_and_tokens_with_prefill_tracking(
......@@ -243,17 +343,20 @@ impl ActiveSequences {
isl: usize,
overlap: u32,
track_prefill_tokens: bool,
decay_now: Instant,
) -> (usize, usize) {
let potential_blocks = if let Some(token_seq) = token_sequence {
self.new_blocks(token_seq) + self.active_blocks()
} else {
self.active_blocks()
};
let active_tokens = self.active_tokens(decay_now);
let potential_tokens = if track_prefill_tokens {
self.new_tokens(isl, overlap) + self.active_tokens
self.new_tokens(isl, overlap) + active_tokens
} else {
self.active_tokens
active_tokens
};
(potential_blocks, potential_tokens)
}
......@@ -261,12 +364,11 @@ impl ActiveSequences {
pub fn new_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
token_sequence
.iter()
.filter(|block| !self.unique_blocks.contains_key(block))
.filter(|block| !self.blocks.unique_blocks.contains_key(block))
.count()
}
/// Return the total number of blocks that would be used if the token sequence was added
/// This is the sum of new blocks that would be added plus the current active blocks
/// Return the total number of blocks that would be used if the token sequence was added.
pub fn potential_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
self.new_blocks(token_sequence) + self.active_blocks()
}
......@@ -275,96 +377,77 @@ impl ActiveSequences {
///
/// This implicitly calls [`Self::mark_prefill_completed`] first, so callers do not need
/// to invoke both when the request is finishing.
pub fn free(&mut self, request_id: &RequestId) -> usize {
self.mark_prefill_completed(request_id);
// Remove expected output tokens tracking
self.expected_output_tokens.remove(request_id);
// Remove from active_seqs and get the token sequence
self.request_timestamps.remove(request_id);
let token_seq = match self.active_seqs.remove(request_id) {
Some(seq) => seq,
None => {
tracing::warn!("Trying to free non-existent request {request_id}");
return self.active_blocks();
}
pub fn free(&mut self, request_id: &RequestId, decay_now: Instant) -> usize {
self.mark_prefill_completed(request_id, decay_now);
let Some(request_state) = self.requests.remove(request_id) else {
tracing::warn!("Trying to free non-existent request {request_id}");
return self.active_blocks();
};
// Drop each Rc reference, then clean up the corresponding weak reference
for (block_hash, rc) in token_seq {
let _ = request_state.expected_output_tokens;
for (block_hash, rc) in request_state.blocks {
drop(rc);
self.try_remove_block(&block_hash);
self.blocks.try_remove_block(&block_hash);
}
self.validate_state();
self.active_blocks()
}
/// Add an output block with a random hash and optional fractional decay weight.
///
/// This is used during generation to track output blocks as they are created.
/// The decay_fraction (if provided) represents how "temporary" the block is:
/// - 1.0 means fully counted (early in generation)
/// - 0.0 means not counted (near end of expected output)
/// - Computed as: 1 - (current_osl / expected_output_tokens)
///
/// Returns true if the block was added, false if the request was not found.
pub fn add_output_block(
&mut self,
request_id: &RequestId,
decay_fraction: Option<f64>,
) -> bool {
// Check if request exists first (immutable borrow)
if !self.active_seqs.contains_key(request_id) {
if !self.requests.contains_key(request_id) {
tracing::warn!("Request {request_id} not found for add_output_block");
return false;
}
// Generate a random block hash using UUID
let random_hash: SequenceHash = Uuid::new_v4().as_u64_pair().0;
// Touch the block (adds to unique_blocks)
let rc = self.touch_block(&random_hash);
// Now we can safely get_mut and push
self.active_seqs
let rc = self.blocks.touch_block(&random_hash);
self.requests
.get_mut(request_id)
.unwrap()
.expect("request existence was checked above")
.blocks
.push((random_hash, rc));
// Apply fractional decay to all single-ref blocks in this request if provided
if let Some(frac) = decay_fraction {
self.set_single_ref_blocks_as_fractional(request_id, frac);
}
self.validate_state();
true
}
/// Force expiry of stale requests if the timer has elapsed
/// Returns the set of expired request IDs that were removed
/// Force expiry of stale requests if the timer has elapsed.
/// Returns the set of expired request IDs that were removed.
pub fn force_expiry(&mut self) -> HashSet<RequestId> {
let now = Instant::now();
// Early return if timer hasn't expired yet.
if now < self.last_expiry_check_time + CHECK_EXPIRY_FREQUENCY {
return HashSet::new();
}
self.last_expiry_check_time = now;
let expired_requests_time = now - EXPIRY_DURATION;
let mut expired_requests: HashSet<RequestId> = HashSet::new();
for (request_id, timestamp) in self.request_timestamps.iter() {
if *timestamp < expired_requests_time {
expired_requests.insert(request_id.clone());
}
}
let expired_requests: HashSet<RequestId> = self
.requests
.iter()
.filter(|(_, state)| state.started_at < expired_requests_time)
.map(|(request_id, _)| request_id.clone())
.collect();
for request_id in &expired_requests {
tracing::warn!("Expiring stale request: {}", request_id);
self.free(request_id);
self.free(request_id, now);
}
self.validate_state();
expired_requests
}
}
......@@ -372,103 +455,131 @@ impl ActiveSequences {
#[cfg(test)]
mod tests {
use super::*;
use std::collections::VecDeque;
fn prefill_hint(tokens: usize, duration_secs: u64) -> PrefillLoadHint {
PrefillLoadHint {
initial_effective_prefill_tokens: tokens,
expected_prefill_duration: Some(Duration::from_secs(duration_secs)),
}
}
#[test]
fn test_active_sequences_shared_blocks() {
let block_size = 4;
let mut seq_manager = ActiveSequences::new(block_size);
let decay_now = Instant::now();
seq_manager.add_request("request_1".to_string(), Some(vec![1, 2, 3]), 12, 0, None);
seq_manager.add_request(
"request_1".to_string(),
Some(vec![1, 2, 3]),
12,
0,
None,
decay_now,
);
assert_eq!(seq_manager.active_blocks(), 3);
assert_eq!(seq_manager.active_tokens(), 12);
assert_eq!(seq_manager.active_tokens(decay_now), 12);
seq_manager.add_request("request_2".to_string(), Some(vec![4]), 4, 0, None);
seq_manager.add_request(
"request_2".to_string(),
Some(vec![4]),
4,
0,
None,
decay_now,
);
assert_eq!(seq_manager.active_blocks(), 4);
assert_eq!(seq_manager.active_tokens(), 16);
assert_eq!(seq_manager.active_tokens(decay_now), 16);
seq_manager.add_request("request_3".to_string(), Some(vec![1, 2, 3, 4]), 16, 4, None);
seq_manager.add_request(
"request_3".to_string(),
Some(vec![1, 2, 3, 4]),
16,
4,
None,
decay_now,
);
assert_eq!(seq_manager.active_blocks(), 4);
assert_eq!(seq_manager.active_tokens(), 16);
assert_eq!(seq_manager.active_tokens(decay_now), 16);
seq_manager.free(&"request_2".to_string());
seq_manager.free(&"request_2".to_string(), decay_now);
assert_eq!(seq_manager.active_blocks(), 4);
assert_eq!(seq_manager.active_tokens(), 12);
assert_eq!(seq_manager.active_tokens(decay_now), 12);
seq_manager.free(&"request_3".to_string());
seq_manager.free(&"request_3".to_string(), decay_now);
assert_eq!(seq_manager.active_blocks(), 3);
assert_eq!(seq_manager.active_tokens(), 12);
assert_eq!(seq_manager.active_tokens(decay_now), 12);
seq_manager.free(&"request_1".to_string());
seq_manager.free(&"request_1".to_string(), decay_now);
assert_eq!(seq_manager.active_blocks(), 0);
assert_eq!(seq_manager.active_tokens(), 0);
assert_eq!(seq_manager.active_tokens(decay_now), 0);
}
#[test]
fn test_output_blocks_with_fractional_decay() {
let block_size = 4;
let mut seq_manager = ActiveSequences::new(block_size);
let decay_now = Instant::now();
// Add request with 3 prefill blocks
seq_manager.add_request("r1".to_string(), Some(vec![1, 2, 3]), 12, 0, None);
seq_manager.add_request(
"r1".to_string(),
Some(vec![1, 2, 3]),
12,
0,
None,
decay_now,
);
assert_eq!(seq_manager.active_blocks(), 3);
// Add output block with 0.5 decay fraction.
// This adds a random block and sets all single-ref blocks to 0.5.
assert!(seq_manager.add_output_block(&"r1".to_string(), Some(0.5)));
// 4 unique blocks, all single-ref → all fractional at 0.5
// active_blocks = 4 - 4 + 4*0.5 = 2
assert_eq!(seq_manager.active_blocks(), 2);
// Add second request sharing prefix [1, 2]
seq_manager.add_request("r2".to_string(), Some(vec![1, 2]), 8, 0, None);
// Blocks 1,2 now have strong_count=2 but still have fractional 0.5 from before
// No new unique blocks → active_blocks = 4 - 4 + 2.0 = 2
seq_manager.add_request("r2".to_string(), Some(vec![1, 2]), 8, 0, None, decay_now);
assert_eq!(seq_manager.active_blocks(), 2);
// Add another output block with 0.0 decay for r1.
// set_single_ref_blocks_as_fractional updates only single-ref blocks:
// blocks 1,2: strong_count=2, NOT updated (remain 0.5)
// block 3, old output, new output: strong_count=1, set to 0.0
// active_blocks = 5 - 5 + (0.5+0.5+0.0+0.0+0.0) = 1
assert!(seq_manager.add_output_block(&"r1".to_string(), Some(0.0)));
assert_eq!(seq_manager.active_blocks(), 1);
// Free both requests, verify clean state
seq_manager.free(&"r2".to_string());
seq_manager.free(&"r1".to_string());
seq_manager.free(&"r2".to_string(), decay_now);
seq_manager.free(&"r1".to_string(), decay_now);
assert_eq!(seq_manager.active_blocks(), 0);
assert_eq!(seq_manager.active_tokens(), 0);
assert_eq!(seq_manager.active_tokens(decay_now), 0);
}
#[test]
fn test_mark_prefill_completed() {
let block_size = 4;
let mut seq_manager = ActiveSequences::new(block_size);
let decay_now = Instant::now();
// Add request with isl=12, overlap=0 → active_tokens=12
seq_manager.add_request("r1".to_string(), Some(vec![1, 2, 3]), 12, 0, None);
assert_eq!(seq_manager.active_tokens(), 12);
seq_manager.add_request(
"r1".to_string(),
Some(vec![1, 2, 3]),
12,
0,
None,
decay_now,
);
assert_eq!(seq_manager.active_tokens(decay_now), 12);
// Mark prefill completed → active_tokens drops to 0
seq_manager.mark_prefill_completed(&"r1".to_string());
assert_eq!(seq_manager.active_tokens(), 0);
seq_manager.mark_prefill_completed(&"r1".to_string(), decay_now);
assert_eq!(seq_manager.active_tokens(decay_now), 0);
// Double-mark: no panic, still 0
seq_manager.mark_prefill_completed(&"r1".to_string());
assert_eq!(seq_manager.active_tokens(), 0);
seq_manager.mark_prefill_completed(&"r1".to_string(), decay_now);
assert_eq!(seq_manager.active_tokens(decay_now), 0);
// Add second request with isl=8
seq_manager.add_request("r2".to_string(), Some(vec![4, 5]), 8, 0, None);
assert_eq!(seq_manager.active_tokens(), 8);
seq_manager.add_request("r2".to_string(), Some(vec![4, 5]), 8, 0, None, decay_now);
assert_eq!(seq_manager.active_tokens(decay_now), 8);
// Free it (internally calls mark_prefill_completed) → active_tokens=0
seq_manager.free(&"r2".to_string());
assert_eq!(seq_manager.active_tokens(), 0);
seq_manager.free(&"r2".to_string(), decay_now);
assert_eq!(seq_manager.active_tokens(decay_now), 0);
}
#[test]
fn test_add_request_without_prefill_tracking_keeps_active_tokens_zero() {
let mut seq_manager = ActiveSequences::new(4);
let decay_now = Instant::now();
seq_manager.add_request_with_prefill_tracking(
"r1".to_string(),
......@@ -477,18 +588,24 @@ mod tests {
0,
None,
false,
None,
decay_now,
);
assert_eq!(seq_manager.active_tokens(), 0);
seq_manager.mark_prefill_completed(&"r1".to_string());
assert_eq!(seq_manager.active_tokens(), 0);
seq_manager.free(&"r1".to_string());
assert_eq!(seq_manager.active_tokens(decay_now), 0);
assert!(seq_manager.prefill.prefill_order.is_empty());
assert_eq!(seq_manager.prefill.prefill_full_tokens_sum, 0);
seq_manager.mark_prefill_completed(&"r1".to_string(), decay_now);
assert_eq!(seq_manager.active_tokens(decay_now), 0);
seq_manager.free(&"r1".to_string(), decay_now);
assert_eq!(seq_manager.active_blocks(), 0);
}
#[test]
fn test_potential_blocks_and_tokens_without_prefill_tracking_ignores_prompt_load() {
let mut seq_manager = ActiveSequences::new(4);
let decay_now = Instant::now();
seq_manager.add_request_with_prefill_tracking(
"r1".to_string(),
Some(vec![1, 2, 3]),
......@@ -496,6 +613,8 @@ mod tests {
0,
None,
false,
None,
decay_now,
);
let (blocks, tokens) = seq_manager.potential_blocks_and_tokens_with_prefill_tracking(
......@@ -503,49 +622,331 @@ mod tests {
16,
0,
false,
decay_now,
);
assert_eq!(blocks, 4);
assert_eq!(tokens, 0);
}
#[test]
fn test_prefill_decay_only_applies_to_oldest_request() {
let mut seq_manager = ActiveSequences::new(4);
let epoch = Instant::now();
seq_manager.add_request_with_prefill_tracking(
"r1".to_string(),
Some(vec![1]),
100,
0,
None,
true,
Some(prefill_hint(100, 10)),
epoch,
);
seq_manager.add_request_with_prefill_tracking(
"r2".to_string(),
Some(vec![2]),
60,
0,
None,
true,
Some(prefill_hint(60, 6)),
epoch + Duration::from_secs(2),
);
assert_eq!(
seq_manager.active_tokens(epoch + Duration::from_secs(2)),
140
);
let decayed = seq_manager.active_tokens(epoch + Duration::from_secs(5));
assert_eq!(decayed, 110);
assert!(decayed <= 160);
assert!(decayed >= 60);
}
#[test]
fn test_prefill_decay_hands_off_to_next_oldest_request() {
let mut seq_manager = ActiveSequences::new(4);
let epoch = Instant::now();
seq_manager.add_request_with_prefill_tracking(
"r1".to_string(),
Some(vec![1]),
100,
0,
None,
true,
Some(prefill_hint(100, 10)),
epoch,
);
seq_manager.add_request_with_prefill_tracking(
"r2".to_string(),
Some(vec![2]),
40,
0,
None,
true,
Some(prefill_hint(40, 8)),
epoch,
);
assert_eq!(
seq_manager.active_tokens(epoch + Duration::from_secs(3)),
110
);
seq_manager.mark_prefill_completed(&"r1".to_string(), epoch + Duration::from_secs(3));
assert_eq!(
seq_manager.active_tokens(epoch + Duration::from_secs(3)),
40
);
assert_eq!(
seq_manager.prefill.prefill_order,
VecDeque::from(vec!["r2".to_string()])
);
assert!(
seq_manager
.prefill
.anchored_prefill
.as_ref()
.is_some_and(|(request_id, _)| request_id == "r2")
);
assert_eq!(
seq_manager.active_tokens(epoch + Duration::from_secs(5)),
30
);
}
#[test]
fn test_prefill_decay_resets_when_request_becomes_oldest() {
let mut seq_manager = ActiveSequences::new(4);
let epoch = Instant::now();
seq_manager.add_request_with_prefill_tracking(
"r1".to_string(),
Some(vec![1]),
100,
0,
None,
true,
Some(prefill_hint(100, 10)),
epoch,
);
seq_manager.add_request_with_prefill_tracking(
"r2".to_string(),
Some(vec![2]),
80,
0,
None,
true,
Some(prefill_hint(80, 8)),
epoch + Duration::from_secs(4),
);
assert_eq!(
seq_manager.active_tokens(epoch + Duration::from_secs(8)),
100
);
seq_manager.mark_prefill_completed(&"r1".to_string(), epoch + Duration::from_secs(8));
assert_eq!(
seq_manager.active_tokens(epoch + Duration::from_secs(8)),
80
);
assert_eq!(
seq_manager.active_tokens(epoch + Duration::from_secs(10)),
60
);
}
#[test]
fn test_prefill_front_removal_reanchors_queue_front() {
let mut seq_manager = ActiveSequences::new(4);
let epoch = Instant::now();
seq_manager.add_request_with_prefill_tracking(
"r1".to_string(),
Some(vec![1]),
30,
0,
None,
true,
Some(prefill_hint(30, 6)),
epoch,
);
seq_manager.add_request_with_prefill_tracking(
"r2".to_string(),
Some(vec![2]),
20,
0,
None,
true,
Some(prefill_hint(20, 4)),
epoch,
);
seq_manager.mark_prefill_completed(&"r1".to_string(), epoch + Duration::from_secs(2));
assert!(
seq_manager
.prefill
.anchored_prefill
.as_ref()
.is_some_and(|(request_id, _)| request_id == "r2")
);
assert_eq!(
seq_manager.active_tokens(epoch + Duration::from_secs(2)),
20
);
}
#[test]
fn test_prefill_queue_and_sum_invariants_survive_idempotent_cleanup() {
let mut seq_manager = ActiveSequences::new(4);
let decay_now = Instant::now();
seq_manager.add_request_with_prefill_tracking(
"r1".to_string(),
Some(vec![1]),
50,
0,
None,
true,
Some(prefill_hint(50, 10)),
decay_now,
);
seq_manager.add_request_with_prefill_tracking(
"r2".to_string(),
Some(vec![2]),
30,
0,
None,
true,
Some(prefill_hint(30, 10)),
decay_now,
);
assert_eq!(seq_manager.prefill.prefill_full_tokens_sum, 80);
assert_eq!(
seq_manager.prefill.prefill_order,
VecDeque::from(vec!["r1".to_string(), "r2".to_string()])
);
seq_manager.mark_prefill_completed(&"r1".to_string(), decay_now);
seq_manager.mark_prefill_completed(&"r1".to_string(), decay_now);
assert_eq!(seq_manager.prefill.prefill_full_tokens_sum, 30);
assert_eq!(
seq_manager.prefill.prefill_order,
VecDeque::from(vec!["r2".to_string()])
);
seq_manager.free(&"r1".to_string(), decay_now);
seq_manager.free(&"r1".to_string(), decay_now);
assert_eq!(seq_manager.prefill.prefill_full_tokens_sum, 30);
assert_eq!(
seq_manager.prefill.prefill_order,
VecDeque::from(vec!["r2".to_string()])
);
seq_manager.free(&"r2".to_string(), decay_now);
assert_eq!(seq_manager.prefill.prefill_full_tokens_sum, 0);
assert!(seq_manager.prefill.prefill_order.is_empty());
assert!(seq_manager.requests.is_empty());
}
#[tokio::test(start_paused = true)]
async fn test_force_expiry() {
let block_size = 4;
let mut seq_manager = ActiveSequences::new(block_size);
// Add two requests at time 0 (paused clock)
seq_manager.add_request("r1".to_string(), Some(vec![1, 2]), 8, 0, None);
seq_manager.add_request("r2".to_string(), Some(vec![3, 4]), 8, 0, None);
seq_manager.add_request(
"r1".to_string(),
Some(vec![1, 2]),
8,
0,
None,
Instant::now(),
);
seq_manager.add_request(
"r2".to_string(),
Some(vec![3, 4]),
8,
0,
None,
Instant::now(),
);
assert_eq!(seq_manager.active_blocks(), 4);
// Advance 20s: check interval (CHECK_EXPIRY_FREQUENCY = 30s) not reached,
// force_expiry returns without running the check.
tokio::time::advance(Duration::from_secs(20)).await;
let expired = seq_manager.force_expiry();
assert!(expired.is_empty(), "no check before CHECK_EXPIRY_FREQUENCY");
assert_eq!(seq_manager.active_blocks(), 4);
// Advance to 31s: first time we pass the check interval. Requests are 31s old,
// still under EXPIRY_DURATION (300s), so none are expired.
tokio::time::advance(Duration::from_secs(11)).await;
let expired = seq_manager.force_expiry();
assert!(expired.is_empty(), "requests not old enough to expire");
assert_eq!(seq_manager.active_blocks(), 4);
seq_manager.assert_consistent();
// Advance to 301s: requests are now older than EXPIRY_DURATION.
// force_expiry runs and expires r1, r2.
tokio::time::advance(Duration::from_secs(270)).await;
let expired = seq_manager.force_expiry();
assert_eq!(expired, HashSet::from(["r1".to_string(), "r2".to_string()]));
assert_eq!(seq_manager.active_blocks(), 0);
assert_eq!(seq_manager.active_tokens(), 0);
assert_eq!(seq_manager.active_tokens(Instant::now()), 0);
seq_manager.assert_consistent();
// add_request calls force_expiry internally. Add r3; no old requests remain,
// so expired set is empty and only r3 is active.
tokio::time::advance(Duration::from_secs(31)).await;
let expired = seq_manager.add_request("r3".to_string(), Some(vec![5]), 4, 0, None);
let expired =
seq_manager.add_request("r3".to_string(), Some(vec![5]), 4, 0, None, Instant::now());
assert!(expired.is_empty());
assert_eq!(seq_manager.active_blocks(), 1);
assert_eq!(seq_manager.active_tokens(), 4);
assert_eq!(seq_manager.active_tokens(Instant::now()), 4);
seq_manager.assert_consistent();
}
#[tokio::test(start_paused = true)]
async fn test_force_expiry_reanchors_new_oldest_request() {
let mut seq_manager = ActiveSequences::new(4);
let first_decay_now = Instant::now();
seq_manager.add_request_with_prefill_tracking(
"r1".to_string(),
Some(vec![1]),
40,
0,
None,
true,
Some(prefill_hint(40, 100)),
first_decay_now,
);
tokio::time::advance(Duration::from_secs(250)).await;
seq_manager.add_request_with_prefill_tracking(
"r2".to_string(),
Some(vec![2]),
30,
0,
None,
true,
Some(prefill_hint(30, 100)),
Instant::now(),
);
tokio::time::advance(Duration::from_secs(60)).await;
let expired = seq_manager.force_expiry();
assert_eq!(expired, HashSet::from(["r1".to_string()]));
assert_eq!(seq_manager.active_tokens(Instant::now()), 30);
assert!(
seq_manager
.prefill
.anchored_prefill
.as_ref()
.is_some_and(|(request_id, _)| request_id == "r2")
);
tokio::time::advance(Duration::from_secs(20)).await;
assert_eq!(seq_manager.active_tokens(Instant::now()), 24);
}
}
......@@ -4,7 +4,7 @@
use std::{collections::HashSet, sync::Arc};
use dashmap::{DashMap, mapref::entry::Entry};
use dynamo_kv_router::{config::KvRouterConfig, protocols::WorkerId};
use dynamo_kv_router::{PrefillLoadEstimator, config::KvRouterConfig, protocols::WorkerId};
use tokio::sync::oneshot;
use super::worker_monitor::LoadThresholdConfig;
......@@ -568,6 +568,7 @@ impl ModelManager {
endpoint: &Endpoint,
kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
worker_type: &'static str,
model_name: Option<String>,
is_eagle: bool,
......@@ -604,6 +605,7 @@ impl ModelManager {
kv_cache_block_size,
selector,
kv_router_config,
prefill_load_estimator,
worker_type,
model_name,
is_eagle,
......
......@@ -7,6 +7,7 @@ use tokio::sync::mpsc::Sender;
use anyhow::Context as _;
use dashmap::DashSet;
use dynamo_kv_router::PrefillLoadEstimator;
use futures::StreamExt;
use dynamo_runtime::{
......@@ -74,6 +75,7 @@ pub struct ModelWatcher {
notify_on_model: Notify,
model_update_tx: Option<Sender<ModelUpdate>>,
chat_engine_factory: Option<ChatEngineFactoryCallback>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
metrics: Arc<Metrics>,
/// Guards against concurrent pipeline construction for the same (model, namespace).
registering_worker_sets: DashSet<String>,
......@@ -118,6 +120,7 @@ impl ModelWatcher {
router_config: RouterConfig,
migration_limit: u32,
chat_engine_factory: Option<ChatEngineFactoryCallback>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
metrics: Arc<Metrics>,
) -> ModelWatcher {
Self {
......@@ -128,6 +131,7 @@ impl ModelWatcher {
notify_on_model: Notify::new(),
model_update_tx: None,
chat_engine_factory,
prefill_load_estimator,
metrics,
registering_worker_sets: DashSet::new(),
}
......@@ -465,6 +469,7 @@ impl ModelWatcher {
&endpoint,
card.kv_cache_block_size,
Some(self.router_config.kv_router_config.clone()),
self.prefill_load_estimator.clone(),
WORKER_TYPE_DECODE, // This is the decode router
Some(card.display_name.clone()),
card.runtime_config.enable_eagle,
......@@ -506,6 +511,7 @@ impl ModelWatcher {
self.router_config.router_mode,
card.kv_cache_block_size,
Some(prefill_config),
self.prefill_load_estimator.clone(),
self.router_config.enforce_disagg,
model_name.clone(),
namespace.clone(),
......
......@@ -12,7 +12,7 @@ use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use dynamo_kv_router::config::KvRouterConfig;
use dynamo_kv_router::{PrefillLoadEstimator, config::KvRouterConfig};
use dynamo_runtime::{discovery::ModelCardInstanceId, pipeline::RouterMode};
use crate::{
......@@ -68,6 +68,7 @@ pub enum EngineConfig {
Dynamic {
model: Box<LocalModel>,
chat_engine_factory: Option<ChatEngineFactoryCallback>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
},
/// A Text engine receives text, does it's own tokenization and prompt formatting.
......
......@@ -94,7 +94,9 @@ pub async fn prepare_engine(
) -> anyhow::Result<PreparedEngine> {
match engine_config {
EngineConfig::Dynamic {
model: local_model, ..
model: local_model,
prefill_load_estimator,
..
} => {
let model_manager = Arc::new(ModelManager::new());
// Create metrics for migration tracking (not exposed via /metrics in Dynamic engine mode)
......@@ -105,6 +107,7 @@ pub async fn prepare_engine(
RouterConfig::default(),
local_model.migration_limit(),
None,
prefill_load_estimator,
metrics,
));
let discovery = distributed_runtime.discovery();
......
......@@ -33,7 +33,11 @@ pub async fn run(
}
let grpc_service = match engine_config {
EngineConfig::Dynamic { ref model, .. } => {
EngineConfig::Dynamic {
ref model,
ref prefill_load_estimator,
..
} => {
let grpc_service = grpc_service_builder.build()?;
let router_config = model.router_config();
let migration_limit = model.migration_limit();
......@@ -48,6 +52,7 @@ pub async fn run(
router_config.clone(),
migration_limit,
namespace_filter,
prefill_load_estimator.clone(),
)
.await?;
grpc_service
......@@ -111,6 +116,7 @@ async fn run_watcher(
router_config: RouterConfig,
migration_limit: u32,
namespace_filter: NamespaceFilter,
prefill_load_estimator: Option<Arc<dyn dynamo_kv_router::PrefillLoadEstimator>>,
) -> anyhow::Result<()> {
// Create metrics for migration tracking (not exposed via /metrics in gRPC mode)
let metrics = Arc::new(Metrics::new());
......@@ -120,6 +126,7 @@ async fn run_watcher(
router_config,
migration_limit,
None,
prefill_load_estimator,
metrics,
);
tracing::debug!("Waiting for remote model");
......
......@@ -67,6 +67,7 @@ pub async fn run(
EngineConfig::Dynamic {
ref model,
ref chat_engine_factory,
ref prefill_load_estimator,
} => {
// Pass the discovery client so the /health endpoint can query active instances
http_service_builder =
......@@ -90,6 +91,7 @@ pub async fn run(
Arc::new(http_service.clone()),
http_service.state().metrics_clone(),
chat_engine_factory.clone(),
prefill_load_estimator.clone(),
)
.await?;
http_service
......@@ -167,6 +169,7 @@ async fn run_watcher(
http_service: Arc<HttpService>,
metrics: Arc<crate::http::service::metrics::Metrics>,
chat_engine_factory: Option<ChatEngineFactoryCallback>,
prefill_load_estimator: Option<Arc<dyn dynamo_kv_router::PrefillLoadEstimator>>,
) -> anyhow::Result<()> {
let mut watch_obj = ModelWatcher::new(
runtime.clone(),
......@@ -174,6 +177,7 @@ async fn run_watcher(
router_config,
migration_limit,
chat_engine_factory,
prefill_load_estimator,
metrics.clone(),
);
tracing::debug!("Waiting for remote model");
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use std::time::Instant;
use anyhow::Result;
use dynamo_kv_router::{
PrefillLoadEstimator,
config::{KvRouterConfig, RouterConfigOverride, min_initial_workers_from_env},
indexer::KvRouterError,
protocols::KV_EVENT_SUBJECT,
protocols::{
BlockExtraInfo, BlockHashOptions, DpRank, RouterEvent, RouterRequest, RouterResponse,
TokensWithHashes, WorkerId, WorkerWithDpRank, compute_block_hash_for_seq,
BlockExtraInfo, BlockHashOptions, DpRank, PrefillLoadHint, RouterEvent, RouterRequest,
RouterResponse, TokensWithHashes, WorkerId, WorkerWithDpRank, compute_block_hash_for_seq,
},
};
use dynamo_runtime::{
......@@ -111,6 +113,7 @@ where
scheduler: KvScheduler<Sel>,
block_size: u32,
kv_router_config: KvRouterConfig,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
cancellation_token: tokio_util::sync::CancellationToken,
client: Client,
is_eagle: bool,
......@@ -128,6 +131,7 @@ where
block_size: u32,
selector: Sel,
kv_router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
worker_type: &'static str,
model_name: Option<String>,
is_eagle: bool,
......@@ -159,6 +163,7 @@ where
workers_with_configs.clone(),
selector,
&kv_router_config,
prefill_load_estimator.clone(),
worker_type,
)
.await?;
......@@ -184,6 +189,7 @@ where
scheduler,
block_size,
kv_router_config,
prefill_load_estimator,
cancellation_token,
client,
is_eagle,
......@@ -345,6 +351,8 @@ where
let track_prefill_tokens = self
.kv_router_config
.track_prefill_tokens(router_config_override);
let prefill_load_hint =
self.prefill_load_hint_for(isl_tokens, overlap_blocks, track_prefill_tokens);
if let Err(e) = self
.scheduler
......@@ -355,6 +363,7 @@ where
overlap: overlap_blocks,
track_prefill_tokens,
expected_output_tokens,
prefill_load_hint,
worker,
lora_name,
})
......@@ -377,6 +386,42 @@ where
self.scheduler.pending_count()
}
fn prefill_load_hint_for(
&self,
isl_tokens: usize,
overlap_blocks: u32,
track_prefill_tokens: bool,
) -> Option<PrefillLoadHint> {
if !track_prefill_tokens {
return None;
}
let prefix = (overlap_blocks as usize) * (self.block_size as usize);
let effective_isl = isl_tokens.saturating_sub(prefix);
if effective_isl == 0 {
return None;
}
let Some(estimator) = &self.prefill_load_estimator else {
return None;
};
match estimator.predict_prefill_duration(1, effective_isl, prefix) {
Ok(expected_prefill_duration) => Some(PrefillLoadHint {
initial_effective_prefill_tokens: effective_isl,
expected_prefill_duration: Some(expected_prefill_duration),
}),
Err(error) => {
tracing::warn!(
effective_isl,
prefix,
"failed to predict prefill duration for direct add_request path: {error}"
);
None
}
}
}
/// Get the worker type for this router ("prefill" or "decode").
/// Used for Prometheus metric labeling.
pub fn worker_type(&self) -> &'static str {
......
......@@ -6,7 +6,7 @@ use std::sync::Arc;
use anyhow::Result;
use tokio::sync::oneshot;
use dynamo_kv_router::config::KvRouterConfig;
use dynamo_kv_router::{PrefillLoadEstimator, config::KvRouterConfig};
use dynamo_runtime::{
component::{Client, Endpoint},
pipeline::{PushRouter, RouterMode},
......@@ -37,6 +37,7 @@ impl PrefillRouter {
cancel_token: tokio_util::sync::CancellationToken::new(),
router_mode,
enforce_disagg,
prefill_load_estimator: None,
model_name: String::new(), // Not used for disabled router
namespace: String::new(), // Not used for disabled router
is_eagle: false,
......@@ -50,6 +51,7 @@ impl PrefillRouter {
router_mode: RouterMode,
kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
enforce_disagg: bool,
model_name: String,
namespace: String,
......@@ -65,6 +67,7 @@ impl PrefillRouter {
cancel_token: cancel_token.clone(),
router_mode,
enforce_disagg,
prefill_load_estimator,
model_name,
namespace,
is_eagle,
......@@ -85,6 +88,7 @@ impl PrefillRouter {
model_manager,
kv_cache_block_size,
kv_router_config,
router_clone.prefill_load_estimator.clone(),
).await {
tracing::error!(error = %e, "Failed to activate prefill router");
}
......@@ -105,6 +109,7 @@ impl PrefillRouter {
model_manager: Arc<ModelManager>,
kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
) -> Result<()> {
tracing::info!(
router_mode = ?self.router_mode,
......@@ -127,6 +132,7 @@ impl PrefillRouter {
&endpoint,
kv_cache_block_size,
kv_router_config,
prefill_load_estimator,
WORKER_TYPE_PREFILL,
Some(self.model_name.clone()),
self.is_eagle,
......
......@@ -6,6 +6,7 @@ use std::sync::{Arc, OnceLock};
use anyhow::Result;
use tokio_util::sync::CancellationToken;
use dynamo_kv_router::PrefillLoadEstimator;
use dynamo_runtime::{
pipeline::{
AsyncEngineContextProvider, ManyOut, Operator, RouterMode, ServerStreamingEngine, SingleIn,
......@@ -47,6 +48,7 @@ pub struct PrefillRouter {
cancel_token: CancellationToken,
router_mode: RouterMode,
enforce_disagg: bool,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
/// Model name used to look up the worker monitor for prefill client registration
model_name: String,
/// Namespace used to look up the correct WorkerSet's worker monitor
......
......@@ -16,6 +16,7 @@ use crate::discovery::RuntimeConfigWatch;
use crate::local_model::runtime_config::ModelRuntimeConfig;
use anyhow::Result;
use dynamo_kv_router::{
PrefillLoadEstimator,
config::{KvRouterConfig, RouterConfigOverride},
protocols::{OverlapScores, WorkerId},
};
......@@ -45,6 +46,7 @@ where
workers_with_configs: RuntimeConfigWatch,
selector: Sel,
kv_router_config: &KvRouterConfig,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
worker_type: &'static str,
) -> Result<Self, KvSchedulerError> {
let initial_workers: HashMap<WorkerId, ModelRuntimeConfig> =
......@@ -81,6 +83,8 @@ where
block_size,
selector,
policy,
prefill_load_estimator,
kv_router_config.router_queue_recheck_interval(),
kv_router_config.router_track_prefill_tokens,
component.drt().child_token(),
worker_type,
......
......@@ -143,35 +143,58 @@ pub async fn create_multi_worker_sequences(
mod tests {
use super::*;
use dynamo_runtime::{DistributedRuntime, Runtime};
use tokio::time::Instant;
#[test]
fn test_active_sequences_shared_blocks() {
let block_size = 4;
let mut seq_manager = ActiveSequences::new(block_size);
let decay_now = Instant::now();
seq_manager.add_request("request_1".to_string(), Some(vec![1, 2, 3]), 12, 0, None);
seq_manager.add_request(
"request_1".to_string(),
Some(vec![1, 2, 3]),
12,
0,
None,
decay_now,
);
assert_eq!(seq_manager.active_blocks(), 3);
assert_eq!(seq_manager.active_tokens(), 12);
assert_eq!(seq_manager.active_tokens(decay_now), 12);
seq_manager.add_request("request_2".to_string(), Some(vec![4]), 4, 0, None);
seq_manager.add_request(
"request_2".to_string(),
Some(vec![4]),
4,
0,
None,
decay_now,
);
assert_eq!(seq_manager.active_blocks(), 4);
assert_eq!(seq_manager.active_tokens(), 16);
seq_manager.add_request("request_3".to_string(), Some(vec![1, 2, 3, 4]), 16, 4, None);
assert_eq!(seq_manager.active_tokens(decay_now), 16);
seq_manager.add_request(
"request_3".to_string(),
Some(vec![1, 2, 3, 4]),
16,
4,
None,
decay_now,
);
assert_eq!(seq_manager.active_blocks(), 4);
assert_eq!(seq_manager.active_tokens(), 16);
assert_eq!(seq_manager.active_tokens(decay_now), 16);
seq_manager.free(&"request_2".to_string());
seq_manager.free(&"request_2".to_string(), decay_now);
assert_eq!(seq_manager.active_blocks(), 4);
assert_eq!(seq_manager.active_tokens(), 12);
assert_eq!(seq_manager.active_tokens(decay_now), 12);
seq_manager.free(&"request_3".to_string());
seq_manager.free(&"request_3".to_string(), decay_now);
assert_eq!(seq_manager.active_blocks(), 3);
assert_eq!(seq_manager.active_tokens(), 12);
assert_eq!(seq_manager.active_tokens(decay_now), 12);
seq_manager.free(&"request_1".to_string());
seq_manager.free(&"request_1".to_string(), decay_now);
assert_eq!(seq_manager.active_blocks(), 0);
assert_eq!(seq_manager.active_tokens(), 0);
assert_eq!(seq_manager.active_tokens(decay_now), 0);
}
#[tokio::test]
......@@ -217,43 +240,55 @@ mod tests {
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
seq_manager_1.add_request(SequenceRequest {
request_id: "request_0".to_string(),
token_sequence: Some(vec![0, 1, 2]),
isl: 12,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
worker: WorkerWithDpRank::new(0, 0),
lora_name: None,
})?;
seq_manager_1.add_request(SequenceRequest {
request_id: "request_1".to_string(),
token_sequence: Some(vec![3, 4]),
isl: 8,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
worker: WorkerWithDpRank::new(0, 1),
lora_name: None,
})?;
seq_manager_2.add_request(SequenceRequest {
request_id: "request_2".to_string(),
token_sequence: Some(vec![0, 1, 2, 3]),
isl: 16,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
worker: WorkerWithDpRank::new(1, 0),
lora_name: None,
})?;
seq_manager_1.add_request(
SequenceRequest {
request_id: "request_0".to_string(),
token_sequence: Some(vec![0, 1, 2]),
isl: 12,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
prefill_load_hint: None,
worker: WorkerWithDpRank::new(0, 0),
lora_name: None,
},
Instant::now(),
)?;
seq_manager_1.add_request(
SequenceRequest {
request_id: "request_1".to_string(),
token_sequence: Some(vec![3, 4]),
isl: 8,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
prefill_load_hint: None,
worker: WorkerWithDpRank::new(0, 1),
lora_name: None,
},
Instant::now(),
)?;
seq_manager_2.add_request(
SequenceRequest {
request_id: "request_2".to_string(),
token_sequence: Some(vec![0, 1, 2, 3]),
isl: 16,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
prefill_load_hint: None,
worker: WorkerWithDpRank::new(1, 0),
lora_name: None,
},
Instant::now(),
)?;
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
let blocks_phase1 = seq_manager_1.active_blocks();
let tokens_phase1 = seq_manager_1.active_tokens();
let tokens_phase1 = seq_manager_1.active_tokens(Instant::now());
let worker_0_dp0 = WorkerWithDpRank::new(0, 0);
let worker_0_dp1 = WorkerWithDpRank::new(0, 1);
......@@ -284,15 +319,15 @@ mod tests {
"Worker 1 dp_rank 0 should have 16 active tokens (from request_2 added by seq_manager_2)"
);
seq_manager_1.free(&"request_2".to_string())?;
seq_manager_1.free(&"request_2".to_string(), Instant::now())?;
seq_manager_2.free(&"request_0".to_string())?;
seq_manager_2.free(&"request_1".to_string())?;
seq_manager_2.free(&"request_0".to_string(), Instant::now())?;
seq_manager_2.free(&"request_1".to_string(), Instant::now())?;
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
let blocks_phase2 = seq_manager_2.active_blocks();
let tokens_phase2 = seq_manager_2.active_tokens();
let tokens_phase2 = seq_manager_2.active_tokens(Instant::now());
let all_workers = vec![
WorkerWithDpRank::new(0, 0),
......@@ -364,42 +399,54 @@ mod tests {
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
seq_manager_1.add_request(SequenceRequest {
request_id: "request_0".to_string(),
token_sequence: None,
isl: 12,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
worker: WorkerWithDpRank::from_worker_id(0),
lora_name: None,
})?;
seq_manager_1.add_request(SequenceRequest {
request_id: "request_1".to_string(),
token_sequence: None,
isl: 8,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
worker: WorkerWithDpRank::from_worker_id(1),
lora_name: None,
})?;
seq_manager_2.add_request(SequenceRequest {
request_id: "request_2".to_string(),
token_sequence: None,
isl: 16,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
worker: WorkerWithDpRank::from_worker_id(2),
lora_name: None,
})?;
seq_manager_1.add_request(
SequenceRequest {
request_id: "request_0".to_string(),
token_sequence: None,
isl: 12,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
prefill_load_hint: None,
worker: WorkerWithDpRank::from_worker_id(0),
lora_name: None,
},
Instant::now(),
)?;
seq_manager_1.add_request(
SequenceRequest {
request_id: "request_1".to_string(),
token_sequence: None,
isl: 8,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
prefill_load_hint: None,
worker: WorkerWithDpRank::from_worker_id(1),
lora_name: None,
},
Instant::now(),
)?;
seq_manager_2.add_request(
SequenceRequest {
request_id: "request_2".to_string(),
token_sequence: None,
isl: 16,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
prefill_load_hint: None,
worker: WorkerWithDpRank::from_worker_id(2),
lora_name: None,
},
Instant::now(),
)?;
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
let tokens_phase1 = seq_manager_1.active_tokens();
let tokens_phase1 = seq_manager_1.active_tokens(Instant::now());
let worker_0 = WorkerWithDpRank::from_worker_id(0);
let worker_1 = WorkerWithDpRank::from_worker_id(1);
......@@ -418,17 +465,17 @@ mod tests {
"Worker 2 should have 16 active tokens (from request_2 added by seq_manager_2)"
);
seq_manager_1.mark_prefill_completed(&"request_2".to_string())?;
seq_manager_1.free(&"request_2".to_string())?;
seq_manager_1.mark_prefill_completed(&"request_2".to_string(), Instant::now())?;
seq_manager_1.free(&"request_2".to_string(), Instant::now())?;
seq_manager_2.mark_prefill_completed(&"request_0".to_string())?;
seq_manager_2.mark_prefill_completed(&"request_1".to_string())?;
seq_manager_2.free(&"request_0".to_string())?;
seq_manager_2.free(&"request_1".to_string())?;
seq_manager_2.mark_prefill_completed(&"request_0".to_string(), Instant::now())?;
seq_manager_2.mark_prefill_completed(&"request_1".to_string(), Instant::now())?;
seq_manager_2.free(&"request_0".to_string(), Instant::now())?;
seq_manager_2.free(&"request_1".to_string(), Instant::now())?;
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
let tokens_phase2 = seq_manager_2.active_tokens();
let tokens_phase2 = seq_manager_2.active_tokens(Instant::now());
for worker_id in 0..=2 {
let worker = WorkerWithDpRank::from_worker_id(worker_id);
......
......@@ -344,6 +344,7 @@ mod integration_tests {
dynamo_llm::entrypoint::RouterConfig::default(),
0, // migration_limit
None,
None,
service.state().metrics_clone(),
);
// Start watching for model registrations via discovery interface
......
......@@ -22,6 +22,7 @@ anyhow = { workspace = true }
dashmap = { workspace = true }
derive_builder = { workspace = true }
derive-getters = { workspace = true }
indicatif = "0.18"
rand = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
......
......@@ -29,8 +29,8 @@ pub trait DecodeInterpolator: Send + Sync {
/// Implementors call the Python AIC SDK via PyO3 GIL.
pub trait AicCallback: Send + Sync {
/// Predict prefill latency in ms.
/// Parameters: (batch_size, isl, prefix, osl)
fn predict_prefill(&self, batch_size: usize, isl: usize, prefix: usize, osl: usize) -> f64;
/// Parameters: (batch_size, effective_isl, prefix)
fn predict_prefill(&self, batch_size: usize, effective_isl: usize, prefix: usize) -> f64;
/// Predict decode (generation) latency in ms.
/// Parameters: (batch_size, isl, osl)
......@@ -83,7 +83,7 @@ pub enum PerfModel {
decode_interp: Arc<dyn DecodeInterpolator>,
},
/// AI Configurator SDK calls via Python callback.
/// Passes full parameters (batch_size, isl, prefix, osl) for maximum accuracy.
/// Passes the reduced prefill inputs (batch_size, effective_isl, prefix).
Aiconfigurator { callback: Arc<dyn AicCallback> },
}
......@@ -217,7 +217,7 @@ impl PerfModel {
/// Callers always pass all parameters; each variant uses what it needs:
/// - Polynomial/Interpolated: uses total new tokens across the batch
/// (`batch_size * (isl - prefix)`), modeling GPU processing total tokens in parallel
/// - Aiconfigurator: passes (batch_size, isl, prefix) directly to the AIC SDK
/// - Aiconfigurator: passes (batch_size, isl - prefix, prefix) to the AIC SDK
pub fn predict_prefill_time(&self, batch_size: usize, isl: usize, prefix: usize) -> f64 {
let new_tokens_per_req = isl.saturating_sub(prefix);
let time = match self {
......@@ -231,7 +231,7 @@ impl PerfModel {
prefill_interp.interp(tokens).unwrap_or(0.0)
}
PerfModel::Aiconfigurator { callback } => {
callback.predict_prefill(batch_size, isl, prefix, 1)
callback.predict_prefill(batch_size, new_tokens_per_req, prefix)
}
};
time.max(0.0)
......
......@@ -82,16 +82,16 @@ pub struct WorkloadDriver {
}
impl WorkloadDriver {
pub(crate) fn new_trace(trace: Trace) -> Result<Self> {
Self::new(trace, DriverMode::Trace)
pub(crate) fn new_trace(trace: Trace, engine_block_size: usize) -> Result<Self> {
Self::new(trace, engine_block_size, DriverMode::Trace)
}
pub(crate) fn new_concurrency(trace: Trace) -> Result<Self> {
Self::new(trace, DriverMode::Concurrency)
pub(crate) fn new_concurrency(trace: Trace, engine_block_size: usize) -> Result<Self> {
Self::new(trace, engine_block_size, DriverMode::Concurrency)
}
fn new(trace: Trace, mode: DriverMode) -> Result<Self> {
let block_size = trace.block_size;
fn new(trace: Trace, engine_block_size: usize, mode: DriverMode) -> Result<Self> {
let trace_block_size = trace.block_size;
let sessions: Vec<SessionRuntime> = trace
.sessions
.into_iter()
......@@ -105,10 +105,11 @@ impl WorkloadDriver {
.into_iter()
.map(|turn| -> Result<TurnRuntime> {
Ok(TurnRuntime {
tokens: turn.synthesize_tokens(block_size)?,
tokens: turn.synthesize_tokens(trace_block_size)?,
max_output_tokens: turn.max_output_tokens,
delay_after_previous_ms: turn.delay_after_previous_ms,
replay_hashes: turn.to_replay_hashes(block_size)?,
replay_hashes: turn
.to_replay_hashes(trace_block_size, engine_block_size)?,
})
})
.collect::<Result<Vec<_>>>()?;
......@@ -260,4 +261,11 @@ impl WorkloadDriver {
.iter()
.all(|session| session.next_turn_index >= session.turns.len())
}
pub fn total_turns(&self) -> usize {
self.sessions
.iter()
.map(|session| session.turns.len())
.sum()
}
}
......@@ -112,7 +112,7 @@ fn test_turn_replay_hashes_match_full_blocks_only() {
let request = turn
.to_direct_request(4, Uuid::from_u128(1), Some(5.0))
.unwrap();
let replay_hashes = turn.to_replay_hashes(4).unwrap();
let replay_hashes = turn.to_replay_hashes(4, 4).unwrap();
let expected_local =
compute_block_hash_for_seq(&request.tokens, 4, BlockHashOptions::default());
......@@ -124,6 +124,30 @@ fn test_turn_replay_hashes_match_full_blocks_only() {
assert_eq!(replay_hashes.local_block_hashes.len(), 1);
}
#[test]
fn test_turn_replay_hashes_support_distinct_trace_and_engine_block_sizes() {
let turn = TurnTrace {
input_length: 6,
max_output_tokens: 3,
hash_ids: vec![1, 2],
delay_after_previous_ms: 0.0,
};
let request = turn
.to_direct_request(4, Uuid::from_u128(2), Some(5.0))
.unwrap();
let replay_hashes = turn.to_replay_hashes(4, 2).unwrap();
let expected_local =
compute_block_hash_for_seq(&request.tokens, 2, BlockHashOptions::default());
assert_eq!(replay_hashes.local_block_hashes, expected_local);
assert_eq!(
replay_hashes.sequence_hashes,
compute_seq_hash_for_block(&expected_local)
);
assert_eq!(replay_hashes.local_block_hashes.len(), 3);
}
#[test]
fn test_partition_by_session_round_robin_keeps_sessions_intact() {
let trace = Trace::synthetic(SyntheticTraceSpec {
......@@ -406,7 +430,7 @@ fn test_trace_driver_round_trips_turn_semantics_into_ready_requests() {
first.replay_hashes.as_ref(),
Some(
&expected.sessions[0].turns[0]
.to_replay_hashes(expected.block_size)
.to_replay_hashes(expected.block_size, expected.block_size)
.unwrap()
)
);
......@@ -443,7 +467,7 @@ fn test_trace_driver_round_trips_turn_semantics_into_ready_requests() {
second.replay_hashes.as_ref(),
Some(
&expected.sessions[1].turns[0]
.to_replay_hashes(expected.block_size)
.to_replay_hashes(expected.block_size, expected.block_size)
.unwrap()
)
);
......@@ -470,7 +494,7 @@ fn test_trace_driver_round_trips_turn_semantics_into_ready_requests() {
third.replay_hashes.as_ref(),
Some(
&expected.sessions[0].turns[1]
.to_replay_hashes(expected.block_size)
.to_replay_hashes(expected.block_size, expected.block_size)
.unwrap()
)
);
......@@ -488,3 +512,39 @@ fn test_trace_driver_round_trips_turn_semantics_into_ready_requests() {
expected_third_request.arrival_timestamp_ms
);
}
#[test]
fn test_trace_driver_rechunks_trace_blocks_into_engine_blocks() {
let trace = Trace {
block_size: 4,
sessions: vec![SessionTrace {
session_id: "session-a".to_string(),
first_arrival_timestamp_ms: Some(10.0),
turns: vec![TurnTrace {
input_length: 6,
max_output_tokens: 2,
hash_ids: vec![1, 2],
delay_after_previous_ms: 0.0,
}],
}],
};
let mut driver = trace.into_trace_driver_with_block_size(2).unwrap();
let ready = driver.pop_ready(10.0, usize::MAX);
assert_eq!(ready.len(), 1);
let ready = &ready[0];
assert_eq!(ready.request.tokens, vec![1, 1, 1, 1, 2, 2]);
assert_eq!(
ready.replay_hashes.as_ref(),
Some(
&TurnTrace {
input_length: 6,
max_output_tokens: 2,
hash_ids: vec![1, 2],
delay_after_previous_ms: 0.0,
}
.to_replay_hashes(4, 2)
.unwrap()
)
);
}
......@@ -9,7 +9,8 @@ use std::path::Path;
use anyhow::{Context, Result, anyhow, bail};
use dynamo_kv_router::LocalBlockHash;
use dynamo_kv_router::protocols::{
ExternalSequenceBlockHash, WorkerId, XXH3_SEED, compute_seq_hash_for_block,
BlockHashOptions, ExternalSequenceBlockHash, WorkerId, XXH3_SEED, compute_block_hash_for_seq,
compute_seq_hash_for_block,
};
use dynamo_tokens::compute_hash_v2;
use rand::rngs::StdRng;
......@@ -45,27 +46,27 @@ struct RawMooncakeRecord {
}
impl TurnTrace {
fn validate_block_size_and_capacity(&self, block_size: usize) -> Result<()> {
if block_size == 0 {
bail!("block_size must be greater than 0");
fn validate_block_size_and_capacity(&self, trace_block_size: usize) -> Result<()> {
if trace_block_size == 0 {
bail!("trace_block_size must be greater than 0");
}
if self.hash_ids.len() * block_size < self.input_length {
if self.hash_ids.len() * trace_block_size < self.input_length {
bail!(
"input_length {} exceeds synthesized capacity {}",
self.input_length,
self.hash_ids.len() * block_size
self.hash_ids.len() * trace_block_size
);
}
Ok(())
}
pub(crate) fn synthesize_tokens(&self, block_size: usize) -> Result<Vec<u32>> {
self.validate_block_size_and_capacity(block_size)?;
pub(crate) fn synthesize_tokens(&self, trace_block_size: usize) -> Result<Vec<u32>> {
self.validate_block_size_and_capacity(trace_block_size)?;
let mut tokens = Vec::with_capacity(self.input_length);
for &hash_id in &self.hash_ids {
let token_id = hash_id as u32;
tokens.extend((0..block_size).map(|_| token_id));
tokens.extend((0..trace_block_size).map(|_| token_id));
if tokens.len() >= self.input_length {
tokens.truncate(self.input_length);
break;
......@@ -85,11 +86,11 @@ impl TurnTrace {
pub fn to_direct_request(
&self,
block_size: usize,
trace_block_size: usize,
request_uuid: Uuid,
arrival_timestamp_ms: Option<f64>,
) -> Result<DirectRequest> {
let tokens = self.synthesize_tokens(block_size)?;
let tokens = self.synthesize_tokens(trace_block_size)?;
Ok(DirectRequest {
tokens,
max_output_tokens: self.max_output_tokens,
......@@ -99,16 +100,20 @@ impl TurnTrace {
})
}
pub fn to_replay_hashes(&self, block_size: usize) -> Result<ReplayRequestHashes> {
self.validate_block_size_and_capacity(block_size)?;
let num_full_blocks = self.input_length / block_size;
let local_block_hashes = self
.hash_ids
.iter()
.take(num_full_blocks)
.map(|&hash_id| local_block_hash_from_id(hash_id, block_size))
.collect::<Vec<_>>();
pub fn to_replay_hashes(
&self,
trace_block_size: usize,
engine_block_size: usize,
) -> Result<ReplayRequestHashes> {
if engine_block_size == 0 {
bail!("engine_block_size must be greater than 0");
}
let tokens = self.synthesize_tokens(trace_block_size)?;
let engine_block_size =
u32::try_from(engine_block_size).context("engine_block_size does not fit in u32")?;
let local_block_hashes =
compute_block_hash_for_seq(&tokens, engine_block_size, BlockHashOptions::default());
let sequence_hashes = compute_seq_hash_for_block(&local_block_hashes);
Ok(ReplayRequestHashes {
......@@ -119,9 +124,9 @@ impl TurnTrace {
}
impl Trace {
pub fn from_mooncake(path: &Path, block_size: usize) -> Result<Self> {
if block_size == 0 {
bail!("block_size must be greater than 0");
pub fn from_mooncake(path: &Path, trace_block_size: usize) -> Result<Self> {
if trace_block_size == 0 {
bail!("trace_block_size must be greater than 0");
}
let file = File::open(path)
......@@ -157,7 +162,9 @@ impl Trace {
let hash_ids = raw
.hash_ids
.ok_or_else(|| anyhow!("trace line {} is missing hash_ids", line_idx + 1))?;
let input_length = raw.input_length.unwrap_or(hash_ids.len() * block_size);
let input_length = raw
.input_length
.unwrap_or(hash_ids.len() * trace_block_size);
let output_length = raw
.output_length
.ok_or_else(|| anyhow!("trace line {} is missing output_length", line_idx + 1))?;
......@@ -214,12 +221,12 @@ impl Trace {
);
}
if hash_ids.len() * block_size < input_length {
if hash_ids.len() * trace_block_size < input_length {
bail!(
"trace line {} input_length {} exceeds synthesized capacity {}",
line_idx + 1,
input_length,
hash_ids.len() * block_size
hash_ids.len() * trace_block_size
);
}
......@@ -239,7 +246,7 @@ impl Trace {
}
Ok(Self {
block_size,
block_size: trace_block_size,
sessions,
})
}
......@@ -598,12 +605,30 @@ impl Trace {
pub fn into_trace_driver(self) -> Result<WorkloadDriver> {
self.validate_for_trace_mode()?;
WorkloadDriver::new_trace(self)
let engine_block_size = self.block_size;
WorkloadDriver::new_trace(self, engine_block_size)
}
pub fn into_concurrency_driver(self) -> Result<WorkloadDriver> {
self.validate_for_concurrency_mode()?;
WorkloadDriver::new_concurrency(self)
let engine_block_size = self.block_size;
WorkloadDriver::new_concurrency(self, engine_block_size)
}
pub fn into_trace_driver_with_block_size(
self,
engine_block_size: usize,
) -> Result<WorkloadDriver> {
self.validate_for_trace_mode()?;
WorkloadDriver::new_trace(self, engine_block_size)
}
pub fn into_concurrency_driver_with_block_size(
self,
engine_block_size: usize,
) -> Result<WorkloadDriver> {
self.validate_for_concurrency_mode()?;
WorkloadDriver::new_concurrency(self, engine_block_size)
}
fn validate(&self, allow_missing_first_timestamp: bool) -> Result<()> {
......
......@@ -14,7 +14,8 @@ use super::validate::{
validate_online_concurrency_args, validate_online_replay_args,
};
use super::{
OfflineDisaggReplayConfig, ReplayRouterMode, ReplayWorkerArtifacts, TraceSimulationReport,
OfflineDisaggReplayConfig, ReplayPrefillLoadEstimator, ReplayRouterMode, ReplayWorkerArtifacts,
TraceSimulationReport,
};
use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::loadgen::Trace;
......@@ -30,36 +31,43 @@ pub fn generate_trace_worker_artifacts_offline(
pub fn simulate_trace_file(
args: MockEngineArgs,
trace_path: &Path,
trace_block_size: usize,
num_workers: usize,
arrival_speedup_ratio: f64,
) -> Result<TraceSimulationReport> {
simulate_trace_file_with_router_mode(
args,
None,
None,
trace_path,
trace_block_size,
num_workers,
arrival_speedup_ratio,
ReplayRouterMode::RoundRobin,
)
}
#[allow(clippy::too_many_arguments)]
pub fn simulate_trace_file_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace_path: &Path,
trace_block_size: usize,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
validate_offline_replay_args(&args, num_workers, router_mode)?;
let trace = Trace::from_mooncake(trace_path, args.block_size)?
let trace = Trace::from_mooncake(trace_path, trace_block_size)?
.normalize_session_starts()?
.speed_up_timing(arrival_speedup_ratio)?;
let started_at = Instant::now();
let report = crate::replay::offline::simulate_trace_workload(
args,
router_config,
prefill_load_estimator,
trace,
num_workers,
router_mode,
......@@ -70,19 +78,22 @@ pub fn simulate_trace_file_with_router_mode(
pub fn simulate_trace_file_disagg_with_router_mode(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace_path: &Path,
trace_block_size: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let config = config.normalized()?;
validate_offline_disagg_replay_args(&config, router_mode)?;
let trace = Trace::from_mooncake(trace_path, config.prefill_args.block_size)?
let trace = Trace::from_mooncake(trace_path, trace_block_size)?
.normalize_session_starts()?
.speed_up_timing(arrival_speedup_ratio)?;
let started_at = Instant::now();
let report = crate::replay::offline::simulate_trace_workload_disagg(
config,
router_config,
prefill_load_estimator,
trace,
router_mode,
)?;
......@@ -92,33 +103,46 @@ pub fn simulate_trace_file_disagg_with_router_mode(
pub fn simulate_trace_live_file(
args: MockEngineArgs,
trace_path: &Path,
trace_block_size: usize,
num_workers: usize,
arrival_speedup_ratio: f64,
) -> Result<TraceSimulationReport> {
simulate_trace_live_file_with_router_mode(
args,
None,
None,
trace_path,
trace_block_size,
num_workers,
arrival_speedup_ratio,
ReplayRouterMode::RoundRobin,
)
}
#[allow(clippy::too_many_arguments)]
pub fn simulate_trace_live_file_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace_path: &Path,
trace_block_size: usize,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
validate_online_replay_args(&args, num_workers)?;
let trace = Trace::from_mooncake(trace_path, args.block_size)?
let trace = Trace::from_mooncake(trace_path, trace_block_size)?
.normalize_session_starts()?
.speed_up_timing(arrival_speedup_ratio)?;
online::simulate_trace_workload(args, router_config, trace, num_workers, router_mode)
online::simulate_trace_workload(
args,
router_config,
prefill_load_estimator,
trace,
num_workers,
router_mode,
)
}
pub fn simulate_trace_requests(
......@@ -130,6 +154,7 @@ pub fn simulate_trace_requests(
simulate_trace_requests_with_router_mode(
args,
None,
None,
requests,
num_workers,
arrival_speedup_ratio,
......@@ -140,6 +165,7 @@ pub fn simulate_trace_requests(
pub fn simulate_trace_requests_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
......@@ -155,6 +181,7 @@ pub fn simulate_trace_requests_with_router_mode(
let report = crate::replay::offline::simulate_trace(
args,
router_config,
prefill_load_estimator,
requests,
num_workers,
arrival_speedup_ratio,
......@@ -166,6 +193,7 @@ pub fn simulate_trace_requests_with_router_mode(
pub fn simulate_trace_requests_disagg_with_router_mode(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
......@@ -180,6 +208,7 @@ pub fn simulate_trace_requests_disagg_with_router_mode(
let report = crate::replay::offline::simulate_trace_disagg(
config,
router_config,
prefill_load_estimator,
requests,
arrival_speedup_ratio,
router_mode,
......@@ -196,6 +225,7 @@ pub fn simulate_trace_live_requests(
simulate_trace_live_requests_with_router_mode(
args,
None,
None,
requests,
num_workers,
arrival_speedup_ratio,
......@@ -206,6 +236,7 @@ pub fn simulate_trace_live_requests(
pub fn simulate_trace_live_requests_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>,
num_workers: usize,
arrival_speedup_ratio: f64,
......@@ -220,6 +251,7 @@ pub fn simulate_trace_live_requests_with_router_mode(
online::simulate_trace_requests(
args,
router_config,
prefill_load_estimator,
requests,
num_workers,
arrival_speedup_ratio,
......@@ -230,34 +262,41 @@ pub fn simulate_trace_live_requests_with_router_mode(
pub fn simulate_concurrency_file(
args: MockEngineArgs,
trace_path: &Path,
trace_block_size: usize,
max_in_flight: usize,
num_workers: usize,
) -> Result<TraceSimulationReport> {
simulate_concurrency_file_with_router_mode(
args,
None,
None,
trace_path,
trace_block_size,
max_in_flight,
num_workers,
ReplayRouterMode::RoundRobin,
)
}
#[allow(clippy::too_many_arguments)]
pub fn simulate_concurrency_file_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace_path: &Path,
trace_block_size: usize,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
validate_offline_concurrency_args(&args, num_workers, max_in_flight, router_mode)?;
let trace = Trace::from_mooncake(trace_path, args.block_size)?;
let trace = Trace::from_mooncake(trace_path, trace_block_size)?;
let started_at = Instant::now();
let report = simulate_concurrency_workload_with_router_mode(
args,
router_config,
prefill_load_estimator,
trace,
max_in_flight,
num_workers,
......@@ -269,17 +308,20 @@ pub fn simulate_concurrency_file_with_router_mode(
pub fn simulate_concurrency_file_disagg_with_router_mode(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace_path: &Path,
trace_block_size: usize,
max_in_flight: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let config = config.normalized()?;
validate_offline_disagg_concurrency_args(&config, max_in_flight, router_mode)?;
let trace = Trace::from_mooncake(trace_path, config.prefill_args.block_size)?;
let trace = Trace::from_mooncake(trace_path, trace_block_size)?;
let started_at = Instant::now();
let report = simulate_concurrency_workload_disagg_with_router_mode(
config,
router_config,
prefill_load_estimator,
trace,
max_in_flight,
router_mode,
......@@ -290,33 +332,40 @@ pub fn simulate_concurrency_file_disagg_with_router_mode(
pub fn simulate_concurrency_live_file(
args: MockEngineArgs,
trace_path: &Path,
trace_block_size: usize,
max_in_flight: usize,
num_workers: usize,
) -> Result<TraceSimulationReport> {
simulate_concurrency_live_file_with_router_mode(
args,
None,
None,
trace_path,
trace_block_size,
max_in_flight,
num_workers,
ReplayRouterMode::RoundRobin,
)
}
#[allow(clippy::too_many_arguments)]
pub fn simulate_concurrency_live_file_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace_path: &Path,
trace_block_size: usize,
max_in_flight: usize,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
validate_online_concurrency_args(&args, num_workers, max_in_flight)?;
let trace = Trace::from_mooncake(trace_path, args.block_size)?;
let trace = Trace::from_mooncake(trace_path, trace_block_size)?;
online::simulate_concurrency_workload(
args,
router_config,
prefill_load_estimator,
trace,
max_in_flight,
num_workers,
......@@ -333,6 +382,7 @@ pub fn simulate_concurrency_live_requests(
simulate_concurrency_live_requests_with_router_mode(
args,
None,
None,
requests,
max_in_flight,
num_workers,
......@@ -343,6 +393,7 @@ pub fn simulate_concurrency_live_requests(
pub fn simulate_concurrency_live_requests_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
......@@ -357,6 +408,7 @@ pub fn simulate_concurrency_live_requests_with_router_mode(
online::simulate_concurrency_requests(
args,
router_config,
prefill_load_estimator,
requests,
max_in_flight,
num_workers,
......@@ -373,6 +425,7 @@ pub fn simulate_concurrency_requests(
simulate_concurrency_requests_with_router_mode(
args,
None,
None,
requests,
max_in_flight,
num_workers,
......@@ -383,6 +436,7 @@ pub fn simulate_concurrency_requests(
pub fn simulate_concurrency_requests_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
num_workers: usize,
......@@ -397,6 +451,7 @@ pub fn simulate_concurrency_requests_with_router_mode(
crate::replay::offline::simulate_concurrency(
args,
router_config,
prefill_load_estimator,
requests,
max_in_flight,
num_workers,
......@@ -407,6 +462,7 @@ pub fn simulate_concurrency_requests_with_router_mode(
pub fn simulate_concurrency_requests_disagg_with_router_mode(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
requests: Vec<DirectRequest>,
max_in_flight: usize,
router_mode: ReplayRouterMode,
......@@ -420,6 +476,7 @@ pub fn simulate_concurrency_requests_disagg_with_router_mode(
crate::replay::offline::simulate_concurrency_disagg(
config,
router_config,
prefill_load_estimator,
requests,
max_in_flight,
router_mode,
......@@ -434,6 +491,7 @@ pub fn simulate_trace_workload(
simulate_trace_workload_with_router_mode(
args,
None,
None,
trace,
num_workers,
ReplayRouterMode::RoundRobin,
......@@ -443,6 +501,7 @@ pub fn simulate_trace_workload(
pub fn simulate_trace_workload_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
......@@ -453,6 +512,7 @@ pub fn simulate_trace_workload_with_router_mode(
let report = crate::replay::offline::simulate_trace_workload(
args,
router_config,
prefill_load_estimator,
trace,
num_workers,
router_mode,
......@@ -463,6 +523,7 @@ pub fn simulate_trace_workload_with_router_mode(
pub fn simulate_trace_workload_disagg_with_router_mode(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
......@@ -472,6 +533,7 @@ pub fn simulate_trace_workload_disagg_with_router_mode(
let report = crate::replay::offline::simulate_trace_workload_disagg(
config,
router_config,
prefill_load_estimator,
trace,
router_mode,
)?;
......@@ -486,6 +548,7 @@ pub fn simulate_trace_live_workload(
simulate_trace_live_workload_with_router_mode(
args,
None,
None,
trace,
num_workers,
ReplayRouterMode::RoundRobin,
......@@ -495,13 +558,21 @@ pub fn simulate_trace_live_workload(
pub fn simulate_trace_live_workload_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace,
num_workers: usize,
router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
let args = args.normalized()?;
validate_online_replay_args(&args, num_workers)?;
online::simulate_trace_workload(args, router_config, trace, num_workers, router_mode)
online::simulate_trace_workload(
args,
router_config,
prefill_load_estimator,
trace,
num_workers,
router_mode,
)
}
pub fn simulate_concurrency_workload(
......@@ -513,6 +584,7 @@ pub fn simulate_concurrency_workload(
simulate_concurrency_workload_with_router_mode(
args,
None,
None,
trace,
max_in_flight,
num_workers,
......@@ -523,6 +595,7 @@ pub fn simulate_concurrency_workload(
pub fn simulate_concurrency_workload_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
......@@ -533,6 +606,7 @@ pub fn simulate_concurrency_workload_with_router_mode(
crate::replay::offline::simulate_concurrency_workload(
args,
router_config,
prefill_load_estimator,
trace,
max_in_flight,
num_workers,
......@@ -543,6 +617,7 @@ pub fn simulate_concurrency_workload_with_router_mode(
pub fn simulate_concurrency_workload_disagg_with_router_mode(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace,
max_in_flight: usize,
router_mode: ReplayRouterMode,
......@@ -552,6 +627,7 @@ pub fn simulate_concurrency_workload_disagg_with_router_mode(
crate::replay::offline::simulate_concurrency_workload_disagg(
config,
router_config,
prefill_load_estimator,
trace,
max_in_flight,
router_mode,
......@@ -567,6 +643,7 @@ pub fn simulate_concurrency_live_workload(
simulate_concurrency_live_workload_with_router_mode(
args,
None,
None,
trace,
max_in_flight,
num_workers,
......@@ -577,6 +654,7 @@ pub fn simulate_concurrency_live_workload(
pub fn simulate_concurrency_live_workload_with_router_mode(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace: Trace,
max_in_flight: usize,
num_workers: usize,
......@@ -587,6 +665,7 @@ pub fn simulate_concurrency_live_workload_with_router_mode(
online::simulate_concurrency_workload(
args,
router_config,
prefill_load_estimator,
trace,
max_in_flight,
num_workers,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment