pruning.rs 31.4 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
//! Pruning and TTL utilities for KV Indexers
5
//!
6
7
8
//! This module provides utilities for managing TTL-based expiration and size-based pruning
//! of blocks in the radix tree. These utilities are used by the KvIndexer to manage
//! memory usage and keep the cache fresh.
9
10
11
12
13
14

use std::cmp::Reverse;
use std::collections::{BinaryHeap, HashMap};
use std::hash::Hash;
use tokio::time::{Duration, Instant};

15
use super::KvRouterError;
16
use crate::protocols::{ExternalSequenceBlockHash, WorkerWithDpRank};
17

18
19
/// Block entry to be inserted in the [`PruneManager::expirations`] heap.
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
20
pub struct BlockEntry {
21
    /// The key of the block entry.
22
    pub key: ExternalSequenceBlockHash,
Yan Ru Pei's avatar
Yan Ru Pei committed
23
    /// The worker (with dp_rank) that stored this block.
24
    pub worker: WorkerWithDpRank,
25
    /// The position of this block in the sequence (0-indexed).
26
    pub seq_position: usize,
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
}

impl PartialOrd for BlockEntry {
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
        Some(self.cmp(other))
    }
}

impl Ord for BlockEntry {
    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
        // Break ties by sequence position (important for pruning), then by key, then by worker.
        self.seq_position
            .cmp(&other.seq_position)
            .then_with(|| self.key.cmp(&other.key))
            .then_with(|| self.worker.cmp(&other.worker))
    }
}

#[derive(Debug, Clone)]
pub struct PruneConfig {
47
48
    /// Time-to-live duration for blocks before they expire.
    pub ttl: Duration,
49
50
51
52
53
54
    /// The maximum tree size before pruning is considered.
    pub max_tree_size: usize,
    /// The target size ratio to prune down to when max_tree_size is exceeded.
    /// For example, if max_tree_size is 100 and target_size_ratio is 0.5,
    /// we will prune down to 50 nodes when max_tree_size is exceeded.
    pub prune_target_ratio: f64,
55
56
}

57
58
59
60
61
62
63
64
65
66
impl Default for PruneConfig {
    fn default() -> Self {
        Self {
            ttl: Duration::from_secs(120), // 120 seconds
            max_tree_size: 2usize.pow(20), // 2^20 = 1048576
            prune_target_ratio: 0.8,       // Prune down to 80% of max
        }
    }
}

67
68
69
/// A data structure to manage a collection of timers, addressable by a key.
/// This is structured as a sort of "priority queue" of keys, where the priority is the expiration time.
/// It supports insertion as well as updating the expiration time of a key.
70
/// The [`PruneManager::expirations`] heap is lazily updated to reflect the true expiration times in [`PruneManager::timers`]
71
72
/// For now, we have a fixed expiration time for all keys.
#[derive(Debug)]
73
pub struct PruneManager<K: Clone + Hash + Eq + Ord> {
74
75
    /// The source of truth. Maps a key to its current expiration instant.
    timers: HashMap<K, Instant>,
76

77
78
79
80
    /// A max-heap of (Reverse<expiration_instant>, key) used to efficiently find the
    /// next expiring timer. Reverse<Instant> makes earlier times pop first.
    /// An entry in this heap is "stale" if the instant does not match the one in the `timers` map.
    expirations: BinaryHeap<(Reverse<Instant>, K)>,
81
82
83
84

    /// Threshold for rebuilding the heap.
    /// The heap will be rebuilt from scratch to remove stale entries.
    threshold: usize,
85
86
87
88
89

    /// The expiration duration of the timers.
    ttl: Duration,

    /// The configuration for tree-size pruning.
90
    pub prune_config: Option<PruneConfig>,
91
92
}

93
94
impl<K: Clone + Hash + Eq + Ord> PruneManager<K> {
    /// Creates a new, empty PruneManager.
95
96
    pub fn new(threshold: usize, prune_config: PruneConfig) -> Self {
        let ttl = prune_config.ttl;
97
        PruneManager {
98
99
100
            timers: HashMap::new(),
            expirations: BinaryHeap::new(),
            ttl,
101
            threshold,
102
            prune_config: Some(prune_config),
103
104
105
        }
    }

106
107
108
109
110
    /// Rebuilds the expirations heap from the timers map, removing all stale entries.
    fn rebuild_heap(&mut self) {
        self.expirations = self
            .timers
            .iter()
111
            .map(|(key, &expiry)| (Reverse(expiry), key.clone()))
112
113
114
            .collect();
    }

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    /// Inserts a new timer or updates an existing one for the given key.
    ///
    /// # Arguments
    /// * `key` - The unique key for the timer.
    /// * `duration` - The duration from now when the timer should expire.
    pub fn insert(&mut self, keys: Vec<K>) {
        let expiry_time = Instant::now() + self.ttl;

        for key in keys {
            // Insert or update the authoritative time in the map.
            self.timers.insert(key.clone(), expiry_time);

            // Push the new expiration onto the heap. If the key was updated,
            // this leaves a "stale" entry on the heap for the old time,
            // which will be ignored when it's popped.
130
            self.expirations.push((Reverse(expiry_time), key));
131
        }
132
133
134
135
136

        // Check if we should rebuild the heap to remove stale entries
        if self.expirations.len() > self.timers.len() * self.threshold {
            self.rebuild_heap();
        }
137
138
139
140
141
142
143
144
    }

    /// Polls for expired timers and returns a list of keys for all timers
    /// that have expired up to the current moment.
    pub fn pop_expired(&mut self) -> Vec<K> {
        let mut expired_keys = Vec::new();
        let now = Instant::now();

145
        while let Some((Reverse(expiry_time), _)) = self.expirations.peek() {
146
147
148
149
150
151
            // If the next timer in the heap is not yet expired, we can stop.
            if *expiry_time > now {
                break;
            }

            // The timer might be expired, so pop it from the heap.
152
            let (Reverse(expiry_time), key) = self.expirations.pop().unwrap();
153

154
155
156
157
            if self.timers.get(&key) == Some(&expiry_time) {
                // This is a valid, non-stale, expired timer.
                self.timers.remove(&key);
                expired_keys.push(key);
158
159
160
161
162
163
164
165
166
167
            }
        }

        expired_keys
    }

    /// Returns the next expiry time, if it exists.
    pub fn peek_next_expiry(&self) -> Option<Instant> {
        self.expirations
            .peek()
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
            .map(|(Reverse(expiry_time), _)| *expiry_time)
    }

    /// Prunes the tree if the current size is greater than the max tree size.
    pub fn prune(&mut self, current_size: usize) -> Result<Vec<K>, KvRouterError> {
        let max_tree_size: usize;
        let prune_target_ratio: f64;

        if let Some(prune_config) = &self.prune_config {
            max_tree_size = prune_config.max_tree_size;
            prune_target_ratio = prune_config.prune_target_ratio;
        } else {
            tracing::error!("Prune was called but prune config is None. This should never happen");
            return Err(KvRouterError::PruneFailed(
                "prune config is missing".to_string(),
            ));
        }

        if current_size <= max_tree_size {
            // Tree size within bounds, no pruning needed.
            return Ok(Vec::new());
        }

        tracing::info!(
            "Pruning: tree size ({}) exceeded max tree size ({}), starting pruning",
            current_size,
            max_tree_size
        );

        // Number of blocks that will be kept after pruning.
        let target_size = (max_tree_size as f64 * prune_target_ratio) as usize;

        let mut pruned_keys = Vec::new();
        let mut num_pruned = 0;

        while num_pruned < current_size.saturating_sub(target_size) {
            if let Some((Reverse(expiry_time), key)) = self.expirations.pop() {
                if self.timers.get(&key) == Some(&expiry_time) {
                    // This is a valid, non-stale timer.
                    self.timers.remove(&key);
                    pruned_keys.push(key);
                    num_pruned += 1;
                }
            } else {
                break;
            }
        }

        tracing::info!("Pruning: pruned ({}) blocks from tree", num_pruned);

        Ok(pruned_keys)
219
220
221
222
223
224
    }
}

#[cfg(test)]
mod tests {
    use super::*;
225
226
    use crate::indexer::{KvIndexer, KvIndexerInterface, KvIndexerMetrics};
    use crate::protocols::{TokensWithHashes, WorkerId, WorkerWithDpRank};
227
    use std::sync::Arc;
228
229
230
    use tokio::time::{self, Duration, Instant};
    use tokio_util::sync::CancellationToken;

jthomson04's avatar
jthomson04 committed
231
232
    const KV_BLOCK_SIZE: u32 = 4;

233
    impl<T: Clone + Hash + Eq + Ord> PruneManager<T> {
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
        pub fn get_expiry(&self, key: &T) -> Option<&Instant> {
            self.timers.get(key)
        }
    }

    /// Helper to spin until a future evaluates to `true`, or a timeout is reached.
    async fn spin_until<F, Fut>(timeout: Duration, mut predicate: F)
    where
        F: FnMut() -> Fut,
        Fut: std::future::Future<Output = bool>,
    {
        let start = Instant::now();
        const POLL: Duration = Duration::from_millis(1);
        loop {
            if predicate().await {
                return;
            }
            if Instant::now().duration_since(start) >= timeout {
                panic!("timeout waiting for condition");
            }
            time::sleep(POLL).await;
        }
    }

258
    /// Validate basic insert / expiry behaviour of [`PruneManager`].
259
    #[tokio::test]
260
    async fn test_prune_manager_expiry() {
261
        const TTL: Duration = Duration::from_millis(50);
262
263
264
265
266
267
        let prune_config = PruneConfig {
            ttl: TTL,
            max_tree_size: usize::MAX, // Effectively disable size-based pruning
            prune_target_ratio: 0.5,
        };
        let mut pm: PruneManager<u32> = PruneManager::new(50, prune_config);
268

269
270
271
272
        pm.insert(vec![1, 2, 3]);
        assert!(pm.get_expiry(&1).is_some());
        assert!(pm.get_expiry(&2).is_some());
        assert!(pm.get_expiry(&3).is_some());
273
274
275

        // Wait until after the TTL
        time::sleep(TTL + Duration::from_millis(20)).await;
276
        let expired = pm.pop_expired();
277
        assert_eq!(expired.len(), 3);
278
279
280
        assert!(pm.get_expiry(&1).is_none());
        assert!(pm.get_expiry(&2).is_none());
        assert!(pm.get_expiry(&3).is_none());
281
282
283
284
    }

    /// Validate that reinserting an existing key extends its TTL and prevents premature expiry.
    #[tokio::test]
285
    async fn test_prune_manager_update_resets_ttl() {
286
287
        // Validate that reinserting an existing key extends its TTL and prevents premature expiry.
        const TTL: Duration = Duration::from_millis(50);
288
289
290
291
292
293
        let prune_config = PruneConfig {
            ttl: TTL,
            max_tree_size: usize::MAX,
            prune_target_ratio: 0.5,
        };
        let mut pm: PruneManager<u32> = PruneManager::new(50, prune_config);
294
295

        // Initial insert and capture the original expiry.
296
297
        pm.insert(vec![42]);
        let first_expiry = *pm
298
299
300
301
302
            .get_expiry(&42)
            .expect("expiry missing after first insert");

        // Wait for half of the original TTL before reinserting.
        time::sleep(Duration::from_millis(25)).await;
303
304
        pm.insert(vec![42]);
        let second_expiry = *pm
305
306
307
308
309
310
311
312
            .get_expiry(&42)
            .expect("expiry missing after reinsertion");

        // The expiry after reinsertion must be strictly later than the first one.
        assert!(second_expiry > first_expiry);

        // Wait until *after* the first expiry would have fired, but *before* the new expiry.
        time::sleep(Duration::from_millis(30)).await; // 25ms already elapsed, +30ms = 55ms > first TTL
313
        let expired = pm.pop_expired();
314
315
316
317
318
319
320
        assert!(
            expired.is_empty(),
            "key expired prematurely despite TTL refresh"
        );

        // Now wait until after the second expiry should have occurred.
        time::sleep(Duration::from_millis(30)).await; // Ensure we pass the refreshed TTL
321
        let expired_after = pm.pop_expired();
322
323
324
        assert_eq!(expired_after, vec![42]);
    }

325
    /// End-to-end test for [`KvIndexer`] with TTL:
326
327
328
329
330
331
332
    ///   1. No matches before routing decision
    ///   2. Matches appear after `process_routing_decision`
    ///   3. Matches disappear after TTL expiry
    #[tokio::test]
    async fn test_approx_kv_indexer_basic_flow() {
        const TTL: Duration = Duration::from_millis(200);
        let cancel = CancellationToken::new();
333
334
335
336
337
338
339
340
341
342
343
344
345
        let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
        let prune_config = PruneConfig {
            ttl: TTL,
            max_tree_size: usize::MAX,
            prune_target_ratio: 0.5,
        };
        let indexer = KvIndexer::new_with_frequency(
            cancel.clone(),
            None,
            KV_BLOCK_SIZE,
            metrics,
            Some(prune_config),
        );
346
347
348
349
350
351

        let tokens: Vec<u32> = vec![1, 2, 3, 4]; // Exactly one KV block
        let worker_id: WorkerId = 0;

        // 1. Before routing decision there should be no matches
        let pre_scores = indexer
352
            .find_matches_for_request(&tokens, None)
353
354
355
356
357
            .await
            .expect("indexer offline");
        assert!(pre_scores.scores.is_empty());

        // 2. Inform indexer about routing decision
358
        let mut tokens_with_hashes = TokensWithHashes::new(tokens.clone(), KV_BLOCK_SIZE);
359
        indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
360
            .process_routing_decision_for_request(
361
                &mut tokens_with_hashes,
Yan Ru Pei's avatar
Yan Ru Pei committed
362
363
                WorkerWithDpRank::from_worker_id(worker_id),
            )
364
365
366
367
            .await
            .unwrap();

        // Poll until we observe the match being registered
368
        spin_until(Duration::from_millis(100), async || {
369
370
371
372
            let s = indexer
                .find_matches_for_request(&tokens, None)
                .await
                .unwrap();
Yan Ru Pei's avatar
Yan Ru Pei committed
373
374
375
376
            s.scores
                .get(&WorkerWithDpRank::from_worker_id(worker_id))
                .copied()
                == Some(1)
377
378
379
380
381
        })
        .await;

        // 3. After the TTL has passed the entry should expire automatically
        time::sleep(TTL + Duration::from_millis(50)).await;
382
383
384
385
        let post_scores = indexer
            .find_matches_for_request(&tokens, None)
            .await
            .unwrap();
386
387
388
389
390
391
392
393
        assert!(post_scores.scores.is_empty());
    }

    /// Verify that `remove_worker` clears all entries for the specified worker.
    #[tokio::test]
    async fn test_remove_worker() {
        const TTL: Duration = Duration::from_secs(5); // Large enough to avoid expiry during test
        let cancel = CancellationToken::new();
394
395
396
397
398
399
        let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
        let prune_config = PruneConfig {
            ttl: TTL,
            max_tree_size: usize::MAX,
            prune_target_ratio: 0.5,
        };
Yan Ru Pei's avatar
Yan Ru Pei committed
400
        let indexer = KvIndexer::new_with_frequency(
401
402
403
404
405
406
            cancel.clone(),
            None,
            KV_BLOCK_SIZE,
            metrics,
            Some(prune_config),
        );
407
408
409
410

        let tokens: Vec<u32> = vec![10, 11, 12, 13];
        let worker_id: WorkerId = 7;

411
        let mut tokens_with_hashes = TokensWithHashes::new(tokens.clone(), KV_BLOCK_SIZE);
412
        indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
413
            .process_routing_decision_for_request(
414
                &mut tokens_with_hashes,
Yan Ru Pei's avatar
Yan Ru Pei committed
415
416
                WorkerWithDpRank::from_worker_id(worker_id),
            )
417
418
419
420
            .await
            .unwrap();

        // Wait until the worker is registered
421
        spin_until(Duration::from_millis(100), async || {
422
423
424
425
            let s = indexer
                .find_matches_for_request(&tokens, None)
                .await
                .unwrap();
Yan Ru Pei's avatar
Yan Ru Pei committed
426
427
            s.scores
                .contains_key(&WorkerWithDpRank::from_worker_id(worker_id))
428
429
430
431
432
433
434
        })
        .await;

        // Remove the worker
        indexer.remove_worker(worker_id).await;

        // Ensure the worker's entries are gone
435
        spin_until(Duration::from_millis(100), async || {
436
437
438
439
            let s = indexer
                .find_matches_for_request(&tokens, None)
                .await
                .unwrap();
Yan Ru Pei's avatar
Yan Ru Pei committed
440
441
            !s.scores
                .contains_key(&WorkerWithDpRank::from_worker_id(worker_id))
442
443
444
445
446
447
448
449
450
451
        })
        .await;
    }

    /// After removing one of multiple workers that share the same block, the remaining worker's entries should persist.
    #[tokio::test]
    async fn test_remove_worker_preserves_other_workers() {
        const TTL: Duration = Duration::from_secs(5); // Large enough to avoid expiry during test

        let cancel = CancellationToken::new();
452
453
454
455
456
457
        let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
        let prune_config = PruneConfig {
            ttl: TTL,
            max_tree_size: usize::MAX,
            prune_target_ratio: 0.5,
        };
Yan Ru Pei's avatar
Yan Ru Pei committed
458
        let indexer = KvIndexer::new_with_frequency(
459
460
461
462
463
464
            cancel.clone(),
            None,
            KV_BLOCK_SIZE,
            metrics,
            Some(prune_config),
        );
465
466
467
468
469
470

        let tokens: Vec<u32> = vec![100, 101, 102, 103];
        let worker_0: WorkerId = 30;
        let worker_1: WorkerId = 31;

        // Register on both workers
471
        let mut tokens_with_hashes = TokensWithHashes::new(tokens.clone(), KV_BLOCK_SIZE);
472
        indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
473
            .process_routing_decision_for_request(
474
                &mut tokens_with_hashes,
Yan Ru Pei's avatar
Yan Ru Pei committed
475
476
                WorkerWithDpRank::from_worker_id(worker_0),
            )
477
478
            .await
            .unwrap();
479
        let mut tokens_with_hashes = TokensWithHashes::new(tokens.clone(), KV_BLOCK_SIZE);
480
        indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
481
            .process_routing_decision_for_request(
482
                &mut tokens_with_hashes,
Yan Ru Pei's avatar
Yan Ru Pei committed
483
484
                WorkerWithDpRank::from_worker_id(worker_1),
            )
485
486
487
488
            .await
            .unwrap();

        // Ensure both workers are registered
489
        spin_until(Duration::from_millis(100), async || {
490
491
492
493
            let s = indexer
                .find_matches_for_request(&tokens, None)
                .await
                .unwrap();
Yan Ru Pei's avatar
Yan Ru Pei committed
494
495
496
497
498
499
500
501
            s.scores
                .get(&WorkerWithDpRank::from_worker_id(worker_0))
                .copied()
                == Some(1)
                && s.scores
                    .get(&WorkerWithDpRank::from_worker_id(worker_1))
                    .copied()
                    == Some(1)
502
503
504
505
506
507
508
        })
        .await;

        // Remove one worker
        indexer.remove_worker(worker_0).await;

        // Confirm the removed worker is gone, and the other remains.
509
        spin_until(Duration::from_millis(100), async || {
510
511
512
513
            let s = indexer
                .find_matches_for_request(&tokens, None)
                .await
                .unwrap();
Yan Ru Pei's avatar
Yan Ru Pei committed
514
515
516
517
518
519
            !s.scores
                .contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
                && s.scores
                    .get(&WorkerWithDpRank::from_worker_id(worker_1))
                    .copied()
                    == Some(1)
520
521
522
523
524
525
526
527
528
529
        })
        .await;
    }

    /// Two sequences with a shared prefix should yield overlap scores reflecting the common blocks.
    #[tokio::test]
    async fn test_common_prefix_overlap() {
        const TTL: Duration = Duration::from_secs(5);

        let cancel = CancellationToken::new();
530
531
532
533
534
535
536
537
538
539
540
541
542
        let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
        let prune_config = PruneConfig {
            ttl: TTL,
            max_tree_size: usize::MAX,
            prune_target_ratio: 0.5,
        };
        let indexer = KvIndexer::new_with_frequency(
            cancel.clone(),
            None,
            KV_BLOCK_SIZE,
            metrics,
            Some(prune_config),
        );
543
544
545
546
547
548

        // Sequence A : single block
        let seq_a: Vec<u32> = vec![1, 2, 3, 4];
        let worker_a: WorkerId = 11;

        // Register Sequence A on worker A
549
        let mut tokens_with_hashes = TokensWithHashes::new(seq_a.clone(), KV_BLOCK_SIZE);
550
        indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
551
            .process_routing_decision_for_request(
552
                &mut tokens_with_hashes,
Yan Ru Pei's avatar
Yan Ru Pei committed
553
554
                WorkerWithDpRank::from_worker_id(worker_a),
            )
555
556
557
558
            .await
            .unwrap();

        // Ensure the indexer has registered the block
559
        spin_until(Duration::from_millis(100), async || {
560
561
562
563
            let s = indexer
                .find_matches_for_request(&seq_a, None)
                .await
                .unwrap();
Yan Ru Pei's avatar
Yan Ru Pei committed
564
565
566
567
            s.scores
                .get(&WorkerWithDpRank::from_worker_id(worker_a))
                .copied()
                == Some(1)
568
569
570
571
572
573
574
        })
        .await;

        // Sequence B : shares the first block with Sequence A, plus an extra block
        let seq_b: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];

        // Query the indexer for overlaps of Sequence B (before it has been routed anywhere)
575
576
577
578
        let overlap = indexer
            .find_matches_for_request(&seq_b, None)
            .await
            .unwrap();
579
580

        // Expect worker A to have an overlap score of 1 (shared first block)
Yan Ru Pei's avatar
Yan Ru Pei committed
581
582
583
584
585
586
        assert_eq!(
            overlap
                .scores
                .get(&WorkerWithDpRank::from_worker_id(worker_a)),
            Some(&1)
        );
587
588
589
590
591
592
593
594
    }

    /// When the same block resides on multiple workers, all should appear in the overlap scores.
    #[tokio::test]
    async fn test_multiple_workers_same_block() {
        const TTL: Duration = Duration::from_secs(5);

        let cancel = CancellationToken::new();
595
596
597
598
599
600
601
602
603
604
605
606
607
        let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
        let prune_config = PruneConfig {
            ttl: TTL,
            max_tree_size: usize::MAX,
            prune_target_ratio: 0.5,
        };
        let indexer = KvIndexer::new_with_frequency(
            cancel.clone(),
            None,
            KV_BLOCK_SIZE,
            metrics,
            Some(prune_config),
        );
608
609
610
611
612
613

        let tokens: Vec<u32> = vec![9, 8, 7, 6];
        let worker_0: WorkerId = 21;
        let worker_1: WorkerId = 22;

        // Register the same sequence on two different workers
614
        let mut tokens_with_hashes = TokensWithHashes::new(tokens.clone(), KV_BLOCK_SIZE);
615
        indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
616
            .process_routing_decision_for_request(
617
                &mut tokens_with_hashes,
Yan Ru Pei's avatar
Yan Ru Pei committed
618
619
                WorkerWithDpRank::from_worker_id(worker_0),
            )
620
621
            .await
            .unwrap();
622
        let mut tokens_with_hashes = TokensWithHashes::new(tokens.clone(), KV_BLOCK_SIZE);
623
        indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
624
            .process_routing_decision_for_request(
625
                &mut tokens_with_hashes,
Yan Ru Pei's avatar
Yan Ru Pei committed
626
627
                WorkerWithDpRank::from_worker_id(worker_1),
            )
628
629
630
631
            .await
            .unwrap();

        // Wait until both workers are reflected in overlap scores
632
        spin_until(Duration::from_millis(100), async || {
633
634
635
636
            let s = indexer
                .find_matches_for_request(&tokens, None)
                .await
                .unwrap();
Yan Ru Pei's avatar
Yan Ru Pei committed
637
638
639
640
641
642
643
644
            s.scores
                .get(&WorkerWithDpRank::from_worker_id(worker_0))
                .copied()
                == Some(1)
                && s.scores
                    .get(&WorkerWithDpRank::from_worker_id(worker_1))
                    .copied()
                    == Some(1)
645
646
647
        })
        .await;

648
649
650
651
        let scores = indexer
            .find_matches_for_request(&tokens, None)
            .await
            .unwrap();
652

Yan Ru Pei's avatar
Yan Ru Pei committed
653
654
655
656
657
658
659
660
661
662
663
664
        assert_eq!(
            scores
                .scores
                .get(&WorkerWithDpRank::from_worker_id(worker_0)),
            Some(&1)
        );
        assert_eq!(
            scores
                .scores
                .get(&WorkerWithDpRank::from_worker_id(worker_1)),
            Some(&1)
        );
665
    }
666
667
668
669
670
671

    /// Test that pruning returns empty when tree size is within the max tree size.
    #[tokio::test]
    async fn test_prune_manager_no_prune_when_within_bounds() {
        const TTL: Duration = Duration::from_secs(10);
        let prune_config = PruneConfig {
672
            ttl: TTL,
673
674
675
676
            max_tree_size: 100,
            prune_target_ratio: 0.5,
        };

677
        let mut pm: PruneManager<u32> = PruneManager::new(50, prune_config);
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696

        // Insert 50 keys (well below max_tree_size of 100)
        pm.insert((0..50).collect());

        // Pruning should return empty vec when size is within bounds
        let pruned = pm.prune(50).unwrap();
        assert!(pruned.is_empty());

        // All keys should still be present
        for i in 0..50 {
            assert!(pm.get_expiry(&i).is_some());
        }
    }

    /// Test that pruning removes the oldest entries first.
    #[tokio::test]
    async fn test_prune_manager_prune_removes_oldest_first() {
        const TTL: Duration = Duration::from_secs(10);
        let prune_config = PruneConfig {
697
            ttl: TTL,
698
699
700
701
            max_tree_size: 10,
            prune_target_ratio: 0.5,
        };

702
        let mut pm: PruneManager<u32> = PruneManager::new(50, prune_config);
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730

        // Insert keys one at a time with delays to ensure different timestamps
        for i in 1..=15 {
            pm.insert(vec![i]);
            time::sleep(Duration::from_millis(1)).await;
        }

        // Total: 15 keys. Trigger pruning with current_size = 15
        let pruned = pm.prune(15).unwrap();

        // Should prune down to 5 (10 * 0.5), so 10 keys should be pruned (15 - 5)
        assert_eq!(pruned.len(), 10);

        // The oldest keys should be pruned first
        for i in 1..=10 {
            assert!(pruned.contains(&i));
        }

        // The newer keys should still be present
        for i in 11..=15 {
            assert!(pm.get_expiry(&i).is_some());
        }
    }

    /// Test that pruning fails gracefully when config is None.
    #[tokio::test]
    async fn test_prune_manager_prune_fails_without_config() {
        const TTL: Duration = Duration::from_secs(10);
731
732
733
734
735
736
737
738
        let prune_config = PruneConfig {
            ttl: TTL,
            max_tree_size: usize::MAX,
            prune_target_ratio: 0.5,
        };
        let mut pm: PruneManager<u32> = PruneManager::new(50, prune_config);
        // Temporarily set prune_config to None to test the error case
        pm.prune_config = None;
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767

        pm.insert(vec![1, 2, 3]);

        // Pruning should fail when prune_config is None
        let result = pm.prune(150);
        assert!(result.is_err());
        assert!(matches!(result, Err(KvRouterError::PruneFailed(_))));
    }

    /// Test that BlockEntry ordering prioritizes sequence position.
    #[test]
    fn test_block_entry_ordering() {
        let worker = WorkerWithDpRank::from_worker_id(0);

        let entry1 = BlockEntry {
            key: ExternalSequenceBlockHash(100),
            worker,
            seq_position: 0,
        };
        let entry2 = BlockEntry {
            key: ExternalSequenceBlockHash(50),
            worker,
            seq_position: 1,
        };

        // entry1 < entry2 because seq_position 0 < 1
        assert!(entry1 < entry2);
    }

768
    /// End-to-end test for [`KvIndexer`] with TTL and pruning
769
770
771
772
773
774
775
776
777
778
    ///   0. Max tree size is 5, target size is 2 (prune_target_ratio = 0.4)
    ///   1. Insert 5 blocks (at max_tree_size but not exceeding)
    ///   2. Verify all 5 blocks are present
    ///   3. Insert 6th block (exceeds threshold, triggers reactive pruning)
    ///   4. Verify pruning occurred: 4 oldest blocks removed
    ///   5. Verify 2 newest blocks remain
    #[tokio::test]
    async fn test_approx_indexer_e2e_pruning() {
        const TTL: Duration = Duration::from_secs(60); // Long TTL to avoid expiry
        let prune_config = PruneConfig {
779
            ttl: TTL,
780
781
782
783
784
            max_tree_size: 5,        // Very small to trigger pruning quickly
            prune_target_ratio: 0.4, // target size is 5 * 0.4 = 2
        };

        let cancel = CancellationToken::new();
785
786
787
788
789
790
791
792
        let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
        let indexer = KvIndexer::new_with_frequency(
            cancel.clone(),
            None,
            KV_BLOCK_SIZE,
            metrics,
            Some(prune_config),
        );
793
794
795
796
797
798

        let worker = WorkerWithDpRank::from_worker_id(42);

        // Insert 5 sequences (5 blocks total, at max_tree_size but not exceeding)
        for i in 0..5 {
            let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
799
            let mut tokens_with_hashes = TokensWithHashes::new(tokens, KV_BLOCK_SIZE);
800
            indexer
801
                .process_routing_decision_for_request(&mut tokens_with_hashes, worker)
802
803
804
805
806
807
808
809
                .await
                .unwrap();
            time::sleep(Duration::from_millis(1)).await; // Ensure different timestamps
        }

        // Verify all 5 blocks are present (no pruning yet)
        for i in 0..5 {
            let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
810
811
812
813
            let scores = indexer
                .find_matches_for_request(&tokens, None)
                .await
                .unwrap();
814
815
816
817
818
819
820
821
822
823
            assert_eq!(
                scores.scores.get(&worker).copied(),
                Some(1),
                "Block {} should be present before threshold is exceeded",
                i
            );
        }

        // Insert 6th block - this exceeds max_tree_size and should trigger reactive pruning
        let tokens: Vec<u32> = vec![50, 51, 52, 53];
824
        let mut tokens_with_hashes = TokensWithHashes::new(tokens, KV_BLOCK_SIZE);
825
        indexer
826
            .process_routing_decision_for_request(&mut tokens_with_hashes, worker)
827
828
829
830
831
832
833
834
835
836
837
838
            .await
            .unwrap();

        // Wait for pruning to complete
        time::sleep(Duration::from_millis(100)).await;

        // After pruning, we will have exactly 2 blocks (5 * 0.4 = 2)
        // The 2 newest blocks (i=4, i=5) will remain, oldest 4 blocks (i=0,1,2,3) will be pruned

        // Verify that the 4 oldest blocks are pruned
        for i in 0..4 {
            let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
839
840
841
842
            let scores = indexer
                .find_matches_for_request(&tokens, None)
                .await
                .unwrap();
843
844
845
846
847
848
849
850
851
852
            assert!(
                scores.scores.get(&worker).copied().unwrap_or(0) == 0,
                "Block {} should have been pruned but is still present",
                i
            );
        }

        // Verify the 2 newest blocks are present
        for i in 4..6 {
            let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
853
854
855
856
            let scores = indexer
                .find_matches_for_request(&tokens, None)
                .await
                .unwrap();
857
858
859
860
861
862
863
864
865
866
867
868
869
870
            assert_eq!(
                scores.scores.get(&worker).copied(),
                Some(1),
                "Block {} should have been present but was pruned",
                i
            );
        }
    }

    /// Test that re-inserting a key updates its position in the pruning queue.
    #[tokio::test]
    async fn test_prune_manager_prune_reinsertion_updates_position() {
        const TTL: Duration = Duration::from_secs(10);
        let prune_config = PruneConfig {
871
            ttl: TTL,
872
873
874
875
            max_tree_size: 5,
            prune_target_ratio: 0.8,
        };

876
        let mut pm: PruneManager<u32> = PruneManager::new(50, prune_config);
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904

        // Insert keys
        for i in 1..=10 {
            pm.insert(vec![i]);
            time::sleep(Duration::from_millis(1)).await;
        }

        // Re-insert key 1 (should move it to the back of the queue)
        pm.insert(vec![1]);

        // Total: 10 unique keys. Trigger pruning: current_size = 10, target = 4, so prune 6 keys
        // Order by expiry (oldest first): 2, 3, 4, 5, 6, 7, 8, 9, 10, 1 (re-inserted)
        let pruned = pm.prune(10).unwrap();
        assert_eq!(pruned.len(), 6);

        // The oldest keys (2-7) should be pruned
        for i in 2..=7 {
            assert!(pruned.contains(&i));
        }

        // The newest keys (8-10) should still be present
        for i in 8..=10 {
            assert!(pm.get_expiry(&i).is_some());
        }

        // Key 1 should still be present (it was refreshed and is now near the end)
        assert!(pm.get_expiry(&1).is_some());
    }
905
}