Unverified Commit 3ea22fcf authored by Waël Boukhobza's avatar Waël Boukhobza Committed by GitHub
Browse files

feat(router): max tree size based pruning (#4057)


Signed-off-by: default avatarWael Boukhobza <wawa_wael@live.fr>
parent a207b4be
......@@ -726,10 +726,15 @@ impl ApproxKvIndexer {
#[new]
fn new(component: Component, kv_block_size: usize, ttl_secs: f64) -> PyResult<Self> {
let ttl = tokio::time::Duration::from_secs_f64(ttl_secs);
let prune_config = Some(llm_rs::kv_router::approx::PruneConfig {
max_tree_size: 2usize.pow(14), // 2** 14 = 16384
prune_target_ratio: 0.8,
});
let inner = Arc::new(llm_rs::kv_router::approx::ApproxKvIndexer::new(
component.inner.drt().runtime().child_token(),
kv_block_size as u32,
ttl,
prune_config,
));
Ok(Self { inner })
}
......
......@@ -36,6 +36,7 @@ pub use prefill_router::PrefillRouter;
use crate::{
kv_router::{
approx::ApproxKvIndexer,
approx::PruneConfig,
indexer::{
KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent,
compute_block_hash_for_seq, compute_seq_hash_for_block,
......@@ -259,6 +260,10 @@ impl KvRouter {
cancellation_token.clone(),
block_size,
Duration::from_secs(120),
Some(PruneConfig {
max_tree_size: 2usize.pow(14), // 2** 14 = 16384
prune_target_ratio: 0.8,
}),
))
};
......
......@@ -13,13 +13,15 @@
//!
//! - The thinking behind this is that if we send a request to a worker, and shortly after get a request with a similar prefix, odds
//! are that routing to the same worker will result in a large cache hit.
//! - Another benefit is the ability to bound the size of the radix tree, which is not possible if we were trying to accurately represent
//! the state of each worker.
use async_trait::async_trait;
use std::cmp::Reverse;
use std::collections::{BinaryHeap, HashMap};
use std::hash::Hash;
use std::sync::OnceLock;
use tokio::sync::{mpsc, oneshot};
use tokio::sync::{mpsc, oneshot, watch};
use tokio::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
......@@ -54,45 +56,78 @@ struct RouterResult {
sequence_hashes: Vec<u64>,
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)]
struct TimerEntry {
/// The key of the timer.
/// Block entry to be inserted in the [`PruneManager::expirations`] heap.
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
struct BlockEntry {
/// The key of the block entry.
key: ExternalSequenceBlockHash,
/// The worker (with dp_rank) that stored this block.
worker: WorkerWithDpRank,
/// The position of this block in the sequence (0-indexed).
seq_position: usize,
}
impl PartialOrd for BlockEntry {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for BlockEntry {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
// Break ties by sequence position (important for pruning), then by key, then by worker.
self.seq_position
.cmp(&other.seq_position)
.then_with(|| self.key.cmp(&other.key))
.then_with(|| self.worker.cmp(&other.worker))
}
}
#[derive(Debug, Clone)]
pub struct PruneConfig {
/// The maximum tree size before pruning is considered.
pub max_tree_size: usize,
/// The target size ratio to prune down to when max_tree_size is exceeded.
/// For example, if max_tree_size is 100 and target_size_ratio is 0.5,
/// we will prune down to 50 nodes when max_tree_size is exceeded.
pub prune_target_ratio: f64,
}
/// A data structure to manage a collection of timers, addressable by a key.
/// This is structured as a sort of "priority queue" of keys, where the priority is the expiration time.
/// It supports insertion as well as updating the expiration time of a key.
/// The [`TimerManager::expirations`] heap is lazily updated to reflect the true expiration times in [`TimerManager::timers`]
/// The [`PruneManager::expirations`] heap is lazily updated to reflect the true expiration times in [`PruneManager::timers`]
/// For now, we have a fixed expiration time for all keys.
#[derive(Debug)]
struct TimerManager<K: Clone + Hash + Eq + Ord> {
struct PruneManager<K: Clone + Hash + Eq + Ord> {
/// The source of truth. Maps a key to its current expiration instant.
timers: HashMap<K, Instant>,
/// A min-heap of (expiration_instant, key) used to efficiently find the
/// next expiring timer. An entry in this heap is "stale" if the instant
/// does not match the one in the `timers` map.
expirations: BinaryHeap<Reverse<(Instant, K)>>,
/// The expiration duration of the timers.
ttl: Duration,
/// A max-heap of (Reverse<expiration_instant>, key) used to efficiently find the
/// next expiring timer. Reverse<Instant> makes earlier times pop first.
/// An entry in this heap is "stale" if the instant does not match the one in the `timers` map.
expirations: BinaryHeap<(Reverse<Instant>, K)>,
/// Threshold for rebuilding the heap.
/// The heap will be rebuilt from scratch to remove stale entries.
threshold: usize,
/// The expiration duration of the timers.
ttl: Duration,
/// The configuration for tree-size pruning.
prune_config: Option<PruneConfig>,
}
impl<K: Clone + Hash + Eq + Ord> TimerManager<K> {
/// Creates a new, empty TimerManager.
pub fn new(ttl: Duration, threshold: usize) -> Self {
TimerManager {
impl<K: Clone + Hash + Eq + Ord> PruneManager<K> {
/// Creates a new, empty PruneManager.
pub fn new(ttl: Duration, threshold: usize, prune_config: Option<PruneConfig>) -> Self {
PruneManager {
timers: HashMap::new(),
expirations: BinaryHeap::new(),
ttl,
threshold,
prune_config,
}
}
......@@ -101,7 +136,7 @@ impl<K: Clone + Hash + Eq + Ord> TimerManager<K> {
self.expirations = self
.timers
.iter()
.map(|(key, &expiry)| Reverse((expiry, key.clone())))
.map(|(key, &expiry)| (Reverse(expiry), key.clone()))
.collect();
}
......@@ -120,7 +155,7 @@ impl<K: Clone + Hash + Eq + Ord> TimerManager<K> {
// Push the new expiration onto the heap. If the key was updated,
// this leaves a "stale" entry on the heap for the old time,
// which will be ignored when it's popped.
self.expirations.push(Reverse((expiry_time, key)));
self.expirations.push((Reverse(expiry_time), key));
}
// Check if we should rebuild the heap to remove stale entries
......@@ -135,14 +170,14 @@ impl<K: Clone + Hash + Eq + Ord> TimerManager<K> {
let mut expired_keys = Vec::new();
let now = Instant::now();
while let Some(Reverse((expiry_time, _))) = self.expirations.peek() {
while let Some((Reverse(expiry_time), _)) = self.expirations.peek() {
// If the next timer in the heap is not yet expired, we can stop.
if *expiry_time > now {
break;
}
// The timer might be expired, so pop it from the heap.
let Reverse((expiry_time, key)) = self.expirations.pop().unwrap();
let (Reverse(expiry_time), key) = self.expirations.pop().unwrap();
if self.timers.get(&key) == Some(&expiry_time) {
// This is a valid, non-stale, expired timer.
......@@ -158,7 +193,57 @@ impl<K: Clone + Hash + Eq + Ord> TimerManager<K> {
pub fn peek_next_expiry(&self) -> Option<Instant> {
self.expirations
.peek()
.map(|Reverse((expiry_time, _))| *expiry_time)
.map(|(Reverse(expiry_time), _)| *expiry_time)
}
/// Prunes the tree if the current size is greater than the max tree size.
pub fn prune(&mut self, current_size: usize) -> Result<Vec<K>, KvRouterError> {
let max_tree_size: usize;
let prune_target_ratio: f64;
if let Some(prune_config) = &self.prune_config {
max_tree_size = prune_config.max_tree_size;
prune_target_ratio = prune_config.prune_target_ratio;
} else {
tracing::error!("Prune was called but prune config is None. This should never happen");
return Err(KvRouterError::PruneFailed(
"prune config is missing".to_string(),
));
}
if current_size <= max_tree_size {
// Tree size within bounds, no pruning needed.
return Ok(Vec::new());
}
tracing::info!(
"Pruning: tree size ({}) exceeded max tree size ({}), starting pruning",
current_size,
max_tree_size
);
// Number of blocks that will be kept after pruning.
let target_size = (max_tree_size as f64 * prune_target_ratio) as usize;
let mut pruned_keys = Vec::new();
let mut num_pruned = 0;
while num_pruned < current_size.saturating_sub(target_size) {
if let Some((Reverse(expiry_time), key)) = self.expirations.pop() {
if self.timers.get(&key) == Some(&expiry_time) {
// This is a valid, non-stale timer.
self.timers.remove(&key);
pruned_keys.push(key);
num_pruned += 1;
}
} else {
break;
}
}
tracing::info!("Pruning: pruned ({}) blocks from tree", num_pruned);
Ok(pruned_keys)
}
}
......@@ -180,13 +265,19 @@ pub struct ApproxKvIndexer {
}
impl ApproxKvIndexer {
pub fn new(token: CancellationToken, kv_block_size: u32, ttl: Duration) -> Self {
pub fn new(
token: CancellationToken,
kv_block_size: u32,
ttl: Duration,
prune_config: Option<PruneConfig>,
) -> Self {
let (match_tx, mut match_rx) = mpsc::channel::<MatchRequest>(2048);
let (route_tx, mut route_rx) = mpsc::channel::<RouterResult>(2048);
let (remove_worker_tx, mut remove_worker_rx) = mpsc::channel::<WorkerId>(16);
let (_get_workers_tx, mut get_workers_rx) =
mpsc::channel::<super::indexer::GetWorkersRequest>(16);
let (dump_tx, mut dump_rx) = mpsc::channel::<DumpRequest>(16);
let (prune_tx, mut prune_rx) = watch::channel(false);
let cancel_clone = token.clone();
let task = std::thread::spawn(move || {
// create a new tokio runtime which will only perform work on a single thread
......@@ -197,12 +288,13 @@ impl ApproxKvIndexer {
runtime.block_on(async move {
let mut trie = RadixTree::new();
// Use a reasonable threshold - can be made configurable if needed
let mut timer_manager: TimerManager<TimerEntry> = TimerManager::new(ttl, 50);
// Use a reasonable threshold for ttl - can be made configurable if needed
let mut prune_manager: PruneManager<BlockEntry> = PruneManager::new(ttl, 50, prune_config.clone());
let mut event_id = 0;
loop {
// Create a future that sleeps until the next expiration time.
let expiry_fut = if let Some(next_expiry) = timer_manager.peek_next_expiry() {
let expiry_fut = if let Some(next_expiry) = prune_manager.peek_next_expiry() {
tokio::time::sleep_until(next_expiry)
} else {
// If there are no timers, sleep forever.
......@@ -245,12 +337,29 @@ impl ApproxKvIndexer {
}
);
let _ = trie.apply_event(event);
timer_manager.insert(result.sequence_hashes.iter().map(|h| TimerEntry {
if trie.apply_event(event).is_ok() {
prune_manager.insert(result.sequence_hashes.iter().enumerate().map(|(idx, h)| BlockEntry {
key: ExternalSequenceBlockHash(*h),
worker: result.worker,
seq_position: idx,
}).collect());
// Check if we need to prune due to tree size exceeding max threshold.
if let Some(prune_config) = &prune_manager.prune_config {
let current_size = trie.current_size();
if current_size > prune_config.max_tree_size {
tracing::info!(
"Pruning: tree size ({}) exceeded max tree size ({}), scheduling pruning",
current_size,
prune_config.max_tree_size
);
// Send a signal to the pruning watcher to schedule pruning.
if let Err(e) = prune_tx.send(true) {
tracing::error!("Failed to send prune schedule signal: {:?}", e);
}
}
}
}
}
Some(dump_req) = dump_rx.recv() => {
......@@ -263,8 +372,33 @@ impl ApproxKvIndexer {
request.resp.send(scores).unwrap();
}
Ok(_) = prune_rx.changed() => {
// The tree has exceeded the max tree size, so proceed with pruning.
if let Ok(pruned) = prune_manager.prune(trie.current_size()) {
pruned.iter().for_each(|p| {
event_id += 1;
let event = RouterEvent::new(
p.worker.worker_id,
KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![p.key],
}),
dp_rank: p.worker.dp_rank,
}
);
let _ = trie.apply_event(event);
});
// Reset the pruning watcher to false to indicate that pruning is complete.
if let Err(e) = prune_tx.send(true) {
tracing::error!("Failed to send prune completion signal: {:?}", e);
}
}
}
_ = expiry_fut => {
let expired = timer_manager.pop_expired();
let expired = prune_manager.pop_expired();
expired.iter().for_each(|e| {
event_id += 1;
......@@ -424,7 +558,7 @@ mod tests {
const KV_BLOCK_SIZE: u32 = 4;
impl<T: Clone + Hash + Eq + Ord> TimerManager<T> {
impl<T: Clone + Hash + Eq + Ord> PruneManager<T> {
pub fn get_expiry(&self, key: &T) -> Option<&Instant> {
self.timers.get(key)
}
......@@ -449,43 +583,43 @@ mod tests {
}
}
/// Validate basic insert / expiry behaviour of [`TimerManager`].
/// Validate basic insert / expiry behaviour of [`PruneManager`].
#[tokio::test]
async fn test_timer_manager_expiry() {
async fn test_prune_manager_expiry() {
const TTL: Duration = Duration::from_millis(50);
let mut tm: TimerManager<u32> = TimerManager::new(TTL, 50);
let mut pm: PruneManager<u32> = PruneManager::new(TTL, 50, None);
tm.insert(vec![1, 2, 3]);
assert!(tm.get_expiry(&1).is_some());
assert!(tm.get_expiry(&2).is_some());
assert!(tm.get_expiry(&3).is_some());
pm.insert(vec![1, 2, 3]);
assert!(pm.get_expiry(&1).is_some());
assert!(pm.get_expiry(&2).is_some());
assert!(pm.get_expiry(&3).is_some());
// Wait until after the TTL
time::sleep(TTL + Duration::from_millis(20)).await;
let expired = tm.pop_expired();
let expired = pm.pop_expired();
assert_eq!(expired.len(), 3);
assert!(tm.get_expiry(&1).is_none());
assert!(tm.get_expiry(&2).is_none());
assert!(tm.get_expiry(&3).is_none());
assert!(pm.get_expiry(&1).is_none());
assert!(pm.get_expiry(&2).is_none());
assert!(pm.get_expiry(&3).is_none());
}
/// Validate that reinserting an existing key extends its TTL and prevents premature expiry.
#[tokio::test]
async fn test_timer_manager_update_resets_ttl() {
async fn test_prune_manager_update_resets_ttl() {
// Validate that reinserting an existing key extends its TTL and prevents premature expiry.
const TTL: Duration = Duration::from_millis(50);
let mut tm: TimerManager<u32> = TimerManager::new(TTL, 50);
let mut pm: PruneManager<u32> = PruneManager::new(TTL, 50, None);
// Initial insert and capture the original expiry.
tm.insert(vec![42]);
let first_expiry = *tm
pm.insert(vec![42]);
let first_expiry = *pm
.get_expiry(&42)
.expect("expiry missing after first insert");
// Wait for half of the original TTL before reinserting.
time::sleep(Duration::from_millis(25)).await;
tm.insert(vec![42]);
let second_expiry = *tm
pm.insert(vec![42]);
let second_expiry = *pm
.get_expiry(&42)
.expect("expiry missing after reinsertion");
......@@ -494,7 +628,7 @@ mod tests {
// Wait until *after* the first expiry would have fired, but *before* the new expiry.
time::sleep(Duration::from_millis(30)).await; // 25ms already elapsed, +30ms = 55ms > first TTL
let expired = tm.pop_expired();
let expired = pm.pop_expired();
assert!(
expired.is_empty(),
"key expired prematurely despite TTL refresh"
......@@ -502,7 +636,7 @@ mod tests {
// Now wait until after the second expiry should have occurred.
time::sleep(Duration::from_millis(30)).await; // Ensure we pass the refreshed TTL
let expired_after = tm.pop_expired();
let expired_after = pm.pop_expired();
assert_eq!(expired_after, vec![42]);
}
......@@ -514,7 +648,7 @@ mod tests {
async fn test_approx_kv_indexer_basic_flow() {
const TTL: Duration = Duration::from_millis(200);
let cancel = CancellationToken::new();
let indexer = ApproxKvIndexer::new(cancel.clone(), KV_BLOCK_SIZE, TTL);
let indexer = ApproxKvIndexer::new(cancel.clone(), KV_BLOCK_SIZE, TTL, None);
let tokens: Vec<u32> = vec![1, 2, 3, 4]; // Exactly one KV block
let worker_id: WorkerId = 0;
......@@ -556,7 +690,7 @@ mod tests {
async fn test_remove_worker() {
const TTL: Duration = Duration::from_secs(5); // Large enough to avoid expiry during test
let cancel = CancellationToken::new();
let mut indexer = ApproxKvIndexer::new(cancel.clone(), KV_BLOCK_SIZE, TTL);
let mut indexer = ApproxKvIndexer::new(cancel.clone(), KV_BLOCK_SIZE, TTL, None);
let tokens: Vec<u32> = vec![10, 11, 12, 13];
let worker_id: WorkerId = 7;
......@@ -595,7 +729,7 @@ mod tests {
const TTL: Duration = Duration::from_secs(5); // Large enough to avoid expiry during test
let cancel = CancellationToken::new();
let mut indexer = ApproxKvIndexer::new(cancel.clone(), KV_BLOCK_SIZE, TTL);
let mut indexer = ApproxKvIndexer::new(cancel.clone(), KV_BLOCK_SIZE, TTL, None);
let tokens: Vec<u32> = vec![100, 101, 102, 103];
let worker_0: WorkerId = 30;
......@@ -653,7 +787,7 @@ mod tests {
const TTL: Duration = Duration::from_secs(5);
let cancel = CancellationToken::new();
let indexer = ApproxKvIndexer::new(cancel.clone(), KV_BLOCK_SIZE, TTL);
let indexer = ApproxKvIndexer::new(cancel.clone(), KV_BLOCK_SIZE, TTL, None);
// Sequence A : single block
let seq_a: Vec<u32> = vec![1, 2, 3, 4];
......@@ -699,7 +833,7 @@ mod tests {
const TTL: Duration = Duration::from_secs(5);
let cancel = CancellationToken::new();
let indexer = ApproxKvIndexer::new(cancel.clone(), KV_BLOCK_SIZE, TTL);
let indexer = ApproxKvIndexer::new(cancel.clone(), KV_BLOCK_SIZE, TTL, None);
let tokens: Vec<u32> = vec![9, 8, 7, 6];
let worker_0: WorkerId = 21;
......@@ -750,4 +884,214 @@ mod tests {
Some(&1)
);
}
/// Test that pruning returns empty when tree size is within the max tree size.
#[tokio::test]
async fn test_prune_manager_no_prune_when_within_bounds() {
const TTL: Duration = Duration::from_secs(10);
let prune_config = PruneConfig {
max_tree_size: 100,
prune_target_ratio: 0.5,
};
let mut pm: PruneManager<u32> = PruneManager::new(TTL, 50, Some(prune_config));
// Insert 50 keys (well below max_tree_size of 100)
pm.insert((0..50).collect());
// Pruning should return empty vec when size is within bounds
let pruned = pm.prune(50).unwrap();
assert!(pruned.is_empty());
// All keys should still be present
for i in 0..50 {
assert!(pm.get_expiry(&i).is_some());
}
}
/// Test that pruning removes the oldest entries first.
#[tokio::test]
async fn test_prune_manager_prune_removes_oldest_first() {
const TTL: Duration = Duration::from_secs(10);
let prune_config = PruneConfig {
max_tree_size: 10,
prune_target_ratio: 0.5,
};
let mut pm: PruneManager<u32> = PruneManager::new(TTL, 50, Some(prune_config));
// Insert keys one at a time with delays to ensure different timestamps
for i in 1..=15 {
pm.insert(vec![i]);
time::sleep(Duration::from_millis(1)).await;
}
// Total: 15 keys. Trigger pruning with current_size = 15
let pruned = pm.prune(15).unwrap();
// Should prune down to 5 (10 * 0.5), so 10 keys should be pruned (15 - 5)
assert_eq!(pruned.len(), 10);
// The oldest keys should be pruned first
for i in 1..=10 {
assert!(pruned.contains(&i));
}
// The newer keys should still be present
for i in 11..=15 {
assert!(pm.get_expiry(&i).is_some());
}
}
/// Test that pruning fails gracefully when config is None.
#[tokio::test]
async fn test_prune_manager_prune_fails_without_config() {
const TTL: Duration = Duration::from_secs(10);
let mut pm: PruneManager<u32> = PruneManager::new(TTL, 50, None);
pm.insert(vec![1, 2, 3]);
// Pruning should fail when prune_config is None
let result = pm.prune(150);
assert!(result.is_err());
assert!(matches!(result, Err(KvRouterError::PruneFailed(_))));
}
/// Test that BlockEntry ordering prioritizes sequence position.
#[test]
fn test_block_entry_ordering() {
let worker = WorkerWithDpRank::from_worker_id(0);
let entry1 = BlockEntry {
key: ExternalSequenceBlockHash(100),
worker,
seq_position: 0,
};
let entry2 = BlockEntry {
key: ExternalSequenceBlockHash(50),
worker,
seq_position: 1,
};
// entry1 < entry2 because seq_position 0 < 1
assert!(entry1 < entry2);
}
/// End-to-end test for [`ApproxKvIndexer`] with pruning
/// 0. Max tree size is 5, target size is 2 (prune_target_ratio = 0.4)
/// 1. Insert 5 blocks (at max_tree_size but not exceeding)
/// 2. Verify all 5 blocks are present
/// 3. Insert 6th block (exceeds threshold, triggers reactive pruning)
/// 4. Verify pruning occurred: 4 oldest blocks removed
/// 5. Verify 2 newest blocks remain
#[tokio::test]
async fn test_approx_indexer_e2e_pruning() {
const TTL: Duration = Duration::from_secs(60); // Long TTL to avoid expiry
let prune_config = PruneConfig {
max_tree_size: 5, // Very small to trigger pruning quickly
prune_target_ratio: 0.4, // target size is 5 * 0.4 = 2
};
let cancel = CancellationToken::new();
let indexer = ApproxKvIndexer::new(cancel.clone(), KV_BLOCK_SIZE, TTL, Some(prune_config));
let worker = WorkerWithDpRank::from_worker_id(42);
// Insert 5 sequences (5 blocks total, at max_tree_size but not exceeding)
for i in 0..5 {
let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
indexer
.process_routing_decision_for_request(&tokens, worker)
.await
.unwrap();
time::sleep(Duration::from_millis(1)).await; // Ensure different timestamps
}
// Verify all 5 blocks are present (no pruning yet)
for i in 0..5 {
let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
let scores = indexer.find_matches_for_request(&tokens).await.unwrap();
assert_eq!(
scores.scores.get(&worker).copied(),
Some(1),
"Block {} should be present before threshold is exceeded",
i
);
}
// Insert 6th block - this exceeds max_tree_size and should trigger reactive pruning
let tokens: Vec<u32> = vec![50, 51, 52, 53];
indexer
.process_routing_decision_for_request(&tokens, worker)
.await
.unwrap();
// Wait for pruning to complete
time::sleep(Duration::from_millis(100)).await;
// After pruning, we will have exactly 2 blocks (5 * 0.4 = 2)
// The 2 newest blocks (i=4, i=5) will remain, oldest 4 blocks (i=0,1,2,3) will be pruned
// Verify that the 4 oldest blocks are pruned
for i in 0..4 {
let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
let scores = indexer.find_matches_for_request(&tokens).await.unwrap();
assert!(
scores.scores.get(&worker).copied().unwrap_or(0) == 0,
"Block {} should have been pruned but is still present",
i
);
}
// Verify the 2 newest blocks are present
for i in 4..6 {
let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
let scores = indexer.find_matches_for_request(&tokens).await.unwrap();
assert_eq!(
scores.scores.get(&worker).copied(),
Some(1),
"Block {} should have been present but was pruned",
i
);
}
}
/// Test that re-inserting a key updates its position in the pruning queue.
#[tokio::test]
async fn test_prune_manager_prune_reinsertion_updates_position() {
const TTL: Duration = Duration::from_secs(10);
let prune_config = PruneConfig {
max_tree_size: 5,
prune_target_ratio: 0.8,
};
let mut pm: PruneManager<u32> = PruneManager::new(TTL, 50, Some(prune_config));
// Insert keys
for i in 1..=10 {
pm.insert(vec![i]);
time::sleep(Duration::from_millis(1)).await;
}
// Re-insert key 1 (should move it to the back of the queue)
pm.insert(vec![1]);
// Total: 10 unique keys. Trigger pruning: current_size = 10, target = 4, so prune 6 keys
// Order by expiry (oldest first): 2, 3, 4, 5, 6, 7, 8, 9, 10, 1 (re-inserted)
let pruned = pm.prune(10).unwrap();
assert_eq!(pruned.len(), 6);
// The oldest keys (2-7) should be pruned
for i in 2..=7 {
assert!(pruned.contains(&i));
}
// The newest keys (8-10) should still be present
for i in 8..=10 {
assert!(pm.get_expiry(&i).is_some());
}
// Key 1 should still be present (it was refreshed and is now near the end)
assert!(pm.get_expiry(&1).is_some());
}
}
......@@ -68,6 +68,9 @@ pub enum KvRouterError {
#[error("Indexer is dropped request")]
IndexerDroppedRequest,
#[error("Prune operation failed: {0}")]
PruneFailed(String),
}
/// Errors that can occur during KV Cache Event processing.
......@@ -235,6 +238,8 @@ pub struct RadixTree {
lookup: HashMap<WorkerWithDpRank, HashMap<ExternalSequenceBlockHash, SharedRadixBlock>>,
/// The time buffer the radix tree should check when considering frequence of block accesses
expiration_duration: Option<Duration>,
/// The tree current size.
current_size: usize,
}
impl Default for RadixTree {
......@@ -254,6 +259,7 @@ impl RadixTree {
root: Rc::new(RefCell::new(RadixBlock::new())),
lookup: HashMap::new(),
expiration_duration,
current_size: 0,
}
}
......@@ -380,6 +386,9 @@ impl RadixTree {
.children
.insert(block_id.tokens_hash, new_block.clone());
// increment the current size when creating a new block
self.current_size = self.current_size.saturating_add(1);
new_block
}
};
......@@ -428,6 +437,9 @@ impl RadixTree {
if guard.workers.is_empty() {
// if no workers are using this block, that is true for all children
guard.children.clear();
// Decrement the current size when removing the last worker from a node
self.current_size = self.current_size.saturating_sub(1);
}
// remove the block from the lookup table
worker_lookup.remove(&block);
......@@ -460,6 +472,9 @@ impl RadixTree {
// If no workers are using this block, that is true for all children
if block.borrow().workers.is_empty() {
block.borrow_mut().children.clear();
// Decrement the current size when removing the last worker from a node
self.current_size = self.current_size.saturating_sub(1);
}
});
......@@ -560,6 +575,10 @@ impl RadixTree {
events
}
pub fn current_size(&self) -> usize {
self.current_size
}
}
/// Metrics for the KV Indexer.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment