README.md 13.3 KB
Newer Older
Yan Ru Pei's avatar
Yan Ru Pei committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
# KV Router Index Data Structures

This document explains the KV cache index implementations: `RadixTree`, `ConcurrentRadixTree`, and `PositionalIndexer` (NestedMap).

## Motivation: The Four Block Identifiers

Every cached KV block in a distributed LLM system needs four pieces of information:

### 1. Local Block Hash (`LocalBlockHash`, u64)

**What**: Hash of the tokens *within* a single block (e.g., 64 tokens).

**Why**: Identifies the content of this specific block, independent of context. Two blocks with the same tokens have the same local hash.

```
Block at position 5: tokens [101, 102, 103, ...]
LocalBlockHash = hash(tokens) = 0xABCD1234
```

### 2. External Sequence Block Hash (`ExternalSequenceBlockHash`, u64)

**What**: Cumulative hash of the entire sequence up to and including this block.

**Why**: Uniquely identifies a block's position in a *specific* sequence history. Two blocks with the same local content but different prefixes have different sequence hashes.

```
Sequence A: [block0, block1, block2]
Sequence B: [block0', block1', block2]  // block2 has same content but different prefix

block2 in A: seq_hash = hash(hash(hash(block0) + block1) + block2) = 0x1111
block2 in B: seq_hash = hash(hash(hash(block0') + block1') + block2) = 0x2222
```

**Computation**: `seq_hash[i] = hash(seq_hash[i-1] || local_hash[i])` where `seq_hash[0] = local_hash[0]`

> **Important: Engine-Provided Hashes**
>
> In practice, the `ExternalSequenceBlockHash` may come directly from the inference engine (e.g., vLLM, TensorRT-LLM) using a rolling hash algorithm that we don't know or control. The engine computes these hashes internally and reports them via KV cache events.
>
> **Implications for index implementations:**
>
> - **RadixTree**: Can handle engine-provided hashes because it traverses the tree structure using `LocalBlockHash` for navigation and only uses `ExternalSequenceBlockHash` as an opaque identifier for lookups. It doesn't need to recompute hashes.
>
> - **NestedMap**: Requires the ability to compute `ExternalSequenceBlockHash` incrementally for its lazy hash optimization in `find_matches`. To use NestedMap, one of the following is required:
>   1. **Force a known hasher**: Configure the engine to use a specific hashing algorithm that the router can replicate, OR
>   2. **Recompute on the relay**: Have the publisher/relay layer recompute the rolling hash using a known algorithm before forwarding events to the router.
>
> Without this, NestedMap's `find_matches` will fail when encountering `SeqEntry::Multi` cases (multiple seq_hashes at the same position+local_hash) because it cannot disambiguate which entry to use.

### 3. Worker ID (`WorkerWithDpRank`)

**What**: Identifies which worker (inference server) has this block cached.

**Why**: The router needs to know which workers can serve a request based on their cached blocks.

### 4. Position (`u64`)

**What**: The block's index in the sequence (0, 1, 2, ...).

**Why**: Enables efficient prefix matching. Position 0 is the first block, position N-1 is the last.

---

## The Core Operations

Both data structures support three operations:

| Operation | Description | Hot Path? |
|-----------|-------------|-----------|
| `store_blocks` | Add blocks for a worker | No (background) |
| `remove_blocks` | Remove blocks for a worker | No (background) |
| `find_matches` | Find workers with matching prefix | **Yes** (per-request) |

The key insight: **reads (find_matches) are far more frequent than writes (store/remove)**. This motivates different structural tradeoffs.

---

## RadixTree: Tree-Based Index

### Structure

```
RadixTree
├── root: SharedRadixBlock (Rc<RefCell<RadixBlock>>)
└── lookup: HashMap<Worker, HashMap<SeqHash, SharedRadixBlock>>

RadixBlock
├── children: HashMap<LocalBlockHash, SharedRadixBlock>
├── workers: HashSet<Worker>
├── block_hash: Option<SeqHash>
└── recent_uses: VecDeque<Instant>
```

### Visual Representation

```
                    [root]
                   /      \
            local=0xA    local=0xB
               ↓            ↓
           [block0]     [block0']
           workers:     workers:
           {W0,W1}      {W2}
              |
         local=0xC

          [block1]
          workers:
          {W0,W1}
              |
         local=0xD

          [block2]
          workers:
          {W0}         ← W1 diverged here
```

### How Operations Work

**store_blocks(worker, parent_hash, blocks)**:
1. Find parent via `lookup[worker][parent_hash]`
2. For each block, traverse/create child nodes using `local_hash`
3. Add worker to each node's `workers` set
4. Update `lookup[worker][seq_hash] = node`

**remove_blocks(worker, block_hashes)**:
1. For each hash, find node via `lookup[worker][hash]`
2. Remove worker from node's `workers` set
3. If `workers` empty, clear children (cascading cleanup)
4. Remove from `lookup[worker]`

**find_matches(local_hashes, early_exit)**:
1. Start at root with all workers as candidates
2. For each position, traverse to child matching `local_hash`
3. Intersect candidates with node's `workers`
4. Track depth where each worker drops out
5. Return `{worker -> depth}` scores

### Complexity

| Operation | Time | Space |
|-----------|------|-------|
| store_blocks (N blocks) | O(N) | O(N) nodes |
| remove_blocks (N blocks) | O(N) | - |
| find_matches (depth D) | O(D × W) | O(W) |

Where W = number of workers.

---

## PositionalIndexer (NestedMap): Position-First HashMap Index

### Structure

```
PositionalIndexer
├── index: DashMap<(Position, LocalHash), SeqEntry>
├── worker_blocks: DashMap<Worker, RwLock<HashMap<SeqHash, (Position, LocalHash)>>>
└── jump_size: usize

SeqEntry (enum for memory optimization)
├── Single(SeqHash, HashSet<Worker>)  // Common case: one seq_hash
└── Multi(HashMap<SeqHash, HashSet<Worker>>)  // Rare: multiple prefixes
```

`PositionalIndexer` implements `SyncIndexer` and is thread-safe via `DashMap` (sharded
concurrent map) and `RwLock`. It is designed to be wrapped in a `ThreadPoolIndexer` which
routes write events to dedicated OS threads and executes reads inline.

The `index` uses a flat compound key `(position, local_hash)` in a `DashMap`, which
distributes lock contention across shards while enabling O(1) random-position access for
the jump optimization. The `worker_blocks` reverse lookup uses `DashMap` for the outer
per-worker map and `RwLock<HashMap>` for each worker's block set, since writes to a
given worker are serialized by sticky routing in `ThreadPoolIndexer`.

### Visual Representation

```
index (DashMap with compound keys):
┌──────────────────────┬──────────────────────────────────────┐
│ (pos=0, local=0xA)   │ Single(seq=0x1111, {W0,W1})          │
│ (pos=0, local=0xB)   │ Single(seq=0x2222, {W2})             │
│ (pos=1, local=0xC)   │ Single(seq=0x3333, {W0,W1})          │
│ (pos=2, local=0xD)   │ Multi{                               │
│                      │    seq=0x4444 → {W0},                │
│                      │    seq=0x5555 → {W1}   ← diverged    │
│                      │  }                                   │
└──────────────────────┴──────────────────────────────────────┘

worker_blocks (DashMap<Worker, RwLock<HashMap>>):
┌─────────┬─────────────────────────────────────────────────┐
│ W0      │ seq=0x1111 → (pos=0, local=0xA)                 │
│         │ seq=0x3333 → (pos=1, local=0xC)                 │
│         │ seq=0x4444 → (pos=2, local=0xD)                 │
├─────────┼─────────────────────────────────────────────────┤
│ W1      │ seq=0x1111 → (pos=0, local=0xA)                 │
│         │ seq=0x3333 → (pos=1, local=0xC)                 │
│         │ seq=0x5555 → (pos=2, local=0xD)                 │
└─────────┴─────────────────────────────────────────────────┘
```

### How Operations Work

**store_blocks(worker, parent_hash, blocks)**:
1. Find starting position: `pos = worker_blocks[worker][parent_hash].position + 1`
2. For each block at position `i`:
   - Insert into `index[(pos+i, local_hash)]` → add worker to SeqEntry
   - Insert into `worker_blocks[worker][seq_hash] = (pos+i, local_hash)`

**remove_blocks(worker, block_hashes)**:
1. For each hash, lookup `(pos, local_hash) = worker_blocks[worker][hash]`
2. Remove worker from `index[(pos, local_hash)]`
3. Remove from `worker_blocks[worker]`
4. Cleanup empty SeqEntry entries from the DashMap

**find_matches(local_hashes, early_exit)** with Jump Optimization:
1. Start at position 0, initialize candidates from first block
2. **Jump**: Skip ahead by `jump_size` positions (e.g., 32)
3. At each jump point, check if candidates still match (count-only, no clone)
4. If workers dropped, **scan back** to find exact drain points
5. Continue until sequence exhausted or one worker remains

```
Query: [b0, b1, b2, ..., b63, b64, ..., b127, ...]
        ↑                   ↑                  ↑
       pos=0              pos=64             pos=128
        │                   │                  │
        └── jump ──────────→└── jump ─────────→│
                           all match?         some dropped?
                              ↓                   ↓
                           continue          scan [64,128]
```

**Lazy Hash Optimization**:
- Most (position, local_hash) pairs have only ONE seq_hash (SeqEntry::Single)
- Skip seq_hash computation entirely in this case
- Only compute when disambiguation needed (SeqEntry::Multi)

**dump_events()**:
1. Iterate `worker_blocks`, collecting all blocks per worker
2. Sort each worker's blocks by position (parents before children)
3. Emit one single-block `RouterEvent::Stored` per block, synthesizing
   `parent_hash` from any seq_hash at the prior position
4. Events can be replayed into a fresh `PositionalIndexer` to reconstruct
   the same index state

### Complexity

| Operation | Time | Space |
|-----------|------|-------|
| store_blocks (N blocks) | O(N) | O(N) entries |
| remove_blocks (N blocks) | O(N) | - |
| find_matches (depth D) | O(D/J + J×W) | O(W) |

Where J = jump_size, W = number of workers. The jump optimization reduces D iterations to D/J jumps plus occasional scans.

---

## Comparison

| Aspect | RadixTree | PositionalIndexer |
|--------|-----------|-------------------|
| **Structure** | Tree with Rc<RefCell<>> nodes | DashMap with compound keys |
| **Concurrent variant** | ConcurrentRadixTree (Arc<RwLock<>> + DashMap) | Thread-safe by default (DashMap + RwLock) |
| **find_matches** | O(D×W) tree traversal | O(D/J) with jump optimization |
| **store_blocks** | O(N) node creation | O(N) DashMap inserts |
| **remove_blocks** | O(N) with cascading cleanup | O(N) with entry cleanup |
| **dump_events** | BFS traversal of tree | Sort by position per worker |
| **Memory** | Higher (Rc/Arc overhead per node) | Lower (flat entries) |
| **Cache locality** | Poor (pointer chasing) | Better (position-first) |

### Benchmark Results (1M blocks, depth 1024, 128 workers)

| Operation | RadixTree | NestedMap | Winner |
|-----------|-----------|-----------|--------|
| STORE_BLOCK | 90µs | 98µs | RadixTree (1.1x) |
| REMOVE_BLOCK | 91µs | 233µs | RadixTree (2.5x) |
| FIND_MATCHES (HIT) | 227µs | **44µs** | **NestedMap (5.2x)** |
| FIND_MATCHES (PARTIAL) | 216µs | **44µs** | **NestedMap (4.9x)** |

**Recommendation**: Use NestedMap for read-heavy workloads (typical router usage).

---

## Why Position Matters for PositionalIndexer

The compound key `(position, local_hash)` in the DashMap enables the jump optimization:

```rust
// Without position-first: must traverse entire tree
for pos in 0..depth {
    node = node.children[local_hashes[pos]];  // O(depth) traversals
}

// With position-first: can jump directly to any position
let workers_at_64 = index.get(&(64, local_hashes[64]));  // O(1) lookup
let workers_at_128 = index.get(&(128, local_hashes[128]));  // O(1) lookup
// Skip positions 1-63, 65-127 entirely!
```

---

## SeqEntry Optimization

The innermost level uses an enum to avoid HashMap allocation in the common case:

```rust
enum SeqEntry {
    // Common: one prefix leads to this (position, local_hash)
    Single(SeqHash, HashSet<Worker>),

    // Rare: different prefixes converge on same (position, local_hash)
    Multi(HashMap<SeqHash, HashSet<Worker>>),
}
```

**When does Multi occur?**

Only when two different sequences have:
1. Same local block content at position P
2. Different prefix histories (different seq_hash)

Example:
```
Sequence A: [tok1, tok2, tok3] → positions 0,1,2
Sequence B: [tok4, tok5, tok3] → positions 0,1,2
                       ^^^^
                 Same local content at pos=2
                 but different seq_hash!
```

This is rare in practice, so `Single` saves ~48 bytes per entry.