indexer.rs 45.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! KV RadixTree
//!
//! This module implements a key-value (KV) store using a Radix Tree structure to efficiently manage and retrieve data blocks.
//! It is designed to support LLM (Large Language Model) inference by re-using a global KV cache.
//!
//! # Overview
//!
//! The main components of this module include:
//!
//! - **Radix Tree Structure**:
//!   - The `RadixTree` struct represents the main data structure, with nodes (`RadixBlock`) containing children and associated worker IDs.
//!   - It allows efficient storage and retrieval of data blocks based on their hashes.
//!
//! - **Event Handling**:
//!   - The `RouterEvent` struct represents events emitted by LLM workers, which can be applied to the Radix Tree to update its state.
//!   - The `KvIndexer` struct manages these events and match requests asynchronously using Tokio channels.
//!
//! - **Hash Computation**:
//!   - Functions like `compute_block_hash` and `compute_block_hash_for_seq` compute hashes for data blocks and sequences of tokens, facilitating quick lookups.
//!
//! - **Concurrency and Asynchronous Operations**:
//!   - The `KvIndexer` uses a single-threaded Tokio runtime to handle events and match requests concurrently, ensuring efficient processing without blocking.
//!
//! - **Match Requests**:
//!   - The `MatchRequest` struct represents requests to find matches in the Radix Tree, returning overlap scores indicating the best matches.
//!
//! # Purpose
//!
//! This module provides a scalable and efficient way to manage and retrieve data blocks for LLM inference, leveraging a global KV cache to optimize performance.

use bytes::Bytes;
// use prometheus::{IntCounter, IntGauge};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::{
    cell::RefCell,
    collections::{HashMap, HashSet, VecDeque},
    iter,
    rc::Rc,
    sync::OnceLock,
    thread::JoinHandle,
    time::{Duration, Instant},
};
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio_util::sync::CancellationToken;
use tracing as log;
use xxhash_rust::xxh3;

pub const XXH3_SEED: u64 = 1337;

use crate::kv_router::protocols::*;

/// Errors that can occur in the KV Router.
#[derive(Debug, thiserror::Error)]
pub enum KvRouterError {
    #[error("Block not found")]
    BlockNotFound,

    #[error("Indexer is offline")]
    IndexerOffline,

    #[error("Indexer is dropped request")]
    IndexerDroppedRequest,
}

/// Identifier of a LLM worker which emits events to the router.
GuanLuo's avatar
GuanLuo committed
82
pub type WorkerId = i64;
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

/// A shared reference to a [`RadixBlock`].
type SharedRadixBlock = Rc<RefCell<RadixBlock>>;

/// Compute the hash of a local block.
///
/// ### Arguments
///
/// * `data` - A byte slice representing the data to hash.
///
/// ### Returns
///
/// A `LocalBlockHash` representing the computed hash.
pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash {
    LocalBlockHash(xxh3::xxh3_64_with_seed(data, XXH3_SEED))
}

// /// Updated version of the `compute_block_hash` function that included the lora_id
// pub fn compute_block_hash_v2(token_id: &[u32], lora_id: u64) {
//     let mut bytes = Vec::new();
//     for token in token_id {
//         bytes.extend_from_slice(&token.to_le_bytes());
//     }
//     bytes.extend_from_slice(&lora_id.to_le_bytes());
//     let hash = xxh3::xxh3_64_with_seed(&bytes, XXH3_SEED);
// }

/// Compute the hash for a sequence of tokens.
///
/// ### Arguments
///
/// * `tokens` - A vector of `u32` tokens.
///
/// ### Returns
///
/// A vector of `LocalBlockHash` representing the computed hashes for each chunk of tokens.
119
pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: usize) -> Vec<LocalBlockHash> {
120
    tokens
121
        .chunks_exact(kv_block_size) // Split into chunks of kv_block_size elements
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
        .map(|chunk| {
            let bytes: Vec<u8> = chunk
                .iter()
                .flat_map(|&num| num.to_le_bytes()) // Convert each i32 to its little-endian bytes
                .collect();

            compute_block_hash(&Bytes::from(bytes)) // Convert the byte Vec to Bytes
        })
        .collect()
}

/// A [`KvCacheEvent`] on a specific LLM worker denoted by [`WorkerId`].
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouterEvent {
    /// The ID of the worker emitting the event.
    worker_id: WorkerId,
    /// The cache event associated with the worker.
    event: KvCacheEvent,
}

impl RouterEvent {
    /// Create a new `RouterEvent`.
    ///
    /// ### Arguments
    ///
    /// * `worker_id` - The ID of the worker emitting the event.
    /// * `event` - The cache event.
    ///
    /// ### Returns
    ///
    /// A new `RouterEvent`.
    pub fn new(worker_id: WorkerId, event: KvCacheEvent) -> Self {
        Self { worker_id, event }
    }
}

/// A block in the Radix Tree.
struct RadixBlock {
    /// A map of child blocks, keyed by their local block hash.
    children: HashMap<LocalBlockHash, SharedRadixBlock>,
    /// A set of worker IDs associated with this block.
    workers: HashSet<WorkerId>,
    /// A buffer of times that this block was last traversed
    recent_uses: VecDeque<Instant>,
}

impl RadixBlock {
    /// Create a new `RadixBlock`.
    ///
    /// ### Returns
    ///
    /// A new `RadixBlock`.
    pub fn new() -> Self {
        Self {
            children: HashMap::new(),
            workers: HashSet::new(),
            recent_uses: VecDeque::new(),
        }
    }
}

pub struct RadixTree {
    /// This is the root of the radix/prefix tree
    /// This will only contain root blocks
    root: SharedRadixBlock,

    /// This is a global lookup table for all blocks which will let you jump into
    /// the radix tree at any point
    /// Lookup is best case O(1) and worst case O(N); however, even constant in-time
    /// could be expensive if N is large
    /// We should monitor the size of this table and consider using a proper radix tree.
    /// Transitioning to a radix tree only would require a change in the messaging structure
    /// as the entire prefix would need to be sent. Alternatively, we could use block_depth
    /// integers to indicate how many blocks to skip and use a radix/prefix tree at each level.
    lookup: HashMap<WorkerId, HashMap<ExternalSequenceBlockHash, SharedRadixBlock>>,
    /// The time buffer the radix tree should check when considering frequence of block accesses
    expiration_duration: Option<Duration>,
}

impl Default for RadixTree {
    fn default() -> Self {
        Self::new()
    }
}

impl RadixTree {
    /// Create a new `RadixTree`.
    ///
    /// ### Returns
    ///
    /// A new `RadixTree`.
    pub fn new_with_frequency(expiration_duration: Option<Duration>) -> Self {
        Self {
            root: Rc::new(RefCell::new(RadixBlock::new())),
            lookup: HashMap::new(),
            expiration_duration,
        }
    }

    pub fn new() -> Self {
        Self::new_with_frequency(None)
    }

    /// Traverse the radix tree to find the best match for a given sequence of [`LocalBlockHash`]es.
    ///
    /// ### Arguments
    ///
    /// * `sequence` - A vector of `LocalBlockHash` representing the sequence to match.
    /// * `early_exit` - A boolean indicating whether to exit early if a single match is found.
    ///
    /// ### Returns
    ///
    /// An `OverlapScores` representing the match scores.
    pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores {
        let mut scores = OverlapScores::new();
        let mut current = self.root.clone();
        let now = Instant::now();
        for block_hash in sequence {
            let next_block = {
                let current_borrow = current.borrow();
                current_borrow.children.get(&block_hash).cloned()
            };

            if let Some(block) = next_block {
                scores.update_scores(&block.borrow().workers);

                if let Some(expiration_duration) = self.expiration_duration {
                    let mut block_mut = block.borrow_mut();

                    while let Some(access_time) = block_mut.recent_uses.front() {
                        if now.duration_since(*access_time) > expiration_duration {
                            block_mut.recent_uses.pop_front();
                        } else {
                            break;
                        }
                    }
                    scores.add_frequency(block_mut.recent_uses.len());
                    block_mut.recent_uses.push_back(now);
                }

                if early_exit && block.borrow().workers.len() == 1 {
                    break;
                }

                current = block;
            } else {
                break;
            }
        }

        scores
    }

    /// Apply a [`RouterEvent`] to the radix tree.
    ///
    /// ### Arguments
    ///
    /// * `event` - The `RouterEvent` to apply.
    pub fn apply_event(&mut self, event: RouterEvent) {
        let (worker_id, event) = (event.worker_id, event.event);
        let (id, op) = (event.event_id, event.data);
        log::debug!(id, "Store operation: {:?}", op);

        let worker_lookup = self.lookup.entry(worker_id).or_default();

        match op {
            KvCacheEventData::Stored(op) => {
                // find the parent block - if the parent exists it must be on our worker, if not,
                // we check the radix tree's root to find it.
                // this is the single most expensive lookup
                let current = match op.parent_hash {
                    Some(parent) => worker_lookup.get(&parent),
                    None => Some(&self.root),
                };

                let mut current = match current {
                    Some(current) => current.clone(),
                    None => {
                        log::warn!(
                            worker_id = worker_id.to_string(),
                            id,
                            parent_hash = ?op.parent_hash,
                            "Failed to find parent block; skipping store operation"
                        );
                        return;
                    }
                };

                for block_id in op.blocks {
                    let mut inner = current.borrow_mut();
                    let block = match inner.children.get(&block_id.tokens_hash) {
                        Some(block) => block.clone(),
                        None => {
                            // create new block - automatically added to the lookup table
                            let new_block = worker_lookup
                                .get(&block_id.block_hash)
                                .cloned()
                                .unwrap_or_else(|| Rc::new(RefCell::new(RadixBlock::new())));

                            // insert into radix tree
                            inner
                                .children
                                .insert(block_id.tokens_hash, new_block.clone());

                            new_block
                        }
                    };

                    // add our worker_id to the block
                    block.borrow_mut().workers.insert(worker_id);

                    // add the block to the worker_id lookup table
                    worker_lookup.insert(block_id.block_hash, block.clone());

                    // drop inner so we can shift current to this block
                    drop(inner);

                    current = block;
                }
            }
            KvCacheEventData::Removed(remove) => {
                // log::trace!(id, "KV Remove Operation: {:?}", op);
                // let mut worker_lookup = self.lookup.get(&worker_id).expect("Worker not found");

                for block in remove.block_hashes {
                    // entry in radix tree
                    // a small optimization would be to get the next block from the reduced set of children
                    // in order to apply this optimization, we would need to know the list of blocks is always sorted
                    // by parent -> child relationship
                    let entry = match worker_lookup.get(&block) {
                        Some(entry) => entry.clone(),
                        None => {
                            log::warn!(
                                worker_id = worker_id.to_string(),
                                id,
                                "Failed to find block to remove; skipping remove operation"
                            );
                            continue;
                        }
                    };

                    let mut guard = entry.borrow_mut();
                    guard.workers.remove(&worker_id);
                    if guard.workers.is_empty() {
                        // if no worker are using this block, that is true for all children
                        guard.children.clear();
                    }
                    // remove the block from the lookup table
                    worker_lookup.remove(&block);
                }
            }
        }
    }

    pub fn remove_worker(&mut self, worker: WorkerId) {
        if let Some((_, blocks)) = self.lookup.remove_entry(&worker) {
            blocks.iter().for_each(|(_, block)| {
                block.borrow_mut().workers.remove(&worker);
            });
        }
    }
}

/// Scores representing the overlap of workers.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OverlapScores {
    // map of worker_id to score
    pub scores: HashMap<WorkerId, u32>,
    // List of frequencies that the blocks have been accessed. Entries with value 0 are omitted.
    pub frequencies: Vec<usize>,
}

impl Default for OverlapScores {
    fn default() -> Self {
        Self::new()
    }
}

impl OverlapScores {
    /// Create a new `OverlapScores`.
    ///
    /// ### Returns
    ///
    /// A new `OverlapScores`.
    pub fn new() -> Self {
        Self {
            scores: HashMap::new(),
            frequencies: Vec::with_capacity(32),
        }
    }

    /// Update the scores with a set of workers.
    ///
    /// ### Arguments
    ///
    /// * `workers` - A reference to a `HashSet` of `WorkerId`s.
    pub fn update_scores(&mut self, workers: &HashSet<WorkerId>) {
        for worker in workers {
            let score = self.scores.entry(*worker).or_insert(0);
            *score += 1;
        }
    }

    /// Add an entry in the frequency list.
    pub fn add_frequency(&mut self, frequency: usize) {
        if frequency != 0 {
            self.frequencies
                .last()
                .inspect(|elem| debug_assert!(**elem >= frequency));
            self.frequencies.push(frequency);
        }
    }
}

/// A request to find matches in the Radix Tree.
pub struct MatchRequest {
    /// A vector of `LocalBlockHash` representing the sequence to match.
    sequence: Vec<LocalBlockHash>,
    /// A boolean indicating whether to exit early if a single match is found.
    early_exit: bool,
    /// A channel sender to send the `OverlapScores` response.
    resp: oneshot::Sender<OverlapScores>,
}

#[async_trait]
pub trait KvIndexerInterface {
    /// Find matches for a given sequence of `LocalBlockHash`es.
    ///
    /// ### Arguments
    ///
    /// * `sequence` - A vector of `LocalBlockHash` representing the sequence to match.
    ///
    /// ### Returns
    ///
    /// An `OverlapScores` representing the match scores.
    async fn find_matches(
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError>;

    /// Find matches for a given sequence of tokens.
    ///
    /// ### Arguments
    ///
    /// * `tokens` - A vector of `u32` tokens.
    ///
    /// ### Returns
    ///
    /// An `OverlapScores` representing the match scores.
    async fn find_matches_for_request(
        &self,
        tokens: &[u32],
    ) -> Result<OverlapScores, KvRouterError>;

    /// Apply a `RouterEvent` to the KV store.
    ///
    /// ### Arguments
    ///
    /// * `event` - The `RouterEvent` to apply.
    async fn apply_event(&mut self, event: RouterEvent);

    /// Remove a worker's entries from the trie.
    ///
    /// ### Arguments
    ///
    /// * `worker` - The worker to remove from the trie.
    async fn remove_worker(&mut self, worker: WorkerId);

    /// Shutdown the KV Indexer.
    fn shutdown(&mut self);
}

/// The KV Indexer, managing the KV store and handling events and match requests.
pub struct KvIndexer {
    /// A `CancellationToken` for managing shutdown.
    cancel: CancellationToken,
    /// A sender for `RouterEvent`s.
    event_tx: mpsc::Sender<RouterEvent>,
    /// A sender for `MatchRequest`s.
    match_tx: mpsc::Sender<MatchRequest>,
    /// A sender for remove worker requests.
    remove_worker_tx: mpsc::Sender<WorkerId>,
    /// A handle to the background task managing the KV store.
    task: OnceLock<std::thread::JoinHandle<()>>,
506
507
    /// The size of the KV block this indexer can handle.
    kv_block_size: usize,
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
}

impl KvIndexer {
    /// Create a new `KvIndexer`.
    ///
    /// ### Arguments
    ///
    /// * `token` - A `CancellationToken` for managing shutdown.
    /// * `expiration_duration` - The amount of time that block usage should be buffered.
    ///
    /// ### Returns
    ///
    /// A new `KvIndexer`.
    pub fn new_with_frequency(
        token: CancellationToken,
        expiration_duration: Option<Duration>,
524
        kv_block_size: usize,
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
    ) -> Self {
        let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(2048);
        let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128);
        let (remove_worker_tx, remove_worker_rx) = mpsc::channel::<WorkerId>(16);
        let cancel_clone = token.clone();
        let task = std::thread::spawn(move || {
            // create a new tokio runtime which will only perform work on a single thread
            let runtime = tokio::runtime::Builder::new_multi_thread()
                .worker_threads(1) // Single-threaded environment
                .enable_all()
                .build()
                .unwrap();

            let local_set = tokio::task::LocalSet::new();

            runtime.block_on(local_set.run_until(async move {
                tokio::task::spawn_local(async move {
                    let cancel = cancel_clone;
                    let mut match_rx = match_rx;
                    let mut event_rx = event_rx;
                    let mut remove_worker_rx = remove_worker_rx;
                    let mut trie = RadixTree::new_with_frequency(expiration_duration);
                    loop {
                        tokio::select! {
                            biased;

                            Some(worker) = remove_worker_rx.recv() => {
                                trie.remove_worker(worker);
                            }

                            Some(req) = match_rx.recv() => {
                                let matches = trie.find_matches(req.sequence, req.early_exit);
                                let _ = req.resp.send(matches);
                            }

                            _ = cancel.cancelled() => {
                                log::debug!("KvCacheIndexer progress loop shutting down");
                                return;
                            }

                            Some(event) = event_rx.recv() => {
                                trie.apply_event(event);
                            }
                        }
                    }
                })
                .await
                .unwrap()
            }));

            log::debug!("KvCacheIndexer task completed");
        });

        let once = OnceLock::new();
        once.set(task).unwrap();

        Self {
            cancel: token,
            event_tx,
            match_tx,
            remove_worker_tx,
            task: once,
587
            kv_block_size,
588
589
590
        }
    }

591
592
    pub fn new(token: CancellationToken, kv_block_size: usize) -> Self {
        Self::new_with_frequency(token, None, kv_block_size)
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
    }

    /// Get a sender for `RouterEvent`s.
    ///
    /// ### Returns
    ///
    /// A `mpsc::Sender` for `RouterEvent`s.
    pub fn event_sender(&self) -> mpsc::Sender<RouterEvent> {
        self.event_tx.clone()
    }
}

#[async_trait]
impl KvIndexerInterface for KvIndexer {
    async fn find_matches(
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        let (resp_tx, resp_rx) = oneshot::channel();
        let req = MatchRequest {
            sequence,
            early_exit: false,
            resp: resp_tx,
        };

        if let Err(e) = self.match_tx.send(req).await {
            log::error!(
                "Failed to send match request: {:?}; the indexer maybe offline",
                e
            );
            return Err(KvRouterError::IndexerOffline);
        }

        resp_rx
            .await
            .map_err(|_| KvRouterError::IndexerDroppedRequest)
    }

    async fn find_matches_for_request(
        &self,
        tokens: &[u32],
    ) -> Result<OverlapScores, KvRouterError> {
        log::debug!(
            "Finding matches for request tokens: {:?} / len: {}",
            tokens,
            tokens.len()
        );
640
        let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size);
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
        log::debug!("Computed sequence: {:?}", sequence);
        self.find_matches(sequence).await
    }

    async fn apply_event(&mut self, event: RouterEvent) {
        self.event_tx.send(event).await.unwrap();
    }

    async fn remove_worker(&mut self, worker: WorkerId) {
        self.remove_worker_tx.send(worker).await.unwrap();
    }

    fn shutdown(&mut self) {
        self.cancel.cancel();
        if let Some(task) = self.task.take() {
            task.join().expect("Failed to join kv indexer task");
        }
    }
}

#[derive(Debug, Clone)]
pub struct ShardedMatchRequest {
    sequence: Vec<LocalBlockHash>,
    early_exit: bool,
    resp: mpsc::Sender<OverlapScores>,
}

/// The KV Indexer, managing the KV store and handling events and match requests.
pub struct KvIndexerSharded {
    /// A `CancellationToken` for managing shutdown.
    cancel: CancellationToken,
672
673
    /// The size of the KV block this indexer can handle.
    kv_block_size: usize,
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
    worker_assignments: HashMap<WorkerId, usize>,
    worker_counts: Vec<usize>,

    event_tx: Vec<mpsc::Sender<RouterEvent>>,
    request_broadcast_tx: broadcast::Sender<ShardedMatchRequest>,
    remove_worker_tx: Vec<mpsc::Sender<WorkerId>>,
    tasks: Vec<JoinHandle<()>>,
}

impl KvIndexerSharded {
    /// Create a new `KvIndexerSharded`.
    ///
    /// ### Arguments
    ///
    /// * `token` - A `CancellationToken` for managing shutdown.
    /// * `shards` - A list of kvindexer shards.
    /// * `expiration_duration` - The amount of time that block usage should be buffered.
    ///
    /// ### Returns
    ///
    /// A new `KvIndexer`.
    pub fn new_with_frequency(
        token: CancellationToken,
        num_shards: usize,
        expiration_duration: Option<Duration>,
699
        kv_block_size: usize,
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
    ) -> Self {
        let worker_assignments: HashMap<WorkerId, usize> = HashMap::new();
        let worker_counts: Vec<usize> = vec![0; num_shards];

        let mut event_tx = Vec::new();
        let mut remove_worker_tx = Vec::new();
        let mut tasks = Vec::new();

        let (request_broadcast_tx, _) = broadcast::channel::<ShardedMatchRequest>(1048576);

        for _ in 0..num_shards {
            let (shard_event_tx, mut shard_event_rx) = mpsc::channel::<RouterEvent>(2048);
            let (shard_remove_worker_tx, mut shard_remove_worker_rx) =
                mpsc::channel::<WorkerId>(16);
            let mut shard_broadcast_rx = request_broadcast_tx.subscribe();
            let cancel = token.clone();

            event_tx.push(shard_event_tx);
            remove_worker_tx.push(shard_remove_worker_tx);

            let runtime = tokio::runtime::Builder::new_multi_thread()
                .worker_threads(1)
                .enable_all()
                .build()
                .unwrap();

            tasks.push(std::thread::spawn(move || {
                let local_set = tokio::task::LocalSet::new();

                runtime.block_on(local_set.run_until(async move {
                    tokio::task::spawn_local(async move {
                        let mut trie = RadixTree::new_with_frequency(expiration_duration);
                        loop {
                            tokio::select! {
                                biased;

                                Some(worker) = shard_remove_worker_rx.recv() => {
                                    trie.remove_worker(worker);
                                }

                                Ok(req) = shard_broadcast_rx.recv() => {
                                    let matches = trie.find_matches(req.sequence, req.early_exit);
                                    if let Err(e) = req.resp.send(matches).await {
                                        log::trace!("Failed to send match response: {:?}", e);
                                    }
                                }

                                _ = cancel.cancelled() => {
                                    log::debug!("KvCacheIndexer progress loop shutting down");
                                    return;
                                }

                                Some(event) = shard_event_rx.recv() => {
                                    trie.apply_event(event);
                                }
                            }
                        }
                    })
                    .await
                    .unwrap()
                }));

                log::debug!("KvCacheIndexer task completed");
            }));
        }

        Self {
            cancel: token,
768
            kv_block_size,
769
770
771
772
773
774
775
776
777
            worker_assignments,
            worker_counts,
            event_tx,
            request_broadcast_tx,
            remove_worker_tx,
            tasks,
        }
    }

778
779
    pub fn new(token: CancellationToken, num_shards: usize, kv_block_size: usize) -> Self {
        Self::new_with_frequency(token, num_shards, None, kv_block_size)
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
    }
}

#[async_trait]
impl KvIndexerInterface for KvIndexerSharded {
    async fn find_matches(
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        'match_loop: loop {
            let (match_tx, mut match_rx) = mpsc::channel(self.event_tx.len());
            self.request_broadcast_tx
                .send(ShardedMatchRequest {
                    sequence: sequence.clone(),
                    early_exit: false,
                    resp: match_tx,
                })
                .map_err(|_| KvRouterError::IndexerOffline)?;

            let mut scores = OverlapScores::new();

            for response_num in 0..self.event_tx.len() {
                match match_rx.recv().await {
                    Some(response) => {
                        scores.scores.extend(response.scores);

                        if response_num == 0 {
                            scores.frequencies = response.frequencies;
                        } else {
                            let diff = (response.frequencies.len() as i64)
                                - (scores.frequencies.len() as i64);

                            if diff > 0 {
                                scores
                                    .frequencies
                                    .extend(iter::repeat(0).take(diff as usize));
                            }

                            for i in 0..response.frequencies.len() {
                                scores.frequencies[i] += response.frequencies[i];
                            }
                        }
                    }
                    None => {
                        // This can only happen if the broadcast channel overflows.
                        // In this case, we don't want to recursively call find_matches again. Otherwise, we could overflow the stack.
                        continue 'match_loop;
                    }
                }
            }
            return Ok(scores);
        }
    }

    async fn find_matches_for_request(
        &self,
        tokens: &[u32],
    ) -> Result<OverlapScores, KvRouterError> {
838
        let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size);
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
        self.find_matches(sequence).await
    }

    async fn apply_event(&mut self, event: RouterEvent) {
        #[allow(clippy::map_entry)]
        if !self.worker_assignments.contains_key(&event.worker_id) {
            // Get the shard with the smallest amount of workers.
            let selected_shard = self
                .worker_counts
                .iter()
                .enumerate()
                .min_by_key(|&(_, value)| value)
                .unwrap()
                .0;

            self.worker_assignments
                .insert(event.worker_id, selected_shard);
            self.worker_counts[selected_shard] += 1;
        }

        self.event_tx[self.worker_assignments[&event.worker_id]]
            .send(event)
            .await
            .unwrap();
    }

    async fn remove_worker(&mut self, worker: WorkerId) {
        if let Some((_, shard)) = self.worker_assignments.remove_entry(&worker) {
            self.worker_counts[shard] -= 1;
            self.remove_worker_tx[shard].send(worker).await.unwrap();
        }
    }

    /// Shutdown the KV Indexer.
    fn shutdown(&mut self) {
        self.cancel.cancel();
        while !self.tasks.is_empty() {
            self.tasks.pop().unwrap().join().unwrap();
        }
    }
}

#[cfg(test)]
mod tests {

    use super::*;
    use rstest::rstest;
886
    use rstest_reuse::{self, *};
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
    use tokio::time;
    use tokio_util::sync::CancellationToken;

    fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> {
        hashes
            .iter()
            .map(|i| KvCacheStoredBlockData {
                tokens_hash: LocalBlockHash(*i),
                block_hash: ExternalSequenceBlockHash(*i * 100),
            })
            .collect()
    }

    fn add_blocks(
        hashes: Vec<u64>,
        parent_hash: Option<ExternalSequenceBlockHash>,
    ) -> KvCacheEventData {
        KvCacheEventData::Stored(KvCacheStoreData {
            parent_hash,
            blocks: make_blocks(hashes),
        })
    }

    fn create_store_event(
        worker_id: WorkerId,
        event_id: u64,
        hashes: Vec<u64>,
        parent: Option<ExternalSequenceBlockHash>,
    ) -> RouterEvent {
        RouterEvent {
            worker_id,
            event: KvCacheEvent {
                event_id,
                data: add_blocks(hashes, parent),
            },
        }
    }

    fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec<u64>) -> RouterEvent {
        RouterEvent {
            worker_id,
            event: KvCacheEvent {
                event_id,
                data: KvCacheEventData::Removed(KvCacheRemoveData {
                    block_hashes: hashes
                        .iter()
                        .map(|i| ExternalSequenceBlockHash(*i * 100))
                        .collect(),
                }),
            },
        }
    }

    #[test]
    fn test_radix_tree() {
        let mut trie = RadixTree::new();

944
945
        let worker_1 = 0;
        let worker_2 = 1;
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146

        trie.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None));

        let scores = trie.find_matches(
            vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
            false,
        );
        assert_eq!(scores.scores.get(&worker_1).unwrap(), &3);

        assert_eq!(trie.lookup.len(), 1);
        assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
        assert_eq!(trie.root.borrow().workers.len(), 0);
        assert_eq!(trie.root.borrow().children.len(), 1);
        assert_eq!(
            trie.root
                .borrow()
                .children
                .get(&LocalBlockHash(1))
                .unwrap()
                .borrow()
                .workers
                .len(),
            1
        );
        assert_eq!(
            trie.root
                .borrow()
                .children
                .get(&LocalBlockHash(1))
                .unwrap()
                .borrow()
                .children
                .len(),
            1
        );

        trie.apply_event(create_store_event(worker_2, 1, vec![1, 4, 5], None));

        let scores = trie.find_matches(
            vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
            false,
        );
        assert_eq!(scores.scores.get(&worker_1).unwrap(), &3);
        assert_eq!(scores.scores.get(&worker_2).unwrap(), &1);

        assert_eq!(trie.lookup.len(), 2);
        assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
        assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 3);
        assert_eq!(trie.root.borrow().workers.len(), 0);
        assert_eq!(trie.root.borrow().children.len(), 1);
        assert_eq!(
            trie.root
                .borrow()
                .children
                .get(&LocalBlockHash(1))
                .unwrap()
                .borrow()
                .workers
                .len(),
            2
        );
        assert_eq!(
            trie.root
                .borrow()
                .children
                .get(&LocalBlockHash(1))
                .unwrap()
                .borrow()
                .children
                .len(),
            2
        );

        trie.apply_event(create_remove_event(worker_2, 2, vec![5]));
        assert_eq!(trie.lookup.len(), 2);
        assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
        assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 2);
        assert_eq!(trie.root.borrow().workers.len(), 0);
        assert_eq!(trie.root.borrow().children.len(), 1);
        assert_eq!(
            trie.root
                .borrow()
                .children
                .get(&LocalBlockHash(1))
                .unwrap()
                .borrow()
                .workers
                .len(),
            2
        );
        assert_eq!(
            trie.root
                .borrow()
                .children
                .get(&LocalBlockHash(1))
                .unwrap()
                .borrow()
                .children
                .len(),
            2
        );

        trie.apply_event(create_remove_event(worker_2, 3, vec![4]));

        assert_eq!(trie.lookup.len(), 2);
        assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
        assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 1);
        assert_eq!(trie.root.borrow().workers.len(), 0);
        assert_eq!(trie.root.borrow().children.len(), 1);
        assert_eq!(
            trie.root
                .borrow()
                .children
                .get(&LocalBlockHash(1))
                .unwrap()
                .borrow()
                .workers
                .len(),
            2
        );
        assert_eq!(
            trie.root
                .borrow()
                .children
                .get(&LocalBlockHash(1))
                .unwrap()
                .borrow()
                .children
                .len(),
            2
        );

        trie.apply_event(create_store_event(
            worker_2,
            4,
            vec![2, 6, 7],
            Some(ExternalSequenceBlockHash(100)),
        ));

        let scores = trie.find_matches(
            vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
            false,
        );
        assert_eq!(scores.scores.get(&worker_1).unwrap(), &3);
        assert_eq!(scores.scores.get(&worker_2).unwrap(), &2);

        assert_eq!(trie.lookup.len(), 2);
        assert_eq!(trie.lookup.get(&worker_1).unwrap().len(), 3);
        assert_eq!(trie.lookup.get(&worker_2).unwrap().len(), 4);
        assert_eq!(trie.root.borrow().workers.len(), 0);
        assert_eq!(trie.root.borrow().children.len(), 1);
        assert_eq!(
            trie.root
                .borrow()
                .children
                .get(&LocalBlockHash(1))
                .unwrap()
                .borrow()
                .workers
                .len(),
            2
        );
        assert_eq!(
            trie.root
                .borrow()
                .children
                .get(&LocalBlockHash(1))
                .unwrap()
                .borrow()
                .children
                .len(),
            2
        );
        assert_eq!(
            trie.lookup
                .get(&worker_1)
                .unwrap()
                .get(&ExternalSequenceBlockHash(200))
                .unwrap()
                .borrow()
                .workers
                .len(),
            2
        );
        assert_eq!(
            trie.lookup
                .get(&worker_2)
                .unwrap()
                .get(&ExternalSequenceBlockHash(200))
                .unwrap()
                .borrow()
                .workers
                .len(),
            2
        );
    }

    #[test]
    fn test_remove_worker() {
        let mut trie = RadixTree::new();

1147
1148
        let worker_0 = 0;
        let worker_1 = 1;
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170

        assert!(trie
            .find_matches(vec![LocalBlockHash(0)], false)
            .scores
            .is_empty());

        trie.apply_event(create_store_event(worker_0, 0, vec![0], None));
        trie.apply_event(create_store_event(worker_1, 0, vec![0], None));

        let result = trie.find_matches(vec![LocalBlockHash(0)], false).scores;
        assert!(result.len() == 2 && result[&worker_0] == 1 && result[&worker_1] == 1);

        trie.remove_worker(worker_0);

        let result = trie.find_matches(vec![LocalBlockHash(0)], false).scores;
        assert!(result.len() == 1 && result[&worker_1] == 1);
    }

    #[test]
    fn test_early_stopping() {
        let mut trie = RadixTree::new();

1171
1172
        let worker_0 = 0;
        let worker_1 = 1;
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191

        trie.apply_event(create_store_event(worker_0, 0, vec![0, 1, 2], None));
        trie.apply_event(create_store_event(worker_1, 0, vec![0], None));

        let result = trie
            .find_matches(
                vec![LocalBlockHash(0), LocalBlockHash(1), LocalBlockHash(2)],
                true,
            )
            .scores;

        assert!(result.len() == 2 && result[&worker_0] == 2 && result[&worker_1] == 1);

        let result = trie
            .find_matches(vec![LocalBlockHash(0), LocalBlockHash(1)], true)
            .scores;
        assert!(result.len() == 2 && result[&worker_0] == 2 && result[&worker_1] == 1);
    }

1192
1193
1194
1195
1196
    #[rstest]
    #[case(11)]
    #[case(32)]
    #[case(64)]
    fn test_compute_block_hash_for_seq(#[case] kv_block_size: usize) {
1197
        // create a sequence of 64 elements
1198
1199
        let sequence = (0..kv_block_size).map(|i| i as u32).collect::<Vec<u32>>();
        let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
1200
1201
1202
        assert_eq!(hashes.len(), 1);

        // create a sequence of 65 elements
1203
        let sequence = (0..(kv_block_size + 1))
Ryan Olson's avatar
Ryan Olson committed
1204
1205
            .map(|i| i as u32)
            .collect::<Vec<u32>>();
1206
        let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
1207
1208
1209
        assert_eq!(hashes.len(), 1);

        // create a sequence of 129 elements
1210
        let sequence = (0..(2 * kv_block_size + 1))
Ryan Olson's avatar
Ryan Olson committed
1211
1212
            .map(|i| i as u32)
            .collect::<Vec<u32>>();
1213
        let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
1214
1215
1216
        assert_eq!(hashes.len(), 2);
    }

1217
1218
1219
1220
1221
    fn make_indexer(
        token: &CancellationToken,
        num_shards: usize,
        kv_block_size: usize,
    ) -> Box<dyn KvIndexerInterface> {
1222
        if num_shards == 1 {
1223
            Box::new(KvIndexer::new(token.clone(), kv_block_size))
1224
        } else {
1225
1226
1227
1228
1229
            Box::new(KvIndexerSharded::new(
                token.clone(),
                num_shards,
                kv_block_size,
            ))
1230
1231
1232
        }
    }

1233
    #[template]
1234
    #[rstest]
1235
1236
1237
1238
1239
1240
    fn indexer_template(
        #[values(1, 3, 8)] num_shards: usize,
        #[values(11, 32, 64)] kv_block_size: usize,
    ) {
    }

1241
    #[tokio::test]
1242
1243
1244
1245
    #[apply(indexer_template)]
    async fn test_kv_indexer_new(num_shards: usize, kv_block_size: usize) {
        let token: CancellationToken = CancellationToken::new();
        let _ = make_indexer(&token, num_shards, kv_block_size);
1246
1247
1248
    }

    #[tokio::test]
1249
1250
    #[apply(indexer_template)]
    async fn test_find_matches(num_shards: usize, kv_block_size: usize) {
1251
        let token = CancellationToken::new();
1252
        let kv_indexer = make_indexer(&token, num_shards, kv_block_size);
1253
1254
1255
1256
1257
1258
1259
1260

        let sequence = vec![compute_block_hash(b"test data")];
        let scores = kv_indexer.find_matches(sequence).await;

        assert!(scores.unwrap().scores.is_empty());
    }

    #[tokio::test]
1261
1262
    #[apply(indexer_template)]
    async fn test_find_matches_for_request(num_shards: usize, kv_block_size: usize) {
1263
        let token = CancellationToken::new();
1264
        let kv_indexer = make_indexer(&token, num_shards, kv_block_size);
1265
1266
1267
1268
1269
1270
1271
1272

        let tokens = vec![1, 2, 3, 4];
        let scores = kv_indexer.find_matches_for_request(&tokens).await;

        assert!(scores.unwrap().scores.is_empty());
    }

    #[tokio::test]
1273
1274
    #[apply(indexer_template)]
    async fn test_apply_event(num_shards: usize, kv_block_size: usize) {
1275
        let worker_id = 0;
1276
1277

        let token = CancellationToken::new();
1278
        let mut kv_indexer = make_indexer(&token, num_shards, kv_block_size);
1279
1280
1281
1282
1283
1284
1285
1286

        let event = create_store_event(worker_id, 1, vec![1, 2, 3], None);
        kv_indexer.apply_event(event).await;

        // No assertion here, just ensuring it runs without panic
    }

    #[tokio::test]
1287
1288
    #[apply(indexer_template)]
    async fn test_shutdown(num_shards: usize, kv_block_size: usize) {
1289
        let token = CancellationToken::new();
1290
        let mut kv_indexer = make_indexer(&token, num_shards, kv_block_size);
1291
1292
1293
1294
1295

        kv_indexer.shutdown();
    }

    #[tokio::test]
1296
1297
    #[apply(indexer_template)]
    async fn test_frequency(num_shards: usize, kv_block_size: usize) {
1298
1299
1300
1301
1302
        let mut kv_indexer: Box<dyn KvIndexerInterface>;
        let token = CancellationToken::new();
        let duration = Some(Duration::from_millis(50));

        if num_shards == 1 {
1303
1304
1305
1306
1307
            kv_indexer = Box::new(KvIndexer::new_with_frequency(
                token,
                duration,
                kv_block_size,
            ));
1308
1309
        } else {
            kv_indexer = Box::new(KvIndexerSharded::new_with_frequency(
1310
1311
1312
1313
                token,
                num_shards,
                duration,
                kv_block_size,
1314
1315
1316
            ));
        }

1317
        let worker_id = 0;
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356

        let event = create_store_event(worker_id, 0, vec![1, 2, 3, 4], None);
        kv_indexer.apply_event(event).await;

        time::sleep(Duration::from_millis(5)).await;

        let block_hashes = vec![
            LocalBlockHash(1),
            LocalBlockHash(2),
            LocalBlockHash(3),
            LocalBlockHash(4),
        ];
        let scores = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();

        assert_eq!(scores.frequencies.len(), 0);

        let scores = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
        assert_eq!(scores.frequencies, vec![1, 1, 1, 1]);

        time::sleep(Duration::from_millis(100)).await;

        let scores = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
        assert_eq!(scores.frequencies.len(), 0);

        let scores = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
        assert_eq!(scores.frequencies, vec![1, 1, 1, 1]);

        let scores = kv_indexer
            .find_matches(block_hashes[0..3].to_vec())
            .await
            .unwrap();
        assert_eq!(scores.frequencies, vec![2, 2, 2]);

        let scores = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
        assert_eq!(scores.frequencies, vec![3, 3, 3, 2]);
    }

    #[test]
    fn test_router_event_new() {
1357
        let worker_id = 0;
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
        let kv_cache_event = KvCacheEvent {
            event_id: 1,
            data: KvCacheEventData::Stored(KvCacheStoreData {
                parent_hash: None,
                blocks: vec![KvCacheStoredBlockData {
                    block_hash: ExternalSequenceBlockHash(0),
                    tokens_hash: LocalBlockHash(13226331709069118873),
                }],
            }),
        };
        let router_event = RouterEvent::new(worker_id, kv_cache_event);

        assert_eq!(router_event.worker_id, worker_id);
        assert_eq!(router_event.event.event_id, 1);
        if let KvCacheEventData::Stored(store_op) = &router_event.event.data {
            assert_eq!(store_op.blocks.len(), 1);
            assert_eq!(
                store_op.blocks[0].tokens_hash,
                compute_block_hash(b"test data")
            );
            assert_eq!(store_op.blocks[0].block_hash, ExternalSequenceBlockHash(0));
        } else {
            panic!("Expected KvCacheEventData::Stored");
        }
    }

    #[test]
    fn test_radix_tree_default() {
        let radix_tree: RadixTree = Default::default();
        assert!(radix_tree.root.borrow().children.is_empty());
        assert!(radix_tree.root.borrow().workers.is_empty());
        assert!(radix_tree.lookup.is_empty());
    }

    #[test]
    fn test_overlap_scores_default() {
        let overlap_scores: OverlapScores = Default::default();
        assert!(overlap_scores.scores.is_empty());
    }
}