Unverified Commit 134d484d authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(kv-router): add prompt membership index for scheduler reads (#8175)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent d0d9c030
...@@ -2,35 +2,50 @@ ...@@ -2,35 +2,50 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use dynamo_tokens::SequenceHash; use dynamo_tokens::SequenceHash;
use std::collections::HashMap; use rustc_hash::FxHashMap;
use std::sync::{Arc, Weak}; use std::sync::{Arc, Weak};
#[derive(Debug)]
pub(super) struct BlockAcquire {
pub(super) rc: Arc<()>,
pub(super) became_present_on_worker: bool,
}
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub(super) struct BlockTracker { pub(super) struct BlockTracker {
pub(super) unique_blocks: HashMap<SequenceHash, Weak<()>>, pub(super) unique_blocks: FxHashMap<SequenceHash, Weak<()>>,
pub(super) fractional_blocks: HashMap<SequenceHash, f64>, pub(super) fractional_blocks: FxHashMap<SequenceHash, f64>,
} }
impl BlockTracker { impl BlockTracker {
pub(super) fn touch_block(&mut self, block: &SequenceHash) -> Arc<()> { pub(super) fn touch_block(&mut self, block: &SequenceHash) -> BlockAcquire {
if let Some(weak) = self.unique_blocks.get(block) if let Some(weak) = self.unique_blocks.get(block)
&& let Some(rc) = weak.upgrade() && let Some(rc) = weak.upgrade()
{ {
return rc; return BlockAcquire {
rc,
became_present_on_worker: false,
};
} }
let rc = Arc::new(()); let rc = Arc::new(());
self.unique_blocks.insert(*block, Arc::downgrade(&rc)); self.unique_blocks.insert(*block, Arc::downgrade(&rc));
rc BlockAcquire {
rc,
became_present_on_worker: true,
}
} }
pub(super) fn try_remove_block(&mut self, block: &SequenceHash) { pub(super) fn try_remove_block(&mut self, block: &SequenceHash) -> bool {
if let Some(weak) = self.unique_blocks.get(block) if let Some(weak) = self.unique_blocks.get(block)
&& weak.strong_count() == 0 && weak.strong_count() == 0
{ {
self.unique_blocks.remove(block); self.unique_blocks.remove(block);
self.fractional_blocks.remove(block); self.fractional_blocks.remove(block);
return true;
} }
false
} }
pub(super) fn active_blocks(&self) -> usize { pub(super) fn active_blocks(&self) -> usize {
...@@ -43,3 +58,68 @@ impl BlockTracker { ...@@ -43,3 +58,68 @@ impl BlockTracker {
count.round() as usize count.round() as usize
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn first_touch_and_last_remove_report_presence_transitions() {
let mut tracker = BlockTracker::default();
let first = tracker.touch_block(&1);
let second = tracker.touch_block(&1);
assert!(first.became_present_on_worker);
assert!(!second.became_present_on_worker);
assert_eq!(tracker.active_blocks(), 1);
drop(first.rc);
assert!(!tracker.try_remove_block(&1));
assert_eq!(tracker.active_blocks(), 1);
drop(second.rc);
assert!(tracker.try_remove_block(&1));
assert_eq!(tracker.active_blocks(), 0);
}
#[test]
fn fractional_blocks_adjust_active_block_count() {
let mut tracker = BlockTracker::default();
let first = tracker.touch_block(&1);
let second = tracker.touch_block(&2);
tracker.fractional_blocks.insert(1, 0.5);
tracker.fractional_blocks.insert(2, 0.5);
assert_eq!(tracker.active_blocks(), 1);
drop(first.rc);
assert!(tracker.try_remove_block(&1));
assert!(!tracker.fractional_blocks.contains_key(&1));
assert_eq!(tracker.active_blocks(), 1);
drop(second.rc);
assert!(tracker.try_remove_block(&2));
assert!(tracker.fractional_blocks.is_empty());
assert_eq!(tracker.active_blocks(), 0);
}
#[test]
fn shared_block_counts_once_until_last_reference_drops() {
let mut tracker = BlockTracker::default();
let first = tracker.touch_block(&7);
let second = tracker.touch_block(&7);
let third = tracker.touch_block(&7);
assert_eq!(tracker.active_blocks(), 1);
drop(first.rc);
drop(second.rc);
assert!(!tracker.try_remove_block(&7));
assert_eq!(tracker.active_blocks(), 1);
drop(third.rc);
assert!(tracker.try_remove_block(&7));
assert_eq!(tracker.active_blocks(), 0);
}
}
...@@ -4,7 +4,11 @@ ...@@ -4,7 +4,11 @@
mod block_tracker; mod block_tracker;
pub mod multi_worker; pub mod multi_worker;
mod prefill_tracker; mod prefill_tracker;
mod prompt_membership_trie;
mod prompt_registry;
mod request_maps;
pub mod single; pub mod single;
mod topology;
pub use multi_worker::*; pub use multi_worker::*;
pub use single::*; pub use single::*;
...@@ -9,17 +9,20 @@ ...@@ -9,17 +9,20 @@
//! transport (e.g., NATS EventPublisher, Prometheus gauges) so that all business logic lives in //! transport (e.g., NATS EventPublisher, Prometheus gauges) so that all business logic lives in
//! this crate while the runtime glue stays in `lib/llm`. //! this crate while the runtime glue stays in `lib/llm`.
use dashmap::DashMap;
use dynamo_tokens::SequenceHash; use dynamo_tokens::SequenceHash;
use parking_lot::RwLock; use parking_lot::RwLock;
use std::collections::{HashMap, HashSet}; use rustc_hash::FxHashMap;
use std::collections::HashMap;
use std::future::Future; use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::watch; use tokio::sync::watch;
use tokio::time::{Duration, Instant}; use tokio::time::{Duration, Instant};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use super::single::{ActiveSequences, RequestId}; use super::prompt_registry::{PromptRegistry, WorkerLoadSnapshot};
use super::request_maps::RequestIndex;
use super::single::{ActiveSequences, PromptMembershipDelta, RequestId};
use super::topology::WorkerTable;
use crate::protocols::{ use crate::protocols::{
ActiveLoad, ActiveSequenceEvent, ActiveSequenceEventData, OverlapScores, PrefillLoadHint, ActiveLoad, ActiveSequenceEvent, ActiveSequenceEventData, OverlapScores, PrefillLoadHint,
WorkerWithDpRank, WorkerWithDpRank,
...@@ -100,35 +103,6 @@ pub struct SequenceRequest { ...@@ -100,35 +103,6 @@ pub struct SequenceRequest {
pub lora_name: Option<String>, pub lora_name: Option<String>,
} }
// ---------------------------------------------------------------------------
// WorkerTable
// ---------------------------------------------------------------------------
struct WorkerTable {
slots: Vec<(WorkerWithDpRank, RwLock<ActiveSequences>)>,
index: HashMap<WorkerWithDpRank, usize>,
}
impl WorkerTable {
fn new(block_size: usize, dp_range: &HashMap<u64, (u32, u32)>) -> Self {
let mut slots = Vec::new();
let mut index = HashMap::new();
for (&worker_id, &(dp_start, dp_size)) in dp_range {
for dp_rank in dp_start..dp_start + dp_size {
let worker = WorkerWithDpRank::new(worker_id, dp_rank);
let idx = slots.len();
slots.push((worker, RwLock::new(ActiveSequences::new(block_size))));
index.insert(worker, idx);
}
}
Self { slots, index }
}
}
// ---------------------------------------------------------------------------
// ActiveSequencesMultiWorker
// ---------------------------------------------------------------------------
/// Multi-worker extension of [`ActiveSequences`] with per-worker `parking_lot::RwLock` for /// Multi-worker extension of [`ActiveSequences`] with per-worker `parking_lot::RwLock` for
/// fine-grained concurrent access. /// fine-grained concurrent access.
/// ///
...@@ -140,8 +114,8 @@ impl WorkerTable { ...@@ -140,8 +114,8 @@ impl WorkerTable {
/// and metrics infrastructure. /// and metrics infrastructure.
pub struct ActiveSequencesMultiWorker<P: SequencePublisher> { pub struct ActiveSequencesMultiWorker<P: SequencePublisher> {
workers: RwLock<WorkerTable>, workers: RwLock<WorkerTable>,
request_to_worker: DashMap<RequestId, WorkerWithDpRank>, request_index: RequestIndex,
request_to_lora: DashMap<RequestId, String>, prompt_registry: PromptRegistry,
block_size: usize, block_size: usize,
router_id: u64, router_id: u64,
publisher: Arc<P>, publisher: Arc<P>,
...@@ -164,11 +138,13 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -164,11 +138,13 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
) -> Self { ) -> Self {
assert!(block_size > 1, "block_size must be greater than 1"); assert!(block_size > 1, "block_size must be greater than 1");
let (remote_state_updates, _) = watch::channel(()); let (remote_state_updates, _) = watch::channel(());
let workers = WorkerTable::new(block_size, &dp_range);
let prompt_registry = PromptRegistry::new(workers.workers());
Self { Self {
workers: RwLock::new(WorkerTable::new(block_size, &dp_range)), workers: RwLock::new(workers),
request_to_worker: DashMap::new(), request_index: RequestIndex::default(),
request_to_lora: DashMap::new(), prompt_registry,
block_size, block_size,
router_id, router_id,
publisher: Arc::new(publisher), publisher: Arc::new(publisher),
...@@ -178,6 +154,59 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -178,6 +154,59 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
} }
} }
#[cfg(any(test, feature = "bench"))]
pub fn assert_completely_drained(&self, decay_now: Instant) {
let active_blocks = self.active_blocks();
assert!(
active_blocks.values().all(|&count| count == 0),
"expected all workers to have zero active blocks, got {active_blocks:?}",
);
let active_tokens = self.active_tokens(decay_now);
assert!(
active_tokens.values().all(|&count| count == 0),
"expected all workers to have zero active tokens, got {active_tokens:?}",
);
assert!(
self.request_index.is_empty(),
"expected no active request-to-worker mappings, found {}",
self.request_index.worker_len(),
);
assert!(
self.get_active_lora_counts().is_empty(),
"expected no active LoRA counts, found {:?}",
self.get_active_lora_counts(),
);
assert!(
self.prompt_registry.is_block_index_empty(),
"expected reverse block index to be empty after drain",
);
}
fn publish_worker_load_snapshot(
&self,
worker: WorkerWithDpRank,
load: WorkerLoadSnapshot,
decay_now: Instant,
) {
let active_blocks = load.active_blocks;
let active_tokens = load.active_tokens(decay_now);
self.publisher
.observe_load(&worker, self.worker_type, active_blocks, active_tokens);
let active_load = ActiveLoad {
worker_id: worker.worker_id,
dp_rank: worker.dp_rank,
active_decode_blocks: Some(active_blocks as u64),
active_prefill_tokens: Some(active_tokens as u64),
kv_used_blocks: None,
};
self.publisher.publish_load(active_load);
}
fn spawn_publish_event(&self, event: ActiveSequenceEvent) { fn spawn_publish_event(&self, event: ActiveSequenceEvent) {
if !self.replica_sync { if !self.replica_sync {
return; return;
...@@ -257,26 +286,40 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -257,26 +286,40 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
expected_output_tokens, expected_output_tokens,
prefill_load_hint, prefill_load_hint,
} => { } => {
self.request_to_worker self.ensure_worker_registered(event.worker);
.insert(event.request_id.clone(), event.worker);
if let Some(ref lora_name) = event.lora_name {
self.request_to_lora
.insert(event.request_id.clone(), lora_name.clone());
}
let table = self.workers.read(); let table = self.workers.read();
if let Some(&idx) = table.index.get(&event.worker) { if let Some(&idx) = table.index.get(&event.worker) {
table.slots[idx].1.write().add_request_with_prefill_tracking( self.request_index.set_request(
event.request_id.clone(), event.request_id.clone(),
token_sequence.clone(), event.worker,
*isl, event.lora_name.clone(),
*overlap,
*expected_output_tokens,
*track_prefill_tokens,
*prefill_load_hint,
decay_now,
); );
let (expired_request_ids, load) = {
let slot = &table.slots[idx];
let mut seq = slot.sequences.write();
let outcome = seq.add_request_with_prefill_tracking(
event.request_id.clone(),
token_sequence.clone(),
*isl,
*overlap,
*expected_output_tokens,
*track_prefill_tokens,
*prefill_load_hint,
decay_now,
);
let load = seq.worker_load_snapshot();
self.prompt_registry.apply_membership_delta_and_load(
event.worker,
&slot.trie_lookup,
outcome.membership_delta,
load,
);
(outcome.expired_request_ids, load)
};
drop(table);
self.request_index.remove_requests(expired_request_ids.iter());
self.publish_worker_load_snapshot(event.worker, load, decay_now);
continue;
} else { } else {
tracing::warn!( tracing::warn!(
"Worker {:?} not found, cannot process AddRequest", "Worker {:?} not found, cannot process AddRequest",
...@@ -285,27 +328,40 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -285,27 +328,40 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
} }
} }
ActiveSequenceEventData::Free => { ActiveSequenceEventData::Free => {
if let Some((_, worker)) = if let Some(worker) = self.request_index.remove_request(&event.request_id) {
self.request_to_worker.remove(&event.request_id)
{
let table = self.workers.read(); let table = self.workers.read();
if let Some(&idx) = table.index.get(&worker) { if let Some(&idx) = table.index.get(&worker) {
table.slots[idx].1.write().free(&event.request_id, decay_now); let load = {
let slot = &table.slots[idx];
let mut seq = slot.sequences.write();
let delta = seq.free(&event.request_id, decay_now);
let load = seq.worker_load_snapshot();
self.prompt_registry.apply_membership_delta_and_load(
worker,
&slot.trie_lookup,
delta,
load,
);
load
};
drop(table);
self.publish_worker_load_snapshot(worker, load, decay_now);
remote_capacity_changed = true; remote_capacity_changed = true;
} }
} }
self.request_to_lora.remove(&event.request_id);
} }
ActiveSequenceEventData::MarkPrefillCompleted => { ActiveSequenceEventData::MarkPrefillCompleted => {
let worker = let worker = self.request_index.worker_for(&event.request_id);
self.request_to_worker.get(&event.request_id).map(|r| *r);
if let Some(worker) = worker { if let Some(worker) = worker {
let table = self.workers.read(); let table = self.workers.read();
if let Some(&idx) = table.index.get(&worker) { if let Some(&idx) = table.index.get(&worker) {
table.slots[idx] {
.1 let mut seq = table.slots[idx].sequences.write();
.write() seq.mark_prefill_completed(&event.request_id, decay_now);
.mark_prefill_completed(&event.request_id, decay_now); let load = seq.worker_load_snapshot();
self.prompt_registry.replace_worker_load_state(worker, load);
}
drop(table);
remote_capacity_changed = true; remote_capacity_changed = true;
} }
} }
...@@ -337,151 +393,35 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -337,151 +393,35 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
/// Worker removal in External mode will be handled separately via GAIE /// Worker removal in External mode will be handled separately via GAIE
/// lifecycle events (not yet implemented). TODO (atchernych) once we upgrade to GAIE latest. /// lifecycle events (not yet implemented). TODO (atchernych) once we upgrade to GAIE latest.
pub fn register_external_workers(&self, dp_range: &HashMap<u64, (u32, u32)>) { pub fn register_external_workers(&self, dp_range: &HashMap<u64, (u32, u32)>) {
let mut table = self.workers.write(); let change = {
for (&worker_id, &(dp_start, dp_size)) in dp_range { let mut table = self.workers.write();
for dp_rank in dp_start..(dp_start + dp_size) { table.register_external(self.block_size, dp_range)
let worker = WorkerWithDpRank::new(worker_id, dp_rank); };
if !table.index.contains_key(&worker) {
tracing::debug!("Lazily registering external worker {:?}", worker); for worker in &change.added {
let idx = table.slots.len(); tracing::debug!("Lazily registering external worker {:?}", worker);
table
.slots
.push((worker, RwLock::new(ActiveSequences::new(self.block_size))));
table.index.insert(worker, idx);
}
}
} }
self.prompt_registry.apply_topology_change(change);
} }
/// Update the set of workers, adding and removing as needed. /// Update the set of workers, adding and removing as needed.
/// ///
/// `new_dp_range` maps worker IDs to their data-parallel range (start, size). /// `new_dp_range` maps worker IDs to their data-parallel range (start, size).
pub fn update_workers(&self, new_dp_range: &HashMap<u64, (u32, u32)>) { pub fn update_workers(&self, new_dp_range: &HashMap<u64, (u32, u32)>) {
let mut table = self.workers.write(); let change = {
let mut target_workers: HashSet<WorkerWithDpRank> = HashSet::new();
for (&worker_id, &(dp_start, dp_size)) in new_dp_range {
for dp_rank in dp_start..(dp_start + dp_size) {
target_workers.insert(WorkerWithDpRank::new(worker_id, dp_rank));
}
}
// Clean up request mappings for workers being removed.
for (worker, _) in &table.slots {
if target_workers.contains(worker) {
continue;
}
tracing::warn!("Removing worker {:?}", worker);
let requests_to_remove: Vec<RequestId> = self
.request_to_worker
.iter()
.filter(|entry| entry.value() == worker)
.map(|entry| entry.key().clone())
.collect();
self.request_to_worker
.retain(|_request_id, mapped_worker| mapped_worker != worker);
for request_id in requests_to_remove {
self.request_to_lora.remove(&request_id);
}
}
// Drain old slots, preserving ActiveSequences for retained workers.
let mut old: HashMap<WorkerWithDpRank, ActiveSequences> = table
.slots
.drain(..)
.map(|(w, lock)| (w, lock.into_inner()))
.collect();
table.index.clear();
// Rebuild with target workers, reusing state where possible.
for worker in target_workers {
if !old.contains_key(&worker) {
tracing::warn!("Adding worker {:?}", worker);
}
let idx = table.slots.len();
let seq = old
.remove(&worker)
.unwrap_or_else(|| ActiveSequences::new(self.block_size));
table.slots.push((worker, RwLock::new(seq)));
table.index.insert(worker, idx);
}
}
fn add_request_local(
&self,
req: SequenceRequest,
decay_now: Instant,
) -> Result<(), SequenceError> {
let SequenceRequest {
request_id,
token_sequence,
isl,
overlap,
track_prefill_tokens,
expected_output_tokens,
prefill_load_hint,
worker,
lora_name,
} = req;
if !self.workers.read().index.contains_key(&worker) {
// The selector already picked this worker from the discovery watch,
// but the slot tracker hasn't been updated yet. Lazily register it
// so we don't drop tracking for this request.
let mut table = self.workers.write(); let mut table = self.workers.write();
if !table.index.contains_key(&worker) { table.reconcile(self.block_size, new_dp_range)
tracing::debug!(?worker, "Lazily registering worker in slot tracker");
let idx = table.slots.len();
table
.slots
.push((worker, RwLock::new(ActiveSequences::new(self.block_size))));
table.index.insert(worker, idx);
}
}
if let Some(existing_worker) = self.request_to_worker.get(&request_id) {
return Err(SequenceError::DuplicateRequest {
request_id,
worker: *existing_worker,
});
}
self.request_to_worker.insert(request_id.clone(), worker);
if let Some(lora) = lora_name {
self.request_to_lora.insert(request_id.clone(), lora);
}
let removed_requests = {
let table = self.workers.read();
let &idx = table
.index
.get(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?;
let mut seq = table.slots[idx].1.write();
seq.add_request_with_prefill_tracking(
request_id,
token_sequence,
isl,
overlap,
expected_output_tokens,
track_prefill_tokens,
prefill_load_hint,
decay_now,
)
}; };
for expired_id in &removed_requests { for removed in &change.removed {
self.request_to_worker.remove(expired_id); tracing::warn!("Removing worker {:?}", removed.worker);
self.request_to_lora.remove(expired_id); self.request_index.remove_worker_requests(removed.worker);
}
for worker in &change.added {
tracing::warn!("Adding worker {:?}", worker);
} }
self.publish_active_load_for_worker(worker, decay_now); self.prompt_registry.apply_topology_change(change);
Ok(())
} }
pub fn add_request( pub fn add_request(
...@@ -489,7 +429,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -489,7 +429,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
req: SequenceRequest, req: SequenceRequest,
decay_now: Instant, decay_now: Instant,
) -> Result<(), SequenceError> { ) -> Result<(), SequenceError> {
self.spawn_publish_event(ActiveSequenceEvent { let event = ActiveSequenceEvent {
request_id: req.request_id.clone(), request_id: req.request_id.clone(),
worker: req.worker, worker: req.worker,
data: ActiveSequenceEventData::AddRequest { data: ActiveSequenceEventData::AddRequest {
...@@ -502,78 +442,12 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -502,78 +442,12 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
}, },
router_id: self.router_id, router_id: self.router_id,
lora_name: req.lora_name.clone(), lora_name: req.lora_name.clone(),
}); };
self.add_request_local(req, decay_now) self.add_request_local(req, decay_now)?;
} self.spawn_publish_event(event);
/// Send a mutation to the worker assigned to a request, optionally publishing
/// a replica-sync event and cleaning up request mappings afterward.
fn mutate_request_worker_local(
&self,
request_id: &RequestId,
decay_now: Instant,
mutate_fn: impl FnOnce(&mut ActiveSequences, &RequestId, Instant),
remove_mapping: bool,
) -> Result<(), SequenceError> {
let worker = self
.request_to_worker
.get(request_id)
.map(|entry| *entry)
.ok_or_else(|| SequenceError::RequestNotFound {
request_id: request_id.clone(),
})?;
{
let table = self.workers.read();
let &idx = table
.index
.get(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?;
let mut seq = table.slots[idx].1.write();
mutate_fn(&mut seq, request_id, decay_now);
}
if remove_mapping {
self.request_to_worker.remove(request_id);
self.request_to_lora.remove(request_id);
}
self.publish_active_load_for_worker(worker, decay_now);
Ok(()) Ok(())
} }
fn mutate_request_worker(
&self,
request_id: &RequestId,
decay_now: Instant,
event_data: ActiveSequenceEventData,
mutate_fn: impl FnOnce(&mut ActiveSequences, &RequestId, Instant),
remove_mapping: bool,
) -> Result<(), SequenceError> {
let worker = self
.request_to_worker
.get(request_id)
.map(|entry| *entry)
.ok_or_else(|| SequenceError::RequestNotFound {
request_id: request_id.clone(),
})?;
let lora_name = self
.request_to_lora
.get(request_id)
.map(|entry| entry.value().clone());
self.spawn_publish_event(ActiveSequenceEvent {
request_id: request_id.clone(),
worker,
data: event_data,
router_id: self.router_id,
lora_name,
});
self.mutate_request_worker_local(request_id, decay_now, mutate_fn, remove_mapping)
}
/// Free all blocks associated with a request. /// Free all blocks associated with a request.
/// ///
/// Note: This operation is idempotent. Calling it multiple times for the same request /// Note: This operation is idempotent. Calling it multiple times for the same request
...@@ -583,20 +457,20 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -583,20 +457,20 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
/// [`ActiveSequences::free`], so callers do not need to call /// [`ActiveSequences::free`], so callers do not need to call
/// [`Self::mark_prefill_completed`] before freeing a completed request. /// [`Self::mark_prefill_completed`] before freeing a completed request.
pub fn free(&self, request_id: &RequestId, decay_now: Instant) -> Result<(), SequenceError> { pub fn free(&self, request_id: &RequestId, decay_now: Instant) -> Result<(), SequenceError> {
if !self.request_to_worker.contains_key(request_id) { match self.mutate_request_worker_prompt_state(
tracing::debug!("Request {request_id} not found, already freed (idempotent)");
return Ok(());
}
self.mutate_request_worker(
request_id, request_id,
decay_now, decay_now,
ActiveSequenceEventData::Free, ActiveSequenceEventData::Free,
|seqs, rid, decay_now| { |seqs, rid, decay_now| seqs.free(rid, decay_now),
seqs.free(rid, decay_now);
},
true, true,
) ) {
Ok(()) => Ok(()),
Err(SequenceError::RequestNotFound { .. }) => {
tracing::debug!("Request {request_id} not found, already freed (idempotent)");
Ok(())
}
Err(err) => Err(err),
}
} }
/// Mark prefill as completed for a request. /// Mark prefill as completed for a request.
...@@ -608,14 +482,13 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -608,14 +482,13 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
request_id: &RequestId, request_id: &RequestId,
decay_now: Instant, decay_now: Instant,
) -> Result<(), SequenceError> { ) -> Result<(), SequenceError> {
self.mutate_request_worker( self.mutate_request_worker_load_state(
request_id, request_id,
decay_now, decay_now,
ActiveSequenceEventData::MarkPrefillCompleted, ActiveSequenceEventData::MarkPrefillCompleted,
|seqs, rid, decay_now| { |seqs, rid, decay_now| {
seqs.mark_prefill_completed(rid, decay_now); seqs.mark_prefill_completed(rid, decay_now);
}, },
false,
) )
} }
...@@ -630,63 +503,37 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -630,63 +503,37 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
request_id: &RequestId, request_id: &RequestId,
decay_fraction: Option<f64>, decay_fraction: Option<f64>,
) -> Result<(), SequenceError> { ) -> Result<(), SequenceError> {
let worker = self let worker = self.request_index.worker_for(request_id).ok_or_else(|| {
.request_to_worker SequenceError::RequestNotFound {
.get(request_id)
.map(|entry| *entry)
.ok_or_else(|| SequenceError::RequestNotFound {
request_id: request_id.clone(),
})?;
let success = {
let table = self.workers.read();
let &idx = table
.index
.get(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?;
let mut seq = table.slots[idx].1.write();
seq.add_output_block(request_id, decay_fraction)
};
if !success {
return Err(SequenceError::RequestNotFound {
request_id: request_id.clone(), request_id: request_id.clone(),
}); }
} })?;
self.publish_active_load_for_worker(worker, Instant::now());
Ok(())
}
/// Read active blocks/tokens from a worker and publish ActiveLoad metrics. let load = {
fn publish_active_load_for_worker(&self, worker: WorkerWithDpRank, decay_now: Instant) {
let (active_blocks, active_tokens) = {
let table = self.workers.read(); let table = self.workers.read();
let Some(&idx) = table.index.get(&worker) else { let Some(&idx) = table.index.get(&worker) else {
tracing::warn!("Worker {worker:?} not found when publishing ActiveLoad"); drop(table);
return; return Err(self.stale_request_not_found(request_id, worker, "add_output_block"));
}; };
let seq = table.slots[idx].1.read(); let mut seq = table.slots[idx].sequences.write();
(seq.active_blocks(), seq.active_tokens(decay_now)) let Some(_new_block_hash) = seq.add_output_block(request_id, decay_fraction) else {
return Err(SequenceError::RequestNotFound {
request_id: request_id.clone(),
});
};
let load = seq.worker_load_snapshot();
self.prompt_registry.replace_worker_load_state(worker, load);
load
}; };
self.publisher self.publish_worker_load_snapshot(worker, load, Instant::now());
.observe_load(&worker, self.worker_type, active_blocks, active_tokens);
let active_load = ActiveLoad {
worker_id: worker.worker_id,
dp_rank: worker.dp_rank,
active_decode_blocks: Some(active_blocks as u64),
active_prefill_tokens: Some(active_tokens as u64),
kv_used_blocks: None,
};
self.publisher.publish_load(active_load); Ok(())
} }
/// Get the number of workers. /// Get the number of workers.
pub fn num_workers(&self) -> usize { #[cfg(test)]
pub(crate) fn num_workers(&self) -> usize {
self.workers.read().slots.len() self.workers.read().slots.len()
} }
...@@ -699,8 +546,11 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -699,8 +546,11 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
pub fn new_blocks(&self, token_sequence: &[SequenceHash]) -> HashMap<WorkerWithDpRank, usize> { pub fn new_blocks(&self, token_sequence: &[SequenceHash]) -> HashMap<WorkerWithDpRank, usize> {
let table = self.workers.read(); let table = self.workers.read();
let mut results = HashMap::with_capacity(table.slots.len()); let mut results = HashMap::with_capacity(table.slots.len());
for (worker, lock) in &table.slots { for slot in &table.slots {
results.insert(*worker, lock.read().new_blocks(token_sequence)); results.insert(
slot.worker,
slot.sequences.read().new_blocks(token_sequence),
);
} }
results results
} }
...@@ -712,8 +562,11 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -712,8 +562,11 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
) -> HashMap<WorkerWithDpRank, usize> { ) -> HashMap<WorkerWithDpRank, usize> {
let table = self.workers.read(); let table = self.workers.read();
let mut results = HashMap::with_capacity(table.slots.len()); let mut results = HashMap::with_capacity(table.slots.len());
for (worker, lock) in &table.slots { for slot in &table.slots {
results.insert(*worker, lock.read().potential_blocks(token_sequence)); results.insert(
slot.worker,
slot.sequences.read().potential_blocks(token_sequence),
);
} }
results results
} }
...@@ -726,8 +579,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -726,8 +579,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
overlaps: OverlapScores, overlaps: OverlapScores,
decay_now: Instant, decay_now: Instant,
) -> ( ) -> (
HashMap<WorkerWithDpRank, usize>, FxHashMap<WorkerWithDpRank, usize>,
HashMap<WorkerWithDpRank, usize>, FxHashMap<WorkerWithDpRank, usize>,
) { ) {
self.potential_blocks_and_tokens_with_prefill_tracking( self.potential_blocks_and_tokens_with_prefill_tracking(
token_sequence, token_sequence,
...@@ -746,91 +599,42 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -746,91 +599,42 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
track_prefill_tokens: bool, track_prefill_tokens: bool,
decay_now: Instant, decay_now: Instant,
) -> ( ) -> (
HashMap<WorkerWithDpRank, usize>, FxHashMap<WorkerWithDpRank, usize>,
HashMap<WorkerWithDpRank, usize>, FxHashMap<WorkerWithDpRank, usize>,
) { ) {
#[cfg(feature = "bench")] self.prompt_registry
let start = tokio::time::Instant::now(); .potential_blocks_and_tokens_with_prefill_tracking(
token_sequence,
let table = self.workers.read(); isl,
&overlaps,
#[cfg(feature = "bench")] track_prefill_tokens,
let num_workers = table.slots.len(); self.block_size,
decay_now,
let mut potential_blocks = HashMap::with_capacity(table.slots.len()); )
let mut potential_tokens = HashMap::with_capacity(table.slots.len());
for (worker, lock) in &table.slots {
let overlap = *overlaps.scores.get(worker).unwrap_or(&0);
let (blocks, tokens) = lock
.read()
.potential_blocks_and_tokens_with_prefill_tracking(
token_sequence,
isl,
overlap,
track_prefill_tokens,
decay_now,
);
potential_blocks.insert(*worker, blocks);
potential_tokens.insert(*worker, tokens);
}
#[cfg(feature = "bench")]
{
let total_elapsed = start.elapsed();
tracing::info!(
num_workers,
total_us = total_elapsed.as_micros() as u64,
"potential_blocks_and_tokens completed"
);
}
(potential_blocks, potential_tokens)
} }
/// Query all workers for their current number of active blocks. /// Query all workers for their current number of active blocks.
pub fn active_blocks(&self) -> HashMap<WorkerWithDpRank, usize> { pub fn active_blocks(&self) -> HashMap<WorkerWithDpRank, usize> {
let table = self.workers.read(); self.prompt_registry.active_blocks()
let mut results = HashMap::with_capacity(table.slots.len());
for (worker, lock) in &table.slots {
results.insert(*worker, lock.read().active_blocks());
}
results
} }
/// Query all workers for their current number of active tokens. /// Query all workers for their current number of active tokens.
pub fn active_tokens(&self, decay_now: Instant) -> HashMap<WorkerWithDpRank, usize> { pub fn active_tokens(&self, decay_now: Instant) -> HashMap<WorkerWithDpRank, usize> {
let table = self.workers.read(); self.prompt_registry.active_tokens(decay_now)
let mut results = HashMap::with_capacity(table.slots.len());
for (worker, lock) in &table.slots {
results.insert(*worker, lock.read().active_tokens(decay_now));
}
results
} }
/// Return true if any worker satisfies the provided predicate on active token count. /// Return true if any worker satisfies the provided predicate on active token count.
pub fn any_worker_matches_active_tokens( pub fn any_worker_matches_active_tokens(
&self, &self,
decay_now: Instant, decay_now: Instant,
mut predicate: impl FnMut(WorkerWithDpRank, usize) -> bool, predicate: impl FnMut(WorkerWithDpRank, usize) -> bool,
) -> bool { ) -> bool {
let table = self.workers.read(); self.prompt_registry
for (worker, lock) in &table.slots { .any_worker_matches_active_tokens(decay_now, predicate)
if predicate(*worker, lock.read().active_tokens(decay_now)) {
return true;
}
}
false
} }
pub fn get_active_lora_counts(&self) -> HashMap<String, usize> { pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
let mut counts: HashMap<String, usize> = HashMap::new(); self.request_index.active_lora_counts()
for entry in self.request_to_lora.iter() {
let lora_name = entry.value().clone();
*counts.entry(lora_name).or_insert(0) += 1;
}
counts
} }
/// Force expire stale requests across all workers (one-shot). /// Force expire stale requests across all workers (one-shot).
...@@ -844,17 +648,24 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -844,17 +648,24 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
let now = Instant::now(); let now = Instant::now();
let table = self.workers.read(); let table = self.workers.read();
let mut removed_request_count = 0; let mut removed_request_count = 0;
for (worker, lock) in &table.slots { for slot in &table.slots {
let removed_requests = lock.write().force_expiry(); let mut seq = slot.sequences.write();
if !removed_requests.is_empty() { let outcome = seq.force_expiry();
for expired_id in &removed_requests { if !outcome.expired_request_ids.is_empty() {
self.request_to_worker.remove(expired_id); let load = seq.worker_load_snapshot();
self.request_to_lora.remove(expired_id); self.prompt_registry.apply_membership_delta_and_load(
removed_request_count += 1; slot.worker,
} &slot.trie_lookup,
self.publish_active_load_for_worker(*worker, now); outcome.membership_delta,
load,
);
removed_request_count += outcome.expired_request_ids.len();
self.request_index
.remove_requests(outcome.expired_request_ids.iter());
self.publish_worker_load_snapshot(slot.worker, load, now);
} }
} }
drop(table);
let duration = now.elapsed(); let duration = now.elapsed();
tracing::debug!( tracing::debug!(
duration = duration.as_secs_f64(), duration = duration.as_secs_f64(),
...@@ -890,31 +701,336 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -890,31 +701,336 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
} }
}); });
} }
}
#[cfg(test)] fn ensure_worker_registered(&self, worker: WorkerWithDpRank) {
mod tests { if self.workers.read().index.contains_key(&worker) {
use std::collections::HashMap; return;
use std::time::Duration; }
use super::*; let mut table = self.workers.write();
use crate::protocols::{OverlapScores, PrefillLoadHint}; if table.index.contains_key(&worker) {
use crate::test_utils::NoopSequencePublisher; return;
}
fn make_sequences() -> ActiveSequencesMultiWorker<NoopSequencePublisher> { tracing::debug!(?worker, "Lazily registering worker in slot tracker");
ActiveSequencesMultiWorker::new( let change = table.ensure_worker(self.block_size, worker);
NoopSequencePublisher, drop(table);
4,
HashMap::from([(1_u64, (0_u32, 1_u32))]), self.prompt_registry.apply_topology_change(change);
false,
0,
"test",
)
} }
#[tokio::test] fn add_request_local(
async fn add_request_can_skip_prefill_token_tracking() { &self,
let sequences = make_sequences(); req: SequenceRequest,
decay_now: Instant,
) -> Result<(), SequenceError> {
let SequenceRequest {
request_id,
token_sequence,
isl,
overlap,
track_prefill_tokens,
expected_output_tokens,
prefill_load_hint,
worker,
lora_name,
} = req;
self.ensure_worker_registered(worker);
let (expired_request_ids, load) = {
let table = self.workers.read();
let &idx = table
.index
.get(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?;
if let Err(existing_worker) =
self.request_index
.try_insert_request(request_id.clone(), worker, lora_name)
{
return Err(SequenceError::DuplicateRequest {
request_id,
worker: existing_worker,
});
}
let slot = &table.slots[idx];
let mut seq = slot.sequences.write();
let outcome = seq.add_request_with_prefill_tracking(
request_id,
token_sequence,
isl,
overlap,
expected_output_tokens,
track_prefill_tokens,
prefill_load_hint,
decay_now,
);
let load = seq.worker_load_snapshot();
self.prompt_registry.apply_membership_delta_and_load(
worker,
&slot.trie_lookup,
outcome.membership_delta,
load,
);
(outcome.expired_request_ids, load)
};
self.request_index
.remove_requests(expired_request_ids.iter());
self.publish_worker_load_snapshot(worker, load, decay_now);
Ok(())
}
fn stale_request_not_found(
&self,
request_id: &RequestId,
worker: WorkerWithDpRank,
operation: &'static str,
) -> SequenceError {
if self.request_index.worker_for(request_id) == Some(worker) {
self.request_index.remove_request(request_id);
tracing::warn!(
%request_id,
?worker,
operation,
"request index referenced a missing worker slot; removed stale mapping"
);
} else {
tracing::warn!(
%request_id,
?worker,
operation,
"request worker slot disappeared before the mutation ran"
);
}
SequenceError::RequestNotFound {
request_id: request_id.clone(),
}
}
fn mutate_request_worker_prompt_state_local(
&self,
worker: WorkerWithDpRank,
request_id: &RequestId,
decay_now: Instant,
mutate_fn: impl FnOnce(&mut ActiveSequences, &RequestId, Instant) -> PromptMembershipDelta,
remove_mapping: bool,
) -> Result<(), SequenceError> {
let load = {
let table = self.workers.read();
let Some(&idx) = table.index.get(&worker) else {
drop(table);
return Err(self.stale_request_not_found(request_id, worker, "free_or_mutate"));
};
let slot = &table.slots[idx];
let mut seq = slot.sequences.write();
let delta = mutate_fn(&mut seq, request_id, decay_now);
let load = seq.worker_load_snapshot();
self.prompt_registry.apply_membership_delta_and_load(
worker,
&slot.trie_lookup,
delta,
load,
);
load
};
if remove_mapping {
self.request_index.remove_request(request_id);
}
self.publish_worker_load_snapshot(worker, load, decay_now);
Ok(())
}
fn mutate_request_worker_load_state_local(
&self,
worker: WorkerWithDpRank,
request_id: &RequestId,
decay_now: Instant,
mutate_fn: impl FnOnce(&mut ActiveSequences, &RequestId, Instant),
) -> Result<(), SequenceError> {
let load = {
let table = self.workers.read();
let Some(&idx) = table.index.get(&worker) else {
drop(table);
return Err(self.stale_request_not_found(request_id, worker, "load_only_mutate"));
};
let mut seq = table.slots[idx].sequences.write();
mutate_fn(&mut seq, request_id, decay_now);
let load = seq.worker_load_snapshot();
self.prompt_registry.replace_worker_load_state(worker, load);
load
};
self.publish_worker_load_snapshot(worker, load, decay_now);
Ok(())
}
fn mutate_request_worker_prompt_state(
&self,
request_id: &RequestId,
decay_now: Instant,
event_data: ActiveSequenceEventData,
mutate_fn: impl FnOnce(&mut ActiveSequences, &RequestId, Instant) -> PromptMembershipDelta,
remove_mapping: bool,
) -> Result<(), SequenceError> {
let worker = self.request_index.worker_for(request_id).ok_or_else(|| {
SequenceError::RequestNotFound {
request_id: request_id.clone(),
}
})?;
let lora_name = self.request_index.lora_for(request_id);
self.mutate_request_worker_prompt_state_local(
worker,
request_id,
decay_now,
mutate_fn,
remove_mapping,
)?;
self.spawn_publish_event(ActiveSequenceEvent {
request_id: request_id.clone(),
worker,
data: event_data,
router_id: self.router_id,
lora_name,
});
Ok(())
}
fn mutate_request_worker_load_state(
&self,
request_id: &RequestId,
decay_now: Instant,
event_data: ActiveSequenceEventData,
mutate_fn: impl FnOnce(&mut ActiveSequences, &RequestId, Instant),
) -> Result<(), SequenceError> {
let worker = self.request_index.worker_for(request_id).ok_or_else(|| {
SequenceError::RequestNotFound {
request_id: request_id.clone(),
}
})?;
let lora_name = self.request_index.lora_for(request_id);
self.mutate_request_worker_load_state_local(worker, request_id, decay_now, mutate_fn)?;
self.spawn_publish_event(ActiveSequenceEvent {
request_id: request_id.clone(),
worker,
data: event_data,
router_id: self.router_id,
lora_name,
});
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::collections::{HashMap, VecDeque};
use std::future::{self, Future};
use std::time::Duration;
use rustc_hash::FxHashMap;
use super::*;
use crate::protocols::{
ActiveSequenceEvent, ActiveSequenceEventData, BlockHashOptions, OverlapScores,
PrefillLoadHint, compute_block_hash_for_seq, compute_seq_hash_for_block,
};
use crate::test_utils::NoopSequencePublisher;
fn make_sequences() -> ActiveSequencesMultiWorker<NoopSequencePublisher> {
ActiveSequencesMultiWorker::new(
NoopSequencePublisher,
4,
HashMap::from([(1_u64, (0_u32, 1_u32))]),
false,
0,
"test",
)
}
fn make_multi_sequences() -> ActiveSequencesMultiWorker<NoopSequencePublisher> {
ActiveSequencesMultiWorker::new(
NoopSequencePublisher,
4,
HashMap::from([(1_u64, (0_u32, 1_u32)), (2_u64, (0_u32, 1_u32))]),
false,
0,
"test",
)
}
fn naive_potential_loads(
sequences: &ActiveSequencesMultiWorker<NoopSequencePublisher>,
token_sequence: Option<&[SequenceHash]>,
isl: usize,
overlaps: &OverlapScores,
track_prefill_tokens: bool,
decay_now: Instant,
) -> (
FxHashMap<WorkerWithDpRank, usize>,
FxHashMap<WorkerWithDpRank, usize>,
) {
let table = sequences.workers.read();
let mut potential_blocks = FxHashMap::default();
let mut potential_tokens = FxHashMap::default();
for slot in &table.slots {
let seq = slot.sequences.read();
let overlap_depth = token_sequence.map_or(0, |query| {
let active_hashes = seq.active_prompt_hashes();
query
.iter()
.position(|hash| !active_hashes.contains(hash))
.unwrap_or(query.len())
});
let new_blocks =
token_sequence.map_or(0, |query| query.len().saturating_sub(overlap_depth));
let overlap = *overlaps.scores.get(&slot.worker).unwrap_or(&0);
let added_tokens = if track_prefill_tokens {
seq.new_tokens(isl, overlap)
} else {
0
};
potential_blocks.insert(slot.worker, seq.active_blocks() + new_blocks);
potential_tokens.insert(slot.worker, seq.active_tokens(decay_now) + added_tokens);
}
(potential_blocks, potential_tokens)
}
fn seq_hashes_for_tokens(tokens: &[u32], lora_name: Option<&str>) -> Vec<SequenceHash> {
let block_hashes = compute_block_hash_for_seq(
tokens,
4,
BlockHashOptions {
lora_name,
..Default::default()
},
);
compute_seq_hash_for_block(&block_hashes)
}
struct VecSubscriber {
events: VecDeque<anyhow::Result<ActiveSequenceEvent>>,
}
impl SequenceSubscriber for VecSubscriber {
fn next_event(
&mut self,
) -> impl Future<Output = Option<anyhow::Result<ActiveSequenceEvent>>> + Send {
future::ready(self.events.pop_front())
}
}
#[tokio::test]
async fn add_request_can_skip_prefill_token_tracking() {
let sequences = make_sequences();
let worker = WorkerWithDpRank::new(1, 0); let worker = WorkerWithDpRank::new(1, 0);
let decay_now = Instant::now(); let decay_now = Instant::now();
...@@ -941,6 +1057,419 @@ mod tests { ...@@ -941,6 +1057,419 @@ mod tests {
); );
} }
#[test]
fn block_membership_index_matches_naive_loads_with_output_blocks_and_prefill_updates() {
let sequences = make_multi_sequences();
let worker_a = WorkerWithDpRank::new(1, 0);
let worker_b = WorkerWithDpRank::new(2, 0);
let decay_now = Instant::now();
sequences
.add_request(
SequenceRequest {
request_id: "req-a".to_string(),
token_sequence: Some(vec![1, 2, 3]),
isl: 12,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
prefill_load_hint: None,
worker: worker_a,
lora_name: None,
},
decay_now,
)
.unwrap();
sequences
.add_output_block(&"req-a".to_string(), Some(0.5))
.unwrap();
sequences
.mark_prefill_completed(&"req-a".to_string(), decay_now)
.unwrap();
sequences
.add_request(
SequenceRequest {
request_id: "req-b".to_string(),
token_sequence: Some(vec![1, 2, 4]),
isl: 12,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
prefill_load_hint: None,
worker: worker_b,
lora_name: None,
},
decay_now,
)
.unwrap();
let prompt = vec![1, 2, 3, 5];
let mut expected_overlaps = OverlapScores::new();
expected_overlaps.scores.insert(worker_a, 2);
expected_overlaps.scores.insert(worker_b, 1);
let expected = naive_potential_loads(
&sequences,
Some(&prompt),
16,
&expected_overlaps,
true,
decay_now,
);
let mut actual_overlaps = OverlapScores::new();
actual_overlaps.scores.insert(worker_a, 2);
actual_overlaps.scores.insert(worker_b, 1);
let actual = sequences.potential_blocks_and_tokens_with_prefill_tracking(
Some(&prompt),
16,
actual_overlaps,
true,
decay_now,
);
assert_eq!(actual.0, expected.0);
assert_eq!(actual.1, expected.1);
}
#[test]
fn lora_specific_sequence_hashes_do_not_cross_match() {
let sequences = make_multi_sequences();
let worker_a = WorkerWithDpRank::new(1, 0);
let worker_b = WorkerWithDpRank::new(2, 0);
let decay_now = Instant::now();
let tokens = [1_u32, 2, 3, 4, 5, 6, 7, 8];
let base_prompt = seq_hashes_for_tokens(&tokens, None);
let lora_prompt = seq_hashes_for_tokens(&tokens, Some("adapter-a"));
assert_ne!(base_prompt, lora_prompt);
sequences
.add_request(
SequenceRequest {
request_id: "base".to_string(),
token_sequence: Some(base_prompt.clone()),
isl: 8,
overlap: 0,
track_prefill_tokens: false,
expected_output_tokens: None,
prefill_load_hint: None,
worker: worker_a,
lora_name: None,
},
decay_now,
)
.unwrap();
sequences
.add_request(
SequenceRequest {
request_id: "lora".to_string(),
token_sequence: Some(lora_prompt),
isl: 8,
overlap: 0,
track_prefill_tokens: false,
expected_output_tokens: None,
prefill_load_hint: None,
worker: worker_b,
lora_name: Some("adapter-a".to_string()),
},
decay_now,
)
.unwrap();
let expected = naive_potential_loads(
&sequences,
Some(&base_prompt),
8,
&OverlapScores::default(),
false,
decay_now,
);
let actual = sequences.potential_blocks_and_tokens_with_prefill_tracking(
Some(&base_prompt),
8,
OverlapScores::default(),
false,
decay_now,
);
assert_eq!(actual.0, expected.0);
assert_eq!(actual.1, expected.1);
let active_blocks = sequences.active_blocks();
assert_eq!(
actual.0.get(&worker_b).copied(),
Some(active_blocks[&worker_b] + base_prompt.len()),
);
}
#[tokio::test(start_paused = true)]
async fn force_expiry_clears_block_membership_index() {
let sequences = make_multi_sequences();
let worker = WorkerWithDpRank::new(1, 0);
sequences
.add_request(
SequenceRequest {
request_id: "req-1".to_string(),
token_sequence: Some(vec![1, 2, 3]),
isl: 12,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
prefill_load_hint: None,
worker,
lora_name: None,
},
Instant::now(),
)
.unwrap();
tokio::time::advance(Duration::from_secs(331)).await;
sequences.force_expire_requests_across_all_workers();
assert!(sequences.request_index.is_empty());
assert!(sequences.prompt_registry.is_block_index_empty());
assert_eq!(sequences.active_blocks().get(&worker).copied(), Some(0));
}
#[tokio::test(start_paused = true)]
async fn expiry_then_immediate_readd_preserves_block_membership() {
let sequences = make_sequences();
let worker = WorkerWithDpRank::new(1, 0);
sequences
.add_request(
SequenceRequest {
request_id: "req-1".to_string(),
token_sequence: Some(vec![1, 2, 3]),
isl: 12,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
prefill_load_hint: None,
worker,
lora_name: None,
},
Instant::now(),
)
.unwrap();
tokio::time::advance(Duration::from_secs(331)).await;
sequences
.add_request(
SequenceRequest {
request_id: "req-2".to_string(),
token_sequence: Some(vec![1, 2, 3]),
isl: 12,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
prefill_load_hint: None,
worker,
lora_name: None,
},
Instant::now(),
)
.unwrap();
assert!(!sequences.prompt_registry.is_block_index_empty());
assert_eq!(sequences.active_blocks().get(&worker).copied(), Some(3));
let expected = naive_potential_loads(
&sequences,
Some(&[1, 2, 3]),
12,
&OverlapScores::default(),
false,
Instant::now(),
);
let actual = sequences.potential_blocks_and_tokens_with_prefill_tracking(
Some(&[1, 2, 3]),
12,
OverlapScores::default(),
false,
Instant::now(),
);
assert_eq!(actual, expected);
}
#[tokio::test]
async fn replica_sync_add_and_free_keep_block_membership_consistent() {
let sequences = ActiveSequencesMultiWorker::new(
NoopSequencePublisher,
4,
HashMap::from([(1_u64, (0_u32, 1_u32))]),
true,
0,
"test",
);
let worker = WorkerWithDpRank::new(1, 0);
let subscriber = VecSubscriber {
events: VecDeque::from(vec![
Ok(ActiveSequenceEvent {
request_id: "req-1".to_string(),
worker,
data: ActiveSequenceEventData::AddRequest {
token_sequence: Some(vec![1, 2, 3]),
isl: 12,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
prefill_load_hint: None,
},
router_id: 99,
lora_name: None,
}),
Ok(ActiveSequenceEvent {
request_id: "req-1".to_string(),
worker,
data: ActiveSequenceEventData::Free,
router_id: 99,
lora_name: None,
}),
]),
};
sequences
.run_replica_sync(subscriber, CancellationToken::new())
.await
.unwrap();
assert!(sequences.request_index.is_empty());
assert!(sequences.prompt_registry.is_block_index_empty());
assert_eq!(sequences.active_blocks().get(&worker).copied(), Some(0));
}
#[tokio::test]
async fn replica_sync_add_lazily_registers_missing_worker() {
let sequences = ActiveSequencesMultiWorker::new(
NoopSequencePublisher,
4,
HashMap::new(),
true,
0,
"test",
);
let worker = WorkerWithDpRank::new(1, 0);
let subscriber = VecSubscriber {
events: VecDeque::from(vec![Ok(ActiveSequenceEvent {
request_id: "req-1".to_string(),
worker,
data: ActiveSequenceEventData::AddRequest {
token_sequence: Some(vec![1, 2, 3]),
isl: 12,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
prefill_load_hint: None,
},
router_id: 99,
lora_name: None,
})]),
};
sequences
.run_replica_sync(subscriber, CancellationToken::new())
.await
.unwrap();
assert_eq!(sequences.num_workers(), 1);
assert_eq!(
sequences.request_index.worker_for(&"req-1".to_string()),
Some(worker)
);
assert!(!sequences.prompt_registry.is_block_index_empty());
assert_eq!(sequences.active_blocks().get(&worker).copied(), Some(3));
}
#[test]
fn worker_removal_then_readd_starts_with_empty_registry_state() {
let sequences = make_sequences();
let worker = WorkerWithDpRank::new(1, 0);
let decay_now = Instant::now();
sequences
.add_request(
SequenceRequest {
request_id: "req-1".to_string(),
token_sequence: Some(vec![1, 2, 3]),
isl: 12,
overlap: 0,
track_prefill_tokens: false,
expected_output_tokens: None,
prefill_load_hint: None,
worker,
lora_name: None,
},
decay_now,
)
.unwrap();
sequences.update_workers(&HashMap::new());
assert!(sequences.prompt_registry.is_block_index_empty());
assert!(sequences.active_blocks().is_empty());
assert!(sequences.request_index.is_empty());
sequences.update_workers(&HashMap::from([(1_u64, (0_u32, 1_u32))]));
assert_eq!(sequences.active_blocks().get(&worker).copied(), Some(0));
assert!(sequences.prompt_registry.is_block_index_empty());
}
#[test]
fn free_is_idempotent_after_request_is_removed() {
let sequences = make_sequences();
let worker = WorkerWithDpRank::new(1, 0);
let request_id = "req-1".to_string();
let decay_now = Instant::now();
sequences
.add_request(
SequenceRequest {
request_id: request_id.clone(),
token_sequence: Some(vec![1, 2, 3]),
isl: 12,
overlap: 0,
track_prefill_tokens: false,
expected_output_tokens: None,
prefill_load_hint: None,
worker,
lora_name: None,
},
decay_now,
)
.unwrap();
sequences.free(&request_id, decay_now).unwrap();
sequences.free(&request_id, decay_now).unwrap();
assert!(sequences.request_index.is_empty());
assert!(sequences.prompt_registry.is_block_index_empty());
assert_eq!(sequences.active_blocks().get(&worker).copied(), Some(0));
}
#[test]
fn free_cleans_stale_request_mapping_when_worker_slot_is_missing() {
let sequences = make_sequences();
let worker = WorkerWithDpRank::new(1, 0);
let request_id = "stale-request".to_string();
sequences.request_index.set_request(
request_id.clone(),
worker,
Some("adapter".to_string()),
);
{
let mut table = sequences.workers.write();
*table = WorkerTable::new(sequences.block_size, &HashMap::new());
}
sequences.free(&request_id, Instant::now()).unwrap();
assert!(sequences.request_index.is_empty());
}
#[test] #[test]
fn explicit_decay_time_drives_multi_worker_load_queries_consistently() { fn explicit_decay_time_drives_multi_worker_load_queries_consistently() {
let sequences = make_sequences(); let sequences = make_sequences();
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::collections::VecDeque; use std::collections::{HashMap, VecDeque};
use std::time::Duration; use std::time::Duration;
use tokio::time::Instant; use tokio::time::Instant;
use super::single::RequestId; use super::single::RequestId;
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(super) struct PrefillLoadState { pub(super) struct PrefillLoadState {
pub(super) initial_effective_prefill_tokens: usize, pub(super) initial_effective_prefill_tokens: usize,
pub(super) expected_prefill_duration: Option<Duration>, pub(super) expected_prefill_duration: Option<Duration>,
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(super) struct AnchoredPrefillSnapshot {
pub(super) initial_effective_prefill_tokens: usize,
pub(super) expected_prefill_duration: Option<Duration>,
pub(super) anchored_since: Instant,
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub(super) struct PrefillLoadSnapshot {
pub(super) prefill_full_tokens_sum: usize,
pub(super) anchored_prefill: Option<AnchoredPrefillSnapshot>,
}
impl PrefillLoadSnapshot {
pub(super) fn active_tokens_at(&self, now: Instant) -> usize {
let Some(anchored_prefill) = self.anchored_prefill else {
return 0;
};
let anchored_full = anchored_prefill.initial_effective_prefill_tokens;
let anchored_remaining = match anchored_prefill.expected_prefill_duration {
None => anchored_full,
Some(expected_prefill_duration) if expected_prefill_duration.is_zero() => 0,
Some(expected_prefill_duration) => {
let elapsed = now.saturating_duration_since(anchored_prefill.anchored_since);
let remaining_fraction = (1.0
- (elapsed.as_secs_f64() / expected_prefill_duration.as_secs_f64()))
.clamp(0.0, 1.0);
((anchored_full as f64) * remaining_fraction).ceil() as usize
}
};
self.prefill_full_tokens_sum
.checked_sub(anchored_full)
.expect("prefill_full_tokens_sum smaller than anchored load")
+ anchored_remaining
}
}
pub(super) fn added_prefill_tokens(block_size: usize, isl: usize, overlap: u32) -> usize {
let cached_tokens = (overlap as usize) * block_size;
isl.checked_sub(cached_tokens).unwrap_or_else(|| {
tracing::error!(
"prefill_tokens < 0 with ISL {isl} < cached_tokens {cached_tokens} (overlap {overlap} * block_size {block_size}), returning 0",
);
0
})
}
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub(super) struct PrefillLoadTracker { pub(super) struct PrefillLoadTracker {
pub(super) prefills: HashMap<RequestId, PrefillLoadState>,
pub(super) prefill_order: VecDeque<RequestId>, pub(super) prefill_order: VecDeque<RequestId>,
pub(super) prefill_full_tokens_sum: usize, pub(super) prefill_full_tokens_sum: usize,
pub(super) anchored_prefill: Option<(RequestId, Instant)>, pub(super) anchored_prefill: Option<(RequestId, Instant)>,
...@@ -27,6 +76,7 @@ impl PrefillLoadTracker { ...@@ -27,6 +76,7 @@ impl PrefillLoadTracker {
prefill: PrefillLoadState, prefill: PrefillLoadState,
decay_now: Instant, decay_now: Instant,
) { ) {
self.prefills.insert(request_id.clone(), prefill);
self.prefill_full_tokens_sum += prefill.initial_effective_prefill_tokens; self.prefill_full_tokens_sum += prefill.initial_effective_prefill_tokens;
let should_anchor = self.anchored_prefill.is_none(); let should_anchor = self.anchored_prefill.is_none();
self.prefill_order.push_back(request_id.clone()); self.prefill_order.push_back(request_id.clone());
...@@ -38,9 +88,9 @@ impl PrefillLoadTracker { ...@@ -38,9 +88,9 @@ impl PrefillLoadTracker {
pub(super) fn remove( pub(super) fn remove(
&mut self, &mut self,
request_id: &RequestId, request_id: &RequestId,
prefill: PrefillLoadState,
decay_now: Instant, decay_now: Instant,
) { ) -> Option<PrefillLoadState> {
let prefill = self.prefills.remove(request_id)?;
self.prefill_full_tokens_sum = self self.prefill_full_tokens_sum = self
.prefill_full_tokens_sum .prefill_full_tokens_sum
.checked_sub(prefill.initial_effective_prefill_tokens) .checked_sub(prefill.initial_effective_prefill_tokens)
...@@ -60,6 +110,7 @@ impl PrefillLoadTracker { ...@@ -60,6 +110,7 @@ impl PrefillLoadTracker {
{ {
self.set_anchor_to_front(decay_now); self.set_anchor_to_front(decay_now);
} }
Some(prefill)
} }
pub(super) fn set_anchor_to_front(&mut self, now: Instant) { pub(super) fn set_anchor_to_front(&mut self, now: Instant) {
...@@ -69,4 +120,209 @@ impl PrefillLoadTracker { ...@@ -69,4 +120,209 @@ impl PrefillLoadTracker {
.cloned() .cloned()
.map(|request_id| (request_id, now)); .map(|request_id| (request_id, now));
} }
pub(super) fn snapshot(&self) -> PrefillLoadSnapshot {
PrefillLoadSnapshot {
prefill_full_tokens_sum: self.prefill_full_tokens_sum,
anchored_prefill: self
.anchored_prefill
.as_ref()
.map(|(request_id, anchored_since)| {
let prefill = self
.prefills
.get(request_id)
.copied()
.expect("anchored prefill missing request state");
AnchoredPrefillSnapshot {
initial_effective_prefill_tokens: prefill.initial_effective_prefill_tokens,
expected_prefill_duration: prefill.expected_prefill_duration,
anchored_since: *anchored_since,
}
}),
}
}
#[cfg(any(test, debug_assertions))]
pub(super) fn assert_consistent(&self) {
let active_prefills: std::collections::HashSet<RequestId> =
self.prefills.keys().cloned().collect();
let ordered_prefills: std::collections::HashSet<RequestId> =
self.prefill_order.iter().cloned().collect();
let recomputed_prefill_sum: usize = self
.prefills
.values()
.map(|prefill| prefill.initial_effective_prefill_tokens)
.sum();
assert_eq!(
ordered_prefills.len(),
self.prefill_order.len(),
"prefill_order contains duplicate request ids",
);
assert_eq!(
ordered_prefills, active_prefills,
"prefill_order must match active prefill requests",
);
assert_eq!(
self.prefill_full_tokens_sum, recomputed_prefill_sum,
"prefill_full_tokens_sum drifted from tracker state",
);
if let Some(oldest_request_id) = self.prefill_order.front() {
let Some((anchored_request_id, _)) = self.anchored_prefill.as_ref() else {
panic!("anchored_prefill must exist when prefill_order is non-empty");
};
assert!(
self.prefills.contains_key(oldest_request_id),
"prefill_order front must point to an active prefill request",
);
assert_eq!(
anchored_request_id, oldest_request_id,
"anchored_prefill must match prefill_order.front()",
);
} else {
assert!(
self.anchored_prefill.is_none(),
"anchored_prefill must be absent when no active prefills remain",
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn prefill_state(tokens: usize, duration_secs: u64) -> PrefillLoadState {
PrefillLoadState {
initial_effective_prefill_tokens: tokens,
expected_prefill_duration: Some(Duration::from_secs(duration_secs)),
}
}
#[test]
fn snapshot_without_anchor_reports_zero_active_tokens() {
let tracker = PrefillLoadTracker::default();
let snapshot = tracker.snapshot();
assert_eq!(snapshot.active_tokens_at(Instant::now()), 0);
}
#[test]
fn snapshot_only_decays_oldest_prefill() {
let epoch = Instant::now();
let mut tracker = PrefillLoadTracker::default();
let r1 = "r1".to_string();
let r2 = "r2".to_string();
let p1 = prefill_state(100, 10);
let p2 = prefill_state(60, 6);
tracker.insert(&r1, p1, epoch);
tracker.insert(&r2, p2, epoch + Duration::from_secs(2));
let snapshot = tracker.snapshot();
assert_eq!(
snapshot.active_tokens_at(epoch + Duration::from_secs(2)),
140
);
assert_eq!(
snapshot.active_tokens_at(epoch + Duration::from_secs(5)),
110
);
}
#[test]
fn removing_anchored_prefill_reanchors_front_and_resets_decay() {
let epoch = Instant::now();
let mut tracker = PrefillLoadTracker::default();
let r1 = "r1".to_string();
let r2 = "r2".to_string();
let p1 = prefill_state(100, 10);
let p2 = prefill_state(40, 8);
tracker.insert(&r1, p1, epoch);
tracker.insert(&r2, p2, epoch);
assert_eq!(
tracker.remove(&r1, epoch + Duration::from_secs(3)),
Some(p1)
);
assert_eq!(tracker.prefill_order, VecDeque::from([r2.clone()]));
assert!(
tracker
.anchored_prefill
.as_ref()
.is_some_and(|(request_id, _)| request_id == &r2)
);
let snapshot = tracker.snapshot();
assert_eq!(
snapshot.active_tokens_at(epoch + Duration::from_secs(3)),
40
);
assert_eq!(
snapshot.active_tokens_at(epoch + Duration::from_secs(5)),
30
);
}
#[test]
fn removing_nonfront_prefill_preserves_existing_anchor() {
let epoch = Instant::now();
let mut tracker = PrefillLoadTracker::default();
let r1 = "r1".to_string();
let r2 = "r2".to_string();
let p1 = prefill_state(30, 6);
let p2 = prefill_state(20, 4);
tracker.insert(&r1, p1, epoch);
tracker.insert(&r2, p2, epoch);
assert_eq!(
tracker.remove(&r2, epoch + Duration::from_secs(2)),
Some(p2)
);
assert_eq!(tracker.prefill_order, VecDeque::from([r1.clone()]));
assert!(
tracker
.anchored_prefill
.as_ref()
.is_some_and(|(request_id, anchored_since)| {
request_id == &r1 && *anchored_since == epoch
})
);
let snapshot = tracker.snapshot();
assert_eq!(
snapshot.active_tokens_at(epoch + Duration::from_secs(2)),
21
);
}
#[test]
fn duplicate_cleanup_is_idempotent() {
let epoch = Instant::now();
let mut tracker = PrefillLoadTracker::default();
let r1 = "r1".to_string();
let r2 = "r2".to_string();
let p1 = prefill_state(50, 10);
let p2 = prefill_state(30, 10);
tracker.insert(&r1, p1, epoch);
tracker.insert(&r2, p2, epoch);
tracker.assert_consistent();
assert_eq!(tracker.remove(&r1, epoch), Some(p1));
assert_eq!(tracker.remove(&r1, epoch), None);
assert_eq!(tracker.prefill_full_tokens_sum, 30);
assert_eq!(tracker.prefill_order, VecDeque::from([r2.clone()]));
assert_eq!(tracker.remove(&r2, epoch), Some(p2));
assert_eq!(tracker.remove(&r2, epoch), None);
tracker.assert_consistent();
assert_eq!(tracker.prefill_full_tokens_sum, 0);
assert!(tracker.prefill_order.is_empty());
assert!(tracker.prefills.is_empty());
}
} }
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use dynamo_tokens::SequenceHash;
use parking_lot::RwLock;
use rustc_hash::{FxHashMap, FxHashSet};
use crate::protocols::WorkerWithDpRank;
type SharedNode = Arc<RwLock<PromptTrieNode>>;
pub(super) type WorkerLookup = FxHashMap<SequenceHash, SharedNode>;
#[derive(Debug)]
pub(super) struct PromptTrieNode {
edge: Vec<SequenceHash>,
edge_index: FxHashMap<SequenceHash, usize>,
worker_cutoffs: FxHashMap<WorkerWithDpRank, usize>,
full_edge_workers: FxHashSet<WorkerWithDpRank>,
children: FxHashMap<SequenceHash, SharedNode>,
}
impl PromptTrieNode {
fn new() -> Self {
Self {
edge: Vec::new(),
edge_index: FxHashMap::default(),
worker_cutoffs: FxHashMap::default(),
full_edge_workers: FxHashSet::default(),
children: FxHashMap::default(),
}
}
#[cfg(any(test, feature = "bench"))]
fn has_any_workers(&self) -> bool {
!self.full_edge_workers.is_empty() || !self.worker_cutoffs.is_empty()
}
fn current_cutoff(&self, worker: WorkerWithDpRank) -> usize {
if self.full_edge_workers.contains(&worker) {
self.edge.len()
} else {
self.worker_cutoffs.get(&worker).copied().unwrap_or(0)
}
}
fn covers_pos(&self, worker: WorkerWithDpRank, pos: usize) -> bool {
self.full_edge_workers.contains(&worker)
|| matches!(self.worker_cutoffs.get(&worker), Some(&cutoff) if pos < cutoff)
}
fn clear_children_if_unreachable(&mut self) {
if self.full_edge_workers.is_empty() {
self.children.clear();
}
}
fn uncovered_suffix_hashes(&self, cutoff: usize) -> Vec<SequenceHash> {
debug_assert!(cutoff <= self.edge.len());
self.edge[cutoff..].to_vec()
}
fn drop_worker(&mut self, worker: WorkerWithDpRank) {
self.full_edge_workers.remove(&worker);
self.worker_cutoffs.remove(&worker);
self.clear_children_if_unreachable();
}
fn promote_to_full(&mut self, worker: WorkerWithDpRank) {
if !self.full_edge_workers.contains(&worker) {
self.worker_cutoffs.remove(&worker);
self.full_edge_workers.insert(worker);
}
}
fn remove_worker_at_pos(
&mut self,
worker: WorkerWithDpRank,
pos: usize,
removed_hash: SequenceHash,
) -> RemoveOutcome {
let current_cutoff = self.current_cutoff(worker);
if pos >= current_cutoff {
return RemoveOutcome {
stale_hashes: vec![removed_hash],
};
}
let new_cutoff = pos;
let stale_hashes = self.uncovered_suffix_hashes(new_cutoff);
if new_cutoff == 0 {
self.drop_worker(worker);
} else {
self.full_edge_workers.remove(&worker);
self.worker_cutoffs.insert(worker, new_cutoff);
self.clear_children_if_unreachable();
}
RemoveOutcome { stale_hashes }
}
#[cfg(any(test, feature = "bench"))]
fn live_children(&self) -> Vec<SharedNode> {
self.children
.values()
.filter(|child| {
let guard = child.read();
guard.has_any_workers() || !guard.children.is_empty()
})
.cloned()
.collect()
}
}
struct RemoveOutcome {
stale_hashes: Vec<SequenceHash>,
}
pub(super) struct PromptMembershipTrie {
root: SharedNode,
}
impl Default for PromptMembershipTrie {
fn default() -> Self {
Self::new()
}
}
impl Drop for PromptMembershipTrie {
fn drop(&mut self) {
let mut stack: Vec<SharedNode> = Vec::new();
{
let mut root = self.root.write();
stack.extend(root.children.drain().map(|(_, child)| child));
}
while let Some(node) = stack.pop() {
if let Ok(rwlock) = Arc::try_unwrap(node) {
let mut inner = rwlock.into_inner();
stack.extend(inner.children.drain().map(|(_, child)| child));
}
}
}
}
impl PromptMembershipTrie {
pub(super) fn new() -> Self {
Self {
root: Arc::new(RwLock::new(PromptTrieNode::new())),
}
}
fn find_in_subtree(start: &SharedNode, hash: SequenceHash) -> Option<SharedNode> {
let mut stack = Vec::new();
{
let guard = start.read();
stack.extend(guard.children.values().cloned());
}
while let Some(node) = stack.pop() {
let guard = node.read();
if guard.edge_index.contains_key(&hash) {
drop(guard);
return Some(node);
}
stack.extend(guard.children.values().cloned());
}
None
}
fn resolve_lookup(worker_lookup: &mut WorkerLookup, hash: SequenceHash) -> Option<SharedNode> {
let node = worker_lookup.get(&hash)?.clone();
let found = {
let guard = node.read();
guard.edge_index.contains_key(&hash)
};
if found {
return Some(node);
}
let resolved = Self::find_in_subtree(&node, hash)?;
worker_lookup.insert(hash, resolved.clone());
Some(resolved)
}
fn split_node(node: &mut PromptTrieNode, pos: usize) -> SharedNode {
debug_assert!(pos > 0 && pos < node.edge.len());
let suffix_edge = node.edge.split_off(pos);
let suffix_first_hash = suffix_edge[0];
let mut suffix_edge_index = FxHashMap::default();
for (i, &hash) in suffix_edge.iter().enumerate() {
suffix_edge_index.insert(hash, i);
}
for &hash in &suffix_edge {
node.edge_index.remove(&hash);
}
let mut suffix_full = FxHashSet::default();
let mut suffix_cutoffs = FxHashMap::default();
let mut to_promote = Vec::new();
for &worker in &node.full_edge_workers {
suffix_full.insert(worker);
}
for (&worker, &cutoff) in &node.worker_cutoffs {
if cutoff >= pos {
to_promote.push(worker);
let suffix_cutoff = cutoff - pos;
if suffix_cutoff > 0 {
suffix_cutoffs.insert(worker, suffix_cutoff);
}
}
}
for worker in to_promote {
node.worker_cutoffs.remove(&worker);
node.full_edge_workers.insert(worker);
}
let suffix_children = std::mem::take(&mut node.children);
let suffix = Arc::new(RwLock::new(PromptTrieNode {
edge: suffix_edge,
edge_index: suffix_edge_index,
worker_cutoffs: suffix_cutoffs,
full_edge_workers: suffix_full,
children: suffix_children,
}));
node.children.insert(suffix_first_hash, suffix.clone());
suffix
}
pub(super) fn store_chain(
&self,
worker: WorkerWithDpRank,
lookup: &Arc<RwLock<WorkerLookup>>,
parent: Option<SequenceHash>,
hashes: &[SequenceHash],
) {
if hashes.is_empty() {
return;
}
let mut worker_lookup = lookup.write();
let parent = match parent {
Some(parent_hash) => loop {
let Some(node) = Self::resolve_lookup(&mut worker_lookup, parent_hash) else {
tracing::warn!(?worker, ?parent_hash, "prompt parent hash not found");
return;
};
{
let guard = node.read();
let Some(&pos) = guard.edge_index.get(&parent_hash) else {
continue;
};
if !guard.covers_pos(worker, pos) {
worker_lookup.remove(&parent_hash);
tracing::warn!(
?worker,
?parent_hash,
pos,
"worker no longer covers prompt parent"
);
return;
}
}
let split_suffix = {
let mut guard = node.write();
if !guard.edge_index.contains_key(&parent_hash) {
continue;
}
if !guard.edge.is_empty() && *guard.edge.last().unwrap() != parent_hash {
let split_pos = guard
.edge
.iter()
.position(|hash| *hash == parent_hash)
.expect("parent hash presence was checked above");
Some(Self::split_node(&mut guard, split_pos + 1))
} else {
None
}
};
if split_suffix.is_some() {
continue;
}
break node;
},
None => self.root.clone(),
};
self.insert_hashes_from(worker, &mut worker_lookup, &parent, hashes);
}
fn insert_hashes_from(
&self,
worker: WorkerWithDpRank,
worker_lookup: &mut WorkerLookup,
parent: &SharedNode,
hashes: &[SequenceHash],
) {
let mut current_parent = parent.clone();
let mut remaining = hashes;
let mut last_hash = None;
while !remaining.is_empty() {
let first_hash = remaining[0];
let child = {
let mut parent_guard = current_parent.write();
if let Some(last_hash) = last_hash
&& !parent_guard.edge_index.contains_key(&last_hash)
{
drop(parent_guard);
if let Some(resolved) = Self::resolve_lookup(worker_lookup, last_hash) {
current_parent = resolved;
}
continue;
}
match parent_guard.children.get(&first_hash).cloned() {
Some(existing) => existing,
None => {
let edge = remaining.to_vec();
let mut edge_index = FxHashMap::default();
for (i, &hash) in edge.iter().enumerate() {
edge_index.insert(hash, i);
}
let mut full_edge_workers = FxHashSet::default();
full_edge_workers.insert(worker);
let new_node = Arc::new(RwLock::new(PromptTrieNode {
edge,
edge_index,
worker_cutoffs: FxHashMap::default(),
full_edge_workers,
children: FxHashMap::default(),
}));
parent_guard.children.insert(first_hash, new_node.clone());
drop(parent_guard);
for &hash in remaining {
worker_lookup.insert(hash, new_node.clone());
}
return;
}
}
};
{
let mut child_guard = child.write();
let edge_len = child_guard.edge.len();
let mut match_len = 0;
for (&edge_hash, &query_hash) in child_guard.edge.iter().zip(remaining.iter()) {
if edge_hash != query_hash {
break;
}
match_len += 1;
}
debug_assert!(match_len >= 1);
if match_len < edge_len {
let _suffix = Self::split_node(&mut child_guard, match_len);
child_guard.promote_to_full(worker);
let tail = &remaining[match_len..];
if !tail.is_empty() {
let edge = tail.to_vec();
let mut edge_index = FxHashMap::default();
for (i, &hash) in edge.iter().enumerate() {
edge_index.insert(hash, i);
}
let mut full_edge_workers = FxHashSet::default();
full_edge_workers.insert(worker);
let tail_first_hash = tail[0];
let new_node = Arc::new(RwLock::new(PromptTrieNode {
edge,
edge_index,
worker_cutoffs: FxHashMap::default(),
full_edge_workers,
children: FxHashMap::default(),
}));
child_guard
.children
.insert(tail_first_hash, new_node.clone());
drop(child_guard);
for &hash in &remaining[..match_len] {
worker_lookup.insert(hash, child.clone());
}
for &hash in tail {
worker_lookup.insert(hash, new_node.clone());
}
} else {
drop(child_guard);
for &hash in &remaining[..match_len] {
worker_lookup.insert(hash, child.clone());
}
}
return;
}
child_guard.promote_to_full(worker);
drop(child_guard);
for &hash in &remaining[..edge_len] {
worker_lookup.insert(hash, child.clone());
}
last_hash = Some(remaining[edge_len - 1]);
remaining = &remaining[edge_len..];
current_parent = child;
}
}
}
pub(super) fn remove_chain(
&self,
worker: WorkerWithDpRank,
lookup: &Arc<RwLock<WorkerLookup>>,
hashes: &[SequenceHash],
) {
let mut worker_lookup = lookup.write();
if worker_lookup.is_empty() {
return;
}
'outer: for &hash in hashes {
let mut current_node = match Self::resolve_lookup(&mut worker_lookup, hash) {
Some(node) => node,
None => continue,
};
loop {
let update = {
let mut guard = current_node.write();
guard
.edge_index
.get(&hash)
.copied()
.map(|pos| guard.remove_worker_at_pos(worker, pos, hash))
};
match update {
Some(outcome) => {
for stale_hash in outcome.stale_hashes {
worker_lookup.remove(&stale_hash);
}
continue 'outer;
}
None => match Self::find_in_subtree(&current_node, hash) {
Some(resolved) => {
worker_lookup.insert(hash, resolved.clone());
current_node = resolved;
}
None => {
worker_lookup.remove(&hash);
continue 'outer;
}
},
}
}
}
}
pub(super) fn remove_worker(
&self,
worker: WorkerWithDpRank,
lookup: &Arc<RwLock<WorkerLookup>>,
) {
let mut worker_lookup = lookup.write();
if worker_lookup.is_empty() {
return;
}
let hashes: Vec<_> = worker_lookup.keys().copied().collect();
let mut nodes = Vec::new();
let mut seen = FxHashSet::<usize>::default();
for hash in hashes {
let Some(node) = Self::resolve_lookup(&mut worker_lookup, hash) else {
worker_lookup.remove(&hash);
continue;
};
let ptr = Arc::as_ptr(&node) as usize;
if seen.insert(ptr) {
nodes.push(node);
}
}
worker_lookup.clear();
drop(worker_lookup);
for node in nodes {
let mut guard = node.write();
guard.drop_worker(worker);
}
}
pub(super) fn compute_overlap_depths(
&self,
query: Option<&[SequenceHash]>,
) -> FxHashMap<WorkerWithDpRank, usize> {
let Some(query) = query else {
return FxHashMap::default();
};
if query.is_empty() {
return FxHashMap::default();
}
let mut matched_depth = FxHashMap::default();
let mut active = FxHashSet::default();
let mut active_count = 0usize;
let mut query_pos = 0usize;
let mut depth = 0usize;
let mut first_node = true;
let mut next_child = {
let root = self.root.read();
root.children.get(&query[0]).cloned()
};
loop {
if query_pos >= query.len() {
break;
}
let Some(child) = next_child.take() else {
break;
};
let edge_len;
let edge_match_len;
{
let guard = child.read();
edge_len = guard.edge.len();
let walk_len = edge_len.min(query.len() - query_pos);
let mut match_len = 1usize;
for i in 1..walk_len {
if guard.edge[i] != query[query_pos + i] {
break;
}
match_len += 1;
}
edge_match_len = match_len;
let prev_depth = depth;
if first_node {
active = guard.full_edge_workers.clone();
active_count = active.len();
for (&worker, &cutoff) in &guard.worker_cutoffs {
let contribution = cutoff.min(edge_match_len);
if contribution > 0 {
matched_depth.insert(worker, contribution);
}
}
first_node = false;
} else if !guard.worker_cutoffs.is_empty() {
active.retain(|worker| {
if guard.full_edge_workers.contains(worker) {
true
} else if let Some(&cutoff) = guard.worker_cutoffs.get(worker) {
matched_depth.insert(*worker, prev_depth + cutoff.min(edge_match_len));
false
} else {
matched_depth.insert(*worker, prev_depth);
false
}
});
active_count = active.len();
} else {
let full_count = guard.full_edge_workers.len();
if full_count != active_count {
active.retain(|worker| {
if guard.full_edge_workers.contains(worker) {
true
} else {
matched_depth.insert(*worker, prev_depth);
false
}
});
active_count = active.len();
}
}
next_child = if edge_match_len == edge_len
&& active_count > 0
&& query_pos + edge_match_len < query.len()
{
guard
.children
.get(&query[query_pos + edge_match_len])
.cloned()
} else {
None
};
}
if active_count == 0 {
break;
}
depth += edge_match_len;
if edge_match_len < edge_len {
break;
}
query_pos += edge_match_len;
}
for worker in active {
matched_depth.insert(worker, depth);
}
matched_depth
}
#[cfg(test)]
pub(super) fn worker_hashes(&self) -> FxHashMap<WorkerWithDpRank, FxHashSet<SequenceHash>> {
let mut worker_hashes = FxHashMap::default();
let mut stack = vec![self.root.clone()];
while let Some(node) = stack.pop() {
let guard = node.read();
for &worker in &guard.full_edge_workers {
worker_hashes
.entry(worker)
.or_insert_with(FxHashSet::default)
.extend(guard.edge.iter().copied());
}
for (&worker, &cutoff) in &guard.worker_cutoffs {
worker_hashes
.entry(worker)
.or_insert_with(FxHashSet::default)
.extend(guard.edge[..cutoff].iter().copied());
}
stack.extend(guard.children.values().cloned());
}
worker_hashes
}
#[cfg(any(test, feature = "bench"))]
pub(super) fn is_empty(&self) -> bool {
let root = self.root.read();
root.live_children().is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn worker(worker_id: u64, dp_rank: u32) -> WorkerWithDpRank {
WorkerWithDpRank::new(worker_id, dp_rank)
}
fn lookup() -> Arc<RwLock<WorkerLookup>> {
Arc::new(RwLock::new(WorkerLookup::default()))
}
#[test]
fn parent_continuation_chains_extend_and_trim() {
let trie = PromptMembershipTrie::new();
let worker = worker(1, 0);
let lookup = lookup();
trie.store_chain(worker, &lookup, None, &[1, 2, 3]);
trie.store_chain(worker, &lookup, Some(3), &[4, 5]);
assert_eq!(
trie.compute_overlap_depths(Some(&[1, 2, 3, 4, 5])),
FxHashMap::from_iter([(worker, 5)]),
);
trie.remove_chain(worker, &lookup, &[4, 5]);
assert_eq!(
trie.compute_overlap_depths(Some(&[1, 2, 3, 4, 5])),
FxHashMap::from_iter([(worker, 3)]),
);
}
#[test]
fn branching_continuations_across_workers_match_expected_depths() {
let trie = PromptMembershipTrie::new();
let worker_a = worker(1, 0);
let worker_b = worker(2, 0);
let lookup_a = lookup();
let lookup_b = lookup();
trie.store_chain(worker_a, &lookup_a, None, &[1, 2, 3, 4]);
trie.store_chain(worker_b, &lookup_b, None, &[1, 2, 5]);
assert_eq!(
trie.compute_overlap_depths(Some(&[1, 2, 3, 4])),
FxHashMap::from_iter([(worker_a, 4), (worker_b, 2)]),
);
assert_eq!(
trie.compute_overlap_depths(Some(&[1, 2, 5])),
FxHashMap::from_iter([(worker_a, 2), (worker_b, 3)]),
);
}
#[test]
fn partial_suffix_removal_keeps_prefix() {
let trie = PromptMembershipTrie::new();
let worker = worker(1, 0);
let lookup = lookup();
trie.store_chain(worker, &lookup, None, &[1, 2, 3, 4, 5]);
trie.remove_chain(worker, &lookup, &[3, 4, 5]);
assert_eq!(
trie.compute_overlap_depths(Some(&[1, 2, 3, 4, 5])),
FxHashMap::from_iter([(worker, 2)]),
);
}
#[test]
fn remove_worker_preserves_other_workers() {
let trie = PromptMembershipTrie::new();
let worker_a = worker(1, 0);
let worker_b = worker(2, 0);
let lookup_a = lookup();
let lookup_b = lookup();
trie.store_chain(worker_a, &lookup_a, None, &[1, 2, 3]);
trie.store_chain(worker_b, &lookup_b, None, &[1, 2, 3]);
trie.remove_worker(worker_a, &lookup_a);
assert_eq!(
trie.compute_overlap_depths(Some(&[1, 2, 3])),
FxHashMap::from_iter([(worker_b, 3)]),
);
}
#[test]
fn multiple_dp_ranks_with_same_worker_id_remain_isolated() {
let trie = PromptMembershipTrie::new();
let worker_a = worker(1, 0);
let worker_b = worker(1, 1);
let lookup_a = lookup();
let lookup_b = lookup();
trie.store_chain(worker_a, &lookup_a, None, &[1, 2, 3]);
trie.store_chain(worker_b, &lookup_b, None, &[1, 2]);
assert_eq!(
trie.compute_overlap_depths(Some(&[1, 2, 3])),
FxHashMap::from_iter([(worker_a, 3), (worker_b, 2)]),
);
}
#[test]
fn clear_worker_state_then_reuse_starts_empty() {
let trie = PromptMembershipTrie::new();
let worker = worker(1, 0);
let lookup = lookup();
trie.store_chain(worker, &lookup, None, &[1, 2, 3]);
trie.remove_worker(worker, &lookup);
assert!(trie.compute_overlap_depths(Some(&[1, 2, 3])).is_empty());
trie.store_chain(worker, &lookup, None, &[1, 2]);
assert_eq!(
trie.compute_overlap_depths(Some(&[1, 2, 3])),
FxHashMap::from_iter([(worker, 2)]),
);
}
#[test]
fn redundant_batched_remove_is_idempotent() {
let trie = PromptMembershipTrie::new();
let worker = worker(1, 0);
let lookup = lookup();
trie.store_chain(worker, &lookup, None, &[1, 2, 3, 4]);
trie.remove_chain(worker, &lookup, &[2, 3, 4]);
trie.remove_chain(worker, &lookup, &[2, 3, 4]);
assert_eq!(
trie.compute_overlap_depths(Some(&[1, 2, 3, 4])),
FxHashMap::from_iter([(worker, 1)]),
);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dashmap::DashMap;
use dynamo_tokens::SequenceHash;
#[cfg(test)]
use rustc_hash::FxHashSet;
use rustc_hash::{FxBuildHasher, FxHashMap};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::time::Instant;
use super::prefill_tracker::{PrefillLoadSnapshot, added_prefill_tokens};
use super::prompt_membership_trie::{PromptMembershipTrie, WorkerLookup};
use super::single::PromptMembershipDelta;
use super::topology::WorkerTopologyChange;
use crate::protocols::{OverlapScores, WorkerWithDpRank};
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub(super) struct WorkerLoadSnapshot {
pub(super) active_blocks: usize,
pub(super) prefill: PrefillLoadSnapshot,
}
impl WorkerLoadSnapshot {
pub(super) fn active_tokens(&self, decay_now: Instant) -> usize {
self.prefill.active_tokens_at(decay_now)
}
}
pub(super) struct PromptRegistry {
// WARNING: prompt membership and worker load are only eventually consistent.
// Each mutation still starts from one worker-local source of truth: we mutate the chosen
// `ActiveSequences`, derive an exact `PromptMembershipDelta` plus `WorkerLoadSnapshot`, then
// publish them separately here. The trie and load map converge to the correct final state
// after the write finishes, but reads can still observe a mixed membership/load state that
// never existed atomically and make a suboptimal routing choice.
membership: PromptMembershipTrie,
loads: DashMap<WorkerWithDpRank, WorkerLoadSnapshot, FxBuildHasher>,
}
impl Default for PromptRegistry {
fn default() -> Self {
Self {
membership: PromptMembershipTrie::new(),
loads: DashMap::with_hasher(FxBuildHasher),
}
}
}
impl PromptRegistry {
pub(super) fn new(workers: impl IntoIterator<Item = WorkerWithDpRank>) -> Self {
let registry = Self::default();
for worker in workers {
registry.loads.entry(worker).or_default();
}
registry
}
pub(super) fn replace_worker_load_state(
&self,
worker: WorkerWithDpRank,
load: WorkerLoadSnapshot,
) {
self.loads.insert(worker, load);
}
pub(super) fn apply_membership_delta_and_load(
&self,
worker: WorkerWithDpRank,
lookup: &Arc<parking_lot::RwLock<WorkerLookup>>,
delta: PromptMembershipDelta,
load: WorkerLoadSnapshot,
) {
for remove in delta.removes {
self.membership.remove_chain(worker, lookup, &remove.hashes);
}
for store in delta.stores {
self.membership
.store_chain(worker, lookup, store.parent, &store.hashes);
}
self.loads.insert(worker, load);
}
pub(super) fn apply_topology_change(&self, change: WorkerTopologyChange) {
for removed in change.removed {
self.membership
.remove_worker(removed.worker, &removed.trie_lookup);
self.loads.remove(&removed.worker);
}
for worker in change.added {
self.loads.entry(worker).or_default();
}
}
#[expect(clippy::too_many_arguments)]
fn project_loads_from_overlap(
&self,
query_len: usize,
matched_depth: &FxHashMap<WorkerWithDpRank, usize>,
isl: usize,
overlaps: &OverlapScores,
track_prefill_tokens: bool,
block_size: usize,
decay_now: Instant,
) -> (
FxHashMap<WorkerWithDpRank, usize>,
FxHashMap<WorkerWithDpRank, usize>,
) {
let mut potential_blocks =
FxHashMap::with_capacity_and_hasher(self.loads.len(), FxBuildHasher);
let mut potential_tokens =
FxHashMap::with_capacity_and_hasher(self.loads.len(), FxBuildHasher);
for entry in &self.loads {
let worker = *entry.key();
let load = *entry.value();
let overlap_depth = matched_depth.get(&worker).copied().unwrap_or(0);
let new_blocks = query_len.saturating_sub(overlap_depth);
let active_tokens = load.active_tokens(decay_now);
let overlap = *overlaps.scores.get(&worker).unwrap_or(&0);
let added_tokens = if track_prefill_tokens {
added_prefill_tokens(block_size, isl, overlap)
} else {
0
};
potential_blocks.insert(worker, load.active_blocks + new_blocks);
potential_tokens.insert(worker, active_tokens + added_tokens);
}
(potential_blocks, potential_tokens)
}
pub(super) fn potential_blocks_and_tokens_with_prefill_tracking(
&self,
token_sequence: Option<&[SequenceHash]>,
isl: usize,
overlaps: &OverlapScores,
track_prefill_tokens: bool,
block_size: usize,
decay_now: Instant,
) -> (
FxHashMap<WorkerWithDpRank, usize>,
FxHashMap<WorkerWithDpRank, usize>,
) {
let query_len = token_sequence.map_or(0, |query| query.len());
let matched_depth = self.membership.compute_overlap_depths(token_sequence);
self.project_loads_from_overlap(
query_len,
&matched_depth,
isl,
overlaps,
track_prefill_tokens,
block_size,
decay_now,
)
}
pub(super) fn active_blocks(&self) -> HashMap<WorkerWithDpRank, usize> {
self.loads
.iter()
.map(|entry| (*entry.key(), entry.value().active_blocks))
.collect()
}
pub(super) fn active_tokens(&self, decay_now: Instant) -> HashMap<WorkerWithDpRank, usize> {
self.loads
.iter()
.map(|entry| (*entry.key(), entry.value().active_tokens(decay_now)))
.collect()
}
pub(super) fn any_worker_matches_active_tokens(
&self,
decay_now: Instant,
mut predicate: impl FnMut(WorkerWithDpRank, usize) -> bool,
) -> bool {
self.loads
.iter()
.any(|entry| predicate(*entry.key(), entry.value().active_tokens(decay_now)))
}
#[cfg(test)]
pub(super) fn assert_consistent_with_workers(
&self,
expected_loads: &FxHashMap<WorkerWithDpRank, WorkerLoadSnapshot>,
expected_blocks: &FxHashMap<WorkerWithDpRank, FxHashSet<SequenceHash>>,
) {
let actual_loads: FxHashMap<_, _> = self
.loads
.iter()
.map(|entry| (*entry.key(), *entry.value()))
.collect();
let actual_blocks = self.membership.worker_hashes();
assert_eq!(
actual_loads, *expected_loads,
"prompt registry worker loads drifted from per-worker state",
);
assert_eq!(
actual_blocks, *expected_blocks,
"prompt registry prompt membership drifted from per-worker state",
);
}
#[cfg(any(test, feature = "bench"))]
pub(super) fn is_block_index_empty(&self) -> bool {
self.membership.is_empty()
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use rustc_hash::{FxHashMap, FxHashSet};
use super::*;
use crate::protocols::WorkerWithDpRank;
use crate::sequences::prefill_tracker::AnchoredPrefillSnapshot;
use crate::sequences::single::{PromptMembershipRemove, PromptMembershipStore};
use crate::sequences::topology::RemovedWorkerState;
fn worker(worker_id: u64, dp_rank: u32) -> WorkerWithDpRank {
WorkerWithDpRank::new(worker_id, dp_rank)
}
fn lookup() -> Arc<parking_lot::RwLock<WorkerLookup>> {
Arc::new(parking_lot::RwLock::new(WorkerLookup::default()))
}
fn store(parent: Option<SequenceHash>, hashes: &[SequenceHash]) -> PromptMembershipDelta {
PromptMembershipDelta {
stores: vec![PromptMembershipStore {
parent,
hashes: hashes.to_vec(),
}],
removes: Vec::new(),
}
}
fn worker_load_snapshot(active_blocks: usize) -> WorkerLoadSnapshot {
WorkerLoadSnapshot {
active_blocks,
prefill: PrefillLoadSnapshot::default(),
}
}
fn anchored_load_snapshot(
active_blocks: usize,
prefill_full_tokens_sum: usize,
anchored_tokens: usize,
expected_prefill_duration: Option<Duration>,
anchored_since: Instant,
) -> WorkerLoadSnapshot {
WorkerLoadSnapshot {
active_blocks,
prefill: PrefillLoadSnapshot {
prefill_full_tokens_sum,
anchored_prefill: Some(AnchoredPrefillSnapshot {
initial_effective_prefill_tokens: anchored_tokens,
expected_prefill_duration,
anchored_since,
}),
},
}
}
fn hash_set(hashes: &[SequenceHash]) -> FxHashSet<SequenceHash> {
hashes.iter().copied().collect()
}
#[expect(clippy::too_many_arguments)]
fn naive_potential_loads(
expected_loads: &FxHashMap<WorkerWithDpRank, WorkerLoadSnapshot>,
expected_blocks: &FxHashMap<WorkerWithDpRank, FxHashSet<SequenceHash>>,
token_sequence: Option<&[SequenceHash]>,
isl: usize,
overlaps: &OverlapScores,
track_prefill_tokens: bool,
block_size: usize,
decay_now: Instant,
) -> (
FxHashMap<WorkerWithDpRank, usize>,
FxHashMap<WorkerWithDpRank, usize>,
) {
let mut potential_blocks =
FxHashMap::with_capacity_and_hasher(expected_loads.len(), FxBuildHasher);
let mut potential_tokens =
FxHashMap::with_capacity_and_hasher(expected_loads.len(), FxBuildHasher);
for (&worker, load) in expected_loads {
let overlap_depth = token_sequence.map_or(0, |query| {
let worker_blocks = expected_blocks
.get(&worker)
.expect("worker must have a prompt membership entry");
query
.iter()
.position(|hash| !worker_blocks.contains(hash))
.unwrap_or(query.len())
});
let new_blocks =
token_sequence.map_or(0, |query| query.len().saturating_sub(overlap_depth));
let overlap = *overlaps.scores.get(&worker).unwrap_or(&0);
let added_tokens = if track_prefill_tokens {
added_prefill_tokens(block_size, isl, overlap)
} else {
0
};
potential_blocks.insert(worker, load.active_blocks + new_blocks);
potential_tokens.insert(worker, load.active_tokens(decay_now) + added_tokens);
}
(potential_blocks, potential_tokens)
}
#[test]
fn removed_hash_can_be_restored_by_later_store() {
let worker = worker(1, 0);
let registry = PromptRegistry::new([worker]);
let lookup = lookup();
let mut expected_loads = FxHashMap::default();
let mut expected_blocks = FxHashMap::default();
registry.apply_membership_delta_and_load(
worker,
&lookup,
store(None, &[42]),
worker_load_snapshot(1),
);
let load = worker_load_snapshot(1);
registry.apply_membership_delta_and_load(
worker,
&lookup,
PromptMembershipDelta {
removes: vec![PromptMembershipRemove { hashes: vec![42] }],
..Default::default()
},
load,
);
registry.apply_membership_delta_and_load(worker, &lookup, store(None, &[42]), load);
expected_loads.insert(worker, load);
expected_blocks.insert(worker, hash_set(&[42]));
registry.assert_consistent_with_workers(&expected_loads, &expected_blocks);
}
#[test]
fn staggered_prefix_overlap_matches_naive_projection() {
let worker_a = worker(1, 0);
let worker_b = worker(2, 0);
let worker_c = worker(3, 0);
let registry = PromptRegistry::new([worker_a, worker_b, worker_c]);
let lookup_a = lookup();
let lookup_b = lookup();
let lookup_c = lookup();
let decay_now = Instant::now();
let full_prompt: Vec<SequenceHash> = (1_u64..=96).collect();
let mut expected_loads = FxHashMap::default();
let mut expected_blocks = FxHashMap::default();
for (worker, lookup, prompt_len) in [
(worker_a, &lookup_a, 96usize),
(worker_b, &lookup_b, 64),
(worker_c, &lookup_c, 33),
] {
let blocks = full_prompt[..prompt_len].to_vec();
let load = worker_load_snapshot(prompt_len);
registry.apply_membership_delta_and_load(worker, lookup, store(None, &blocks), load);
expected_loads.insert(worker, load);
expected_blocks.insert(worker, blocks.into_iter().collect());
}
registry.assert_consistent_with_workers(&expected_loads, &expected_blocks);
let expected = naive_potential_loads(
&expected_loads,
&expected_blocks,
Some(&full_prompt),
384,
&OverlapScores::default(),
false,
4,
decay_now,
);
let actual = registry.potential_blocks_and_tokens_with_prefill_tracking(
Some(&full_prompt),
384,
&OverlapScores::default(),
false,
4,
decay_now,
);
assert_eq!(actual, expected);
}
#[test]
fn load_only_update_preserves_prompt_membership_and_active_token_projection() {
let worker = worker(1, 0);
let registry = PromptRegistry::new([worker]);
let lookup = lookup();
let now = Instant::now();
let anchored_since = now.checked_sub(Duration::from_secs(3)).unwrap_or(now);
let mut expected_loads = FxHashMap::default();
let mut expected_blocks = FxHashMap::default();
registry.apply_membership_delta_and_load(
worker,
&lookup,
store(None, &[1, 2, 3]),
worker_load_snapshot(3),
);
expected_blocks.insert(worker, hash_set(&[1, 2, 3]));
let updated_load =
anchored_load_snapshot(5, 12, 10, Some(Duration::from_secs(10)), anchored_since);
registry.replace_worker_load_state(worker, updated_load);
expected_loads.insert(worker, updated_load);
registry.assert_consistent_with_workers(&expected_loads, &expected_blocks);
assert_eq!(registry.active_tokens(now).get(&worker).copied(), Some(9));
let actual = registry.potential_blocks_and_tokens_with_prefill_tracking(
Some(&[1, 2, 3]),
12,
&OverlapScores::default(),
false,
4,
now,
);
assert_eq!(actual.0.get(&worker).copied(), Some(5));
assert_eq!(actual.1.get(&worker).copied(), Some(9));
}
#[test]
fn removing_worker_clears_prompt_membership_and_load_state() {
let worker_a = worker(1, 0);
let worker_b = worker(2, 0);
let registry = PromptRegistry::new([worker_a, worker_b]);
let lookup_a = lookup();
let lookup_b = lookup();
let mut expected_loads = FxHashMap::default();
let mut expected_blocks = FxHashMap::default();
let load_a = worker_load_snapshot(3);
let load_b = worker_load_snapshot(2);
registry.apply_membership_delta_and_load(
worker_a,
&lookup_a,
store(None, &[1, 2, 3]),
load_a,
);
registry.apply_membership_delta_and_load(worker_b, &lookup_b, store(None, &[1, 2]), load_b);
expected_loads.insert(worker_a, load_a);
expected_loads.insert(worker_b, load_b);
expected_blocks.insert(worker_a, hash_set(&[1, 2, 3]));
expected_blocks.insert(worker_b, hash_set(&[1, 2]));
registry.apply_topology_change(WorkerTopologyChange {
added: Vec::new(),
removed: vec![RemovedWorkerState {
worker: worker_a,
trie_lookup: Arc::clone(&lookup_a),
}],
});
expected_loads.remove(&worker_a);
expected_blocks.remove(&worker_a);
registry.assert_consistent_with_workers(&expected_loads, &expected_blocks);
assert!(!registry.active_blocks().contains_key(&worker_a));
let actual = registry.potential_blocks_and_tokens_with_prefill_tracking(
Some(&[1, 2, 3]),
12,
&OverlapScores::default(),
false,
4,
Instant::now(),
);
assert_eq!(actual.0.get(&worker_b).copied(), Some(3));
}
#[test]
fn dp_ranks_with_same_worker_id_remain_isolated() {
let worker_a = worker(1, 0);
let worker_b = worker(1, 1);
let registry = PromptRegistry::new([worker_a, worker_b]);
let lookup_a = lookup();
let lookup_b = lookup();
let decay_now = Instant::now();
let mut expected_loads = FxHashMap::default();
let mut expected_blocks = FxHashMap::default();
let load_a = worker_load_snapshot(3);
let load_b = worker_load_snapshot(1);
registry.apply_membership_delta_and_load(
worker_a,
&lookup_a,
store(None, &[1, 2, 3]),
load_a,
);
registry.apply_membership_delta_and_load(worker_b, &lookup_b, store(None, &[1]), load_b);
expected_loads.insert(worker_a, load_a);
expected_loads.insert(worker_b, load_b);
expected_blocks.insert(worker_a, hash_set(&[1, 2, 3]));
expected_blocks.insert(worker_b, hash_set(&[1]));
registry.assert_consistent_with_workers(&expected_loads, &expected_blocks);
let expected = naive_potential_loads(
&expected_loads,
&expected_blocks,
Some(&[1, 2, 3]),
12,
&OverlapScores::default(),
false,
4,
decay_now,
);
let actual = registry.potential_blocks_and_tokens_with_prefill_tracking(
Some(&[1, 2, 3]),
12,
&OverlapScores::default(),
false,
4,
decay_now,
);
assert_eq!(actual, expected);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dashmap::{DashMap, mapref::entry::Entry};
use std::collections::HashMap;
use super::single::RequestId;
use crate::protocols::WorkerWithDpRank;
#[derive(Debug, Default)]
pub(super) struct RequestIndex {
request_to_worker: DashMap<RequestId, WorkerWithDpRank>,
request_to_lora: DashMap<RequestId, String>,
}
impl RequestIndex {
pub(super) fn try_insert_request(
&self,
request_id: RequestId,
worker: WorkerWithDpRank,
lora_name: Option<String>,
) -> Result<(), WorkerWithDpRank> {
match self.request_to_worker.entry(request_id.clone()) {
Entry::Occupied(entry) => Err(*entry.get()),
Entry::Vacant(entry) => {
entry.insert(worker);
if let Some(lora_name) = lora_name {
self.request_to_lora.insert(request_id, lora_name);
}
Ok(())
}
}
}
pub(super) fn set_request(
&self,
request_id: RequestId,
worker: WorkerWithDpRank,
lora_name: Option<String>,
) {
self.request_to_worker.insert(request_id.clone(), worker);
if let Some(lora_name) = lora_name {
self.request_to_lora.insert(request_id, lora_name);
} else {
self.request_to_lora.remove(&request_id);
}
}
pub(super) fn worker_for(&self, request_id: &RequestId) -> Option<WorkerWithDpRank> {
self.request_to_worker.get(request_id).map(|entry| *entry)
}
pub(super) fn lora_for(&self, request_id: &RequestId) -> Option<String> {
self.request_to_lora
.get(request_id)
.map(|entry| entry.value().clone())
}
pub(super) fn remove_request(&self, request_id: &RequestId) -> Option<WorkerWithDpRank> {
let worker = self
.request_to_worker
.remove(request_id)
.map(|(_request_id, worker)| worker);
self.request_to_lora.remove(request_id);
worker
}
pub(super) fn remove_requests<'a>(&self, request_ids: impl IntoIterator<Item = &'a RequestId>) {
for request_id in request_ids {
self.remove_request(request_id);
}
}
pub(super) fn remove_worker_requests(&self, worker: WorkerWithDpRank) -> Vec<RequestId> {
let request_ids: Vec<_> = self
.request_to_worker
.iter()
.filter(|entry| *entry.value() == worker)
.map(|entry| entry.key().clone())
.collect();
self.remove_requests(request_ids.iter());
request_ids
}
pub(super) fn active_lora_counts(&self) -> HashMap<String, usize> {
let mut counts = HashMap::new();
for entry in self.request_to_lora.iter() {
let lora_name = entry.value().clone();
*counts.entry(lora_name).or_insert(0) += 1;
}
counts
}
#[cfg(any(test, feature = "bench"))]
pub(super) fn is_empty(&self) -> bool {
self.request_to_worker.is_empty() && self.request_to_lora.is_empty()
}
#[cfg(any(test, feature = "bench"))]
pub(super) fn worker_len(&self) -> usize {
self.request_to_worker.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn duplicate_insert_returns_existing_worker() {
let index = RequestIndex::default();
let worker = WorkerWithDpRank::new(1, 0);
index
.try_insert_request("req-1".to_string(), worker, Some("adapter".to_string()))
.unwrap();
assert_eq!(
index.try_insert_request("req-1".to_string(), WorkerWithDpRank::new(2, 0), None),
Err(worker)
);
assert_eq!(index.worker_for(&"req-1".to_string()), Some(worker));
assert_eq!(
index.lora_for(&"req-1".to_string()),
Some("adapter".to_string())
);
}
#[test]
fn remove_request_is_idempotent() {
let index = RequestIndex::default();
let worker = WorkerWithDpRank::new(1, 0);
let request_id = "req-1".to_string();
index.set_request(request_id.clone(), worker, Some("adapter".to_string()));
assert_eq!(index.remove_request(&request_id), Some(worker));
assert_eq!(index.remove_request(&request_id), None);
assert!(index.is_empty());
}
#[test]
fn set_request_without_lora_clears_stale_lora_mapping() {
let index = RequestIndex::default();
let request_id = "req-1".to_string();
index.set_request(
request_id.clone(),
WorkerWithDpRank::new(1, 0),
Some("adapter".to_string()),
);
index.set_request(request_id.clone(), WorkerWithDpRank::new(2, 0), None);
assert_eq!(
index.worker_for(&request_id),
Some(WorkerWithDpRank::new(2, 0))
);
assert_eq!(index.lora_for(&request_id), None);
}
#[test]
fn remove_worker_requests_clears_both_maps() {
let index = RequestIndex::default();
let worker_a = WorkerWithDpRank::new(1, 0);
let worker_b = WorkerWithDpRank::new(2, 0);
index.set_request("req-a".to_string(), worker_a, Some("adapter-a".to_string()));
index.set_request("req-b".to_string(), worker_b, Some("adapter-b".to_string()));
index.set_request("req-c".to_string(), worker_a, None);
let mut removed = index.remove_worker_requests(worker_a);
removed.sort();
assert_eq!(removed, vec!["req-a".to_string(), "req-c".to_string()]);
assert_eq!(index.worker_for(&"req-b".to_string()), Some(worker_b));
assert_eq!(
index.active_lora_counts(),
HashMap::from([("adapter-b".to_string(), 1)])
);
}
}
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
//! Each block is identified by a hash of its contents, allowing for deduplication when multiple //! Each block is identified by a hash of its contents, allowing for deduplication when multiple
//! requests share common prefixes (e.g., system prompts, few-shot examples). //! requests share common prefixes (e.g., system prompts, few-shot examples).
use derive_getters::Getters;
use dynamo_tokens::SequenceHash; use dynamo_tokens::SequenceHash;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
...@@ -26,8 +25,12 @@ use std::time::Duration; ...@@ -26,8 +25,12 @@ use std::time::Duration;
use tokio::time::Instant; use tokio::time::Instant;
use uuid::Uuid; use uuid::Uuid;
#[cfg(test)]
use rustc_hash::FxHashSet;
use super::block_tracker::BlockTracker; use super::block_tracker::BlockTracker;
use super::prefill_tracker::{PrefillLoadState, PrefillLoadTracker}; use super::prefill_tracker::{PrefillLoadState, PrefillLoadTracker, added_prefill_tokens};
use super::prompt_registry::WorkerLoadSnapshot;
use crate::protocols::PrefillLoadHint; use crate::protocols::PrefillLoadHint;
/// Duration after which stale requests may be expired (5 minutes). /// Duration after which stale requests may be expired (5 minutes).
...@@ -42,28 +45,75 @@ pub type RequestId = String; ...@@ -42,28 +45,75 @@ pub type RequestId = String;
#[derive(Debug)] #[derive(Debug)]
pub(super) struct RequestState { pub(super) struct RequestState {
blocks: Vec<(SequenceHash, Arc<()>)>, prompt_blocks: Vec<(SequenceHash, Arc<()>)>,
output_blocks: Vec<(SequenceHash, Arc<()>)>,
started_at: Instant, started_at: Instant,
prefill: Option<PrefillLoadState>,
expected_output_tokens: Option<u32>, expected_output_tokens: Option<u32>,
} }
impl RequestState {
fn all_blocks(&self) -> impl Iterator<Item = &(SequenceHash, Arc<()>)> {
self.prompt_blocks.iter().chain(self.output_blocks.iter())
}
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub(super) struct PromptMembershipStore {
pub parent: Option<SequenceHash>,
pub hashes: Vec<SequenceHash>,
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub(super) struct PromptMembershipRemove {
pub hashes: Vec<SequenceHash>,
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub(super) struct PromptMembershipDelta {
pub stores: Vec<PromptMembershipStore>,
pub removes: Vec<PromptMembershipRemove>,
}
impl PromptMembershipDelta {
fn extend(&mut self, other: Self) {
self.stores.extend(other.stores);
self.removes.extend(other.removes);
}
fn push_store(&mut self, parent: Option<SequenceHash>, hashes: Vec<SequenceHash>) {
if hashes.is_empty() {
return;
}
self.stores.push(PromptMembershipStore { parent, hashes });
}
fn push_remove(&mut self, hashes: Vec<SequenceHash>) {
if hashes.is_empty() {
return;
}
self.removes.push(PromptMembershipRemove { hashes });
}
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub(super) struct SequenceMutationOutcome {
pub membership_delta: PromptMembershipDelta,
pub expired_request_ids: HashSet<RequestId>,
}
/// A multi-request sequence manager that handles multiple active sequences with shared KV cache /// A multi-request sequence manager that handles multiple active sequences with shared KV cache
#[derive(Debug, Getters)] #[derive(Debug)]
pub struct ActiveSequences { pub struct ActiveSequences {
requests: HashMap<RequestId, RequestState>, requests: HashMap<RequestId, RequestState>,
prefill: PrefillLoadTracker, prefill: PrefillLoadTracker,
blocks: BlockTracker, blocks: BlockTracker,
#[getter(copy)]
block_size: usize, block_size: usize,
last_expiry_check_time: Instant, last_expiry_check_time: Instant,
} }
impl ActiveSequences { impl ActiveSequences {
/// Create a new SharedSequenceManager instance /// Create a new SharedSequenceManager instance
pub fn new(block_size: usize) -> Self { pub(super) fn new(block_size: usize) -> Self {
assert!(block_size > 1, "block_size must be greater than 1"); assert!(block_size > 1, "block_size must be greater than 1");
Self { Self {
...@@ -77,53 +127,13 @@ impl ActiveSequences { ...@@ -77,53 +127,13 @@ impl ActiveSequences {
#[cfg(any(test, debug_assertions))] #[cfg(any(test, debug_assertions))]
fn assert_consistent(&self) { fn assert_consistent(&self) {
let active_prefills: HashSet<RequestId> = self self.prefill.assert_consistent();
.requests let active_prefills: HashSet<RequestId> = self.prefill.prefills.keys().cloned().collect();
.iter() let active_requests: HashSet<RequestId> = self.requests.keys().cloned().collect();
.filter(|(_, state)| state.prefill.is_some()) assert!(
.map(|(request_id, _)| request_id.clone()) active_prefills.is_subset(&active_requests),
.collect(); "prefill tracker cannot reference missing request state",
let ordered_prefills: HashSet<RequestId> =
self.prefill.prefill_order.iter().cloned().collect();
let recomputed_prefill_sum: usize = self
.requests
.values()
.filter_map(|state| state.prefill)
.map(|prefill| prefill.initial_effective_prefill_tokens)
.sum();
assert_eq!(
ordered_prefills.len(),
self.prefill.prefill_order.len(),
"prefill_order contains duplicate request ids",
);
assert_eq!(
ordered_prefills, active_prefills,
"prefill_order must match requests with active prefill load",
); );
assert_eq!(
self.prefill.prefill_full_tokens_sum, recomputed_prefill_sum,
"prefill_full_tokens_sum drifted from request state",
);
if let Some(oldest_request_id) = self.prefill.prefill_order.front() {
let Some((anchored_request_id, _)) = self.prefill.anchored_prefill.as_ref() else {
panic!("anchored_prefill must exist when prefill_order is non-empty");
};
assert!(
self.requests
.get(oldest_request_id)
.is_some_and(|state| state.prefill.is_some()),
"prefill_order front must point to an active prefill request",
);
assert_eq!(
anchored_request_id, oldest_request_id,
"anchored_prefill must match prefill_order.front()",
);
} else {
assert!(
self.prefill.anchored_prefill.is_none(),
"anchored_prefill must be absent when no active prefills remain",
);
}
assert!( assert!(
self.blocks self.blocks
.fractional_blocks .fractional_blocks
...@@ -139,85 +149,19 @@ impl ActiveSequences { ...@@ -139,85 +149,19 @@ impl ActiveSequences {
self.assert_consistent(); self.assert_consistent();
} }
pub fn active_blocks(&self) -> usize { pub(super) fn active_blocks(&self) -> usize {
self.blocks.active_blocks() self.blocks.active_blocks()
} }
fn insert_prefill_load( #[cfg(test)]
&mut self, pub(super) fn active_tokens(&self, decay_now: Instant) -> usize {
request_id: &RequestId, self.prefill.snapshot().active_tokens_at(decay_now)
prefill: PrefillLoadState,
decay_now: Instant,
) {
self.prefill.insert(request_id, prefill, decay_now);
}
fn remove_prefill_load(
&mut self,
request_id: &RequestId,
decay_now: Instant,
) -> Option<PrefillLoadState> {
let prefill = {
let state = self.requests.get_mut(request_id)?;
state.prefill.take()?
};
self.prefill.remove(request_id, prefill, decay_now);
Some(prefill)
}
fn active_prefill_tokens_at(&self, now: Instant) -> usize {
let Some((oldest_request_id, oldest_since)) = self.prefill.anchored_prefill.as_ref() else {
return 0;
};
let prefill = self
.requests
.get(oldest_request_id)
.and_then(|state| state.prefill)
.expect("prefill_order front missing prefill load");
let oldest_full = prefill.initial_effective_prefill_tokens;
let oldest_remaining = match prefill.expected_prefill_duration {
None => oldest_full,
Some(expected_prefill_duration) if expected_prefill_duration.is_zero() => 0,
Some(expected_prefill_duration) => {
let elapsed = now.saturating_duration_since(*oldest_since);
let remaining_fraction = (1.0
- (elapsed.as_secs_f64() / expected_prefill_duration.as_secs_f64()))
.clamp(0.0, 1.0);
((oldest_full as f64) * remaining_fraction).ceil() as usize
}
};
self.prefill
.prefill_full_tokens_sum
.checked_sub(oldest_full)
.expect("prefill_full_tokens_sum smaller than oldest load")
+ oldest_remaining
}
pub fn active_tokens(&self, decay_now: Instant) -> usize {
self.active_prefill_tokens_at(decay_now)
}
/// Find all blocks in a request that have only a single strong reference (only used by this request)
/// and insert them into fractional_blocks with the given fraction value.
pub fn set_single_ref_blocks_as_fractional(&mut self, request_id: &RequestId, fraction: f64) {
let Some(request_state) = self.requests.get(request_id) else {
tracing::warn!(
"Request {request_id} not found for set_single_ref_blocks_as_fractional"
);
return;
};
for (hash, rc) in &request_state.blocks {
if Arc::strong_count(rc) == 1 {
self.blocks.fractional_blocks.insert(*hash, fraction);
}
}
} }
/// Add a new request with its initial tokens. /// Add a new request with its initial tokens.
/// Returns the set of expired request IDs that were removed during cleanup. /// Returns block membership transitions plus any expired request IDs removed during cleanup.
pub fn add_request( #[cfg(test)]
pub(super) fn add_request(
&mut self, &mut self,
request_id: RequestId, request_id: RequestId,
token_sequence: Option<Vec<SequenceHash>>, token_sequence: Option<Vec<SequenceHash>>,
...@@ -225,7 +169,7 @@ impl ActiveSequences { ...@@ -225,7 +169,7 @@ impl ActiveSequences {
overlap: u32, overlap: u32,
expected_output_tokens: Option<u32>, expected_output_tokens: Option<u32>,
decay_now: Instant, decay_now: Instant,
) -> HashSet<RequestId> { ) -> SequenceMutationOutcome {
self.add_request_with_prefill_tracking( self.add_request_with_prefill_tracking(
request_id, request_id,
token_sequence, token_sequence,
...@@ -239,9 +183,9 @@ impl ActiveSequences { ...@@ -239,9 +183,9 @@ impl ActiveSequences {
} }
/// Add a new request with optional prompt-token load accounting. /// Add a new request with optional prompt-token load accounting.
/// Returns the set of expired request IDs that were removed during cleanup. /// Returns block membership transitions plus any expired request IDs removed during cleanup.
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn add_request_with_prefill_tracking( pub(super) fn add_request_with_prefill_tracking(
&mut self, &mut self,
request_id: RequestId, request_id: RequestId,
token_sequence: Option<Vec<SequenceHash>>, token_sequence: Option<Vec<SequenceHash>>,
...@@ -251,23 +195,48 @@ impl ActiveSequences { ...@@ -251,23 +195,48 @@ impl ActiveSequences {
track_prefill_tokens: bool, track_prefill_tokens: bool,
prefill_load_hint: Option<PrefillLoadHint>, prefill_load_hint: Option<PrefillLoadHint>,
decay_now: Instant, decay_now: Instant,
) -> HashSet<RequestId> { ) -> SequenceMutationOutcome {
if self.requests.contains_key(&request_id) { if self.requests.contains_key(&request_id) {
tracing::error!("Request {request_id} is already active. Ignoring duplicate add."); tracing::error!("Request {request_id} is already active. Ignoring duplicate add.");
return HashSet::new(); return SequenceMutationOutcome::default();
} }
let removed_requests = self.force_expiry(); let mut outcome = self.force_expiry();
let started_at = Instant::now(); let started_at = Instant::now();
let blocks = match token_sequence { let prompt_blocks = match token_sequence {
Some(sequence) => sequence Some(sequence) => {
.into_iter() let mut first_new_prompt_idx = None;
.map(|block| { let prompt_blocks: Vec<_> = sequence
let rc = self.blocks.touch_block(&block); .into_iter()
(block, rc) .enumerate()
}) .map(|(idx, block)| {
.collect(), let acquire = self.blocks.touch_block(&block);
if acquire.became_present_on_worker && first_new_prompt_idx.is_none() {
first_new_prompt_idx = Some(idx);
}
(block, acquire.rc)
})
.collect();
if let Some(first_new_prompt_idx) = first_new_prompt_idx {
debug_assert!(
prompt_blocks[first_new_prompt_idx..]
.iter()
.all(|(hash, _)| self.blocks.unique_blocks.contains_key(hash))
);
let parent = first_new_prompt_idx
.checked_sub(1)
.map(|idx| prompt_blocks[idx].0);
let hashes = prompt_blocks[first_new_prompt_idx..]
.iter()
.map(|(hash, _)| *hash)
.collect();
outcome.membership_delta.push_store(parent, hashes);
}
prompt_blocks
}
None => Vec::new(), None => Vec::new(),
}; };
...@@ -289,166 +258,212 @@ impl ActiveSequences { ...@@ -289,166 +258,212 @@ impl ActiveSequences {
self.requests.insert( self.requests.insert(
request_id.clone(), request_id.clone(),
RequestState { RequestState {
blocks, prompt_blocks,
output_blocks: Vec::new(),
started_at, started_at,
prefill,
expected_output_tokens, expected_output_tokens,
}, },
); );
if let Some(prefill) = prefill { if let Some(prefill) = prefill {
self.insert_prefill_load(&request_id, prefill, decay_now); self.prefill.insert(&request_id, prefill, decay_now);
} }
self.validate_state(); self.validate_state();
removed_requests outcome
} }
/// Mark prefill as completed for a request, removing it from prompt-load tracking. /// Mark prefill as completed for a request, removing it from prompt-load tracking.
pub fn mark_prefill_completed(&mut self, request_id: &RequestId, decay_now: Instant) { pub(super) fn mark_prefill_completed(&mut self, request_id: &RequestId, decay_now: Instant) {
let _ = self.remove_prefill_load(request_id, decay_now); let _ = self.prefill.remove(request_id, decay_now);
self.validate_state(); self.validate_state();
} }
pub fn new_tokens(&self, isl: usize, overlap: u32) -> usize {
let cached_tokens = (overlap as usize) * self.block_size;
isl.checked_sub(cached_tokens).unwrap_or_else(|| {
tracing::error!(
"prefill_tokens < 0 with ISL {isl} < cached_tokens {cached_tokens} (overlap {overlap} * block_size {}), returning 0",
self.block_size
);
0
})
}
pub fn potential_blocks_and_tokens(
&self,
token_sequence: Option<&[SequenceHash]>,
isl: usize,
overlap: u32,
decay_now: Instant,
) -> (usize, usize) {
self.potential_blocks_and_tokens_with_prefill_tracking(
token_sequence,
isl,
overlap,
true,
decay_now,
)
}
pub fn potential_blocks_and_tokens_with_prefill_tracking(
&self,
token_sequence: Option<&[SequenceHash]>,
isl: usize,
overlap: u32,
track_prefill_tokens: bool,
decay_now: Instant,
) -> (usize, usize) {
let potential_blocks = if let Some(token_seq) = token_sequence {
self.new_blocks(token_seq) + self.active_blocks()
} else {
self.active_blocks()
};
let active_tokens = self.active_tokens(decay_now);
let potential_tokens = if track_prefill_tokens {
self.new_tokens(isl, overlap) + active_tokens
} else {
active_tokens
};
(potential_blocks, potential_tokens)
}
/// Match a request against existing blocks and return the number of new blocks that would be added
pub fn new_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
token_sequence
.iter()
.filter(|block| !self.blocks.unique_blocks.contains_key(block))
.count()
}
/// Return the total number of blocks that would be used if the token sequence was added.
pub fn potential_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
self.new_blocks(token_sequence) + self.active_blocks()
}
/// Free all blocks associated with a request. /// Free all blocks associated with a request.
/// ///
/// This implicitly calls [`Self::mark_prefill_completed`] first, so callers do not need /// This implicitly calls [`Self::mark_prefill_completed`] first, so callers do not need
/// to invoke both when the request is finishing. /// to invoke both when the request is finishing.
pub fn free(&mut self, request_id: &RequestId, decay_now: Instant) -> usize { pub(super) fn free(
self.mark_prefill_completed(request_id, decay_now); &mut self,
request_id: &RequestId,
decay_now: Instant,
) -> PromptMembershipDelta {
let _ = self.prefill.remove(request_id, decay_now);
let Some(request_state) = self.requests.remove(request_id) else { let Some(request_state) = self.requests.remove(request_id) else {
tracing::warn!("Trying to free non-existent request {request_id}"); tracing::warn!("Trying to free non-existent request {request_id}");
return self.active_blocks(); return PromptMembershipDelta::default();
}; };
let _ = request_state.expected_output_tokens; let _ = request_state.expected_output_tokens;
for (block_hash, rc) in request_state.blocks { let mut membership_delta = PromptMembershipDelta::default();
let mut first_absent_prompt_idx = None;
let prompt_hashes: Vec<_> = request_state
.prompt_blocks
.iter()
.map(|(hash, _)| *hash)
.collect();
for (idx, (block_hash, rc)) in request_state.prompt_blocks.into_iter().enumerate() {
drop(rc);
if self.blocks.try_remove_block(&block_hash) && first_absent_prompt_idx.is_none() {
first_absent_prompt_idx = Some(idx);
}
}
if let Some(first_absent_prompt_idx) = first_absent_prompt_idx {
let prompt_remove = prompt_hashes[first_absent_prompt_idx..].to_vec();
membership_delta.push_remove(prompt_remove);
}
for (block_hash, rc) in request_state.output_blocks {
drop(rc); drop(rc);
self.blocks.try_remove_block(&block_hash); self.blocks.try_remove_block(&block_hash);
} }
self.validate_state(); self.validate_state();
self.active_blocks() membership_delta
} }
/// Add an output block with a random hash and optional fractional decay weight. /// Add an output block with a random hash and optional fractional decay weight.
/// ///
/// This is used during generation to track output blocks as they are created. /// This is used during generation to track output blocks as they are created.
pub fn add_output_block( pub(super) fn add_output_block(
&mut self, &mut self,
request_id: &RequestId, request_id: &RequestId,
decay_fraction: Option<f64>, decay_fraction: Option<f64>,
) -> bool { ) -> Option<SequenceHash> {
if !self.requests.contains_key(request_id) { if !self.requests.contains_key(request_id) {
tracing::warn!("Request {request_id} not found for add_output_block"); tracing::warn!("Request {request_id} not found for add_output_block");
return false; return None;
} }
// TODO: Output blocks still use random hashes, so indexing them mainly simplifies
// generic block bookkeeping and usually adds little real reuse signal.
let random_hash: SequenceHash = Uuid::new_v4().as_u64_pair().0; let random_hash: SequenceHash = Uuid::new_v4().as_u64_pair().0;
let rc = self.blocks.touch_block(&random_hash); let acquire = self.blocks.touch_block(&random_hash);
self.requests self.requests
.get_mut(request_id) .get_mut(request_id)
.expect("request existence was checked above") .expect("request existence was checked above")
.blocks .output_blocks
.push((random_hash, rc)); .push((random_hash, acquire.rc));
if let Some(frac) = decay_fraction { if let Some(frac) = decay_fraction {
self.set_single_ref_blocks_as_fractional(request_id, frac); self.set_single_ref_blocks_as_fractional(request_id, frac);
} }
self.validate_state(); self.validate_state();
true acquire.became_present_on_worker.then_some(random_hash)
}
pub(super) fn new_tokens(&self, isl: usize, overlap: u32) -> usize {
added_prefill_tokens(self.block_size, isl, overlap)
}
#[cfg(test)]
fn potential_blocks_and_tokens_with_prefill_tracking(
&self,
token_sequence: Option<&[SequenceHash]>,
isl: usize,
overlap: u32,
track_prefill_tokens: bool,
decay_now: Instant,
) -> (usize, usize) {
let potential_blocks = if let Some(token_seq) = token_sequence {
self.new_blocks(token_seq) + self.active_blocks()
} else {
self.active_blocks()
};
let active_tokens = self.active_tokens(decay_now);
let potential_tokens = if track_prefill_tokens {
self.new_tokens(isl, overlap) + active_tokens
} else {
active_tokens
};
(potential_blocks, potential_tokens)
}
/// Match a request against existing blocks and return the number of new blocks that would be added
pub(super) fn new_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
token_sequence
.iter()
.filter(|block| !self.blocks.unique_blocks.contains_key(block))
.count()
}
/// Return the total number of blocks that would be used if the token sequence was added.
pub(super) fn potential_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
self.new_blocks(token_sequence) + self.active_blocks()
} }
/// Force expiry of stale requests if the timer has elapsed. /// Force expiry of stale requests if the timer has elapsed.
/// Returns the set of expired request IDs that were removed. /// Returns block membership transitions plus the set of expired request IDs that were removed.
pub fn force_expiry(&mut self) -> HashSet<RequestId> { pub(super) fn force_expiry(&mut self) -> SequenceMutationOutcome {
let now = Instant::now(); let now = Instant::now();
if now < self.last_expiry_check_time + CHECK_EXPIRY_FREQUENCY { if now < self.last_expiry_check_time + CHECK_EXPIRY_FREQUENCY {
return HashSet::new(); return SequenceMutationOutcome::default();
} }
self.last_expiry_check_time = now; self.last_expiry_check_time = now;
let expired_requests_time = now - EXPIRY_DURATION; let expired_requests_time = now - EXPIRY_DURATION;
let expired_requests: HashSet<RequestId> = self let expired_request_ids: HashSet<RequestId> = self
.requests .requests
.iter() .iter()
.filter(|(_, state)| state.started_at < expired_requests_time) .filter(|(_, state)| state.started_at < expired_requests_time)
.map(|(request_id, _)| request_id.clone()) .map(|(request_id, _)| request_id.clone())
.collect(); .collect();
for request_id in &expired_requests { let mut outcome = SequenceMutationOutcome {
expired_request_ids,
..Default::default()
};
for request_id in &outcome.expired_request_ids {
tracing::warn!("Expiring stale request: {}", request_id); tracing::warn!("Expiring stale request: {}", request_id);
self.free(request_id, now); outcome.membership_delta.extend(self.free(request_id, now));
} }
self.validate_state(); self.validate_state();
expired_requests outcome
}
/// Find all blocks in a request that have only a single strong reference (only used by this request)
/// and insert them into fractional_blocks with the given fraction value.
fn set_single_ref_blocks_as_fractional(&mut self, request_id: &RequestId, fraction: f64) {
let Some(request_state) = self.requests.get(request_id) else {
tracing::warn!(
"Request {request_id} not found for set_single_ref_blocks_as_fractional"
);
return;
};
for (hash, rc) in request_state.all_blocks() {
if Arc::strong_count(rc) == 1 {
self.blocks.fractional_blocks.insert(*hash, fraction);
}
}
}
pub(super) fn worker_load_snapshot(&self) -> WorkerLoadSnapshot {
WorkerLoadSnapshot {
active_blocks: self.active_blocks(),
prefill: self.prefill.snapshot(),
}
}
#[cfg(test)]
pub(super) fn active_block_hashes(&self) -> FxHashSet<SequenceHash> {
self.blocks.unique_blocks.keys().copied().collect()
}
#[cfg(test)]
pub(super) fn active_prompt_hashes(&self) -> FxHashSet<SequenceHash> {
self.requests
.values()
.flat_map(|state| state.prompt_blocks.iter().map(|(hash, _)| *hash))
.collect()
} }
} }
...@@ -464,6 +479,119 @@ mod tests { ...@@ -464,6 +479,119 @@ mod tests {
} }
} }
#[test]
fn test_prompt_membership_delta_only_reports_first_add_and_last_remove() {
let mut seq_manager = ActiveSequences::new(4);
let decay_now = Instant::now();
let first = seq_manager.add_request_with_prefill_tracking(
"r1".to_string(),
Some(vec![1, 2]),
8,
0,
None,
true,
None,
decay_now,
);
assert_eq!(
first.membership_delta,
PromptMembershipDelta {
stores: vec![PromptMembershipStore {
parent: None,
hashes: vec![1, 2],
}],
removes: Vec::new(),
}
);
assert!(first.expired_request_ids.is_empty());
let second = seq_manager.add_request_with_prefill_tracking(
"r2".to_string(),
Some(vec![1, 2, 3]),
12,
0,
None,
true,
None,
decay_now,
);
assert_eq!(
second.membership_delta,
PromptMembershipDelta {
stores: vec![PromptMembershipStore {
parent: Some(2),
hashes: vec![3],
}],
removes: Vec::new(),
}
);
let first_free = seq_manager.free(&"r1".to_string(), decay_now);
assert!(first_free.removes.is_empty());
assert!(first_free.stores.is_empty());
let second_free = seq_manager.free(&"r2".to_string(), decay_now);
assert!(second_free.stores.is_empty());
assert_eq!(
second_free.removes,
vec![PromptMembershipRemove {
hashes: vec![1, 2, 3],
}]
);
}
#[test]
fn test_generic_block_membership_includes_output_blocks() {
let mut seq_manager = ActiveSequences::new(4);
let decay_now = Instant::now();
let outcome = seq_manager.add_request_with_prefill_tracking(
"r1".to_string(),
Some(vec![1, 2, 3]),
12,
0,
None,
true,
None,
decay_now,
);
assert_eq!(
outcome.membership_delta.stores,
vec![PromptMembershipStore {
parent: None,
hashes: vec![1, 2, 3],
}]
);
assert_eq!(
seq_manager.active_block_hashes(),
[1, 2, 3].into_iter().collect()
);
let output_hash = seq_manager
.add_output_block(&"r1".to_string(), Some(0.5))
.expect("request exists");
assert_eq!(
seq_manager.active_block_hashes(),
[1, 2, 3, output_hash].into_iter().collect()
);
seq_manager.mark_prefill_completed(&"r1".to_string(), decay_now);
assert_eq!(seq_manager.active_tokens(decay_now), 0);
assert_eq!(
seq_manager.active_block_hashes(),
[1, 2, 3, output_hash].into_iter().collect()
);
let free_delta = seq_manager.free(&"r1".to_string(), decay_now);
assert_eq!(
free_delta.removes,
vec![PromptMembershipRemove {
hashes: vec![1, 2, 3],
}]
);
}
#[test] #[test]
fn test_active_sequences_shared_blocks() { fn test_active_sequences_shared_blocks() {
let block_size = 4; let block_size = 4;
...@@ -532,13 +660,21 @@ mod tests { ...@@ -532,13 +660,21 @@ mod tests {
); );
assert_eq!(seq_manager.active_blocks(), 3); assert_eq!(seq_manager.active_blocks(), 3);
assert!(seq_manager.add_output_block(&"r1".to_string(), Some(0.5))); assert!(
seq_manager
.add_output_block(&"r1".to_string(), Some(0.5))
.is_some()
);
assert_eq!(seq_manager.active_blocks(), 2); assert_eq!(seq_manager.active_blocks(), 2);
seq_manager.add_request("r2".to_string(), Some(vec![1, 2]), 8, 0, None, decay_now); seq_manager.add_request("r2".to_string(), Some(vec![1, 2]), 8, 0, None, decay_now);
assert_eq!(seq_manager.active_blocks(), 2); assert_eq!(seq_manager.active_blocks(), 2);
assert!(seq_manager.add_output_block(&"r1".to_string(), Some(0.0))); assert!(
seq_manager
.add_output_block(&"r1".to_string(), Some(0.0))
.is_some()
);
assert_eq!(seq_manager.active_blocks(), 1); assert_eq!(seq_manager.active_blocks(), 1);
seq_manager.free(&"r2".to_string(), decay_now); seq_manager.free(&"r2".to_string(), decay_now);
...@@ -628,181 +764,6 @@ mod tests { ...@@ -628,181 +764,6 @@ mod tests {
assert_eq!(tokens, 0); assert_eq!(tokens, 0);
} }
#[test]
fn test_prefill_decay_only_applies_to_oldest_request() {
let mut seq_manager = ActiveSequences::new(4);
let epoch = Instant::now();
seq_manager.add_request_with_prefill_tracking(
"r1".to_string(),
Some(vec![1]),
100,
0,
None,
true,
Some(prefill_hint(100, 10)),
epoch,
);
seq_manager.add_request_with_prefill_tracking(
"r2".to_string(),
Some(vec![2]),
60,
0,
None,
true,
Some(prefill_hint(60, 6)),
epoch + Duration::from_secs(2),
);
assert_eq!(
seq_manager.active_tokens(epoch + Duration::from_secs(2)),
140
);
let decayed = seq_manager.active_tokens(epoch + Duration::from_secs(5));
assert_eq!(decayed, 110);
assert!(decayed <= 160);
assert!(decayed >= 60);
}
#[test]
fn test_prefill_decay_hands_off_to_next_oldest_request() {
let mut seq_manager = ActiveSequences::new(4);
let epoch = Instant::now();
seq_manager.add_request_with_prefill_tracking(
"r1".to_string(),
Some(vec![1]),
100,
0,
None,
true,
Some(prefill_hint(100, 10)),
epoch,
);
seq_manager.add_request_with_prefill_tracking(
"r2".to_string(),
Some(vec![2]),
40,
0,
None,
true,
Some(prefill_hint(40, 8)),
epoch,
);
assert_eq!(
seq_manager.active_tokens(epoch + Duration::from_secs(3)),
110
);
seq_manager.mark_prefill_completed(&"r1".to_string(), epoch + Duration::from_secs(3));
assert_eq!(
seq_manager.active_tokens(epoch + Duration::from_secs(3)),
40
);
assert_eq!(
seq_manager.prefill.prefill_order,
VecDeque::from(vec!["r2".to_string()])
);
assert!(
seq_manager
.prefill
.anchored_prefill
.as_ref()
.is_some_and(|(request_id, _)| request_id == "r2")
);
assert_eq!(
seq_manager.active_tokens(epoch + Duration::from_secs(5)),
30
);
}
#[test]
fn test_prefill_decay_resets_when_request_becomes_oldest() {
let mut seq_manager = ActiveSequences::new(4);
let epoch = Instant::now();
seq_manager.add_request_with_prefill_tracking(
"r1".to_string(),
Some(vec![1]),
100,
0,
None,
true,
Some(prefill_hint(100, 10)),
epoch,
);
seq_manager.add_request_with_prefill_tracking(
"r2".to_string(),
Some(vec![2]),
80,
0,
None,
true,
Some(prefill_hint(80, 8)),
epoch + Duration::from_secs(4),
);
assert_eq!(
seq_manager.active_tokens(epoch + Duration::from_secs(8)),
100
);
seq_manager.mark_prefill_completed(&"r1".to_string(), epoch + Duration::from_secs(8));
assert_eq!(
seq_manager.active_tokens(epoch + Duration::from_secs(8)),
80
);
assert_eq!(
seq_manager.active_tokens(epoch + Duration::from_secs(10)),
60
);
}
#[test]
fn test_prefill_front_removal_reanchors_queue_front() {
let mut seq_manager = ActiveSequences::new(4);
let epoch = Instant::now();
seq_manager.add_request_with_prefill_tracking(
"r1".to_string(),
Some(vec![1]),
30,
0,
None,
true,
Some(prefill_hint(30, 6)),
epoch,
);
seq_manager.add_request_with_prefill_tracking(
"r2".to_string(),
Some(vec![2]),
20,
0,
None,
true,
Some(prefill_hint(20, 4)),
epoch,
);
seq_manager.mark_prefill_completed(&"r1".to_string(), epoch + Duration::from_secs(2));
assert!(
seq_manager
.prefill
.anchored_prefill
.as_ref()
.is_some_and(|(request_id, _)| request_id == "r2")
);
assert_eq!(
seq_manager.active_tokens(epoch + Duration::from_secs(2)),
20
);
}
#[test] #[test]
fn test_prefill_queue_and_sum_invariants_survive_idempotent_cleanup() { fn test_prefill_queue_and_sum_invariants_survive_idempotent_cleanup() {
let mut seq_manager = ActiveSequences::new(4); let mut seq_manager = ActiveSequences::new(4);
...@@ -882,18 +843,27 @@ mod tests { ...@@ -882,18 +843,27 @@ mod tests {
tokio::time::advance(Duration::from_secs(20)).await; tokio::time::advance(Duration::from_secs(20)).await;
let expired = seq_manager.force_expiry(); let expired = seq_manager.force_expiry();
assert!(expired.is_empty(), "no check before CHECK_EXPIRY_FREQUENCY"); assert!(
expired.expired_request_ids.is_empty(),
"no check before CHECK_EXPIRY_FREQUENCY"
);
assert_eq!(seq_manager.active_blocks(), 4); assert_eq!(seq_manager.active_blocks(), 4);
tokio::time::advance(Duration::from_secs(11)).await; tokio::time::advance(Duration::from_secs(11)).await;
let expired = seq_manager.force_expiry(); let expired = seq_manager.force_expiry();
assert!(expired.is_empty(), "requests not old enough to expire"); assert!(
expired.expired_request_ids.is_empty(),
"requests not old enough to expire"
);
assert_eq!(seq_manager.active_blocks(), 4); assert_eq!(seq_manager.active_blocks(), 4);
seq_manager.assert_consistent(); seq_manager.assert_consistent();
tokio::time::advance(Duration::from_secs(270)).await; tokio::time::advance(Duration::from_secs(270)).await;
let expired = seq_manager.force_expiry(); let expired = seq_manager.force_expiry();
assert_eq!(expired, HashSet::from(["r1".to_string(), "r2".to_string()])); assert_eq!(
expired.expired_request_ids,
HashSet::from(["r1".to_string(), "r2".to_string()])
);
assert_eq!(seq_manager.active_blocks(), 0); assert_eq!(seq_manager.active_blocks(), 0);
assert_eq!(seq_manager.active_tokens(Instant::now()), 0); assert_eq!(seq_manager.active_tokens(Instant::now()), 0);
seq_manager.assert_consistent(); seq_manager.assert_consistent();
...@@ -901,7 +871,7 @@ mod tests { ...@@ -901,7 +871,7 @@ mod tests {
tokio::time::advance(Duration::from_secs(31)).await; tokio::time::advance(Duration::from_secs(31)).await;
let expired = let expired =
seq_manager.add_request("r3".to_string(), Some(vec![5]), 4, 0, None, Instant::now()); seq_manager.add_request("r3".to_string(), Some(vec![5]), 4, 0, None, Instant::now());
assert!(expired.is_empty()); assert!(expired.expired_request_ids.is_empty());
assert_eq!(seq_manager.active_blocks(), 1); assert_eq!(seq_manager.active_blocks(), 1);
assert_eq!(seq_manager.active_tokens(Instant::now()), 4); assert_eq!(seq_manager.active_tokens(Instant::now()), 4);
seq_manager.assert_consistent(); seq_manager.assert_consistent();
...@@ -936,7 +906,10 @@ mod tests { ...@@ -936,7 +906,10 @@ mod tests {
tokio::time::advance(Duration::from_secs(60)).await; tokio::time::advance(Duration::from_secs(60)).await;
let expired = seq_manager.force_expiry(); let expired = seq_manager.force_expiry();
assert_eq!(expired, HashSet::from(["r1".to_string()])); assert_eq!(
expired.expired_request_ids,
HashSet::from(["r1".to_string()])
);
assert_eq!(seq_manager.active_tokens(Instant::now()), 30); assert_eq!(seq_manager.active_tokens(Instant::now()), 30);
assert!( assert!(
seq_manager seq_manager
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use parking_lot::RwLock;
use rustc_hash::{FxHashMap, FxHashSet};
use std::collections::HashMap;
use super::prompt_membership_trie::WorkerLookup;
use super::single::ActiveSequences;
use crate::protocols::WorkerWithDpRank;
#[derive(Clone)]
pub(super) struct RemovedWorkerState {
pub(super) worker: WorkerWithDpRank,
pub(super) trie_lookup: Arc<RwLock<WorkerLookup>>,
}
impl std::fmt::Debug for RemovedWorkerState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RemovedWorkerState")
.field("worker", &self.worker)
.finish_non_exhaustive()
}
}
#[derive(Debug, Default, Clone)]
pub(super) struct WorkerTopologyChange {
pub(super) added: Vec<WorkerWithDpRank>,
pub(super) removed: Vec<RemovedWorkerState>,
}
pub(super) struct WorkerSlot {
pub(super) worker: WorkerWithDpRank,
pub(super) sequences: RwLock<ActiveSequences>,
pub(super) trie_lookup: Arc<RwLock<WorkerLookup>>,
}
impl WorkerSlot {
fn new(worker: WorkerWithDpRank, block_size: usize) -> Self {
Self {
worker,
sequences: RwLock::new(ActiveSequences::new(block_size)),
trie_lookup: Arc::new(RwLock::new(WorkerLookup::default())),
}
}
}
pub(super) struct WorkerTable {
pub(super) slots: Vec<WorkerSlot>,
pub(super) index: FxHashMap<WorkerWithDpRank, usize>,
}
impl WorkerTable {
pub(super) fn new(block_size: usize, dp_range: &HashMap<u64, (u32, u32)>) -> Self {
let mut slots = Vec::new();
let mut index = FxHashMap::default();
for worker in workers_from_dp_range(dp_range) {
let idx = slots.len();
slots.push(WorkerSlot::new(worker, block_size));
index.insert(worker, idx);
}
Self { slots, index }
}
pub(super) fn workers(&self) -> impl Iterator<Item = WorkerWithDpRank> + '_ {
self.slots.iter().map(|slot| slot.worker)
}
pub(super) fn register_external(
&mut self,
block_size: usize,
dp_range: &HashMap<u64, (u32, u32)>,
) -> WorkerTopologyChange {
let mut change = WorkerTopologyChange::default();
for worker in workers_from_dp_range(dp_range) {
if self.index.contains_key(&worker) {
continue;
}
let idx = self.slots.len();
self.slots.push(WorkerSlot::new(worker, block_size));
self.index.insert(worker, idx);
change.added.push(worker);
}
change
}
pub(super) fn reconcile(
&mut self,
block_size: usize,
new_dp_range: &HashMap<u64, (u32, u32)>,
) -> WorkerTopologyChange {
let target_workers: FxHashSet<WorkerWithDpRank> =
workers_from_dp_range(new_dp_range).into_iter().collect();
let mut old: FxHashMap<WorkerWithDpRank, WorkerSlot> = self
.slots
.drain(..)
.map(|slot| (slot.worker, slot))
.collect();
self.index.clear();
let mut added = Vec::new();
for worker in target_workers {
if !old.contains_key(&worker) {
added.push(worker);
}
let idx = self.slots.len();
let slot = old
.remove(&worker)
.unwrap_or_else(|| WorkerSlot::new(worker, block_size));
self.slots.push(slot);
self.index.insert(worker, idx);
}
let removed = old
.into_values()
.map(|slot| RemovedWorkerState {
worker: slot.worker,
trie_lookup: slot.trie_lookup,
})
.collect();
WorkerTopologyChange { added, removed }
}
pub(super) fn ensure_worker(
&mut self,
block_size: usize,
worker: WorkerWithDpRank,
) -> WorkerTopologyChange {
if self.index.contains_key(&worker) {
return WorkerTopologyChange::default();
}
let idx = self.slots.len();
self.slots.push(WorkerSlot::new(worker, block_size));
self.index.insert(worker, idx);
WorkerTopologyChange {
added: vec![worker],
removed: Vec::new(),
}
}
}
fn workers_from_dp_range(dp_range: &HashMap<u64, (u32, u32)>) -> Vec<WorkerWithDpRank> {
let mut workers = Vec::new();
for (&worker_id, &(dp_start, dp_size)) in dp_range {
for dp_rank in dp_start..(dp_start + dp_size) {
workers.push(WorkerWithDpRank::new(worker_id, dp_rank));
}
}
workers
}
#[cfg(test)]
mod tests {
use tokio::time::Instant;
use super::*;
fn worker(worker_id: u64, dp_rank: u32) -> WorkerWithDpRank {
WorkerWithDpRank::new(worker_id, dp_rank)
}
#[test]
fn new_expands_dp_ranges_into_slots_and_index() {
let table = WorkerTable::new(4, &HashMap::from([(7, (2, 3)), (9, (0, 1))]));
let workers: FxHashSet<_> = table.workers().collect();
assert_eq!(
workers,
FxHashSet::from_iter([worker(7, 2), worker(7, 3), worker(7, 4), worker(9, 0)])
);
assert_eq!(table.index.len(), 4);
assert_eq!(table.slots.len(), 4);
for worker in workers {
assert!(table.index.contains_key(&worker));
}
}
#[test]
fn register_external_only_adds_missing_workers() {
let mut table = WorkerTable::new(4, &HashMap::from([(1, (0, 1))]));
let change = table.register_external(4, &HashMap::from([(1, (0, 2)), (2, (0, 1))]));
assert_eq!(
change.added.into_iter().collect::<FxHashSet<_>>(),
FxHashSet::from_iter([worker(1, 1), worker(2, 0)])
);
assert!(change.removed.is_empty());
assert_eq!(table.index.len(), 3);
}
#[test]
fn ensure_worker_is_idempotent() {
let mut table = WorkerTable::new(4, &HashMap::from([(1, (0, 1))]));
let target = worker(2, 0);
let first = table.ensure_worker(4, target);
let second = table.ensure_worker(4, target);
assert_eq!(first.added, vec![target]);
assert!(first.removed.is_empty());
assert!(second.added.is_empty());
assert!(second.removed.is_empty());
assert_eq!(table.index.len(), 2);
}
#[test]
fn reconcile_preserves_existing_worker_state_and_reports_delta() {
let mut table = WorkerTable::new(4, &HashMap::from([(1, (0, 1)), (2, (0, 1))]));
let existing = worker(1, 0);
let removed = worker(2, 0);
let added = worker(3, 0);
{
let idx = table.index[&existing];
let mut seq = table.slots[idx].sequences.write();
let outcome = seq.add_request(
"req-1".to_string(),
Some(vec![1, 2, 3]),
12,
0,
None,
Instant::now(),
);
assert_eq!(outcome.membership_delta.stores[0].hashes, vec![1, 2, 3],);
}
let change = table.reconcile(4, &HashMap::from([(1, (0, 1)), (3, (0, 1))]));
assert_eq!(change.added, vec![added]);
assert_eq!(
change
.removed
.iter()
.map(|state| state.worker)
.collect::<Vec<_>>(),
vec![removed]
);
assert!(table.index.contains_key(&existing));
assert!(table.index.contains_key(&added));
assert!(!table.index.contains_key(&removed));
let existing_idx = table.index[&existing];
assert_eq!(
table.slots[existing_idx].sequences.read().active_blocks(),
3
);
let added_idx = table.index[&added];
assert_eq!(table.slots[added_idx].sequences.read().active_blocks(), 0);
}
}
...@@ -145,58 +145,6 @@ mod tests { ...@@ -145,58 +145,6 @@ mod tests {
use dynamo_runtime::{DistributedRuntime, Runtime}; use dynamo_runtime::{DistributedRuntime, Runtime};
use tokio::time::Instant; use tokio::time::Instant;
#[test]
fn test_active_sequences_shared_blocks() {
let block_size = 4;
let mut seq_manager = ActiveSequences::new(block_size);
let decay_now = Instant::now();
seq_manager.add_request(
"request_1".to_string(),
Some(vec![1, 2, 3]),
12,
0,
None,
decay_now,
);
assert_eq!(seq_manager.active_blocks(), 3);
assert_eq!(seq_manager.active_tokens(decay_now), 12);
seq_manager.add_request(
"request_2".to_string(),
Some(vec![4]),
4,
0,
None,
decay_now,
);
assert_eq!(seq_manager.active_blocks(), 4);
assert_eq!(seq_manager.active_tokens(decay_now), 16);
seq_manager.add_request(
"request_3".to_string(),
Some(vec![1, 2, 3, 4]),
16,
4,
None,
decay_now,
);
assert_eq!(seq_manager.active_blocks(), 4);
assert_eq!(seq_manager.active_tokens(decay_now), 16);
seq_manager.free(&"request_2".to_string(), decay_now);
assert_eq!(seq_manager.active_blocks(), 4);
assert_eq!(seq_manager.active_tokens(decay_now), 12);
seq_manager.free(&"request_3".to_string(), decay_now);
assert_eq!(seq_manager.active_blocks(), 3);
assert_eq!(seq_manager.active_tokens(decay_now), 12);
seq_manager.free(&"request_1".to_string(), decay_now);
assert_eq!(seq_manager.active_blocks(), 0);
assert_eq!(seq_manager.active_tokens(decay_now), 0);
}
#[tokio::test] #[tokio::test]
#[ignore] #[ignore]
async fn test_multi_worker_cross_instance_sync() -> Result<()> { async fn test_multi_worker_cross_instance_sync() -> Result<()> {
......
...@@ -19,6 +19,7 @@ use dynamo_kv_router::{ ...@@ -19,6 +19,7 @@ use dynamo_kv_router::{
SchedulingPolicy, SchedulingRequest, SequenceRequest, WorkerSelector, SchedulingPolicy, SchedulingRequest, SequenceRequest, WorkerSelector,
}; };
use dynamo_tokens::SequenceHash; use dynamo_tokens::SequenceHash;
use rustc_hash::FxHashMap;
use tokio::time::Instant; use tokio::time::Instant;
use uuid::Uuid; use uuid::Uuid;
...@@ -124,8 +125,8 @@ impl PendingRequest { ...@@ -124,8 +125,8 @@ impl PendingRequest {
fn scheduling_request( fn scheduling_request(
&self, &self,
decode_blocks: HashMap<WorkerWithDpRank, usize>, decode_blocks: FxHashMap<WorkerWithDpRank, usize>,
prefill_tokens: HashMap<WorkerWithDpRank, usize>, prefill_tokens: FxHashMap<WorkerWithDpRank, usize>,
) -> SchedulingRequest { ) -> SchedulingRequest {
SchedulingRequest { SchedulingRequest {
maybe_request_id: Some(self.request_id()), maybe_request_id: Some(self.request_id()),
...@@ -408,7 +409,7 @@ impl OfflineReplayRouter { ...@@ -408,7 +409,7 @@ impl OfflineReplayRouter {
let arrival_offset = Duration::from_secs_f64((now_ms.max(0.0)) / 1000.0); let arrival_offset = Duration::from_secs_f64((now_ms.max(0.0)) / 1000.0);
self.policy.enqueue_key( self.policy.enqueue_key(
arrival_offset, arrival_offset,
&request.scheduling_request(HashMap::new(), HashMap::new()), &request.scheduling_request(FxHashMap::default(), FxHashMap::default()),
) )
} }
......
...@@ -272,12 +272,15 @@ fn simulate_prefill_duration( ...@@ -272,12 +272,15 @@ fn simulate_prefill_duration(
} }
fn debug_assert_sglang_scheduler_state( fn debug_assert_sglang_scheduler_state(
waiting: &VecDeque<SglangRequest>, _waiting: &VecDeque<SglangRequest>,
running: &[SglangRequest], _running: &[SglangRequest],
block_size: usize, _block_size: usize,
) { ) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
let waiting = _waiting;
let running = _running;
let block_size = _block_size;
let mut seen = std::collections::HashSet::new(); let mut seen = std::collections::HashSet::new();
for req in waiting { for req in waiting {
debug_assert!( debug_assert!(
......
...@@ -87,9 +87,10 @@ impl SglangRequest { ...@@ -87,9 +87,10 @@ impl SglangRequest {
self.materialized_tokens += 1; self.materialized_tokens += 1;
} }
pub(super) fn debug_assert_invariants(&self, block_size: usize) { pub(super) fn debug_assert_invariants(&self, _block_size: usize) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
let block_size = _block_size;
let sequence_len = self.current_sequence_len(); let sequence_len = self.current_sequence_len();
debug_assert!( debug_assert!(
self.cached_tokens <= self.materialized_tokens, self.cached_tokens <= self.materialized_tokens,
......
...@@ -759,9 +759,11 @@ fn request_sequence_len(requests: &FxHashMap<Uuid, VllmRequestState>, uuid: Uuid ...@@ -759,9 +759,11 @@ fn request_sequence_len(requests: &FxHashMap<Uuid, VllmRequestState>, uuid: Uuid
.unwrap_or_default() .unwrap_or_default()
} }
fn debug_assert_vllm_request_invariants(uuid: Uuid, request: &VllmRequestState) { fn debug_assert_vllm_request_invariants(_uuid: Uuid, _request: &VllmRequestState) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
let uuid = _uuid;
let request = _request;
let seq_len = request.sequence.len(); let seq_len = request.sequence.len();
let allocated = request.sequence.num_allocated_tokens(); let allocated = request.sequence.num_allocated_tokens();
debug_assert!( debug_assert!(
...@@ -776,9 +778,11 @@ fn debug_assert_vllm_request_invariants(uuid: Uuid, request: &VllmRequestState) ...@@ -776,9 +778,11 @@ fn debug_assert_vllm_request_invariants(uuid: Uuid, request: &VllmRequestState)
} }
} }
fn debug_assert_vllm_request_progress(uuid: Uuid, request: &VllmRequestState) { fn debug_assert_vllm_request_progress(_uuid: Uuid, _request: &VllmRequestState) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
let uuid = _uuid;
let request = _request;
debug_assert_vllm_request_invariants(uuid, request); debug_assert_vllm_request_invariants(uuid, request);
let allocated = request.sequence.num_allocated_tokens(); let allocated = request.sequence.num_allocated_tokens();
debug_assert!( debug_assert!(
...@@ -789,9 +793,11 @@ fn debug_assert_vllm_request_progress(uuid: Uuid, request: &VllmRequestState) { ...@@ -789,9 +793,11 @@ fn debug_assert_vllm_request_progress(uuid: Uuid, request: &VllmRequestState) {
} }
} }
fn debug_assert_vllm_ready_to_decode(requests: &FxHashMap<Uuid, VllmRequestState>, uuid: Uuid) { fn debug_assert_vllm_ready_to_decode(_requests: &FxHashMap<Uuid, VllmRequestState>, _uuid: Uuid) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
let requests = _requests;
let uuid = _uuid;
let Some(request) = requests.get(&uuid) else { let Some(request) = requests.get(&uuid) else {
return; return;
}; };
...@@ -807,9 +813,10 @@ fn debug_assert_vllm_ready_to_decode(requests: &FxHashMap<Uuid, VllmRequestState ...@@ -807,9 +813,10 @@ fn debug_assert_vllm_ready_to_decode(requests: &FxHashMap<Uuid, VllmRequestState
} }
} }
fn debug_assert_vllm_scheduler_state(state: &SchedulerState) { fn debug_assert_vllm_scheduler_state(_state: &SchedulerState) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
let state = _state;
let mut seen = std::collections::HashSet::new(); let mut seen = std::collections::HashSet::new();
for uuid in &state.waiting_members { for uuid in &state.waiting_members {
debug_assert!( debug_assert!(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment