traits.rs 7.68 KB
Newer Older
1
2
3
4
5
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use async_trait::async_trait;

6
7
8
use std::sync::Arc;

use super::{KvIndexerMetrics, KvRouterError, WorkerTask};
9
10
use crate::protocols::*;

11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/// Trait for querying an external shared KV cache pool.
///
/// Implementations check which blocks/pages from a request's token sequence
/// exist in the shared cache. The returned `SharedCacheHits` describes which
/// block positions are available externally (and thus cheaper to prefill).
#[async_trait]
pub trait SharedKvCache: Send + Sync {
    /// Query which blocks exist in the shared cache for the given token sequence.
    async fn check_blocks(
        &self,
        tokens: &[u32],
        block_size: u32,
    ) -> Result<SharedCacheHits, KvRouterError>;
}

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
/// Per-shard size snapshot returned by [`KvIndexerInterface::shard_sizes`].
///
/// `worker_count` and `block_count` are always populated.
/// `node_count` is populated only when the `shard-metrics` feature is enabled
/// on the `dynamo-kv-router` crate; otherwise it is `0`.
#[derive(Debug, Clone)]
pub struct ShardSizeSnapshot {
    /// Zero-based shard index.
    pub shard_idx: usize,
    /// Distinct `(worker_id, dp_rank)` pairs stored in this shard.
    pub worker_count: usize,
    /// Total cached blocks across all workers in this shard.
    pub block_count: usize,
    /// Radix-tree node count (only non-zero with `shard-metrics` feature).
    pub node_count: usize,
}

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#[async_trait]
pub trait KvIndexerInterface {
    /// Find matches for a given sequence of `LocalBlockHash`es.
    ///
    /// ### Arguments
    ///
    /// * `sequence` - A vector of `LocalBlockHash` representing the sequence to match.
    ///
    /// ### Returns
    ///
    /// An `OverlapScores` representing the match scores.
    async fn find_matches(
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError>;

    /// Find matches for a given sequence of tokens.
    ///
    /// ### Arguments
    ///
    /// * `tokens` - A vector of `u32` tokens.
    /// * `lora_name` - Optional LoRA adapter name to include in block hash computation.
    ///
    /// ### Returns
    ///
    /// An `OverlapScores` representing the match scores.
    async fn find_matches_for_request(
        &self,
        tokens: &[u32],
        lora_name: Option<&str>,
73
        is_eagle: Option<bool>,
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    ) -> Result<OverlapScores, KvRouterError>;

    /// Apply a `RouterEvent` to the KV store.
    ///
    /// ### Arguments
    ///
    /// * `event` - The `RouterEvent` to apply.
    async fn apply_event(&self, event: RouterEvent);

    /// Remove a worker's entries from the trie.
    ///
    /// ### Arguments
    ///
    /// * `worker` - The worker to remove from the trie.
    async fn remove_worker(&self, worker: WorkerId);

    /// Remove a single dp_rank for a worker from the trie.
    ///
    /// Default implementation falls back to removing the entire worker.
    /// Indexers that track dp_rank-level granularity should override this.
    async fn remove_worker_dp_rank(&self, worker: WorkerId, _dp_rank: DpRank) {
        self.remove_worker(worker).await;
    }

    /// Shutdown the KV Indexer.
    fn shutdown(&self);

    /// Dump the entire tree as RouterEvents.
    ///
    /// ### Returns
    ///
    /// A vector of RouterEvents representing the current state of the tree.
    async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError>;

    /// Process a routing decision for a request with tokens.
    ///
    /// Uses TokensWithHashes for lazy hash computation - if hashes were already
    /// computed (e.g., by find_best_match), they will be reused.
    ///
    /// ### Arguments
    ///
    /// * `tokens_with_hashes` - Tokens with lazily computed hashes.
    /// * `worker` - The worker (with dp_rank) that was selected.
    async fn process_routing_decision_for_request(
        &self,
        tokens_with_hashes: &mut TokensWithHashes,
        worker: WorkerWithDpRank,
    ) -> Result<(), KvRouterError>;

    /// Async task that returns when all pending events have been processed.
    /// For now, we assume that no requests or events are being sent in the meantime.
    /// Returns the amount of events still in the queue at the time of the flush.
    /// Used primarily for debugging.
    async fn flush(&self) -> usize;
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

    /// Return a human-readable timing breakdown of `find_matches` overhead.
    ///
    /// Implementations that track per-phase timing (e.g. scatter/gather overhead
    /// vs. actual shard work) override this to return a multi-line report string.
    /// The default returns an empty string so callers can skip printing it.
    fn timing_report(&self) -> String {
        String::new()
    }

    /// Return a size snapshot for each shard.
    ///
    /// Single-shard indexers return one entry (shard 0).  Multi-shard indexers
    /// return one entry per shard.  Non-sharded indexers (and implementations
    /// that don't override this) return an empty `Vec`.
    ///
    /// See [`ShardSizeSnapshot`] for the fields exposed per shard.
    fn shard_sizes(&self) -> Vec<ShardSizeSnapshot> {
        vec![]
    }

    /// Edge lengths (hashes per node) for every non-root node.
    /// Returns an empty vec for backends that don't support this.
    fn node_edge_lengths(&self) -> Vec<usize> {
        vec![]
    }
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
}

// ============================================================================
// SyncIndexer trait
// ============================================================================

/// Trait for thread-safe data structures that support KV cache indexing operations.
///
/// All methods take `&self` and are synchronous. Implementations must be safe for
/// concurrent access (via internal locking, DashMap, etc).
///
/// This trait is used with [`ThreadPoolIndexer`](super::ThreadPoolIndexer), which wraps a `SyncIndexer` to
/// provide the async [`KvIndexerInterface`] with:
/// - Sticky event routing to N worker threads
/// - Inline reads on the caller's thread (no channel dispatch for find_matches)
pub trait SyncIndexer: Send + Sync + 'static {
170
171
172
173
174
    fn worker(
        &self,
        event_receiver: flume::Receiver<WorkerTask>,
        metrics: Option<Arc<KvIndexerMetrics>>,
    ) -> anyhow::Result<()>;
175
176
177
178

    /// Find matches for a sequence of block hashes.
    fn find_matches(&self, sequence: &[LocalBlockHash], early_exit: bool) -> OverlapScores;

179
180
181
182
183
184
185
186
187
188
189
    /// Returns true when a maintenance task should be enqueued.
    fn try_schedule_cleanup(&self) -> bool {
        false
    }

    /// Rolls back a scheduled cleanup when enqueueing the task fails.
    fn cancel_scheduled_cleanup(&self) {}

    /// Executes a maintenance task on a worker thread.
    fn run_cleanup_task(&self) {}

190
191
192
193
194
195
196
    /// Dump events directly from the shared structure, bypassing worker channels.
    /// Returns `Some(events)` for backends whose tree state is fully shared (e.g.
    /// ConcurrentRadixTree). Returns `None` for backends that keep per-thread
    /// state and must dump via the worker channel.
    fn dump_events(&self) -> Option<Vec<RouterEvent>> {
        None
    }
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218

    /// Number of distinct workers registered in this backend.
    fn worker_count(&self) -> usize {
        0
    }

    /// Total cached blocks across all workers.
    fn block_count(&self) -> usize {
        0
    }

    /// Number of radix-tree nodes created since construction.
    /// Only meaningful when the `shard-metrics` feature is enabled; returns 0 otherwise.
    fn node_count(&self) -> usize {
        0
    }

    /// Edge lengths (hashes per node) for every non-root node in the tree.
    /// Returns an empty vec for backends that don't support this.
    fn node_edge_lengths(&self) -> Vec<usize> {
        vec![]
    }
219
}