approx.rs 30.8 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
//! Pruning and TTL utilities for KV Indexers
5
//!
6
7
8
//! This module provides utilities for managing TTL-based expiration and size-based pruning
//! of blocks in the radix tree. These utilities are used by the KvIndexer to manage
//! memory usage and keep the cache fresh.
9
10
11
12
13
14

use std::cmp::Reverse;
use std::collections::{BinaryHeap, HashMap};
use std::hash::Hash;
use tokio::time::{Duration, Instant};

15
16
use crate::indexer::KvRouterError;
use crate::protocols::{ExternalSequenceBlockHash, WorkerWithDpRank};
17

18
19
/// Block entry to be inserted in the [`PruneManager::expirations`] heap.
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
20
pub struct BlockEntry {
21
    /// The key of the block entry.
22
    pub key: ExternalSequenceBlockHash,
Yan Ru Pei's avatar
Yan Ru Pei committed
23
    /// The worker (with dp_rank) that stored this block.
24
    pub worker: WorkerWithDpRank,
25
    /// The position of this block in the sequence (0-indexed).
26
    pub seq_position: usize,
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
}

impl PartialOrd for BlockEntry {
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
        Some(self.cmp(other))
    }
}

impl Ord for BlockEntry {
    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
        // Break ties by sequence position (important for pruning), then by key, then by worker.
        self.seq_position
            .cmp(&other.seq_position)
            .then_with(|| self.key.cmp(&other.key))
            .then_with(|| self.worker.cmp(&other.worker))
    }
}

#[derive(Debug, Clone)]
pub struct PruneConfig {
47
48
    /// Time-to-live duration for blocks before they expire.
    pub ttl: Duration,
49
50
51
52
53
54
    /// The maximum tree size before pruning is considered.
    pub max_tree_size: usize,
    /// The target size ratio to prune down to when max_tree_size is exceeded.
    /// For example, if max_tree_size is 100 and target_size_ratio is 0.5,
    /// we will prune down to 50 nodes when max_tree_size is exceeded.
    pub prune_target_ratio: f64,
55
56
}

57
58
59
60
61
62
63
64
65
66
impl Default for PruneConfig {
    fn default() -> Self {
        Self {
            ttl: Duration::from_secs(120), // 120 seconds
            max_tree_size: 2usize.pow(20), // 2^20 = 1048576
            prune_target_ratio: 0.8,       // Prune down to 80% of max
        }
    }
}

67
68
69
/// A data structure to manage a collection of timers, addressable by a key.
/// This is structured as a sort of "priority queue" of keys, where the priority is the expiration time.
/// It supports insertion as well as updating the expiration time of a key.
70
/// The [`PruneManager::expirations`] heap is lazily updated to reflect the true expiration times in [`PruneManager::timers`]
71
72
/// For now, we have a fixed expiration time for all keys.
#[derive(Debug)]
73
pub struct PruneManager<K: Clone + Hash + Eq + Ord> {
74
75
    /// The source of truth. Maps a key to its current expiration instant.
    timers: HashMap<K, Instant>,
76

77
78
79
80
    /// A max-heap of (Reverse<expiration_instant>, key) used to efficiently find the
    /// next expiring timer. Reverse<Instant> makes earlier times pop first.
    /// An entry in this heap is "stale" if the instant does not match the one in the `timers` map.
    expirations: BinaryHeap<(Reverse<Instant>, K)>,
81
82
83
84

    /// Threshold for rebuilding the heap.
    /// The heap will be rebuilt from scratch to remove stale entries.
    threshold: usize,
85
86
87
88
89

    /// The expiration duration of the timers.
    ttl: Duration,

    /// The configuration for tree-size pruning.
90
    pub prune_config: Option<PruneConfig>,
91
92
}

93
94
impl<K: Clone + Hash + Eq + Ord> PruneManager<K> {
    /// Creates a new, empty PruneManager.
95
96
    pub fn new(threshold: usize, prune_config: PruneConfig) -> Self {
        let ttl = prune_config.ttl;
97
        PruneManager {
98
99
100
            timers: HashMap::new(),
            expirations: BinaryHeap::new(),
            ttl,
101
            threshold,
102
            prune_config: Some(prune_config),
103
104
105
        }
    }

106
107
108
109
110
    /// Rebuilds the expirations heap from the timers map, removing all stale entries.
    fn rebuild_heap(&mut self) {
        self.expirations = self
            .timers
            .iter()
111
            .map(|(key, &expiry)| (Reverse(expiry), key.clone()))
112
113
114
            .collect();
    }

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    /// Inserts a new timer or updates an existing one for the given key.
    ///
    /// # Arguments
    /// * `key` - The unique key for the timer.
    /// * `duration` - The duration from now when the timer should expire.
    pub fn insert(&mut self, keys: Vec<K>) {
        let expiry_time = Instant::now() + self.ttl;

        for key in keys {
            // Insert or update the authoritative time in the map.
            self.timers.insert(key.clone(), expiry_time);

            // Push the new expiration onto the heap. If the key was updated,
            // this leaves a "stale" entry on the heap for the old time,
            // which will be ignored when it's popped.
130
            self.expirations.push((Reverse(expiry_time), key));
131
        }
132
133
134
135
136

        // Check if we should rebuild the heap to remove stale entries
        if self.expirations.len() > self.timers.len() * self.threshold {
            self.rebuild_heap();
        }
137
138
139
140
141
142
143
144
    }

    /// Polls for expired timers and returns a list of keys for all timers
    /// that have expired up to the current moment.
    pub fn pop_expired(&mut self) -> Vec<K> {
        let mut expired_keys = Vec::new();
        let now = Instant::now();

145
        while let Some((Reverse(expiry_time), _)) = self.expirations.peek() {
146
147
148
149
150
151
            // If the next timer in the heap is not yet expired, we can stop.
            if *expiry_time > now {
                break;
            }

            // The timer might be expired, so pop it from the heap.
152
            let (Reverse(expiry_time), key) = self.expirations.pop().unwrap();
153

154
155
156
157
            if self.timers.get(&key) == Some(&expiry_time) {
                // This is a valid, non-stale, expired timer.
                self.timers.remove(&key);
                expired_keys.push(key);
158
159
160
161
162
163
164
165
166
167
            }
        }

        expired_keys
    }

    /// Returns the next expiry time, if it exists.
    pub fn peek_next_expiry(&self) -> Option<Instant> {
        self.expirations
            .peek()
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
            .map(|(Reverse(expiry_time), _)| *expiry_time)
    }

    /// Prunes the tree if the current size is greater than the max tree size.
    pub fn prune(&mut self, current_size: usize) -> Result<Vec<K>, KvRouterError> {
        let max_tree_size: usize;
        let prune_target_ratio: f64;

        if let Some(prune_config) = &self.prune_config {
            max_tree_size = prune_config.max_tree_size;
            prune_target_ratio = prune_config.prune_target_ratio;
        } else {
            tracing::error!("Prune was called but prune config is None. This should never happen");
            return Err(KvRouterError::PruneFailed(
                "prune config is missing".to_string(),
            ));
        }

        if current_size <= max_tree_size {
            // Tree size within bounds, no pruning needed.
            return Ok(Vec::new());
        }

        tracing::info!(
            "Pruning: tree size ({}) exceeded max tree size ({}), starting pruning",
            current_size,
            max_tree_size
        );

        // Number of blocks that will be kept after pruning.
        let target_size = (max_tree_size as f64 * prune_target_ratio) as usize;

        let mut pruned_keys = Vec::new();
        let mut num_pruned = 0;

        while num_pruned < current_size.saturating_sub(target_size) {
            if let Some((Reverse(expiry_time), key)) = self.expirations.pop() {
                if self.timers.get(&key) == Some(&expiry_time) {
                    // This is a valid, non-stale timer.
                    self.timers.remove(&key);
                    pruned_keys.push(key);
                    num_pruned += 1;
                }
            } else {
                break;
            }
        }

        tracing::info!("Pruning: pruned ({}) blocks from tree", num_pruned);

        Ok(pruned_keys)
219
220
221
222
223
224
    }
}

#[cfg(test)]
mod tests {
    use super::*;
225
226
    use crate::indexer::{KvIndexer, KvIndexerInterface, KvIndexerMetrics};
    use crate::protocols::{TokensWithHashes, WorkerId, WorkerWithDpRank};
227
    use std::sync::Arc;
228
229
230
    use tokio::time::{self, Duration, Instant};
    use tokio_util::sync::CancellationToken;

jthomson04's avatar
jthomson04 committed
231
232
    const KV_BLOCK_SIZE: u32 = 4;

233
    impl<T: Clone + Hash + Eq + Ord> PruneManager<T> {
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
        pub fn get_expiry(&self, key: &T) -> Option<&Instant> {
            self.timers.get(key)
        }
    }

    /// Helper to spin until a future evaluates to `true`, or a timeout is reached.
    async fn spin_until<F, Fut>(timeout: Duration, mut predicate: F)
    where
        F: FnMut() -> Fut,
        Fut: std::future::Future<Output = bool>,
    {
        let start = Instant::now();
        const POLL: Duration = Duration::from_millis(1);
        loop {
            if predicate().await {
                return;
            }
            if Instant::now().duration_since(start) >= timeout {
                panic!("timeout waiting for condition");
            }
            time::sleep(POLL).await;
        }
    }

258
    /// Validate basic insert / expiry behaviour of [`PruneManager`].
259
    #[tokio::test]
260
    async fn test_prune_manager_expiry() {
261
        const TTL: Duration = Duration::from_millis(50);
262
263
264
265
266
267
        let prune_config = PruneConfig {
            ttl: TTL,
            max_tree_size: usize::MAX, // Effectively disable size-based pruning
            prune_target_ratio: 0.5,
        };
        let mut pm: PruneManager<u32> = PruneManager::new(50, prune_config);
268

269
270
271
272
        pm.insert(vec![1, 2, 3]);
        assert!(pm.get_expiry(&1).is_some());
        assert!(pm.get_expiry(&2).is_some());
        assert!(pm.get_expiry(&3).is_some());
273
274
275

        // Wait until after the TTL
        time::sleep(TTL + Duration::from_millis(20)).await;
276
        let expired = pm.pop_expired();
277
        assert_eq!(expired.len(), 3);
278
279
280
        assert!(pm.get_expiry(&1).is_none());
        assert!(pm.get_expiry(&2).is_none());
        assert!(pm.get_expiry(&3).is_none());
281
282
283
284
    }

    /// Validate that reinserting an existing key extends its TTL and prevents premature expiry.
    #[tokio::test]
285
    async fn test_prune_manager_update_resets_ttl() {
286
287
        // Validate that reinserting an existing key extends its TTL and prevents premature expiry.
        const TTL: Duration = Duration::from_millis(50);
288
289
290
291
292
293
        let prune_config = PruneConfig {
            ttl: TTL,
            max_tree_size: usize::MAX,
            prune_target_ratio: 0.5,
        };
        let mut pm: PruneManager<u32> = PruneManager::new(50, prune_config);
294
295

        // Initial insert and capture the original expiry.
296
297
        pm.insert(vec![42]);
        let first_expiry = *pm
298
299
300
301
302
            .get_expiry(&42)
            .expect("expiry missing after first insert");

        // Wait for half of the original TTL before reinserting.
        time::sleep(Duration::from_millis(25)).await;
303
304
        pm.insert(vec![42]);
        let second_expiry = *pm
305
306
307
308
309
310
311
312
            .get_expiry(&42)
            .expect("expiry missing after reinsertion");

        // The expiry after reinsertion must be strictly later than the first one.
        assert!(second_expiry > first_expiry);

        // Wait until *after* the first expiry would have fired, but *before* the new expiry.
        time::sleep(Duration::from_millis(30)).await; // 25ms already elapsed, +30ms = 55ms > first TTL
313
        let expired = pm.pop_expired();
314
315
316
317
318
319
320
        assert!(
            expired.is_empty(),
            "key expired prematurely despite TTL refresh"
        );

        // Now wait until after the second expiry should have occurred.
        time::sleep(Duration::from_millis(30)).await; // Ensure we pass the refreshed TTL
321
        let expired_after = pm.pop_expired();
322
323
324
        assert_eq!(expired_after, vec![42]);
    }

325
    /// End-to-end test for [`KvIndexer`] with TTL:
326
327
328
329
330
331
332
    ///   1. No matches before routing decision
    ///   2. Matches appear after `process_routing_decision`
    ///   3. Matches disappear after TTL expiry
    #[tokio::test]
    async fn test_approx_kv_indexer_basic_flow() {
        const TTL: Duration = Duration::from_millis(200);
        let cancel = CancellationToken::new();
333
334
335
336
337
338
339
340
341
342
343
344
345
        let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
        let prune_config = PruneConfig {
            ttl: TTL,
            max_tree_size: usize::MAX,
            prune_target_ratio: 0.5,
        };
        let indexer = KvIndexer::new_with_frequency(
            cancel.clone(),
            None,
            KV_BLOCK_SIZE,
            metrics,
            Some(prune_config),
        );
346
347
348
349
350
351
352
353
354
355
356
357

        let tokens: Vec<u32> = vec![1, 2, 3, 4]; // Exactly one KV block
        let worker_id: WorkerId = 0;

        // 1. Before routing decision there should be no matches
        let pre_scores = indexer
            .find_matches_for_request(&tokens)
            .await
            .expect("indexer offline");
        assert!(pre_scores.scores.is_empty());

        // 2. Inform indexer about routing decision
358
        let mut tokens_with_hashes = TokensWithHashes::new(tokens.clone(), KV_BLOCK_SIZE);
359
        indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
360
            .process_routing_decision_for_request(
361
                &mut tokens_with_hashes,
Yan Ru Pei's avatar
Yan Ru Pei committed
362
363
                WorkerWithDpRank::from_worker_id(worker_id),
            )
364
365
366
367
368
369
            .await
            .unwrap();

        // Poll until we observe the match being registered
        spin_until(Duration::from_millis(100), || async {
            let s = indexer.find_matches_for_request(&tokens).await.unwrap();
Yan Ru Pei's avatar
Yan Ru Pei committed
370
371
372
373
            s.scores
                .get(&WorkerWithDpRank::from_worker_id(worker_id))
                .copied()
                == Some(1)
374
375
376
377
378
379
380
381
382
383
384
385
386
387
        })
        .await;

        // 3. After the TTL has passed the entry should expire automatically
        time::sleep(TTL + Duration::from_millis(50)).await;
        let post_scores = indexer.find_matches_for_request(&tokens).await.unwrap();
        assert!(post_scores.scores.is_empty());
    }

    /// Verify that `remove_worker` clears all entries for the specified worker.
    #[tokio::test]
    async fn test_remove_worker() {
        const TTL: Duration = Duration::from_secs(5); // Large enough to avoid expiry during test
        let cancel = CancellationToken::new();
388
389
390
391
392
393
394
395
396
397
398
399
400
        let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
        let prune_config = PruneConfig {
            ttl: TTL,
            max_tree_size: usize::MAX,
            prune_target_ratio: 0.5,
        };
        let mut indexer = KvIndexer::new_with_frequency(
            cancel.clone(),
            None,
            KV_BLOCK_SIZE,
            metrics,
            Some(prune_config),
        );
401
402
403
404

        let tokens: Vec<u32> = vec![10, 11, 12, 13];
        let worker_id: WorkerId = 7;

405
        let mut tokens_with_hashes = TokensWithHashes::new(tokens.clone(), KV_BLOCK_SIZE);
406
        indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
407
            .process_routing_decision_for_request(
408
                &mut tokens_with_hashes,
Yan Ru Pei's avatar
Yan Ru Pei committed
409
410
                WorkerWithDpRank::from_worker_id(worker_id),
            )
411
412
413
414
415
416
            .await
            .unwrap();

        // Wait until the worker is registered
        spin_until(Duration::from_millis(100), || async {
            let s = indexer.find_matches_for_request(&tokens).await.unwrap();
Yan Ru Pei's avatar
Yan Ru Pei committed
417
418
            s.scores
                .contains_key(&WorkerWithDpRank::from_worker_id(worker_id))
419
420
421
422
423
424
425
426
427
        })
        .await;

        // Remove the worker
        indexer.remove_worker(worker_id).await;

        // Ensure the worker's entries are gone
        spin_until(Duration::from_millis(100), || async {
            let s = indexer.find_matches_for_request(&tokens).await.unwrap();
Yan Ru Pei's avatar
Yan Ru Pei committed
428
429
            !s.scores
                .contains_key(&WorkerWithDpRank::from_worker_id(worker_id))
430
431
432
433
434
435
436
437
438
439
        })
        .await;
    }

    /// After removing one of multiple workers that share the same block, the remaining worker's entries should persist.
    #[tokio::test]
    async fn test_remove_worker_preserves_other_workers() {
        const TTL: Duration = Duration::from_secs(5); // Large enough to avoid expiry during test

        let cancel = CancellationToken::new();
440
441
442
443
444
445
446
447
448
449
450
451
452
        let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
        let prune_config = PruneConfig {
            ttl: TTL,
            max_tree_size: usize::MAX,
            prune_target_ratio: 0.5,
        };
        let mut indexer = KvIndexer::new_with_frequency(
            cancel.clone(),
            None,
            KV_BLOCK_SIZE,
            metrics,
            Some(prune_config),
        );
453
454
455
456
457
458

        let tokens: Vec<u32> = vec![100, 101, 102, 103];
        let worker_0: WorkerId = 30;
        let worker_1: WorkerId = 31;

        // Register on both workers
459
        let mut tokens_with_hashes = TokensWithHashes::new(tokens.clone(), KV_BLOCK_SIZE);
460
        indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
461
            .process_routing_decision_for_request(
462
                &mut tokens_with_hashes,
Yan Ru Pei's avatar
Yan Ru Pei committed
463
464
                WorkerWithDpRank::from_worker_id(worker_0),
            )
465
466
            .await
            .unwrap();
467
        let mut tokens_with_hashes = TokensWithHashes::new(tokens.clone(), KV_BLOCK_SIZE);
468
        indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
469
            .process_routing_decision_for_request(
470
                &mut tokens_with_hashes,
Yan Ru Pei's avatar
Yan Ru Pei committed
471
472
                WorkerWithDpRank::from_worker_id(worker_1),
            )
473
474
475
476
477
478
            .await
            .unwrap();

        // Ensure both workers are registered
        spin_until(Duration::from_millis(100), || async {
            let s = indexer.find_matches_for_request(&tokens).await.unwrap();
Yan Ru Pei's avatar
Yan Ru Pei committed
479
480
481
482
483
484
485
486
            s.scores
                .get(&WorkerWithDpRank::from_worker_id(worker_0))
                .copied()
                == Some(1)
                && s.scores
                    .get(&WorkerWithDpRank::from_worker_id(worker_1))
                    .copied()
                    == Some(1)
487
488
489
490
491
492
493
494
495
        })
        .await;

        // Remove one worker
        indexer.remove_worker(worker_0).await;

        // Confirm the removed worker is gone, and the other remains.
        spin_until(Duration::from_millis(100), || async {
            let s = indexer.find_matches_for_request(&tokens).await.unwrap();
Yan Ru Pei's avatar
Yan Ru Pei committed
496
497
498
499
500
501
            !s.scores
                .contains_key(&WorkerWithDpRank::from_worker_id(worker_0))
                && s.scores
                    .get(&WorkerWithDpRank::from_worker_id(worker_1))
                    .copied()
                    == Some(1)
502
503
504
505
506
507
508
509
510
511
        })
        .await;
    }

    /// Two sequences with a shared prefix should yield overlap scores reflecting the common blocks.
    #[tokio::test]
    async fn test_common_prefix_overlap() {
        const TTL: Duration = Duration::from_secs(5);

        let cancel = CancellationToken::new();
512
513
514
515
516
517
518
519
520
521
522
523
524
        let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
        let prune_config = PruneConfig {
            ttl: TTL,
            max_tree_size: usize::MAX,
            prune_target_ratio: 0.5,
        };
        let indexer = KvIndexer::new_with_frequency(
            cancel.clone(),
            None,
            KV_BLOCK_SIZE,
            metrics,
            Some(prune_config),
        );
525
526
527
528
529
530

        // Sequence A : single block
        let seq_a: Vec<u32> = vec![1, 2, 3, 4];
        let worker_a: WorkerId = 11;

        // Register Sequence A on worker A
531
        let mut tokens_with_hashes = TokensWithHashes::new(seq_a.clone(), KV_BLOCK_SIZE);
532
        indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
533
            .process_routing_decision_for_request(
534
                &mut tokens_with_hashes,
Yan Ru Pei's avatar
Yan Ru Pei committed
535
536
                WorkerWithDpRank::from_worker_id(worker_a),
            )
537
538
539
540
541
542
            .await
            .unwrap();

        // Ensure the indexer has registered the block
        spin_until(Duration::from_millis(100), || async {
            let s = indexer.find_matches_for_request(&seq_a).await.unwrap();
Yan Ru Pei's avatar
Yan Ru Pei committed
543
544
545
546
            s.scores
                .get(&WorkerWithDpRank::from_worker_id(worker_a))
                .copied()
                == Some(1)
547
548
549
550
551
552
553
554
555
556
        })
        .await;

        // Sequence B : shares the first block with Sequence A, plus an extra block
        let seq_b: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];

        // Query the indexer for overlaps of Sequence B (before it has been routed anywhere)
        let overlap = indexer.find_matches_for_request(&seq_b).await.unwrap();

        // Expect worker A to have an overlap score of 1 (shared first block)
Yan Ru Pei's avatar
Yan Ru Pei committed
557
558
559
560
561
562
        assert_eq!(
            overlap
                .scores
                .get(&WorkerWithDpRank::from_worker_id(worker_a)),
            Some(&1)
        );
563
564
565
566
567
568
569
570
    }

    /// When the same block resides on multiple workers, all should appear in the overlap scores.
    #[tokio::test]
    async fn test_multiple_workers_same_block() {
        const TTL: Duration = Duration::from_secs(5);

        let cancel = CancellationToken::new();
571
572
573
574
575
576
577
578
579
580
581
582
583
        let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
        let prune_config = PruneConfig {
            ttl: TTL,
            max_tree_size: usize::MAX,
            prune_target_ratio: 0.5,
        };
        let indexer = KvIndexer::new_with_frequency(
            cancel.clone(),
            None,
            KV_BLOCK_SIZE,
            metrics,
            Some(prune_config),
        );
584
585
586
587
588
589

        let tokens: Vec<u32> = vec![9, 8, 7, 6];
        let worker_0: WorkerId = 21;
        let worker_1: WorkerId = 22;

        // Register the same sequence on two different workers
590
        let mut tokens_with_hashes = TokensWithHashes::new(tokens.clone(), KV_BLOCK_SIZE);
591
        indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
592
            .process_routing_decision_for_request(
593
                &mut tokens_with_hashes,
Yan Ru Pei's avatar
Yan Ru Pei committed
594
595
                WorkerWithDpRank::from_worker_id(worker_0),
            )
596
597
            .await
            .unwrap();
598
        let mut tokens_with_hashes = TokensWithHashes::new(tokens.clone(), KV_BLOCK_SIZE);
599
        indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
600
            .process_routing_decision_for_request(
601
                &mut tokens_with_hashes,
Yan Ru Pei's avatar
Yan Ru Pei committed
602
603
                WorkerWithDpRank::from_worker_id(worker_1),
            )
604
605
606
607
608
609
            .await
            .unwrap();

        // Wait until both workers are reflected in overlap scores
        spin_until(Duration::from_millis(100), || async {
            let s = indexer.find_matches_for_request(&tokens).await.unwrap();
Yan Ru Pei's avatar
Yan Ru Pei committed
610
611
612
613
614
615
616
617
            s.scores
                .get(&WorkerWithDpRank::from_worker_id(worker_0))
                .copied()
                == Some(1)
                && s.scores
                    .get(&WorkerWithDpRank::from_worker_id(worker_1))
                    .copied()
                    == Some(1)
618
619
620
621
622
        })
        .await;

        let scores = indexer.find_matches_for_request(&tokens).await.unwrap();

Yan Ru Pei's avatar
Yan Ru Pei committed
623
624
625
626
627
628
629
630
631
632
633
634
        assert_eq!(
            scores
                .scores
                .get(&WorkerWithDpRank::from_worker_id(worker_0)),
            Some(&1)
        );
        assert_eq!(
            scores
                .scores
                .get(&WorkerWithDpRank::from_worker_id(worker_1)),
            Some(&1)
        );
635
    }
