"benchmarks/vscode:/vscode.git/clone" did not exist on "24cde76a152fbffde30fa2be0d08dcbad490530e"
Unverified Commit 134d484d authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(kv-router): add prompt membership index for scheduler reads (#8175)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent d0d9c030
...@@ -2,35 +2,50 @@ ...@@ -2,35 +2,50 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use dynamo_tokens::SequenceHash; use dynamo_tokens::SequenceHash;
use std::collections::HashMap; use rustc_hash::FxHashMap;
use std::sync::{Arc, Weak}; use std::sync::{Arc, Weak};
#[derive(Debug)]
pub(super) struct BlockAcquire {
pub(super) rc: Arc<()>,
pub(super) became_present_on_worker: bool,
}
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub(super) struct BlockTracker { pub(super) struct BlockTracker {
pub(super) unique_blocks: HashMap<SequenceHash, Weak<()>>, pub(super) unique_blocks: FxHashMap<SequenceHash, Weak<()>>,
pub(super) fractional_blocks: HashMap<SequenceHash, f64>, pub(super) fractional_blocks: FxHashMap<SequenceHash, f64>,
} }
impl BlockTracker { impl BlockTracker {
pub(super) fn touch_block(&mut self, block: &SequenceHash) -> Arc<()> { pub(super) fn touch_block(&mut self, block: &SequenceHash) -> BlockAcquire {
if let Some(weak) = self.unique_blocks.get(block) if let Some(weak) = self.unique_blocks.get(block)
&& let Some(rc) = weak.upgrade() && let Some(rc) = weak.upgrade()
{ {
return rc; return BlockAcquire {
rc,
became_present_on_worker: false,
};
} }
let rc = Arc::new(()); let rc = Arc::new(());
self.unique_blocks.insert(*block, Arc::downgrade(&rc)); self.unique_blocks.insert(*block, Arc::downgrade(&rc));
rc BlockAcquire {
rc,
became_present_on_worker: true,
}
} }
pub(super) fn try_remove_block(&mut self, block: &SequenceHash) { pub(super) fn try_remove_block(&mut self, block: &SequenceHash) -> bool {
if let Some(weak) = self.unique_blocks.get(block) if let Some(weak) = self.unique_blocks.get(block)
&& weak.strong_count() == 0 && weak.strong_count() == 0
{ {
self.unique_blocks.remove(block); self.unique_blocks.remove(block);
self.fractional_blocks.remove(block); self.fractional_blocks.remove(block);
return true;
} }
false
} }
pub(super) fn active_blocks(&self) -> usize { pub(super) fn active_blocks(&self) -> usize {
...@@ -43,3 +58,68 @@ impl BlockTracker { ...@@ -43,3 +58,68 @@ impl BlockTracker {
count.round() as usize count.round() as usize
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn first_touch_and_last_remove_report_presence_transitions() {
let mut tracker = BlockTracker::default();
let first = tracker.touch_block(&1);
let second = tracker.touch_block(&1);
assert!(first.became_present_on_worker);
assert!(!second.became_present_on_worker);
assert_eq!(tracker.active_blocks(), 1);
drop(first.rc);
assert!(!tracker.try_remove_block(&1));
assert_eq!(tracker.active_blocks(), 1);
drop(second.rc);
assert!(tracker.try_remove_block(&1));
assert_eq!(tracker.active_blocks(), 0);
}
#[test]
fn fractional_blocks_adjust_active_block_count() {
let mut tracker = BlockTracker::default();
let first = tracker.touch_block(&1);
let second = tracker.touch_block(&2);
tracker.fractional_blocks.insert(1, 0.5);
tracker.fractional_blocks.insert(2, 0.5);
assert_eq!(tracker.active_blocks(), 1);
drop(first.rc);
assert!(tracker.try_remove_block(&1));
assert!(!tracker.fractional_blocks.contains_key(&1));
assert_eq!(tracker.active_blocks(), 1);
drop(second.rc);
assert!(tracker.try_remove_block(&2));
assert!(tracker.fractional_blocks.is_empty());
assert_eq!(tracker.active_blocks(), 0);
}
#[test]
fn shared_block_counts_once_until_last_reference_drops() {
let mut tracker = BlockTracker::default();
let first = tracker.touch_block(&7);
let second = tracker.touch_block(&7);
let third = tracker.touch_block(&7);
assert_eq!(tracker.active_blocks(), 1);
drop(first.rc);
drop(second.rc);
assert!(!tracker.try_remove_block(&7));
assert_eq!(tracker.active_blocks(), 1);
drop(third.rc);
assert!(tracker.try_remove_block(&7));
assert_eq!(tracker.active_blocks(), 0);
}
}
...@@ -4,7 +4,11 @@ ...@@ -4,7 +4,11 @@
mod block_tracker; mod block_tracker;
pub mod multi_worker; pub mod multi_worker;
mod prefill_tracker; mod prefill_tracker;
mod prompt_membership_trie;
mod prompt_registry;
mod request_maps;
pub mod single; pub mod single;
mod topology;
pub use multi_worker::*; pub use multi_worker::*;
pub use single::*; pub use single::*;
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::collections::VecDeque; use std::collections::{HashMap, VecDeque};
use std::time::Duration; use std::time::Duration;
use tokio::time::Instant; use tokio::time::Instant;
use super::single::RequestId; use super::single::RequestId;
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(super) struct PrefillLoadState { pub(super) struct PrefillLoadState {
pub(super) initial_effective_prefill_tokens: usize, pub(super) initial_effective_prefill_tokens: usize,
pub(super) expected_prefill_duration: Option<Duration>, pub(super) expected_prefill_duration: Option<Duration>,
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(super) struct AnchoredPrefillSnapshot {
pub(super) initial_effective_prefill_tokens: usize,
pub(super) expected_prefill_duration: Option<Duration>,
pub(super) anchored_since: Instant,
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub(super) struct PrefillLoadSnapshot {
pub(super) prefill_full_tokens_sum: usize,
pub(super) anchored_prefill: Option<AnchoredPrefillSnapshot>,
}
impl PrefillLoadSnapshot {
pub(super) fn active_tokens_at(&self, now: Instant) -> usize {
let Some(anchored_prefill) = self.anchored_prefill else {
return 0;
};
let anchored_full = anchored_prefill.initial_effective_prefill_tokens;
let anchored_remaining = match anchored_prefill.expected_prefill_duration {
None => anchored_full,
Some(expected_prefill_duration) if expected_prefill_duration.is_zero() => 0,
Some(expected_prefill_duration) => {
let elapsed = now.saturating_duration_since(anchored_prefill.anchored_since);
let remaining_fraction = (1.0
- (elapsed.as_secs_f64() / expected_prefill_duration.as_secs_f64()))
.clamp(0.0, 1.0);
((anchored_full as f64) * remaining_fraction).ceil() as usize
}
};
self.prefill_full_tokens_sum
.checked_sub(anchored_full)
.expect("prefill_full_tokens_sum smaller than anchored load")
+ anchored_remaining
}
}
pub(super) fn added_prefill_tokens(block_size: usize, isl: usize, overlap: u32) -> usize {
let cached_tokens = (overlap as usize) * block_size;
isl.checked_sub(cached_tokens).unwrap_or_else(|| {
tracing::error!(
"prefill_tokens < 0 with ISL {isl} < cached_tokens {cached_tokens} (overlap {overlap} * block_size {block_size}), returning 0",
);
0
})
}
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub(super) struct PrefillLoadTracker { pub(super) struct PrefillLoadTracker {
pub(super) prefills: HashMap<RequestId, PrefillLoadState>,
pub(super) prefill_order: VecDeque<RequestId>, pub(super) prefill_order: VecDeque<RequestId>,
pub(super) prefill_full_tokens_sum: usize, pub(super) prefill_full_tokens_sum: usize,
pub(super) anchored_prefill: Option<(RequestId, Instant)>, pub(super) anchored_prefill: Option<(RequestId, Instant)>,
...@@ -27,6 +76,7 @@ impl PrefillLoadTracker { ...@@ -27,6 +76,7 @@ impl PrefillLoadTracker {
prefill: PrefillLoadState, prefill: PrefillLoadState,
decay_now: Instant, decay_now: Instant,
) { ) {
self.prefills.insert(request_id.clone(), prefill);
self.prefill_full_tokens_sum += prefill.initial_effective_prefill_tokens; self.prefill_full_tokens_sum += prefill.initial_effective_prefill_tokens;
let should_anchor = self.anchored_prefill.is_none(); let should_anchor = self.anchored_prefill.is_none();
self.prefill_order.push_back(request_id.clone()); self.prefill_order.push_back(request_id.clone());
...@@ -38,9 +88,9 @@ impl PrefillLoadTracker { ...@@ -38,9 +88,9 @@ impl PrefillLoadTracker {
pub(super) fn remove( pub(super) fn remove(
&mut self, &mut self,
request_id: &RequestId, request_id: &RequestId,
prefill: PrefillLoadState,
decay_now: Instant, decay_now: Instant,
) { ) -> Option<PrefillLoadState> {
let prefill = self.prefills.remove(request_id)?;
self.prefill_full_tokens_sum = self self.prefill_full_tokens_sum = self
.prefill_full_tokens_sum .prefill_full_tokens_sum
.checked_sub(prefill.initial_effective_prefill_tokens) .checked_sub(prefill.initial_effective_prefill_tokens)
...@@ -60,6 +110,7 @@ impl PrefillLoadTracker { ...@@ -60,6 +110,7 @@ impl PrefillLoadTracker {
{ {
self.set_anchor_to_front(decay_now); self.set_anchor_to_front(decay_now);
} }
Some(prefill)
} }
pub(super) fn set_anchor_to_front(&mut self, now: Instant) { pub(super) fn set_anchor_to_front(&mut self, now: Instant) {
...@@ -69,4 +120,209 @@ impl PrefillLoadTracker { ...@@ -69,4 +120,209 @@ impl PrefillLoadTracker {
.cloned() .cloned()
.map(|request_id| (request_id, now)); .map(|request_id| (request_id, now));
} }
pub(super) fn snapshot(&self) -> PrefillLoadSnapshot {
PrefillLoadSnapshot {
prefill_full_tokens_sum: self.prefill_full_tokens_sum,
anchored_prefill: self
.anchored_prefill
.as_ref()
.map(|(request_id, anchored_since)| {
let prefill = self
.prefills
.get(request_id)
.copied()
.expect("anchored prefill missing request state");
AnchoredPrefillSnapshot {
initial_effective_prefill_tokens: prefill.initial_effective_prefill_tokens,
expected_prefill_duration: prefill.expected_prefill_duration,
anchored_since: *anchored_since,
}
}),
}
}
#[cfg(any(test, debug_assertions))]
pub(super) fn assert_consistent(&self) {
let active_prefills: std::collections::HashSet<RequestId> =
self.prefills.keys().cloned().collect();
let ordered_prefills: std::collections::HashSet<RequestId> =
self.prefill_order.iter().cloned().collect();
let recomputed_prefill_sum: usize = self
.prefills
.values()
.map(|prefill| prefill.initial_effective_prefill_tokens)
.sum();
assert_eq!(
ordered_prefills.len(),
self.prefill_order.len(),
"prefill_order contains duplicate request ids",
);
assert_eq!(
ordered_prefills, active_prefills,
"prefill_order must match active prefill requests",
);
assert_eq!(
self.prefill_full_tokens_sum, recomputed_prefill_sum,
"prefill_full_tokens_sum drifted from tracker state",
);
if let Some(oldest_request_id) = self.prefill_order.front() {
let Some((anchored_request_id, _)) = self.anchored_prefill.as_ref() else {
panic!("anchored_prefill must exist when prefill_order is non-empty");
};
assert!(
self.prefills.contains_key(oldest_request_id),
"prefill_order front must point to an active prefill request",
);
assert_eq!(
anchored_request_id, oldest_request_id,
"anchored_prefill must match prefill_order.front()",
);
} else {
assert!(
self.anchored_prefill.is_none(),
"anchored_prefill must be absent when no active prefills remain",
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn prefill_state(tokens: usize, duration_secs: u64) -> PrefillLoadState {
PrefillLoadState {
initial_effective_prefill_tokens: tokens,
expected_prefill_duration: Some(Duration::from_secs(duration_secs)),
}
}
#[test]
fn snapshot_without_anchor_reports_zero_active_tokens() {
let tracker = PrefillLoadTracker::default();
let snapshot = tracker.snapshot();
assert_eq!(snapshot.active_tokens_at(Instant::now()), 0);
}
#[test]
fn snapshot_only_decays_oldest_prefill() {
let epoch = Instant::now();
let mut tracker = PrefillLoadTracker::default();
let r1 = "r1".to_string();
let r2 = "r2".to_string();
let p1 = prefill_state(100, 10);
let p2 = prefill_state(60, 6);
tracker.insert(&r1, p1, epoch);
tracker.insert(&r2, p2, epoch + Duration::from_secs(2));
let snapshot = tracker.snapshot();
assert_eq!(
snapshot.active_tokens_at(epoch + Duration::from_secs(2)),
140
);
assert_eq!(
snapshot.active_tokens_at(epoch + Duration::from_secs(5)),
110
);
}
#[test]
fn removing_anchored_prefill_reanchors_front_and_resets_decay() {
let epoch = Instant::now();
let mut tracker = PrefillLoadTracker::default();
let r1 = "r1".to_string();
let r2 = "r2".to_string();
let p1 = prefill_state(100, 10);
let p2 = prefill_state(40, 8);
tracker.insert(&r1, p1, epoch);
tracker.insert(&r2, p2, epoch);
assert_eq!(
tracker.remove(&r1, epoch + Duration::from_secs(3)),
Some(p1)
);
assert_eq!(tracker.prefill_order, VecDeque::from([r2.clone()]));
assert!(
tracker
.anchored_prefill
.as_ref()
.is_some_and(|(request_id, _)| request_id == &r2)
);
let snapshot = tracker.snapshot();
assert_eq!(
snapshot.active_tokens_at(epoch + Duration::from_secs(3)),
40
);
assert_eq!(
snapshot.active_tokens_at(epoch + Duration::from_secs(5)),
30
);
}
#[test]
fn removing_nonfront_prefill_preserves_existing_anchor() {
let epoch = Instant::now();
let mut tracker = PrefillLoadTracker::default();
let r1 = "r1".to_string();
let r2 = "r2".to_string();
let p1 = prefill_state(30, 6);
let p2 = prefill_state(20, 4);
tracker.insert(&r1, p1, epoch);
tracker.insert(&r2, p2, epoch);
assert_eq!(
tracker.remove(&r2, epoch + Duration::from_secs(2)),
Some(p2)
);
assert_eq!(tracker.prefill_order, VecDeque::from([r1.clone()]));
assert!(
tracker
.anchored_prefill
.as_ref()
.is_some_and(|(request_id, anchored_since)| {
request_id == &r1 && *anchored_since == epoch
})
);
let snapshot = tracker.snapshot();
assert_eq!(
snapshot.active_tokens_at(epoch + Duration::from_secs(2)),
21
);
}
#[test]
fn duplicate_cleanup_is_idempotent() {
let epoch = Instant::now();
let mut tracker = PrefillLoadTracker::default();
let r1 = "r1".to_string();
let r2 = "r2".to_string();
let p1 = prefill_state(50, 10);
let p2 = prefill_state(30, 10);
tracker.insert(&r1, p1, epoch);
tracker.insert(&r2, p2, epoch);
tracker.assert_consistent();
assert_eq!(tracker.remove(&r1, epoch), Some(p1));
assert_eq!(tracker.remove(&r1, epoch), None);
assert_eq!(tracker.prefill_full_tokens_sum, 30);
assert_eq!(tracker.prefill_order, VecDeque::from([r2.clone()]));
assert_eq!(tracker.remove(&r2, epoch), Some(p2));
assert_eq!(tracker.remove(&r2, epoch), None);
tracker.assert_consistent();
assert_eq!(tracker.prefill_full_tokens_sum, 0);
assert!(tracker.prefill_order.is_empty());
assert!(tracker.prefills.is_empty());
}
} }
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -145,58 +145,6 @@ mod tests { ...@@ -145,58 +145,6 @@ mod tests {
use dynamo_runtime::{DistributedRuntime, Runtime}; use dynamo_runtime::{DistributedRuntime, Runtime};
use tokio::time::Instant; use tokio::time::Instant;
#[test]
fn test_active_sequences_shared_blocks() {
let block_size = 4;
let mut seq_manager = ActiveSequences::new(block_size);
let decay_now = Instant::now();
seq_manager.add_request(
"request_1".to_string(),
Some(vec![1, 2, 3]),
12,
0,
None,
decay_now,
);
assert_eq!(seq_manager.active_blocks(), 3);
assert_eq!(seq_manager.active_tokens(decay_now), 12);
seq_manager.add_request(
"request_2".to_string(),
Some(vec![4]),
4,
0,
None,
decay_now,
);
assert_eq!(seq_manager.active_blocks(), 4);
assert_eq!(seq_manager.active_tokens(decay_now), 16);
seq_manager.add_request(
"request_3".to_string(),
Some(vec![1, 2, 3, 4]),
16,
4,
None,
decay_now,
);
assert_eq!(seq_manager.active_blocks(), 4);
assert_eq!(seq_manager.active_tokens(decay_now), 16);
seq_manager.free(&"request_2".to_string(), decay_now);
assert_eq!(seq_manager.active_blocks(), 4);
assert_eq!(seq_manager.active_tokens(decay_now), 12);
seq_manager.free(&"request_3".to_string(), decay_now);
assert_eq!(seq_manager.active_blocks(), 3);
assert_eq!(seq_manager.active_tokens(decay_now), 12);
seq_manager.free(&"request_1".to_string(), decay_now);
assert_eq!(seq_manager.active_blocks(), 0);
assert_eq!(seq_manager.active_tokens(decay_now), 0);
}
#[tokio::test] #[tokio::test]
#[ignore] #[ignore]
async fn test_multi_worker_cross_instance_sync() -> Result<()> { async fn test_multi_worker_cross_instance_sync() -> Result<()> {
......
...@@ -19,6 +19,7 @@ use dynamo_kv_router::{ ...@@ -19,6 +19,7 @@ use dynamo_kv_router::{
SchedulingPolicy, SchedulingRequest, SequenceRequest, WorkerSelector, SchedulingPolicy, SchedulingRequest, SequenceRequest, WorkerSelector,
}; };
use dynamo_tokens::SequenceHash; use dynamo_tokens::SequenceHash;
use rustc_hash::FxHashMap;
use tokio::time::Instant; use tokio::time::Instant;
use uuid::Uuid; use uuid::Uuid;
...@@ -124,8 +125,8 @@ impl PendingRequest { ...@@ -124,8 +125,8 @@ impl PendingRequest {
fn scheduling_request( fn scheduling_request(
&self, &self,
decode_blocks: HashMap<WorkerWithDpRank, usize>, decode_blocks: FxHashMap<WorkerWithDpRank, usize>,
prefill_tokens: HashMap<WorkerWithDpRank, usize>, prefill_tokens: FxHashMap<WorkerWithDpRank, usize>,
) -> SchedulingRequest { ) -> SchedulingRequest {
SchedulingRequest { SchedulingRequest {
maybe_request_id: Some(self.request_id()), maybe_request_id: Some(self.request_id()),
...@@ -408,7 +409,7 @@ impl OfflineReplayRouter { ...@@ -408,7 +409,7 @@ impl OfflineReplayRouter {
let arrival_offset = Duration::from_secs_f64((now_ms.max(0.0)) / 1000.0); let arrival_offset = Duration::from_secs_f64((now_ms.max(0.0)) / 1000.0);
self.policy.enqueue_key( self.policy.enqueue_key(
arrival_offset, arrival_offset,
&request.scheduling_request(HashMap::new(), HashMap::new()), &request.scheduling_request(FxHashMap::default(), FxHashMap::default()),
) )
} }
......
...@@ -272,12 +272,15 @@ fn simulate_prefill_duration( ...@@ -272,12 +272,15 @@ fn simulate_prefill_duration(
} }
fn debug_assert_sglang_scheduler_state( fn debug_assert_sglang_scheduler_state(
waiting: &VecDeque<SglangRequest>, _waiting: &VecDeque<SglangRequest>,
running: &[SglangRequest], _running: &[SglangRequest],
block_size: usize, _block_size: usize,
) { ) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
let waiting = _waiting;
let running = _running;
let block_size = _block_size;
let mut seen = std::collections::HashSet::new(); let mut seen = std::collections::HashSet::new();
for req in waiting { for req in waiting {
debug_assert!( debug_assert!(
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment