Unverified Commit d94b350d authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore(kv-router): make sequences stop doing token math (#8260)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent c2ec3359
...@@ -7,7 +7,7 @@ use common::*; ...@@ -7,7 +7,7 @@ use common::*;
use clap::Parser; use clap::Parser;
use common::NoopSequencePublisher; use common::NoopSequencePublisher;
use dynamo_kv_router::protocols::WorkerWithDpRank; use dynamo_kv_router::protocols::{PrefillLoadHint, WorkerWithDpRank};
use dynamo_kv_router::{ActiveSequencesMultiWorker, OverlapScores, SequenceRequest}; use dynamo_kv_router::{ActiveSequencesMultiWorker, OverlapScores, SequenceRequest};
use dynamo_mocker::loadgen::Trace; use dynamo_mocker::loadgen::Trace;
use dynamo_tokens::SequenceHash; use dynamo_tokens::SequenceHash;
...@@ -389,11 +389,12 @@ async fn apply_entry( ...@@ -389,11 +389,12 @@ async fn apply_entry(
SequenceRequest { SequenceRequest {
request_id, request_id,
token_sequence: Some(block_hashes), token_sequence: Some(block_hashes),
isl,
overlap: 0,
track_prefill_tokens: true, track_prefill_tokens: true,
expected_output_tokens: Some(output_length as u32), expected_output_tokens: Some(output_length as u32),
prefill_load_hint: None, prefill_load_hint: Some(PrefillLoadHint {
initial_effective_prefill_tokens: isl,
expected_prefill_duration: None,
}),
worker, worker,
lora_name: None, lora_name: None,
}, },
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Shared stale-child cleanup machinery for rooted tree structures.
//!
//! Provides a throttled, one-in-flight sweep that unlinks empty child nodes
//! from their parent. It is used by [`ConcurrentRadixTree`](super::concurrent_radix_tree),
//! [`ConcurrentRadixTreeCompressed`](super::concurrent_radix_tree_compressed)
//! and the sequence-side
//! [`PromptMembershipTrie`](super::sequences::prompt_membership_trie::PromptMembershipTrie),
//! each of which embeds a [`CleanupState`] and implements [`CleanableNode`]
//! for its node type.
//!
//! # Sweep semantics
//!
//! [`sweep_stale_children`] is a reverse-BFS prune:
//! - BFS from the root under read locks, collecting `(parent_weak, key, child_weak)` edges.
//! - Iterate edges deepest-first so children are swept before parents.
//! - For each edge: upgrade weaks, take the parent write lock, verify the
//! child pointer still matches, `try_write` the child, and unlink only when
//! the child has no workers, no children, and `Arc::strong_count == 2`
//! (parent map ref + our local upgrade). The strong-count gate is what
//! prevents reclaiming a node that a concurrent `find_matches` is currently
//! traversing — such edges are skipped and retried on the next sweep.
use std::collections::VecDeque;
use std::hash::Hash;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Weak};
use std::time::Instant;
use parking_lot::RwLock;
use rustc_hash::FxHashMap;
pub const CLEANUP_INTERVAL_MS: u64 = 5 * 60 * 1000;
/// Node type that participates in the reverse-BFS cleanup sweep.
pub trait CleanableNode: Sized + Send + Sync + 'static {
/// Key type used in this node's children map (e.g. `LocalBlockHash`,
/// `SequenceHash`).
type ChildKey: Copy + Eq + Hash + Send + Sync + 'static;
/// True if this node still carries worker state that pins it in the tree.
fn has_any_workers(&self) -> bool;
/// Read-only view of this node's children keyed by the first edge element.
fn children(&self) -> &FxHashMap<Self::ChildKey, Arc<RwLock<Self>>>;
/// Unlink a child edge.
fn remove_child(&mut self, key: &Self::ChildKey);
}
pub struct CleanupState {
clock_origin: Instant,
last_cleanup_elapsed_ms: AtomicU64,
scheduled: AtomicBool,
}
impl CleanupState {
pub fn new() -> Self {
Self {
clock_origin: Instant::now(),
last_cleanup_elapsed_ms: AtomicU64::new(0),
scheduled: AtomicBool::new(false),
}
}
pub fn elapsed_ms(&self) -> u64 {
self.clock_origin.elapsed().as_millis() as u64
}
pub fn try_schedule(&self) -> bool {
let now_ms = self.elapsed_ms();
let last_ms = self.last_cleanup_elapsed_ms.load(Ordering::Relaxed);
if now_ms.saturating_sub(last_ms) < CLEANUP_INTERVAL_MS {
return false;
}
self.scheduled
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
}
pub fn cancel(&self) {
self.scheduled.store(false, Ordering::Release);
}
}
impl Default for CleanupState {
fn default() -> Self {
Self::new()
}
}
pub struct CleanupGuard<'a> {
state: &'a CleanupState,
completed_elapsed_ms: Option<u64>,
}
impl<'a> CleanupGuard<'a> {
pub fn new(state: &'a CleanupState) -> Self {
Self {
state,
completed_elapsed_ms: None,
}
}
pub fn mark_completed(&mut self) {
self.completed_elapsed_ms = Some(self.state.elapsed_ms());
}
}
impl Drop for CleanupGuard<'_> {
fn drop(&mut self) {
if let Some(elapsed_ms) = self.completed_elapsed_ms {
self.state
.last_cleanup_elapsed_ms
.store(elapsed_ms, Ordering::Relaxed);
}
self.state.scheduled.store(false, Ordering::Release);
}
}
struct CleanupEdge<N: CleanableNode> {
parent: Weak<RwLock<N>>,
key: N::ChildKey,
child: Weak<RwLock<N>>,
}
/// Reverse-BFS sweep that unlinks empty, unreferenced leaf nodes from the tree.
pub fn sweep_stale_children<N: CleanableNode>(root: &Arc<RwLock<N>>) {
let mut queue: VecDeque<Arc<RwLock<N>>> = VecDeque::from([root.clone()]);
let mut edges: Vec<CleanupEdge<N>> = Vec::new();
while let Some(parent) = queue.pop_front() {
let guard = parent.read();
for (&key, child) in guard.children() {
queue.push_back(child.clone());
edges.push(CleanupEdge {
parent: Arc::downgrade(&parent),
key,
child: Arc::downgrade(child),
});
}
}
for edge in edges.into_iter().rev() {
let (Some(parent), Some(child)) = (edge.parent.upgrade(), edge.child.upgrade()) else {
continue;
};
let mut parent_guard = parent.write();
let still_attached = parent_guard
.children()
.get(&edge.key)
.is_some_and(|current| Arc::ptr_eq(current, &child));
if !still_attached {
continue;
}
let Some(child_guard) = child.try_write() else {
continue;
};
if child_guard.has_any_workers() || !child_guard.children().is_empty() {
continue;
}
if Arc::strong_count(&child) != 2 {
continue;
}
parent_guard.remove_child(&edge.key);
drop(child_guard);
}
}
...@@ -34,6 +34,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; ...@@ -34,6 +34,7 @@ use std::sync::atomic::{AtomicUsize, Ordering};
use super::{EventKind, KvIndexerMetrics, SyncIndexer, WorkerTask}; use super::{EventKind, KvIndexerMetrics, SyncIndexer, WorkerTask};
use crate::active_set::reconcile_active_workers; use crate::active_set::reconcile_active_workers;
use crate::cleanup::{self, CleanableNode, CleanupGuard, CleanupState};
use crate::protocols::*; use crate::protocols::*;
/// Thread-safe shared reference to a Block. /// Thread-safe shared reference to a Block.
...@@ -83,6 +84,22 @@ impl Block { ...@@ -83,6 +84,22 @@ impl Block {
} }
} }
impl CleanableNode for Block {
type ChildKey = LocalBlockHash;
fn has_any_workers(&self) -> bool {
!self.workers.is_empty()
}
fn children(&self) -> &FxHashMap<LocalBlockHash, SharedBlock> {
&self.children
}
fn remove_child(&mut self, key: &LocalBlockHash) {
self.children.remove(key);
}
}
/// Thread-safe radix tree for concurrent KV cache lookups. /// Thread-safe radix tree for concurrent KV cache lookups.
/// ///
/// Unlike `RadixTree` which uses `Rc<RefCell<>>` and requires single-threaded access, /// Unlike `RadixTree` which uses `Rc<RefCell<>>` and requires single-threaded access,
...@@ -109,6 +126,7 @@ pub struct ConcurrentRadixTree { ...@@ -109,6 +126,7 @@ pub struct ConcurrentRadixTree {
root: SharedBlock, root: SharedBlock,
tree_sizes: DashMap<WorkerWithDpRank, AtomicUsize, FxBuildHasher>, tree_sizes: DashMap<WorkerWithDpRank, AtomicUsize, FxBuildHasher>,
cleanup: CleanupState,
} }
impl Default for ConcurrentRadixTree { impl Default for ConcurrentRadixTree {
...@@ -147,6 +165,7 @@ impl ConcurrentRadixTree { ...@@ -147,6 +165,7 @@ impl ConcurrentRadixTree {
Self { Self {
root: Arc::new(RwLock::new(Block::new())), root: Arc::new(RwLock::new(Block::new())),
tree_sizes: DashMap::with_hasher(FxBuildHasher), tree_sizes: DashMap::with_hasher(FxBuildHasher),
cleanup: CleanupState::new(),
} }
} }
...@@ -665,6 +684,20 @@ impl SyncIndexer for ConcurrentRadixTree { ...@@ -665,6 +684,20 @@ impl SyncIndexer for ConcurrentRadixTree {
self.find_matches_impl(sequence, early_exit) self.find_matches_impl(sequence, early_exit)
} }
fn try_schedule_cleanup(&self) -> bool {
self.cleanup.try_schedule()
}
fn cancel_scheduled_cleanup(&self) {
self.cleanup.cancel();
}
fn run_cleanup_task(&self) {
let mut cleanup_guard = CleanupGuard::new(&self.cleanup);
cleanup::sweep_stale_children(&self.root);
cleanup_guard.mark_completed();
}
fn dump_events(&self) -> Option<Vec<RouterEvent>> { fn dump_events(&self) -> Option<Vec<RouterEvent>> {
Some(self.dump_tree_as_events()) Some(self.dump_tree_as_events())
} }
......
...@@ -59,16 +59,16 @@ ...@@ -59,16 +59,16 @@
//! - `new_with_frequency()` is not provided //! - `new_with_frequency()` is not provided
//! - `find_matches` does not populate `OverlapScores.frequencies` //! - `find_matches` does not populate `OverlapScores.frequencies`
use std::sync::{Arc, Weak}; use std::sync::Arc;
use std::time::Instant;
use dashmap::DashMap; use dashmap::DashMap;
use parking_lot::RwLock; use parking_lot::RwLock;
use rustc_hash::{FxBuildHasher, FxHashMap, FxHashSet}; use rustc_hash::{FxBuildHasher, FxHashMap, FxHashSet};
use std::collections::VecDeque; use std::collections::VecDeque;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
use super::{EventKind, KvIndexerMetrics, SyncIndexer, WorkerTask}; use super::{EventKind, KvIndexerMetrics, SyncIndexer, WorkerTask};
use crate::cleanup::{self, CleanableNode, CleanupGuard, CleanupState};
use crate::protocols::*; use crate::protocols::*;
macro_rules! read_lock { macro_rules! read_lock {
...@@ -87,8 +87,6 @@ type SharedNode = Arc<RwLock<Node>>; ...@@ -87,8 +87,6 @@ type SharedNode = Arc<RwLock<Node>>;
/// stored here, keeping the map compact and correct across concurrent splits. /// stored here, keeping the map compact and correct across concurrent splits.
type WorkerLookup = FxHashMap<ExternalSequenceBlockHash, SharedNode>; type WorkerLookup = FxHashMap<ExternalSequenceBlockHash, SharedNode>;
const CLEANUP_INTERVAL_MS: u64 = 5 * 60 * 1000;
/// A node in the concurrent radix tree. /// A node in the concurrent radix tree.
/// ///
/// Stores a compressed edge with per-worker match indices. Workers with full coverage /// Stores a compressed edge with per-worker match indices. Workers with full coverage
...@@ -122,10 +120,6 @@ impl Node { ...@@ -122,10 +120,6 @@ impl Node {
} }
} }
fn has_any_workers(&self) -> bool {
!self.full_edge_workers.is_empty() || !self.worker_cutoffs.is_empty()
}
#[inline] #[inline]
fn current_cutoff(&self, worker: WorkerWithDpRank) -> usize { fn current_cutoff(&self, worker: WorkerWithDpRank) -> usize {
if self.full_edge_workers.contains(&worker) { if self.full_edge_workers.contains(&worker) {
...@@ -226,6 +220,22 @@ impl Node { ...@@ -226,6 +220,22 @@ impl Node {
} }
} }
impl CleanableNode for Node {
type ChildKey = LocalBlockHash;
fn has_any_workers(&self) -> bool {
!self.full_edge_workers.is_empty() || !self.worker_cutoffs.is_empty()
}
fn children(&self) -> &FxHashMap<LocalBlockHash, SharedNode> {
&self.children
}
fn remove_child(&mut self, key: &LocalBlockHash) {
self.children.remove(key);
}
}
/// Data returned by [`ConcurrentRadixTreeCompressed::split_node`] for deferred lookup updates. /// Data returned by [`ConcurrentRadixTreeCompressed::split_node`] for deferred lookup updates.
/// ///
/// Callers must call [`ConcurrentRadixTreeCompressed::apply_split_lookup`] **after** /// Callers must call [`ConcurrentRadixTreeCompressed::apply_split_lookup`] **after**
...@@ -240,64 +250,6 @@ struct RemoveOutcome { ...@@ -240,64 +250,6 @@ struct RemoveOutcome {
stale_hashes: Vec<ExternalSequenceBlockHash>, stale_hashes: Vec<ExternalSequenceBlockHash>,
} }
struct CleanupEdge {
parent: Weak<RwLock<Node>>,
key: LocalBlockHash,
child: Weak<RwLock<Node>>,
}
struct CleanupState {
clock_origin: Instant,
last_cleanup_elapsed_ms: AtomicU64,
scheduled: AtomicBool,
}
impl CleanupState {
fn new() -> Self {
Self {
clock_origin: Instant::now(),
last_cleanup_elapsed_ms: AtomicU64::new(0),
scheduled: AtomicBool::new(false),
}
}
fn elapsed_ms(&self) -> u64 {
self.clock_origin.elapsed().as_millis() as u64
}
fn try_schedule(&self) -> bool {
let now_ms = self.elapsed_ms();
let last_ms = self.last_cleanup_elapsed_ms.load(Ordering::Relaxed);
if now_ms.saturating_sub(last_ms) < CLEANUP_INTERVAL_MS {
return false;
}
self.scheduled
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
}
fn cancel(&self) {
self.scheduled.store(false, Ordering::Release);
}
}
struct CleanupGuard<'a> {
state: &'a CleanupState,
completed_elapsed_ms: Option<u64>,
}
impl Drop for CleanupGuard<'_> {
fn drop(&mut self) {
if let Some(elapsed_ms) = self.completed_elapsed_ms {
self.state
.last_cleanup_elapsed_ms
.store(elapsed_ms, Ordering::Relaxed);
}
self.state.scheduled.store(false, Ordering::Release);
}
}
/// Thread-safe radix tree (compressed trie) for concurrent KV cache lookups. /// Thread-safe radix tree (compressed trie) for concurrent KV cache lookups.
pub struct ConcurrentRadixTreeCompressed { pub struct ConcurrentRadixTreeCompressed {
/// The root of the radix tree. Has an empty edge and only contains children. /// The root of the radix tree. Has an empty edge and only contains children.
...@@ -340,50 +292,6 @@ impl ConcurrentRadixTreeCompressed { ...@@ -340,50 +292,6 @@ impl ConcurrentRadixTreeCompressed {
} }
} }
fn cleanup_stale_children(&self) {
let mut queue = VecDeque::from([self.root.clone()]);
let mut edges = Vec::new();
while let Some(parent) = queue.pop_front() {
let guard = parent.read();
for (&key, child) in &guard.children {
queue.push_back(child.clone());
edges.push(CleanupEdge {
parent: Arc::downgrade(&parent),
key,
child: Arc::downgrade(child),
});
}
}
for edge in edges.into_iter().rev() {
let (Some(parent), Some(child)) = (edge.parent.upgrade(), edge.child.upgrade()) else {
continue;
};
let mut parent_guard = parent.write();
let Some(current) = parent_guard.children.get(&edge.key) else {
continue;
};
if !Arc::ptr_eq(current, &child) {
continue;
}
let Some(child_guard) = child.try_write() else {
continue;
};
if child_guard.has_any_workers() || !child_guard.children.is_empty() {
continue;
}
if Arc::strong_count(&child) != 2 {
continue;
}
parent_guard.children.remove(&edge.key);
drop(child_guard);
}
}
#[cfg(test)] #[cfg(test)]
pub(crate) fn raw_child_edge_count(&self) -> usize { pub(crate) fn raw_child_edge_count(&self) -> usize {
let mut queue = VecDeque::from([self.root.clone()]); let mut queue = VecDeque::from([self.root.clone()]);
...@@ -400,7 +308,7 @@ impl ConcurrentRadixTreeCompressed { ...@@ -400,7 +308,7 @@ impl ConcurrentRadixTreeCompressed {
#[cfg(test)] #[cfg(test)]
pub(crate) fn run_cleanup_for_test(&self) { pub(crate) fn run_cleanup_for_test(&self) {
self.cleanup_stale_children(); cleanup::sweep_stale_children(&self.root);
} }
// ------------------------------------------------------------------ // ------------------------------------------------------------------
...@@ -1378,13 +1286,9 @@ impl SyncIndexer for ConcurrentRadixTreeCompressed { ...@@ -1378,13 +1286,9 @@ impl SyncIndexer for ConcurrentRadixTreeCompressed {
} }
fn run_cleanup_task(&self) { fn run_cleanup_task(&self) {
let mut cleanup_guard = CleanupGuard { let mut cleanup_guard = CleanupGuard::new(&self.cleanup);
state: &self.cleanup, cleanup::sweep_stale_children(&self.root);
completed_elapsed_ms: None, cleanup_guard.mark_completed();
};
self.cleanup_stale_children();
cleanup_guard.completed_elapsed_ms = Some(self.cleanup.elapsed_ms());
} }
fn dump_events(&self) -> Option<Vec<RouterEvent>> { fn dump_events(&self) -> Option<Vec<RouterEvent>> {
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
//! efficient KV cache lookup and routing in distributed LLM inference systems. //! efficient KV cache lookup and routing in distributed LLM inference systems.
mod active_set; mod active_set;
pub(crate) mod cleanup;
pub mod indexer; pub mod indexer;
pub mod protocols; pub mod protocols;
......
...@@ -446,8 +446,6 @@ pub struct PrefillLoadHint { ...@@ -446,8 +446,6 @@ pub struct PrefillLoadHint {
pub enum ActiveSequenceEventData { pub enum ActiveSequenceEventData {
AddRequest { AddRequest {
token_sequence: Option<Vec<SequenceHash>>, token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlap: u32,
#[serde(default = "default_track_prefill_tokens")] #[serde(default = "default_track_prefill_tokens")]
track_prefill_tokens: bool, track_prefill_tokens: bool,
expected_output_tokens: Option<u32>, expected_output_tokens: Option<u32>,
......
...@@ -285,8 +285,6 @@ impl< ...@@ -285,8 +285,6 @@ impl<
SequenceRequest { SequenceRequest {
request_id: request_id.clone(), request_id: request_id.clone(),
token_sequence: request.token_seq, token_sequence: request.token_seq,
isl: request.isl_tokens,
overlap: selection.overlap_blocks,
track_prefill_tokens: request.track_prefill_tokens, track_prefill_tokens: request.track_prefill_tokens,
expected_output_tokens: request.expected_output_tokens, expected_output_tokens: request.expected_output_tokens,
prefill_load_hint, prefill_load_hint,
...@@ -315,24 +313,25 @@ impl< ...@@ -315,24 +313,25 @@ impl<
return None; return None;
} }
let Some(estimator) = &self.prefill_load_estimator else { let expected_prefill_duration = match &self.prefill_load_estimator {
return None; Some(estimator) => match estimator.predict_prefill_duration(1, effective_isl, prefix) {
Ok(expected_prefill_duration) => Some(expected_prefill_duration),
Err(error) => {
tracing::warn!(
effective_isl,
prefix,
"failed to predict prefill duration for active load tracking: {error}"
);
None
}
},
None => None,
}; };
match estimator.predict_prefill_duration(1, effective_isl, prefix) { Some(PrefillLoadHint {
Ok(expected_prefill_duration) => Some(PrefillLoadHint { initial_effective_prefill_tokens: effective_isl,
initial_effective_prefill_tokens: effective_isl, expected_prefill_duration,
expected_prefill_duration: Some(expected_prefill_duration), })
}),
Err(error) => {
tracing::warn!(
effective_isl,
prefix,
"failed to predict prefill duration for active load tracking: {error}"
);
None
}
}
} }
/// Number of requests currently parked in the pending queue (lock-free). /// Number of requests currently parked in the pending queue (lock-free).
......
# Sequence State Model
This directory implements the router's active-sequence state for local request routing and replica sync.
For the local, non-remote path, the model is intentionally organized as a one-way write pipeline:
```mermaid
flowchart TD
A["Routing event<br/>AddRequest / MarkPrefillCompleted / Free"]
B["WorkerTable + RequestIndex<br/>lookup authoritative worker-local state"]
C["ActiveSequences<br/>authoritative per-worker write model"]
D["PromptRegistry<br/>derived read model"]
E["Scheduler reads projected load"]
A --> B
B --> C
C --> D
D -. read .-> E
```
## Source of truth
- `topology.rs` owns `WorkerTable`, which maps a worker identity to its slot.
- `request_maps.rs` owns `RequestIndex`, which maps `request_id -> worker`.
- `single.rs` owns `ActiveSequences`, the authoritative per-worker request, prefill, and block state.
- `prompt_registry.rs` owns `PromptRegistry`, which is not a source of truth. It is a derived routing view.
The local orchestrator in `multi_worker.rs` reads `WorkerTable` and `RequestIndex`, mutates the chosen worker's `ActiveSequences`, then projects the resulting membership/load delta into `PromptRegistry`.
## Why this is a DAG
Within a single local mutation, data moves in one direction:
`event -> authoritative state -> derived read model -> scheduler`
`PromptRegistry` does not write back into `ActiveSequences`, so there is no write-back loop inside the local mutation path.
At runtime there is still a control loop over time, because the scheduler reads the derived view and later emits the next `AddRequest`. That is a system feedback loop, not cyclic state ownership.
## Torn reads are intentional
`PromptRegistry` is allowed to be only eventually consistent with `ActiveSequences`.
That means a reader may temporarily observe:
- a worker-load snapshot from one moment
- prompt membership from another moment
- a combined view that never existed atomically
This is an intentional tradeoff. The derived read model is optimized for lower contention and higher concurrency, not perfect snapshot consistency.
The important safety boundary is:
- lifecycle and ownership invariants live in the write model (`WorkerTable`, `RequestIndex`, `ActiveSequences`)
- scheduling quality lives in the read model (`PromptRegistry`)
So a stale or torn read can lead to a suboptimal routing choice, but it should not cause catastrophic invariant breakage such as losing request ownership or corrupting block membership.
## Eventual consistency contract
- Local writes update `ActiveSequences` first.
- `PromptRegistry` is projected from that authoritative state afterward.
- Replica sync and scheduler decisions may lag behind temporarily.
- The system accepts this lag because the read side is advisory.
This is the core design: a strict local write DAG with an eventually consistent read projection.
...@@ -19,6 +19,8 @@ use tokio::sync::watch; ...@@ -19,6 +19,8 @@ use tokio::sync::watch;
use tokio::time::{Duration, Instant}; use tokio::time::{Duration, Instant};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
#[cfg(any(test, feature = "bench"))]
use super::prompt_membership_trie::lookup_live_hashes;
use super::prompt_registry::{PromptRegistry, WorkerLoadSnapshot}; use super::prompt_registry::{PromptRegistry, WorkerLoadSnapshot};
use super::request_maps::RequestIndex; use super::request_maps::RequestIndex;
use super::single::{ActiveSequences, PromptMembershipDelta, RequestId}; use super::single::{ActiveSequences, PromptMembershipDelta, RequestId};
...@@ -94,8 +96,6 @@ pub enum SequenceError { ...@@ -94,8 +96,6 @@ pub enum SequenceError {
pub struct SequenceRequest { pub struct SequenceRequest {
pub request_id: RequestId, pub request_id: RequestId,
pub token_sequence: Option<Vec<SequenceHash>>, pub token_sequence: Option<Vec<SequenceHash>>,
pub isl: usize,
pub overlap: u32,
pub track_prefill_tokens: bool, pub track_prefill_tokens: bool,
pub expected_output_tokens: Option<u32>, pub expected_output_tokens: Option<u32>,
pub prefill_load_hint: Option<PrefillLoadHint>, pub prefill_load_hint: Option<PrefillLoadHint>,
...@@ -182,6 +182,23 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -182,6 +182,23 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
self.prompt_registry.is_block_index_empty(), self.prompt_registry.is_block_index_empty(),
"expected reverse block index to be empty after drain", "expected reverse block index to be empty after drain",
); );
let trie_lookup_live_hashes: Vec<_> = {
let table = self.workers.read();
table
.slots
.iter()
.filter_map(|slot| {
let live_hashes = lookup_live_hashes(&slot.trie_lookup);
(!live_hashes.is_empty()).then_some((slot.worker, live_hashes))
})
.collect()
};
assert!(
trie_lookup_live_hashes.is_empty(),
"expected all worker trie lookups to reference only dead nodes after drain, found {:?}",
trie_lookup_live_hashes,
);
} }
fn publish_worker_load_snapshot( fn publish_worker_load_snapshot(
...@@ -280,8 +297,6 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -280,8 +297,6 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
match &event.data { match &event.data {
ActiveSequenceEventData::AddRequest { ActiveSequenceEventData::AddRequest {
token_sequence, token_sequence,
isl,
overlap,
track_prefill_tokens, track_prefill_tokens,
expected_output_tokens, expected_output_tokens,
prefill_load_hint, prefill_load_hint,
...@@ -300,8 +315,6 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -300,8 +315,6 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
let outcome = seq.add_request_with_prefill_tracking( let outcome = seq.add_request_with_prefill_tracking(
event.request_id.clone(), event.request_id.clone(),
token_sequence.clone(), token_sequence.clone(),
*isl,
*overlap,
*expected_output_tokens, *expected_output_tokens,
*track_prefill_tokens, *track_prefill_tokens,
*prefill_load_hint, *prefill_load_hint,
...@@ -434,8 +447,6 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -434,8 +447,6 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
worker: req.worker, worker: req.worker,
data: ActiveSequenceEventData::AddRequest { data: ActiveSequenceEventData::AddRequest {
token_sequence: req.token_sequence.clone(), token_sequence: req.token_sequence.clone(),
isl: req.isl,
overlap: req.overlap,
track_prefill_tokens: req.track_prefill_tokens, track_prefill_tokens: req.track_prefill_tokens,
expected_output_tokens: req.expected_output_tokens, expected_output_tokens: req.expected_output_tokens,
prefill_load_hint: req.prefill_load_hint, prefill_load_hint: req.prefill_load_hint,
...@@ -727,8 +738,6 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -727,8 +738,6 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
let SequenceRequest { let SequenceRequest {
request_id, request_id,
token_sequence, token_sequence,
isl,
overlap,
track_prefill_tokens, track_prefill_tokens,
expected_output_tokens, expected_output_tokens,
prefill_load_hint, prefill_load_hint,
...@@ -758,8 +767,6 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -758,8 +767,6 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
let outcome = seq.add_request_with_prefill_tracking( let outcome = seq.add_request_with_prefill_tracking(
request_id, request_id,
token_sequence, token_sequence,
isl,
overlap,
expected_output_tokens, expected_output_tokens,
track_prefill_tokens, track_prefill_tokens,
prefill_load_hint, prefill_load_hint,
...@@ -938,6 +945,7 @@ mod tests { ...@@ -938,6 +945,7 @@ mod tests {
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use super::super::prefill_tracker::added_prefill_tokens;
use super::*; use super::*;
use crate::protocols::{ use crate::protocols::{
ActiveSequenceEvent, ActiveSequenceEventData, BlockHashOptions, OverlapScores, ActiveSequenceEvent, ActiveSequenceEventData, BlockHashOptions, OverlapScores,
...@@ -994,7 +1002,7 @@ mod tests { ...@@ -994,7 +1002,7 @@ mod tests {
token_sequence.map_or(0, |query| query.len().saturating_sub(overlap_depth)); token_sequence.map_or(0, |query| query.len().saturating_sub(overlap_depth));
let overlap = *overlaps.scores.get(&slot.worker).unwrap_or(&0); let overlap = *overlaps.scores.get(&slot.worker).unwrap_or(&0);
let added_tokens = if track_prefill_tokens { let added_tokens = if track_prefill_tokens {
seq.new_tokens(isl, overlap) added_prefill_tokens(sequences.block_size, isl, overlap)
} else { } else {
0 0
}; };
...@@ -1016,6 +1024,13 @@ mod tests { ...@@ -1016,6 +1024,13 @@ mod tests {
compute_seq_hash_for_block(&block_hashes) compute_seq_hash_for_block(&block_hashes)
} }
fn tracking_hint(tokens: usize) -> Option<PrefillLoadHint> {
Some(PrefillLoadHint {
initial_effective_prefill_tokens: tokens,
expected_prefill_duration: None,
})
}
struct VecSubscriber { struct VecSubscriber {
events: VecDeque<anyhow::Result<ActiveSequenceEvent>>, events: VecDeque<anyhow::Result<ActiveSequenceEvent>>,
} }
...@@ -1039,8 +1054,6 @@ mod tests { ...@@ -1039,8 +1054,6 @@ mod tests {
SequenceRequest { SequenceRequest {
request_id: "req-1".to_string(), request_id: "req-1".to_string(),
token_sequence: Some(vec![1, 2, 3]), token_sequence: Some(vec![1, 2, 3]),
isl: 12,
overlap: 0,
track_prefill_tokens: false, track_prefill_tokens: false,
expected_output_tokens: None, expected_output_tokens: None,
prefill_load_hint: None, prefill_load_hint: None,
...@@ -1069,11 +1082,9 @@ mod tests { ...@@ -1069,11 +1082,9 @@ mod tests {
SequenceRequest { SequenceRequest {
request_id: "req-a".to_string(), request_id: "req-a".to_string(),
token_sequence: Some(vec![1, 2, 3]), token_sequence: Some(vec![1, 2, 3]),
isl: 12,
overlap: 0,
track_prefill_tokens: true, track_prefill_tokens: true,
expected_output_tokens: None, expected_output_tokens: None,
prefill_load_hint: None, prefill_load_hint: tracking_hint(12),
worker: worker_a, worker: worker_a,
lora_name: None, lora_name: None,
}, },
...@@ -1092,11 +1103,9 @@ mod tests { ...@@ -1092,11 +1103,9 @@ mod tests {
SequenceRequest { SequenceRequest {
request_id: "req-b".to_string(), request_id: "req-b".to_string(),
token_sequence: Some(vec![1, 2, 4]), token_sequence: Some(vec![1, 2, 4]),
isl: 12,
overlap: 0,
track_prefill_tokens: true, track_prefill_tokens: true,
expected_output_tokens: None, expected_output_tokens: None,
prefill_load_hint: None, prefill_load_hint: tracking_hint(12),
worker: worker_b, worker: worker_b,
lora_name: None, lora_name: None,
}, },
...@@ -1149,8 +1158,6 @@ mod tests { ...@@ -1149,8 +1158,6 @@ mod tests {
SequenceRequest { SequenceRequest {
request_id: "base".to_string(), request_id: "base".to_string(),
token_sequence: Some(base_prompt.clone()), token_sequence: Some(base_prompt.clone()),
isl: 8,
overlap: 0,
track_prefill_tokens: false, track_prefill_tokens: false,
expected_output_tokens: None, expected_output_tokens: None,
prefill_load_hint: None, prefill_load_hint: None,
...@@ -1165,8 +1172,6 @@ mod tests { ...@@ -1165,8 +1172,6 @@ mod tests {
SequenceRequest { SequenceRequest {
request_id: "lora".to_string(), request_id: "lora".to_string(),
token_sequence: Some(lora_prompt), token_sequence: Some(lora_prompt),
isl: 8,
overlap: 0,
track_prefill_tokens: false, track_prefill_tokens: false,
expected_output_tokens: None, expected_output_tokens: None,
prefill_load_hint: None, prefill_load_hint: None,
...@@ -1213,11 +1218,9 @@ mod tests { ...@@ -1213,11 +1218,9 @@ mod tests {
SequenceRequest { SequenceRequest {
request_id: "req-1".to_string(), request_id: "req-1".to_string(),
token_sequence: Some(vec![1, 2, 3]), token_sequence: Some(vec![1, 2, 3]),
isl: 12,
overlap: 0,
track_prefill_tokens: true, track_prefill_tokens: true,
expected_output_tokens: None, expected_output_tokens: None,
prefill_load_hint: None, prefill_load_hint: tracking_hint(12),
worker, worker,
lora_name: None, lora_name: None,
}, },
...@@ -1243,11 +1246,9 @@ mod tests { ...@@ -1243,11 +1246,9 @@ mod tests {
SequenceRequest { SequenceRequest {
request_id: "req-1".to_string(), request_id: "req-1".to_string(),
token_sequence: Some(vec![1, 2, 3]), token_sequence: Some(vec![1, 2, 3]),
isl: 12,
overlap: 0,
track_prefill_tokens: true, track_prefill_tokens: true,
expected_output_tokens: None, expected_output_tokens: None,
prefill_load_hint: None, prefill_load_hint: tracking_hint(12),
worker, worker,
lora_name: None, lora_name: None,
}, },
...@@ -1262,11 +1263,9 @@ mod tests { ...@@ -1262,11 +1263,9 @@ mod tests {
SequenceRequest { SequenceRequest {
request_id: "req-2".to_string(), request_id: "req-2".to_string(),
token_sequence: Some(vec![1, 2, 3]), token_sequence: Some(vec![1, 2, 3]),
isl: 12,
overlap: 0,
track_prefill_tokens: true, track_prefill_tokens: true,
expected_output_tokens: None, expected_output_tokens: None,
prefill_load_hint: None, prefill_load_hint: tracking_hint(12),
worker, worker,
lora_name: None, lora_name: None,
}, },
...@@ -1313,11 +1312,9 @@ mod tests { ...@@ -1313,11 +1312,9 @@ mod tests {
worker, worker,
data: ActiveSequenceEventData::AddRequest { data: ActiveSequenceEventData::AddRequest {
token_sequence: Some(vec![1, 2, 3]), token_sequence: Some(vec![1, 2, 3]),
isl: 12,
overlap: 0,
track_prefill_tokens: true, track_prefill_tokens: true,
expected_output_tokens: None, expected_output_tokens: None,
prefill_load_hint: None, prefill_load_hint: tracking_hint(12),
}, },
router_id: 99, router_id: 99,
lora_name: None, lora_name: None,
...@@ -1359,11 +1356,9 @@ mod tests { ...@@ -1359,11 +1356,9 @@ mod tests {
worker, worker,
data: ActiveSequenceEventData::AddRequest { data: ActiveSequenceEventData::AddRequest {
token_sequence: Some(vec![1, 2, 3]), token_sequence: Some(vec![1, 2, 3]),
isl: 12,
overlap: 0,
track_prefill_tokens: true, track_prefill_tokens: true,
expected_output_tokens: None, expected_output_tokens: None,
prefill_load_hint: None, prefill_load_hint: tracking_hint(12),
}, },
router_id: 99, router_id: 99,
lora_name: None, lora_name: None,
...@@ -1395,8 +1390,6 @@ mod tests { ...@@ -1395,8 +1390,6 @@ mod tests {
SequenceRequest { SequenceRequest {
request_id: "req-1".to_string(), request_id: "req-1".to_string(),
token_sequence: Some(vec![1, 2, 3]), token_sequence: Some(vec![1, 2, 3]),
isl: 12,
overlap: 0,
track_prefill_tokens: false, track_prefill_tokens: false,
expected_output_tokens: None, expected_output_tokens: None,
prefill_load_hint: None, prefill_load_hint: None,
...@@ -1429,8 +1422,6 @@ mod tests { ...@@ -1429,8 +1422,6 @@ mod tests {
SequenceRequest { SequenceRequest {
request_id: request_id.clone(), request_id: request_id.clone(),
token_sequence: Some(vec![1, 2, 3]), token_sequence: Some(vec![1, 2, 3]),
isl: 12,
overlap: 0,
track_prefill_tokens: false, track_prefill_tokens: false,
expected_output_tokens: None, expected_output_tokens: None,
prefill_load_hint: None, prefill_load_hint: None,
...@@ -1481,8 +1472,6 @@ mod tests { ...@@ -1481,8 +1472,6 @@ mod tests {
SequenceRequest { SequenceRequest {
request_id: "req-1".to_string(), request_id: "req-1".to_string(),
token_sequence: Some(vec![1, 2, 3]), token_sequence: Some(vec![1, 2, 3]),
isl: 100,
overlap: 0,
track_prefill_tokens: true, track_prefill_tokens: true,
expected_output_tokens: None, expected_output_tokens: None,
prefill_load_hint: Some(PrefillLoadHint { prefill_load_hint: Some(PrefillLoadHint {
......
...@@ -7,6 +7,7 @@ use dynamo_tokens::SequenceHash; ...@@ -7,6 +7,7 @@ use dynamo_tokens::SequenceHash;
use parking_lot::RwLock; use parking_lot::RwLock;
use rustc_hash::{FxHashMap, FxHashSet}; use rustc_hash::{FxHashMap, FxHashSet};
use crate::cleanup::{self, CleanableNode, CleanupGuard, CleanupState};
use crate::protocols::WorkerWithDpRank; use crate::protocols::WorkerWithDpRank;
type SharedNode = Arc<RwLock<PromptTrieNode>>; type SharedNode = Arc<RwLock<PromptTrieNode>>;
...@@ -32,11 +33,6 @@ impl PromptTrieNode { ...@@ -32,11 +33,6 @@ impl PromptTrieNode {
} }
} }
#[cfg(any(test, feature = "bench"))]
fn has_any_workers(&self) -> bool {
!self.full_edge_workers.is_empty() || !self.worker_cutoffs.is_empty()
}
fn current_cutoff(&self, worker: WorkerWithDpRank) -> usize { fn current_cutoff(&self, worker: WorkerWithDpRank) -> usize {
if self.full_edge_workers.contains(&worker) { if self.full_edge_workers.contains(&worker) {
self.edge.len() self.edge.len()
...@@ -114,12 +110,29 @@ impl PromptTrieNode { ...@@ -114,12 +110,29 @@ impl PromptTrieNode {
} }
} }
impl CleanableNode for PromptTrieNode {
type ChildKey = SequenceHash;
fn has_any_workers(&self) -> bool {
!self.full_edge_workers.is_empty() || !self.worker_cutoffs.is_empty()
}
fn children(&self) -> &FxHashMap<SequenceHash, SharedNode> {
&self.children
}
fn remove_child(&mut self, key: &SequenceHash) {
self.children.remove(key);
}
}
struct RemoveOutcome { struct RemoveOutcome {
stale_hashes: Vec<SequenceHash>, stale_hashes: Vec<SequenceHash>,
} }
pub(super) struct PromptMembershipTrie { pub(super) struct PromptMembershipTrie {
root: SharedNode, root: SharedNode,
cleanup: CleanupState,
} }
impl Default for PromptMembershipTrie { impl Default for PromptMembershipTrie {
...@@ -149,7 +162,22 @@ impl PromptMembershipTrie { ...@@ -149,7 +162,22 @@ impl PromptMembershipTrie {
pub(super) fn new() -> Self { pub(super) fn new() -> Self {
Self { Self {
root: Arc::new(RwLock::new(PromptTrieNode::new())), root: Arc::new(RwLock::new(PromptTrieNode::new())),
cleanup: CleanupState::new(),
}
}
/// Run the stale-child sweep if the throttle interval has elapsed.
///
/// Safe to call from any write path; the sweep is a no-op until
/// [`CLEANUP_INTERVAL_MS`](crate::cleanup::CLEANUP_INTERVAL_MS) has passed
/// since the last completion, and only one sweep runs at a time.
pub(super) fn maybe_cleanup(&self) {
if !self.cleanup.try_schedule() {
return;
} }
let mut guard = CleanupGuard::new(&self.cleanup);
cleanup::sweep_stale_children(&self.root);
guard.mark_completed();
} }
fn find_in_subtree(start: &SharedNode, hash: SequenceHash) -> Option<SharedNode> { fn find_in_subtree(start: &SharedNode, hash: SequenceHash) -> Option<SharedNode> {
...@@ -656,6 +684,15 @@ impl PromptMembershipTrie { ...@@ -656,6 +684,15 @@ impl PromptMembershipTrie {
} }
} }
#[cfg(any(test, feature = "bench"))]
pub(super) fn lookup_live_hashes(lookup: &Arc<RwLock<WorkerLookup>>) -> Vec<SequenceHash> {
let worker_lookup = lookup.read();
worker_lookup
.iter()
.filter_map(|(&hash, node)| node.read().has_any_workers().then_some(hash))
.collect()
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
......
...@@ -80,6 +80,7 @@ impl PromptRegistry { ...@@ -80,6 +80,7 @@ impl PromptRegistry {
.store_chain(worker, lookup, store.parent, &store.hashes); .store_chain(worker, lookup, store.parent, &store.hashes);
} }
self.loads.insert(worker, load); self.loads.insert(worker, load);
self.membership.maybe_cleanup();
} }
pub(super) fn apply_topology_change(&self, change: WorkerTopologyChange) { pub(super) fn apply_topology_change(&self, change: WorkerTopologyChange) {
...@@ -92,6 +93,7 @@ impl PromptRegistry { ...@@ -92,6 +93,7 @@ impl PromptRegistry {
for worker in change.added { for worker in change.added {
self.loads.entry(worker).or_default(); self.loads.entry(worker).or_default();
} }
self.membership.maybe_cleanup();
} }
#[expect(clippy::too_many_arguments)] #[expect(clippy::too_many_arguments)]
......
...@@ -29,7 +29,9 @@ use uuid::Uuid; ...@@ -29,7 +29,9 @@ use uuid::Uuid;
use rustc_hash::FxHashSet; use rustc_hash::FxHashSet;
use super::block_tracker::BlockTracker; use super::block_tracker::BlockTracker;
use super::prefill_tracker::{PrefillLoadState, PrefillLoadTracker, added_prefill_tokens}; #[cfg(test)]
use super::prefill_tracker::added_prefill_tokens;
use super::prefill_tracker::{PrefillLoadState, PrefillLoadTracker};
use super::prompt_registry::WorkerLoadSnapshot; use super::prompt_registry::WorkerLoadSnapshot;
use crate::protocols::PrefillLoadHint; use crate::protocols::PrefillLoadHint;
...@@ -107,6 +109,7 @@ pub struct ActiveSequences { ...@@ -107,6 +109,7 @@ pub struct ActiveSequences {
requests: HashMap<RequestId, RequestState>, requests: HashMap<RequestId, RequestState>,
prefill: PrefillLoadTracker, prefill: PrefillLoadTracker,
blocks: BlockTracker, blocks: BlockTracker,
#[cfg(test)]
block_size: usize, block_size: usize,
last_expiry_check_time: Instant, last_expiry_check_time: Instant,
} }
...@@ -120,6 +123,7 @@ impl ActiveSequences { ...@@ -120,6 +123,7 @@ impl ActiveSequences {
requests: HashMap::new(), requests: HashMap::new(),
prefill: PrefillLoadTracker::default(), prefill: PrefillLoadTracker::default(),
blocks: BlockTracker::default(), blocks: BlockTracker::default(),
#[cfg(test)]
block_size, block_size,
last_expiry_check_time: Instant::now(), last_expiry_check_time: Instant::now(),
} }
...@@ -158,30 +162,6 @@ impl ActiveSequences { ...@@ -158,30 +162,6 @@ impl ActiveSequences {
self.prefill.snapshot().active_tokens_at(decay_now) self.prefill.snapshot().active_tokens_at(decay_now)
} }
/// Add a new request with its initial tokens.
/// Returns block membership transitions plus any expired request IDs removed during cleanup.
#[cfg(test)]
pub(super) fn add_request(
&mut self,
request_id: RequestId,
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlap: u32,
expected_output_tokens: Option<u32>,
decay_now: Instant,
) -> SequenceMutationOutcome {
self.add_request_with_prefill_tracking(
request_id,
token_sequence,
isl,
overlap,
expected_output_tokens,
true,
None,
decay_now,
)
}
/// Add a new request with optional prompt-token load accounting. /// Add a new request with optional prompt-token load accounting.
/// Returns block membership transitions plus any expired request IDs removed during cleanup. /// Returns block membership transitions plus any expired request IDs removed during cleanup.
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
...@@ -189,8 +169,6 @@ impl ActiveSequences { ...@@ -189,8 +169,6 @@ impl ActiveSequences {
&mut self, &mut self,
request_id: RequestId, request_id: RequestId,
token_sequence: Option<Vec<SequenceHash>>, token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlap: u32,
expected_output_tokens: Option<u32>, expected_output_tokens: Option<u32>,
track_prefill_tokens: bool, track_prefill_tokens: bool,
prefill_load_hint: Option<PrefillLoadHint>, prefill_load_hint: Option<PrefillLoadHint>,
...@@ -241,15 +219,11 @@ impl ActiveSequences { ...@@ -241,15 +219,11 @@ impl ActiveSequences {
}; };
let prefill = if track_prefill_tokens { let prefill = if track_prefill_tokens {
let default_tokens = self.new_tokens(isl, overlap); prefill_load_hint.and_then(|hint| {
let hint = prefill_load_hint.unwrap_or(PrefillLoadHint { (hint.initial_effective_prefill_tokens > 0).then_some(PrefillLoadState {
initial_effective_prefill_tokens: default_tokens, initial_effective_prefill_tokens: hint.initial_effective_prefill_tokens,
expected_prefill_duration: None, expected_prefill_duration: hint.expected_prefill_duration,
}); })
(hint.initial_effective_prefill_tokens > 0).then_some(PrefillLoadState {
initial_effective_prefill_tokens: hint.initial_effective_prefill_tokens,
expected_prefill_duration: hint.expected_prefill_duration,
}) })
} else { } else {
None None
...@@ -356,10 +330,6 @@ impl ActiveSequences { ...@@ -356,10 +330,6 @@ impl ActiveSequences {
acquire.became_present_on_worker.then_some(random_hash) acquire.became_present_on_worker.then_some(random_hash)
} }
pub(super) fn new_tokens(&self, isl: usize, overlap: u32) -> usize {
added_prefill_tokens(self.block_size, isl, overlap)
}
#[cfg(test)] #[cfg(test)]
fn potential_blocks_and_tokens_with_prefill_tracking( fn potential_blocks_and_tokens_with_prefill_tracking(
&self, &self,
...@@ -376,7 +346,7 @@ impl ActiveSequences { ...@@ -376,7 +346,7 @@ impl ActiveSequences {
}; };
let active_tokens = self.active_tokens(decay_now); let active_tokens = self.active_tokens(decay_now);
let potential_tokens = if track_prefill_tokens { let potential_tokens = if track_prefill_tokens {
self.new_tokens(isl, overlap) + active_tokens added_prefill_tokens(self.block_size, isl, overlap) + active_tokens
} else { } else {
active_tokens active_tokens
}; };
...@@ -479,6 +449,14 @@ mod tests { ...@@ -479,6 +449,14 @@ mod tests {
} }
} }
fn tracking_hint(block_size: usize, isl: usize, overlap: u32) -> Option<PrefillLoadHint> {
let tokens = added_prefill_tokens(block_size, isl, overlap);
(tokens > 0).then_some(PrefillLoadHint {
initial_effective_prefill_tokens: tokens,
expected_prefill_duration: None,
})
}
#[test] #[test]
fn test_prompt_membership_delta_only_reports_first_add_and_last_remove() { fn test_prompt_membership_delta_only_reports_first_add_and_last_remove() {
let mut seq_manager = ActiveSequences::new(4); let mut seq_manager = ActiveSequences::new(4);
...@@ -487,11 +465,9 @@ mod tests { ...@@ -487,11 +465,9 @@ mod tests {
let first = seq_manager.add_request_with_prefill_tracking( let first = seq_manager.add_request_with_prefill_tracking(
"r1".to_string(), "r1".to_string(),
Some(vec![1, 2]), Some(vec![1, 2]),
8,
0,
None, None,
true, true,
None, tracking_hint(4, 8, 0),
decay_now, decay_now,
); );
assert_eq!( assert_eq!(
...@@ -509,11 +485,9 @@ mod tests { ...@@ -509,11 +485,9 @@ mod tests {
let second = seq_manager.add_request_with_prefill_tracking( let second = seq_manager.add_request_with_prefill_tracking(
"r2".to_string(), "r2".to_string(),
Some(vec![1, 2, 3]), Some(vec![1, 2, 3]),
12,
0,
None, None,
true, true,
None, tracking_hint(4, 12, 0),
decay_now, decay_now,
); );
assert_eq!( assert_eq!(
...@@ -549,11 +523,9 @@ mod tests { ...@@ -549,11 +523,9 @@ mod tests {
let outcome = seq_manager.add_request_with_prefill_tracking( let outcome = seq_manager.add_request_with_prefill_tracking(
"r1".to_string(), "r1".to_string(),
Some(vec![1, 2, 3]), Some(vec![1, 2, 3]),
12,
0,
None, None,
true, true,
None, tracking_hint(4, 12, 0),
decay_now, decay_now,
); );
assert_eq!( assert_eq!(
...@@ -598,34 +570,34 @@ mod tests { ...@@ -598,34 +570,34 @@ mod tests {
let mut seq_manager = ActiveSequences::new(block_size); let mut seq_manager = ActiveSequences::new(block_size);
let decay_now = Instant::now(); let decay_now = Instant::now();
seq_manager.add_request( seq_manager.add_request_with_prefill_tracking(
"request_1".to_string(), "request_1".to_string(),
Some(vec![1, 2, 3]), Some(vec![1, 2, 3]),
12,
0,
None, None,
true,
tracking_hint(block_size, 12, 0),
decay_now, decay_now,
); );
assert_eq!(seq_manager.active_blocks(), 3); assert_eq!(seq_manager.active_blocks(), 3);
assert_eq!(seq_manager.active_tokens(decay_now), 12); assert_eq!(seq_manager.active_tokens(decay_now), 12);
seq_manager.add_request( seq_manager.add_request_with_prefill_tracking(
"request_2".to_string(), "request_2".to_string(),
Some(vec![4]), Some(vec![4]),
4,
0,
None, None,
true,
tracking_hint(block_size, 4, 0),
decay_now, decay_now,
); );
assert_eq!(seq_manager.active_blocks(), 4); assert_eq!(seq_manager.active_blocks(), 4);
assert_eq!(seq_manager.active_tokens(decay_now), 16); assert_eq!(seq_manager.active_tokens(decay_now), 16);
seq_manager.add_request( seq_manager.add_request_with_prefill_tracking(
"request_3".to_string(), "request_3".to_string(),
Some(vec![1, 2, 3, 4]), Some(vec![1, 2, 3, 4]),
16,
4,
None, None,
true,
tracking_hint(block_size, 16, 4),
decay_now, decay_now,
); );
assert_eq!(seq_manager.active_blocks(), 4); assert_eq!(seq_manager.active_blocks(), 4);
...@@ -650,12 +622,12 @@ mod tests { ...@@ -650,12 +622,12 @@ mod tests {
let mut seq_manager = ActiveSequences::new(block_size); let mut seq_manager = ActiveSequences::new(block_size);
let decay_now = Instant::now(); let decay_now = Instant::now();
seq_manager.add_request( seq_manager.add_request_with_prefill_tracking(
"r1".to_string(), "r1".to_string(),
Some(vec![1, 2, 3]), Some(vec![1, 2, 3]),
12,
0,
None, None,
true,
tracking_hint(block_size, 12, 0),
decay_now, decay_now,
); );
assert_eq!(seq_manager.active_blocks(), 3); assert_eq!(seq_manager.active_blocks(), 3);
...@@ -667,7 +639,14 @@ mod tests { ...@@ -667,7 +639,14 @@ mod tests {
); );
assert_eq!(seq_manager.active_blocks(), 2); assert_eq!(seq_manager.active_blocks(), 2);
seq_manager.add_request("r2".to_string(), Some(vec![1, 2]), 8, 0, None, decay_now); seq_manager.add_request_with_prefill_tracking(
"r2".to_string(),
Some(vec![1, 2]),
None,
true,
tracking_hint(block_size, 8, 0),
decay_now,
);
assert_eq!(seq_manager.active_blocks(), 2); assert_eq!(seq_manager.active_blocks(), 2);
assert!( assert!(
...@@ -689,12 +668,12 @@ mod tests { ...@@ -689,12 +668,12 @@ mod tests {
let mut seq_manager = ActiveSequences::new(block_size); let mut seq_manager = ActiveSequences::new(block_size);
let decay_now = Instant::now(); let decay_now = Instant::now();
seq_manager.add_request( seq_manager.add_request_with_prefill_tracking(
"r1".to_string(), "r1".to_string(),
Some(vec![1, 2, 3]), Some(vec![1, 2, 3]),
12,
0,
None, None,
true,
tracking_hint(block_size, 12, 0),
decay_now, decay_now,
); );
assert_eq!(seq_manager.active_tokens(decay_now), 12); assert_eq!(seq_manager.active_tokens(decay_now), 12);
...@@ -705,7 +684,14 @@ mod tests { ...@@ -705,7 +684,14 @@ mod tests {
seq_manager.mark_prefill_completed(&"r1".to_string(), decay_now); seq_manager.mark_prefill_completed(&"r1".to_string(), decay_now);
assert_eq!(seq_manager.active_tokens(decay_now), 0); assert_eq!(seq_manager.active_tokens(decay_now), 0);
seq_manager.add_request("r2".to_string(), Some(vec![4, 5]), 8, 0, None, decay_now); seq_manager.add_request_with_prefill_tracking(
"r2".to_string(),
Some(vec![4, 5]),
None,
true,
tracking_hint(block_size, 8, 0),
decay_now,
);
assert_eq!(seq_manager.active_tokens(decay_now), 8); assert_eq!(seq_manager.active_tokens(decay_now), 8);
seq_manager.free(&"r2".to_string(), decay_now); seq_manager.free(&"r2".to_string(), decay_now);
...@@ -720,8 +706,6 @@ mod tests { ...@@ -720,8 +706,6 @@ mod tests {
seq_manager.add_request_with_prefill_tracking( seq_manager.add_request_with_prefill_tracking(
"r1".to_string(), "r1".to_string(),
Some(vec![1, 2, 3]), Some(vec![1, 2, 3]),
12,
0,
None, None,
false, false,
None, None,
...@@ -745,8 +729,6 @@ mod tests { ...@@ -745,8 +729,6 @@ mod tests {
seq_manager.add_request_with_prefill_tracking( seq_manager.add_request_with_prefill_tracking(
"r1".to_string(), "r1".to_string(),
Some(vec![1, 2, 3]), Some(vec![1, 2, 3]),
12,
0,
None, None,
false, false,
None, None,
...@@ -772,8 +754,6 @@ mod tests { ...@@ -772,8 +754,6 @@ mod tests {
seq_manager.add_request_with_prefill_tracking( seq_manager.add_request_with_prefill_tracking(
"r1".to_string(), "r1".to_string(),
Some(vec![1]), Some(vec![1]),
50,
0,
None, None,
true, true,
Some(prefill_hint(50, 10)), Some(prefill_hint(50, 10)),
...@@ -782,8 +762,6 @@ mod tests { ...@@ -782,8 +762,6 @@ mod tests {
seq_manager.add_request_with_prefill_tracking( seq_manager.add_request_with_prefill_tracking(
"r2".to_string(), "r2".to_string(),
Some(vec![2]), Some(vec![2]),
30,
0,
None, None,
true, true,
Some(prefill_hint(30, 10)), Some(prefill_hint(30, 10)),
...@@ -823,20 +801,20 @@ mod tests { ...@@ -823,20 +801,20 @@ mod tests {
let block_size = 4; let block_size = 4;
let mut seq_manager = ActiveSequences::new(block_size); let mut seq_manager = ActiveSequences::new(block_size);
seq_manager.add_request( seq_manager.add_request_with_prefill_tracking(
"r1".to_string(), "r1".to_string(),
Some(vec![1, 2]), Some(vec![1, 2]),
8,
0,
None, None,
true,
tracking_hint(block_size, 8, 0),
Instant::now(), Instant::now(),
); );
seq_manager.add_request( seq_manager.add_request_with_prefill_tracking(
"r2".to_string(), "r2".to_string(),
Some(vec![3, 4]), Some(vec![3, 4]),
8,
0,
None, None,
true,
tracking_hint(block_size, 8, 0),
Instant::now(), Instant::now(),
); );
assert_eq!(seq_manager.active_blocks(), 4); assert_eq!(seq_manager.active_blocks(), 4);
...@@ -869,8 +847,14 @@ mod tests { ...@@ -869,8 +847,14 @@ mod tests {
seq_manager.assert_consistent(); seq_manager.assert_consistent();
tokio::time::advance(Duration::from_secs(31)).await; tokio::time::advance(Duration::from_secs(31)).await;
let expired = let expired = seq_manager.add_request_with_prefill_tracking(
seq_manager.add_request("r3".to_string(), Some(vec![5]), 4, 0, None, Instant::now()); "r3".to_string(),
Some(vec![5]),
None,
true,
tracking_hint(block_size, 4, 0),
Instant::now(),
);
assert!(expired.expired_request_ids.is_empty()); assert!(expired.expired_request_ids.is_empty());
assert_eq!(seq_manager.active_blocks(), 1); assert_eq!(seq_manager.active_blocks(), 1);
assert_eq!(seq_manager.active_tokens(Instant::now()), 4); assert_eq!(seq_manager.active_tokens(Instant::now()), 4);
...@@ -885,8 +869,6 @@ mod tests { ...@@ -885,8 +869,6 @@ mod tests {
seq_manager.add_request_with_prefill_tracking( seq_manager.add_request_with_prefill_tracking(
"r1".to_string(), "r1".to_string(),
Some(vec![1]), Some(vec![1]),
40,
0,
None, None,
true, true,
Some(prefill_hint(40, 100)), Some(prefill_hint(40, 100)),
...@@ -896,8 +878,6 @@ mod tests { ...@@ -896,8 +878,6 @@ mod tests {
seq_manager.add_request_with_prefill_tracking( seq_manager.add_request_with_prefill_tracking(
"r2".to_string(), "r2".to_string(),
Some(vec![2]), Some(vec![2]),
30,
0,
None, None,
true, true,
Some(prefill_hint(30, 100)), Some(prefill_hint(30, 100)),
......
...@@ -219,12 +219,15 @@ mod tests { ...@@ -219,12 +219,15 @@ mod tests {
{ {
let idx = table.index[&existing]; let idx = table.index[&existing];
let mut seq = table.slots[idx].sequences.write(); let mut seq = table.slots[idx].sequences.write();
let outcome = seq.add_request( let outcome = seq.add_request_with_prefill_tracking(
"req-1".to_string(), "req-1".to_string(),
Some(vec![1, 2, 3]), Some(vec![1, 2, 3]),
12,
0,
None, None,
true,
Some(crate::protocols::PrefillLoadHint {
initial_effective_prefill_tokens: 12,
expected_prefill_duration: None,
}),
Instant::now(), Instant::now(),
); );
assert_eq!(outcome.membership_delta.stores[0].hashes, vec![1, 2, 3],); assert_eq!(outcome.membership_delta.stores[0].hashes, vec![1, 2, 3],);
......
...@@ -421,8 +421,6 @@ where ...@@ -421,8 +421,6 @@ where
.add_request(SequenceRequest { .add_request(SequenceRequest {
request_id: request_id.clone(), request_id: request_id.clone(),
token_sequence: maybe_seq_hashes, token_sequence: maybe_seq_hashes,
isl: isl_tokens,
overlap: overlap_blocks,
track_prefill_tokens, track_prefill_tokens,
expected_output_tokens, expected_output_tokens,
prefill_load_hint, prefill_load_hint,
...@@ -464,24 +462,25 @@ where ...@@ -464,24 +462,25 @@ where
return None; return None;
} }
let Some(estimator) = &self.prefill_load_estimator else { let expected_prefill_duration = match &self.prefill_load_estimator {
return None; Some(estimator) => match estimator.predict_prefill_duration(1, effective_isl, prefix) {
Ok(expected_prefill_duration) => Some(expected_prefill_duration),
Err(error) => {
tracing::warn!(
effective_isl,
prefix,
"failed to predict prefill duration for direct add_request path: {error}"
);
None
}
},
None => None,
}; };
match estimator.predict_prefill_duration(1, effective_isl, prefix) { Some(PrefillLoadHint {
Ok(expected_prefill_duration) => Some(PrefillLoadHint { initial_effective_prefill_tokens: effective_isl,
initial_effective_prefill_tokens: effective_isl, expected_prefill_duration,
expected_prefill_duration: Some(expected_prefill_duration), })
}),
Err(error) => {
tracing::warn!(
effective_isl,
prefix,
"failed to predict prefill duration for direct add_request path: {error}"
);
None
}
}
} }
/// Get the worker type for this router ("prefill" or "decode"). /// Get the worker type for this router ("prefill" or "decode").
......
...@@ -24,6 +24,8 @@ use std::sync::Arc; ...@@ -24,6 +24,8 @@ use std::sync::Arc;
use super::metrics::WORKER_LOAD_METRICS; use super::metrics::WORKER_LOAD_METRICS;
use crate::kv_router::{ACTIVE_SEQUENCES_SUBJECT, KV_METRICS_SUBJECT}; use crate::kv_router::{ACTIVE_SEQUENCES_SUBJECT, KV_METRICS_SUBJECT};
use crate::local_model::runtime_config::ModelRuntimeConfig; use crate::local_model::runtime_config::ModelRuntimeConfig;
#[cfg(test)]
use dynamo_kv_router::protocols::PrefillLoadHint;
/// Concrete [`SequencePublisher`] backed by NATS [`EventPublisher`] and Prometheus gauges. /// Concrete [`SequencePublisher`] backed by NATS [`EventPublisher`] and Prometheus gauges.
pub struct RuntimeSequencePublisher { pub struct RuntimeSequencePublisher {
...@@ -145,6 +147,13 @@ mod tests { ...@@ -145,6 +147,13 @@ mod tests {
use dynamo_runtime::{DistributedRuntime, Runtime}; use dynamo_runtime::{DistributedRuntime, Runtime};
use tokio::time::Instant; use tokio::time::Instant;
fn tracking_hint(tokens: usize) -> Option<PrefillLoadHint> {
Some(PrefillLoadHint {
initial_effective_prefill_tokens: tokens,
expected_prefill_duration: None,
})
}
#[tokio::test] #[tokio::test]
#[ignore] #[ignore]
async fn test_multi_worker_cross_instance_sync() -> Result<()> { async fn test_multi_worker_cross_instance_sync() -> Result<()> {
...@@ -192,11 +201,9 @@ mod tests { ...@@ -192,11 +201,9 @@ mod tests {
SequenceRequest { SequenceRequest {
request_id: "request_0".to_string(), request_id: "request_0".to_string(),
token_sequence: Some(vec![0, 1, 2]), token_sequence: Some(vec![0, 1, 2]),
isl: 12,
overlap: 0,
track_prefill_tokens: true, track_prefill_tokens: true,
expected_output_tokens: None, expected_output_tokens: None,
prefill_load_hint: None, prefill_load_hint: tracking_hint(12),
worker: WorkerWithDpRank::new(0, 0), worker: WorkerWithDpRank::new(0, 0),
lora_name: None, lora_name: None,
}, },
...@@ -207,11 +214,9 @@ mod tests { ...@@ -207,11 +214,9 @@ mod tests {
SequenceRequest { SequenceRequest {
request_id: "request_1".to_string(), request_id: "request_1".to_string(),
token_sequence: Some(vec![3, 4]), token_sequence: Some(vec![3, 4]),
isl: 8,
overlap: 0,
track_prefill_tokens: true, track_prefill_tokens: true,
expected_output_tokens: None, expected_output_tokens: None,
prefill_load_hint: None, prefill_load_hint: tracking_hint(8),
worker: WorkerWithDpRank::new(0, 1), worker: WorkerWithDpRank::new(0, 1),
lora_name: None, lora_name: None,
}, },
...@@ -222,11 +227,9 @@ mod tests { ...@@ -222,11 +227,9 @@ mod tests {
SequenceRequest { SequenceRequest {
request_id: "request_2".to_string(), request_id: "request_2".to_string(),
token_sequence: Some(vec![0, 1, 2, 3]), token_sequence: Some(vec![0, 1, 2, 3]),
isl: 16,
overlap: 0,
track_prefill_tokens: true, track_prefill_tokens: true,
expected_output_tokens: None, expected_output_tokens: None,
prefill_load_hint: None, prefill_load_hint: tracking_hint(16),
worker: WorkerWithDpRank::new(1, 0), worker: WorkerWithDpRank::new(1, 0),
lora_name: None, lora_name: None,
}, },
...@@ -351,11 +354,9 @@ mod tests { ...@@ -351,11 +354,9 @@ mod tests {
SequenceRequest { SequenceRequest {
request_id: "request_0".to_string(), request_id: "request_0".to_string(),
token_sequence: None, token_sequence: None,
isl: 12,
overlap: 0,
track_prefill_tokens: true, track_prefill_tokens: true,
expected_output_tokens: None, expected_output_tokens: None,
prefill_load_hint: None, prefill_load_hint: tracking_hint(12),
worker: WorkerWithDpRank::from_worker_id(0), worker: WorkerWithDpRank::from_worker_id(0),
lora_name: None, lora_name: None,
}, },
...@@ -366,11 +367,9 @@ mod tests { ...@@ -366,11 +367,9 @@ mod tests {
SequenceRequest { SequenceRequest {
request_id: "request_1".to_string(), request_id: "request_1".to_string(),
token_sequence: None, token_sequence: None,
isl: 8,
overlap: 0,
track_prefill_tokens: true, track_prefill_tokens: true,
expected_output_tokens: None, expected_output_tokens: None,
prefill_load_hint: None, prefill_load_hint: tracking_hint(8),
worker: WorkerWithDpRank::from_worker_id(1), worker: WorkerWithDpRank::from_worker_id(1),
lora_name: None, lora_name: None,
}, },
...@@ -381,11 +380,9 @@ mod tests { ...@@ -381,11 +380,9 @@ mod tests {
SequenceRequest { SequenceRequest {
request_id: "request_2".to_string(), request_id: "request_2".to_string(),
token_sequence: None, token_sequence: None,
isl: 16,
overlap: 0,
track_prefill_tokens: true, track_prefill_tokens: true,
expected_output_tokens: None, expected_output_tokens: None,
prefill_load_hint: None, prefill_load_hint: tracking_hint(16),
worker: WorkerWithDpRank::from_worker_id(2), worker: WorkerWithDpRank::from_worker_id(2),
lora_name: None, lora_name: None,
}, },
......
...@@ -516,8 +516,6 @@ impl OfflineReplayRouter { ...@@ -516,8 +516,6 @@ impl OfflineReplayRouter {
SequenceRequest { SequenceRequest {
request_id, request_id,
token_sequence: request.token_seq, token_sequence: request.token_seq,
isl: request.isl_tokens,
overlap: selection.overlap_blocks,
track_prefill_tokens: request.track_prefill_tokens, track_prefill_tokens: request.track_prefill_tokens,
expected_output_tokens: request.expected_output_tokens, expected_output_tokens: request.expected_output_tokens,
prefill_load_hint, prefill_load_hint,
...@@ -583,24 +581,25 @@ impl OfflineReplayRouter { ...@@ -583,24 +581,25 @@ impl OfflineReplayRouter {
return None; return None;
} }
let Some(estimator) = &self.prefill_load_estimator else { let expected_prefill_duration = match &self.prefill_load_estimator {
return None; Some(estimator) => match estimator.predict_prefill_duration(1, effective_isl, prefix) {
Ok(expected_prefill_duration) => Some(expected_prefill_duration),
Err(error) => {
tracing::warn!(
effective_isl,
prefix,
"failed to predict replay prefill duration for active load tracking: {error}"
);
None
}
},
None => None,
}; };
match estimator.predict_prefill_duration(1, effective_isl, prefix) { Some(PrefillLoadHint {
Ok(expected_prefill_duration) => Some(PrefillLoadHint { initial_effective_prefill_tokens: effective_isl,
initial_effective_prefill_tokens: effective_isl, expected_prefill_duration,
expected_prefill_duration: Some(expected_prefill_duration), })
}),
Err(error) => {
tracing::warn!(
effective_isl,
prefix,
"failed to predict replay prefill duration for active load tracking: {error}"
);
None
}
}
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment