approx.rs 29.9 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
//! Pruning and TTL utilities for KV Indexers
5
//!
6
7
8
//! This module provides utilities for managing TTL-based expiration and size-based pruning
//! of blocks in the radix tree. These utilities are used by the KvIndexer to manage
//! memory usage and keep the cache fresh.
9
10
11
12
13
14

use std::cmp::Reverse;
use std::collections::{BinaryHeap, HashMap};
use std::hash::Hash;
use tokio::time::{Duration, Instant};

15
16
use crate::kv_router::indexer::KvRouterError;
use crate::kv_router::protocols::{ExternalSequenceBlockHash, WorkerWithDpRank};
17

18
19
/// Block entry to be inserted in the [`PruneManager::expirations`] heap.
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
20
pub struct BlockEntry {
21
    /// The key of the block entry.
22
    pub key: ExternalSequenceBlockHash,
Yan Ru Pei's avatar
Yan Ru Pei committed
23
    /// The worker (with dp_rank) that stored this block.
24
    pub worker: WorkerWithDpRank,
25
    /// The position of this block in the sequence (0-indexed).
26
    pub seq_position: usize,
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
}

impl PartialOrd for BlockEntry {
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
        Some(self.cmp(other))
    }
}

impl Ord for BlockEntry {
    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
        // Break ties by sequence position (important for pruning), then by key, then by worker.
        self.seq_position
            .cmp(&other.seq_position)
            .then_with(|| self.key.cmp(&other.key))
            .then_with(|| self.worker.cmp(&other.worker))
    }
}

#[derive(Debug, Clone)]
pub struct PruneConfig {
47
48
    /// Time-to-live duration for blocks before they expire.
    pub ttl: Duration,
49
50
51
52
53
54
    /// The maximum tree size before pruning is considered.
    pub max_tree_size: usize,
    /// The target size ratio to prune down to when max_tree_size is exceeded.
    /// For example, if max_tree_size is 100 and target_size_ratio is 0.5,
    /// we will prune down to 50 nodes when max_tree_size is exceeded.
    pub prune_target_ratio: f64,
55
56
}

57
58
59
60
61
62
63
64
65
66
impl Default for PruneConfig {
    fn default() -> Self {
        Self {
            ttl: Duration::from_secs(120), // 120 seconds
            max_tree_size: 2usize.pow(20), // 2^20 = 1048576
            prune_target_ratio: 0.8,       // Prune down to 80% of max
        }
    }
}

67
68
69
/// A data structure to manage a collection of timers, addressable by a key.
/// This is structured as a sort of "priority queue" of keys, where the priority is the expiration time.
/// It supports insertion as well as updating the expiration time of a key.
70
/// The [`PruneManager::expirations`] heap is lazily updated to reflect the true expiration times in [`PruneManager::timers`]
71
72
/// For now, we have a fixed expiration time for all keys.
#[derive(Debug)]
73
pub struct PruneManager<K: Clone + Hash + Eq + Ord> {
74
75
    /// The source of truth. Maps a key to its current expiration instant.
    timers: HashMap<K, Instant>,
76

77
78
79
80
    /// A max-heap of (Reverse<expiration_instant>, key) used to efficiently find the
    /// next expiring timer. Reverse<Instant> makes earlier times pop first.
    /// An entry in this heap is "stale" if the instant does not match the one in the `timers` map.
    expirations: BinaryHeap<(Reverse<Instant>, K)>,
81
82
83
84

    /// Threshold for rebuilding the heap.
    /// The heap will be rebuilt from scratch to remove stale entries.
    threshold: usize,
85
86
87
88
89

    /// The expiration duration of the timers.
    ttl: Duration,

    /// The configuration for tree-size pruning.
90
    pub prune_config: Option<PruneConfig>,
91
92
}

93
94
impl<K: Clone + Hash + Eq + Ord> PruneManager<K> {
    /// Creates a new, empty PruneManager.
95
96
    pub fn new(threshold: usize, prune_config: PruneConfig) -> Self {
        let ttl = prune_config.ttl;
97
        PruneManager {
98
99
100
            timers: HashMap::new(),
            expirations: BinaryHeap::new(),
            ttl,
101
            threshold,
102
            prune_config: Some(prune_config),
103
104
105
        }
    }

106
107
108
109
110
    /// Rebuilds the expirations heap from the timers map, removing all stale entries.
    fn rebuild_heap(&mut self) {
        self.expirations = self
            .timers
            .iter()
111
            .map(|(key, &expiry)| (Reverse(expiry), key.clone()))
112
113
114
            .collect();
    }

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    /// Inserts a new timer or updates an existing one for the given key.
    ///
    /// # Arguments
    /// * `key` - The unique key for the timer.
    /// * `duration` - The duration from now when the timer should expire.
    pub fn insert(&mut self, keys: Vec<K>) {
        let expiry_time = Instant::now() + self.ttl;

        for key in keys {
            // Insert or update the authoritative time in the map.
            self.timers.insert(key.clone(), expiry_time);

            // Push the new expiration onto the heap. If the key was updated,
            // this leaves a "stale" entry on the heap for the old time,
            // which will be ignored when it's popped.
130
            self.expirations.push((Reverse(expiry_time), key));
131
        }
132
133
134
135
136

        // Check if we should rebuild the heap to remove stale entries
        if self.expirations.len() > self.timers.len() * self.threshold {
            self.rebuild_heap();
        }
137
138
139
140
141
142
143
144
    }

    /// Polls for expired timers and returns a list of keys for all timers
    /// that have expired up to the current moment.
    pub fn pop_expired(&mut self) -> Vec<K> {
        let mut expired_keys = Vec::new();
        let now = Instant::now();

145
        while let Some((Reverse(expiry_time), _)) = self.expirations.peek() {
146
147
148
149
150
151
            // If the next timer in the heap is not yet expired, we can stop.
            if *expiry_time > now {
                break;
            }

            // The timer might be expired, so pop it from the heap.
152
            let (Reverse(expiry_time), key) = self.expirations.pop().unwrap();
153

154
155
156
157
            if self.timers.get(&key) == Some(&expiry_time) {
                // This is a valid, non-stale, expired timer.
                self.timers.remove(&key);
                expired_keys.push(key);
158
159
160
161
162
163
164
165
166
167
            }
        }

        expired_keys
    }

    /// Returns the next expiry time, if it exists.
    pub fn peek_next_expiry(&self) -> Option<Instant> {
        self.expirations
            .peek()
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
            .map(|(Reverse(expiry_time), _)| *expiry_time)
    }

    /// Prunes the tree if the current size is greater than the max tree size.
    pub fn prune(&mut self, current_size: usize) -> Result<Vec<K>, KvRouterError> {
        let max_tree_size: usize;
        let prune_target_ratio: f64;

        if let Some(prune_config) = &self.prune_config {
            max_tree_size = prune_config.max_tree_size;
            prune_target_ratio = prune_config.prune_target_ratio;
        } else {
            tracing::error!("Prune was called but prune config is None. This should never happen");
            return Err(KvRouterError::PruneFailed(
                "prune config is missing".to_string(),
            ));
        }

        if current_size <= max_tree_size {
            // Tree size within bounds, no pruning needed.
            return Ok(Vec::new());
        }

        tracing::info!(
            "Pruning: tree size ({}) exceeded max tree size ({}), starting pruning",
            current_size,
            max_tree_size
        );

        // Number of blocks that will be kept after pruning.
        let target_size = (max_tree_size as f64 * prune_target_ratio) as usize;

        let mut pruned_keys = Vec::new();
        let mut num_pruned = 0;

        while num_pruned < current_size.saturating_sub(target_size) {
            if let Some((Reverse(expiry_time), key)) = self.expirations.pop() {
                if self.timers.get(&key) == Some(&expiry_time) {
                    // This is a valid, non-stale timer.
                    self.timers.remove(&key);
                    pruned_keys.push(key);
                    num_pruned += 1;
                }
            } else {
                break;
            }
        }

        tracing::info!("Pruning: pruned ({}) blocks from tree", num_pruned);

        Ok(pruned_keys)
219
220
221
222
223
224
    }
}

#[cfg(test)]
mod tests {
    use super::*;
225
226
227
    use crate::kv_router::indexer::{KvIndexer, KvIndexerInterface, KvIndexerMetrics};
    use crate::kv_router::protocols::{WorkerId, WorkerWithDpRank};
    use std::sync::Arc;
228
229
230
    use tokio::time::{self, Duration, Instant};
    use tokio_util::sync::CancellationToken;

jthomson04's avatar
jthomson04 committed
231
232
    const KV_BLOCK_SIZE: u32 = 4;

233
    impl<T: Clone + Hash + Eq + Ord> PruneManager<T> {
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
        pub fn get_expiry(&self, key: &T) -> Option<&Instant> {
            self.timers.get(key)
        }
    }

    /// Helper to spin until a future evaluates to `true`, or a timeout is reached.
    async fn spin_until<F, Fut>(timeout: Duration, mut predicate: F)
    where
        F: FnMut() -> Fut,
        Fut: std::future::Future<Output = bool>,
    {
        let start = Instant::now();
        const POLL: Duration = Duration::from_millis(1);
        loop {
            if predicate().await {
                return;
            }
            if Instant::now().duration_since(start) >= timeout {
                panic!("timeout waiting for condition");
            }
            time::sleep(POLL).await;
        }
    }

258
    /// Validate basic insert / expiry behaviour of [`PruneManager`].
259
    #[tokio::test]
260
    async fn test_prune_manager_expiry() {
261
        const TTL: Duration = Duration::from_millis(50);
262
263
264
265
266
267
        let prune_config = PruneConfig {
            ttl: TTL,
            max_tree_size: usize::MAX, // Effectively disable size-based pruning
            prune_target_ratio: 0.5,
        };
        let mut pm: PruneManager<u32> = PruneManager::new(50, prune_config);
268

269
270
271
272
        pm.insert(vec![1, 2, 3]);
        assert!(pm.get_expiry(&1).is_some());
        assert!(pm.get_expiry(&2).is_some());
        assert!(pm.get_expiry(&3).is_some());
273
274
275

        // Wait until after the TTL
        time::sleep(TTL + Duration::from_millis(20)).await;
276
        let expired = pm.pop_expired();
277
        assert_eq!(expired.len(), 3);
278
279
280
        assert!(pm.get_expiry(&1).is_none());
        assert!(pm.get_expiry(&2).is_none());
        assert!(pm.get_expiry(&3).is_none());
281
282
283
284
    }

    /// Validate that reinserting an existing key extends its TTL and prevents premature expiry.
    #[tokio::test]
285
    async fn test_prune_manager_update_resets_ttl() {
286
287
        // Validate that reinserting an existing key extends its TTL and prevents premature expiry.
        const TTL: Duration = Duration::from_millis(50);
288
289
290
291
292
293
        let prune_config = PruneConfig {
            ttl: TTL,
            max_tree_size: usize::MAX,
            prune_target_ratio: 0.5,
        };
        let mut pm: PruneManager<u32> = PruneManager::new(50, prune_config);
294
295

        // Initial insert and capture the original expiry.
296
297
        pm.insert(vec![42]);
        let first_expiry = *pm
298
299
300
301
302
            .get_expiry(&42)
            .expect("expiry missing after first insert");

        // Wait for half of the original TTL before reinserting.
        time::sleep(Duration::from_millis(25)).await;
303
304
        pm.insert(vec![42]);
        let second_expiry = *pm
305
306
307
308
309
310
311
312
            .get_expiry(&42)
            .expect("expiry missing after reinsertion");

        // The expiry after reinsertion must be strictly later than the first one.
        assert!(second_expiry > first_expiry);

        // Wait until *after* the first expiry would have fired, but *before* the new expiry.
        time::sleep(Duration::from_millis(30)).await; // 25ms already elapsed, +30ms = 55ms > first TTL
313
        let expired = pm.pop_expired();
314
315
316
317
318
319
320
        assert!(
            expired.is_empty(),
            "key expired prematurely despite TTL refresh"
        );

        // Now wait until after the second expiry should have occurred.
        time::sleep(Duration::from_millis(30)).await; // Ensure we pass the refreshed TTL
321
        let expired_after = pm.pop_expired();
322
323
324
        assert_eq!(expired_after, vec![42]);
    }

325
    /// End-to-end test for [`KvIndexer`] with TTL:
326
327
328
329
330
331
332
    ///   1. No matches before routing decision
    ///   2. Matches appear after `process_routing_decision`
    ///   3. Matches disappear after TTL expiry
    #[tokio::test]
    async fn test_approx_kv_indexer_basic_flow() {
        const TTL: Duration = Duration::from_millis(200);
        let cancel = CancellationToken::new();
333
334
335
336
337
338
339
340
341
342
343
344
345
        let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
        let prune_config = PruneConfig {
            ttl: TTL,
            max_tree_size: usize::MAX,
            prune_target_ratio: 0.5,
        };
        let indexer = KvIndexer::new_with_frequency(
            cancel.clone(),
            None,
            KV_BLOCK_SIZE,
            metrics,
            Some(prune_config),
        );
346
347
348
349
350
351
352
353
354
355
356
357
358

        let tokens: Vec<u32> = vec![1, 2, 3, 4]; // Exactly one KV block
        let worker_id: WorkerId = 0;

        // 1. Before routing decision there should be no matches
        let pre_scores = indexer
            .find_matches_for_request(&tokens)
            .await
            .expect("indexer offline");
        assert!(pre_scores.scores.is_empty());

        // 2. Inform indexer about routing decision
        indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
359
360
361
362
            .process_routing_decision_for_request(
                &tokens,
                WorkerWithDpRank::from_worker_id(worker_id),
            )
363
364
365
366
367
368
            .await
            .unwrap();

        // Poll until we observe the match being registered
        spin_until(Duration::from_millis(100), || async {
            let s = indexer.find_matches_for_request(&tokens).await.unwrap();
Yan Ru Pei's avatar
Yan Ru Pei committed
369
370
371
372
            s.scores
                .get(&WorkerWithDpRank::from_worker_id(worker_id))
                .copied()
                == Some(1)
373
374
375
376
377
378
379
380
381
382
383
384
385
386
        })
        .await;

        // 3. After the TTL has passed the entry should expire automatically
        time::sleep(TTL + Duration::from_millis(50)).await;
        let post_scores = indexer.find_matches_for_request(&tokens).await.unwrap();
        assert!(post_scores.scores.is_empty());
    }

    /// Verify that `remove_worker` clears all entries for the specified worker.
    #[tokio::test]
    async fn test_remove_worker() {
        const TTL: Duration = Duration::from_secs(5); // Large enough to avoid expiry during test
        let cancel = CancellationToken::new();
387
388
389
390
391
392
393
394
395
396
397
398
399
        let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
        let prune_config = PruneConfig {
            ttl: TTL,
            max_tree_size: usize::MAX,
            prune_target_ratio: 0.5,
        };
        let mut indexer = KvIndexer::new_with_frequency(
            cancel.clone(),
            None,
            KV_BLOCK_SIZE,
            metrics,
            Some(prune_config),
        );
400
401
402
403
404

        let tokens: Vec<u32> = vec![10, 11, 12, 13];
        let worker_id: WorkerId = 7;

        indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
405
406
407
408
            .process_routing_decision_for_request(
                &tokens,
                WorkerWithDpRank::from_worker_id(worker_id),
            )
409
410
411
412
413
414
            .await
            .unwrap();

        // Wait until the worker is registered
        spin_until(Duration::from_millis(100), || async {
            let s = indexer.find_matches_for_request(&tokens).await.unwrap();
Yan Ru Pei's avatar
Yan Ru Pei committed
415
416
            s.scores
                .contains_key(&WorkerWithDpRank::from_worker_id(worker_id))
417
418
419
420
421
422
423
424
425
        })
        .await;

        // Remove the worker
        indexer.remove_worker(worker_id).await;

        // Ensure the worker's entries are gone
        spin_until(Duration::from_millis(100), || async {
            let s = indexer.find_matches_for_request(&tokens).await.unwrap();
Yan Ru Pei's avatar
Yan Ru Pei committed
426
427
            !s.scores
                .contains_key(&WorkerWithDpRank::from_worker_id(worker_id))
428
429
430
431
432
433
434
435
436
437
        })
        .await;
    }

    /// After removing one of multiple workers that share the same block, the remaining worker's entries should persist.
    #[tokio::test]
    async fn test_remove_worker_preserves_other_workers() {
        const TTL: Duration = Duration::from_secs(5); // Large enough to avoid expiry during test

        let cancel = CancellationToken::new();
438
439
440
441
442
443
444
445
446
447
448
449
450
        let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
        let prune_config = PruneConfig {
            ttl: TTL,
            max_tree_size: usize::MAX,
            prune_target_ratio: 0.5,
        };
        let mut indexer = KvIndexer::new_with_frequency(
            cancel.clone(),
            None,
            KV_BLOCK_SIZE,
            metrics,
            Some(prune_config),
        );
451
452
453
454
455
456
457

        let tokens: Vec<u32> = vec![100, 101, 102, 103];
        let worker_0: WorkerId = 30;
        let worker_1: WorkerId = 31;

        // Register on both workers
        indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
458
459
460
461
            .process_routing_decision_for_request(
                &tokens,
                WorkerWithDpRank::from_worker_id(worker_0),
            )
462
463
464
            .await
            .unwrap();
        indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
465
466
467
468
            .process_routing_decision_for_request(
                &tokens,
                WorkerWithDpRank::from_worker_id(worker_1),
            )
469
470
471
472
473
474
            .await
            .unwrap();

        // Ensure both workers are registered
        spin_until(Duration::from_millis(100), || async {
            let s = indexer.find_matches_for_request(&tokens).await.unwrap();
Yan Ru Pei's avatar
Yan Ru Pei committed
475
476
477
478
479
480
481
482
            s.scores
                .get(&WorkerWithDpRank::from_worker_id(worker_0))
                .copied()
                == Some(1)
                && s.scores
                    .get(&WorkerWithDpRank::from_worker_id(worker_1))
                    .copied()
                    == Some(1)
483
484
485
486
487
488
489
490
491
        })
        .await;

        // Remove one worker
        indexer.remove_worker(worker_0).await;

        // Confirm the removed worker is gone, and the other remains.
        spin_until(Duration::from_millis(100), || async {
            let s = indexer.find_matches_for_request(&tokens).await.unwrap();
Yan Ru Pei's avatar
Yan Ru Pei committed
492
493
494
495
496
497
            !s.scores
                .contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
                && s.scores
                    .get(&WorkerWithDpRank::from_worker_id(worker_1))
                    .copied()
                    == Some(1)
498
499
500
501
502
503
504
505
506
507
        })
        .await;
    }

    /// Two sequences with a shared prefix should yield overlap scores reflecting the common blocks.
    #[tokio::test]
    async fn test_common_prefix_overlap() {
        const TTL: Duration = Duration::from_secs(5);

        let cancel = CancellationToken::new();
508
509
510
511
512
513
514
515
516
517
518
519
520
        let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
        let prune_config = PruneConfig {
            ttl: TTL,
            max_tree_size: usize::MAX,
            prune_target_ratio: 0.5,
        };
        let indexer = KvIndexer::new_with_frequency(
            cancel.clone(),
            None,
            KV_BLOCK_SIZE,
            metrics,
            Some(prune_config),
        );
521
522
523
524
525
526
527

        // Sequence A : single block
        let seq_a: Vec<u32> = vec![1, 2, 3, 4];
        let worker_a: WorkerId = 11;

        // Register Sequence A on worker A
        indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
528
529
530
531
            .process_routing_decision_for_request(
                &seq_a,
                WorkerWithDpRank::from_worker_id(worker_a),
            )
532
533
534
535
536
537
            .await
            .unwrap();

        // Ensure the indexer has registered the block
        spin_until(Duration::from_millis(100), || async {
            let s = indexer.find_matches_for_request(&seq_a).await.unwrap();
Yan Ru Pei's avatar
Yan Ru Pei committed
538
539
540
541
            s.scores
                .get(&WorkerWithDpRank::from_worker_id(worker_a))
                .copied()
                == Some(1)
542
543
544
545
546
547
548
549
550
551
        })
        .await;

        // Sequence B : shares the first block with Sequence A, plus an extra block
        let seq_b: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];

        // Query the indexer for overlaps of Sequence B (before it has been routed anywhere)
        let overlap = indexer.find_matches_for_request(&seq_b).await.unwrap();

        // Expect worker A to have an overlap score of 1 (shared first block)
Yan Ru Pei's avatar
Yan Ru Pei committed
552
553
554
555
556
557
        assert_eq!(
            overlap
                .scores
                .get(&WorkerWithDpRank::from_worker_id(worker_a)),
            Some(&1)
        );
558
559
560
561
562
563
564
565
    }

    /// When the same block resides on multiple workers, all should appear in the overlap scores.
    #[tokio::test]
    async fn test_multiple_workers_same_block() {
        const TTL: Duration = Duration::from_secs(5);

        let cancel = CancellationToken::new();
566
567
568
569
570
571
572
573
574
575
576
577
578
        let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
        let prune_config = PruneConfig {
            ttl: TTL,
            max_tree_size: usize::MAX,
            prune_target_ratio: 0.5,
        };
        let indexer = KvIndexer::new_with_frequency(
            cancel.clone(),
            None,
            KV_BLOCK_SIZE,
            metrics,
            Some(prune_config),
        );
579
580
581
582
583
584
585

        let tokens: Vec<u32> = vec![9, 8, 7, 6];
        let worker_0: WorkerId = 21;
        let worker_1: WorkerId = 22;

        // Register the same sequence on two different workers
        indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
586
587
588
589
            .process_routing_decision_for_request(
                &tokens,
                WorkerWithDpRank::from_worker_id(worker_0),
            )
590
591
592
            .await
            .unwrap();
        indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
593
594
595
596
            .process_routing_decision_for_request(
                &tokens,
                WorkerWithDpRank::from_worker_id(worker_1),
            )
597
598
599
600
601
602
            .await
            .unwrap();

        // Wait until both workers are reflected in overlap scores
        spin_until(Duration::from_millis(100), || async {
            let s = indexer.find_matches_for_request(&tokens).await.unwrap();
Yan Ru Pei's avatar
Yan Ru Pei committed
603
604
605
606
607
608
609
610
            s.scores
                .get(&WorkerWithDpRank::from_worker_id(worker_0))
                .copied()
                == Some(1)
                && s.scores
                    .get(&WorkerWithDpRank::from_worker_id(worker_1))
                    .copied()
                    == Some(1)
611
612
613
614
615
        })
        .await;

        let scores = indexer.find_matches_for_request(&tokens).await.unwrap();

Yan Ru Pei's avatar
Yan Ru Pei committed
616
617
618
619
620
621
622
623
624
625
626
627
        assert_eq!(
            scores
                .scores
                .get(&WorkerWithDpRank::from_worker_id(worker_0)),
            Some(&1)
        );
        assert_eq!(
            scores
                .scores
                .get(&WorkerWithDpRank::from_worker_id(worker_1)),
            Some(&1)
        );
628
    }
629
630
631
632
633
634

    /// Test that pruning returns empty when tree size is within the max tree size.
    #[tokio::test]
    async fn test_prune_manager_no_prune_when_within_bounds() {
        const TTL: Duration = Duration::from_secs(10);
        let prune_config = PruneConfig {
635
            ttl: TTL,
636
637
638
639
            max_tree_size: 100,
            prune_target_ratio: 0.5,
        };

640
        let mut pm: PruneManager<u32> = PruneManager::new(50, prune_config);
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659

        // Insert 50 keys (well below max_tree_size of 100)
        pm.insert((0..50).collect());

        // Pruning should return empty vec when size is within bounds
        let pruned = pm.prune(50).unwrap();
        assert!(pruned.is_empty());

        // All keys should still be present
        for i in 0..50 {
            assert!(pm.get_expiry(&i).is_some());
        }
    }

    /// Test that pruning removes the oldest entries first.
    #[tokio::test]
    async fn test_prune_manager_prune_removes_oldest_first() {
        const TTL: Duration = Duration::from_secs(10);
        let prune_config = PruneConfig {
660
            ttl: TTL,
661
662
663
664
            max_tree_size: 10,
            prune_target_ratio: 0.5,
        };

665
        let mut pm: PruneManager<u32> = PruneManager::new(50, prune_config);
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693

        // Insert keys one at a time with delays to ensure different timestamps
        for i in 1..=15 {
            pm.insert(vec![i]);
            time::sleep(Duration::from_millis(1)).await;
        }

        // Total: 15 keys. Trigger pruning with current_size = 15
        let pruned = pm.prune(15).unwrap();

        // Should prune down to 5 (10 * 0.5), so 10 keys should be pruned (15 - 5)
        assert_eq!(pruned.len(), 10);

        // The oldest keys should be pruned first
        for i in 1..=10 {
            assert!(pruned.contains(&i));
        }

        // The newer keys should still be present
        for i in 11..=15 {
            assert!(pm.get_expiry(&i).is_some());
        }
    }

    /// Test that pruning fails gracefully when config is None.
    #[tokio::test]
    async fn test_prune_manager_prune_fails_without_config() {
        const TTL: Duration = Duration::from_secs(10);
694
695
696
697
698
699
700
701
        let prune_config = PruneConfig {
            ttl: TTL,
            max_tree_size: usize::MAX,
            prune_target_ratio: 0.5,
        };
        let mut pm: PruneManager<u32> = PruneManager::new(50, prune_config);
        // Temporarily set prune_config to None to test the error case
        pm.prune_config = None;
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730

        pm.insert(vec![1, 2, 3]);

        // Pruning should fail when prune_config is None
        let result = pm.prune(150);
        assert!(result.is_err());
        assert!(matches!(result, Err(KvRouterError::PruneFailed(_))));
    }

    /// Test that BlockEntry ordering prioritizes sequence position.
    #[test]
    fn test_block_entry_ordering() {
        let worker = WorkerWithDpRank::from_worker_id(0);

        let entry1 = BlockEntry {
            key: ExternalSequenceBlockHash(100),
            worker,
            seq_position: 0,
        };
        let entry2 = BlockEntry {
            key: ExternalSequenceBlockHash(50),
            worker,
            seq_position: 1,
        };

        // entry1 < entry2 because seq_position 0 < 1
        assert!(entry1 < entry2);
    }

731
    /// End-to-end test for [`KvIndexer`] with TTL and pruning
732
733
734
735
736
737
738
739
740
741
    ///   0. Max tree size is 5, target size is 2 (prune_target_ratio = 0.4)
    ///   1. Insert 5 blocks (at max_tree_size but not exceeding)
    ///   2. Verify all 5 blocks are present
    ///   3. Insert 6th block (exceeds threshold, triggers reactive pruning)
    ///   4. Verify pruning occurred: 4 oldest blocks removed
    ///   5. Verify 2 newest blocks remain
    #[tokio::test]
    async fn test_approx_indexer_e2e_pruning() {
        const TTL: Duration = Duration::from_secs(60); // Long TTL to avoid expiry
        let prune_config = PruneConfig {
742
            ttl: TTL,
743
744
745
746
747
            max_tree_size: 5,        // Very small to trigger pruning quickly
            prune_target_ratio: 0.4, // target size is 5 * 0.4 = 2
        };

        let cancel = CancellationToken::new();
748
749
750
751
752
753
754
755
        let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
        let indexer = KvIndexer::new_with_frequency(
            cancel.clone(),
            None,
            KV_BLOCK_SIZE,
            metrics,
            Some(prune_config),
        );
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822

        let worker = WorkerWithDpRank::from_worker_id(42);

        // Insert 5 sequences (5 blocks total, at max_tree_size but not exceeding)
        for i in 0..5 {
            let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
            indexer
                .process_routing_decision_for_request(&tokens, worker)
                .await
                .unwrap();
            time::sleep(Duration::from_millis(1)).await; // Ensure different timestamps
        }

        // Verify all 5 blocks are present (no pruning yet)
        for i in 0..5 {
            let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
            let scores = indexer.find_matches_for_request(&tokens).await.unwrap();
            assert_eq!(
                scores.scores.get(&worker).copied(),
                Some(1),
                "Block {} should be present before threshold is exceeded",
                i
            );
        }

        // Insert 6th block - this exceeds max_tree_size and should trigger reactive pruning
        let tokens: Vec<u32> = vec![50, 51, 52, 53];
        indexer
            .process_routing_decision_for_request(&tokens, worker)
            .await
            .unwrap();

        // Wait for pruning to complete
        time::sleep(Duration::from_millis(100)).await;

        // After pruning, we will have exactly 2 blocks (5 * 0.4 = 2)
        // The 2 newest blocks (i=4, i=5) will remain, oldest 4 blocks (i=0,1,2,3) will be pruned

        // Verify that the 4 oldest blocks are pruned
        for i in 0..4 {
            let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
            let scores = indexer.find_matches_for_request(&tokens).await.unwrap();
            assert!(
                scores.scores.get(&worker).copied().unwrap_or(0) == 0,
                "Block {} should have been pruned but is still present",
                i
            );
        }

        // Verify the 2 newest blocks are present
        for i in 4..6 {
            let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
            let scores = indexer.find_matches_for_request(&tokens).await.unwrap();
            assert_eq!(
                scores.scores.get(&worker).copied(),
                Some(1),
                "Block {} should have been present but was pruned",
                i
            );
        }
    }

    /// Test that re-inserting a key updates its position in the pruning queue.
    #[tokio::test]
    async fn test_prune_manager_prune_reinsertion_updates_position() {
        const TTL: Duration = Duration::from_secs(10);
        let prune_config = PruneConfig {
823
            ttl: TTL,
824
825
826
827
            max_tree_size: 5,
            prune_target_ratio: 0.8,
        };

828
        let mut pm: PruneManager<u32> = PruneManager::new(50, prune_config);
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856

        // Insert keys
        for i in 1..=10 {
            pm.insert(vec![i]);
            time::sleep(Duration::from_millis(1)).await;
        }

        // Re-insert key 1 (should move it to the back of the queue)
        pm.insert(vec![1]);

        // Total: 10 unique keys. Trigger pruning: current_size = 10, target = 4, so prune 6 keys
        // Order by expiry (oldest first): 2, 3, 4, 5, 6, 7, 8, 9, 10, 1 (re-inserted)
        let pruned = pm.prune(10).unwrap();
        assert_eq!(pruned.len(), 6);

        // The oldest keys (2-7) should be pruned
        for i in 2..=7 {
            assert!(pruned.contains(&i));
        }

        // The newest keys (8-10) should still be present
        for i in 8..=10 {
            assert!(pm.get_expiry(&i).is_some());
        }

        // Key 1 should still be present (it was refreshed and is now near the end)
        assert!(pm.get_expiry(&1).is_some());
    }
857
}