636
637
638
639
640
641

    /// Test that pruning returns empty when tree size is within the max tree size.
    #[tokio::test]
    async fn test_prune_manager_no_prune_when_within_bounds() {
        const TTL: Duration = Duration::from_secs(10);
        let prune_config = PruneConfig {
642
            ttl: TTL,
643
644
645
646
            max_tree_size: 100,
            prune_target_ratio: 0.5,
        };

647
        let mut pm: PruneManager<u32> = PruneManager::new(50, prune_config);
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666

        // Insert 50 keys (well below max_tree_size of 100)
        pm.insert((0..50).collect());

        // Pruning should return empty vec when size is within bounds
        let pruned = pm.prune(50).unwrap();
        assert!(pruned.is_empty());

        // All keys should still be present
        for i in 0..50 {
            assert!(pm.get_expiry(&i).is_some());
        }
    }

    /// Test that pruning removes the oldest entries first.
    #[tokio::test]
    async fn test_prune_manager_prune_removes_oldest_first() {
        const TTL: Duration = Duration::from_secs(10);
        let prune_config = PruneConfig {
667
            ttl: TTL,
668
669
670
671
            max_tree_size: 10,
            prune_target_ratio: 0.5,
        };

672
        let mut pm: PruneManager<u32> = PruneManager::new(50, prune_config);
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700

        // Insert keys one at a time with delays to ensure different timestamps
        for i in 1..=15 {
            pm.insert(vec![i]);
            time::sleep(Duration::from_millis(1)).await;
        }

        // Total: 15 keys. Trigger pruning with current_size = 15
        let pruned = pm.prune(15).unwrap();

        // Should prune down to 5 (10 * 0.5), so 10 keys should be pruned (15 - 5)
        assert_eq!(pruned.len(), 10);

        // The oldest keys should be pruned first
        for i in 1..=10 {
            assert!(pruned.contains(&i));
        }

        // The newer keys should still be present
        for i in 11..=15 {
            assert!(pm.get_expiry(&i).is_some());
        }
    }

    /// Test that pruning fails gracefully when config is None.
    #[tokio::test]
    async fn test_prune_manager_prune_fails_without_config() {
        const TTL: Duration = Duration::from_secs(10);
701
702
703
704
705
706
707
708
        let prune_config = PruneConfig {
            ttl: TTL,
            max_tree_size: usize::MAX,
            prune_target_ratio: 0.5,
        };
        let mut pm: PruneManager<u32> = PruneManager::new(50, prune_config);
        // Temporarily set prune_config to None to test the error case
        pm.prune_config = None;
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737

        pm.insert(vec![1, 2, 3]);

        // Pruning should fail when prune_config is None
        let result = pm.prune(150);
        assert!(result.is_err());
        assert!(matches!(result, Err(KvRouterError::PruneFailed(_))));
    }

    /// Test that BlockEntry ordering prioritizes sequence position.
    #[test]
    fn test_block_entry_ordering() {
        let worker = WorkerWithDpRank::from_worker_id(0);

        let entry1 = BlockEntry {
            key: ExternalSequenceBlockHash(100),
            worker,
            seq_position: 0,
        };
        let entry2 = BlockEntry {
            key: ExternalSequenceBlockHash(50),
            worker,
            seq_position: 1,
        };

        // entry1 < entry2 because seq_position 0 < 1
        assert!(entry1 < entry2);
    }

738
    /// End-to-end test for [`KvIndexer`] with TTL and pruning
739
740
741
742
743
744
745
746
747
748
    ///   0. Max tree size is 5, target size is 2 (prune_target_ratio = 0.4)
    ///   1. Insert 5 blocks (at max_tree_size but not exceeding)
    ///   2. Verify all 5 blocks are present
    ///   3. Insert 6th block (exceeds threshold, triggers reactive pruning)
    ///   4. Verify pruning occurred: 4 oldest blocks removed
    ///   5. Verify 2 newest blocks remain
    #[tokio::test]
    async fn test_approx_indexer_e2e_pruning() {
        const TTL: Duration = Duration::from_secs(60); // Long TTL to avoid expiry
        let prune_config = PruneConfig {
749
            ttl: TTL,
750
751
752
753
754
            max_tree_size: 5,        // Very small to trigger pruning quickly
            prune_target_ratio: 0.4, // target size is 5 * 0.4 = 2
        };

        let cancel = CancellationToken::new();
755
756
757
758
759
760
761
762
        let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
        let indexer = KvIndexer::new_with_frequency(
            cancel.clone(),
            None,
            KV_BLOCK_SIZE,
            metrics,
            Some(prune_config),
        );
763
764
765
766
767
768

        let worker = WorkerWithDpRank::from_worker_id(42);

        // Insert 5 sequences (5 blocks total, at max_tree_size but not exceeding)
        for i in 0..5 {
            let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
769
            let mut tokens_with_hashes = TokensWithHashes::new(tokens, KV_BLOCK_SIZE);
770
            indexer
771
                .process_routing_decision_for_request(&mut tokens_with_hashes, worker)
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
                .await
                .unwrap();
            time::sleep(Duration::from_millis(1)).await; // Ensure different timestamps
        }

        // Verify all 5 blocks are present (no pruning yet)
        for i in 0..5 {
            let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
            let scores = indexer.find_matches_for_request(&tokens).await.unwrap();
            assert_eq!(
                scores.scores.get(&worker).copied(),
                Some(1),
                "Block {} should be present before threshold is exceeded",
                i
            );
        }

        // Insert 6th block - this exceeds max_tree_size and should trigger reactive pruning
        let tokens: Vec<u32> = vec![50, 51, 52, 53];
791
        let mut tokens_with_hashes = TokensWithHashes::new(tokens, KV_BLOCK_SIZE);
792
        indexer
793
            .process_routing_decision_for_request(&mut tokens_with_hashes, worker)
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
            .await
            .unwrap();

        // Wait for pruning to complete
        time::sleep(Duration::from_millis(100)).await;

        // After pruning, we will have exactly 2 blocks (5 * 0.4 = 2)
        // The 2 newest blocks (i=4, i=5) will remain, oldest 4 blocks (i=0,1,2,3) will be pruned

        // Verify that the 4 oldest blocks are pruned
        for i in 0..4 {
            let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
            let scores = indexer.find_matches_for_request(&tokens).await.unwrap();
            assert!(
                scores.scores.get(&worker).copied().unwrap_or(0) == 0,
                "Block {} should have been pruned but is still present",
                i
            );
        }

        // Verify the 2 newest blocks are present
        for i in 4..6 {
            let tokens: Vec<u32> = vec![i * 10, i * 10 + 1, i * 10 + 2, i * 10 + 3];
            let scores = indexer.find_matches_for_request(&tokens).await.unwrap();
            assert_eq!(
                scores.scores.get(&worker).copied(),
                Some(1),
                "Block {} should have been present but was pruned",
                i
            );
        }
    }

    /// Test that re-inserting a key updates its position in the pruning queue.
    #[tokio::test]
    async fn test_prune_manager_prune_reinsertion_updates_position() {
        const TTL: Duration = Duration::from_secs(10);
        let prune_config = PruneConfig {
832
            ttl: TTL,
833
834
835
836
            max_tree_size: 5,
            prune_target_ratio: 0.8,
        };

837
        let mut pm: PruneManager<u32> = PruneManager::new(50, prune_config);
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865

        // Insert keys
        for i in 1..=10 {
            pm.insert(vec![i]);
            time::sleep(Duration::from_millis(1)).await;
        }

        // Re-insert key 1 (should move it to the back of the queue)
        pm.insert(vec![1]);

        // Total: 10 unique keys. Trigger pruning: current_size = 10, target = 4, so prune 6 keys
        // Order by expiry (oldest first): 2, 3, 4, 5, 6, 7, 8, 9, 10, 1 (re-inserted)
        let pruned = pm.prune(10).unwrap();
        assert_eq!(pruned.len(), 6);

        // The oldest keys (2-7) should be pruned
        for i in 2..=7 {
            assert!(pruned.contains(&i));
        }

        // The newest keys (8-10) should still be present
        for i in 8..=10 {
            assert!(pm.get_expiry(&i).is_some());
        }

        // Key 1 should still be present (it was refreshed and is now near the end)
        assert!(pm.get_expiry(&1).is_some());
    }
866
}