Unverified Commit 937398cf authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: Flash Indexer (#5785)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Signed-off-by: default avatarjthomson04 <jothomson@nvidia.com>
Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Signed-off-by: default avatarJanelle Cai <jcai18@mit.edu>
Co-authored-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarJanelle Cai <jcai18@mit.edu>
parent de27efe6
......@@ -1735,6 +1735,16 @@ dependencies = [
"memchr",
]
[[package]]
name = "ctor"
version = "0.1.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d2301688392eb071b0bf1a37be05c469d3cc4dbbd95df672fe28ab021e6a096"
dependencies = [
"quote",
"syn 1.0.109",
]
[[package]]
name = "cudarc"
version = "0.17.8"
......@@ -2295,9 +2305,15 @@ dependencies = [
"anyhow",
"async-trait",
"clap 4.5.53",
"dashmap 6.1.0",
"dynamo-mocker",
"dynamo-runtime",
"dynamo-tokens",
"flume",
"futures",
"indicatif 0.18.3",
"minstant",
"parking_lot",
"prometheus",
"rand 0.9.2",
"rstest 0.18.2",
......@@ -2308,6 +2324,7 @@ dependencies = [
"tokio",
"tokio-util",
"tracing",
"uuid 1.18.1",
"xxhash-rust",
]
......@@ -2899,6 +2916,9 @@ name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
dependencies = [
"getrandom 0.2.16",
]
[[package]]
name = "fax"
......@@ -3023,6 +3043,18 @@ dependencies = [
"num-traits",
]
[[package]]
name = "flume"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e139bc46ca777eb5efaf62df0ab8cc5fd400866427e56c68b22e414e53bd3be"
dependencies = [
"fastrand",
"futures-core",
"futures-sink",
"spin",
]
[[package]]
name = "fnv"
version = "1.0.7"
......@@ -5182,6 +5214,16 @@ dependencies = [
"simd-adler32",
]
[[package]]
name = "minstant"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fb9b5c752f145ac5046bccc3c4f62892e3c950c1d1eab80c5949cd68a2078db"
dependencies = [
"ctor",
"web-time",
]
[[package]]
name = "mio"
version = "0.6.23"
......@@ -8674,6 +8716,15 @@ dependencies = [
"vob",
]
[[package]]
name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
dependencies = [
"lock_api",
]
[[package]]
name = "spki"
version = "0.7.3"
......
......@@ -250,6 +250,12 @@ def parse_args():
default=False,
help="KV Router: Track output blocks during generation. When enabled, the router adds placeholder blocks as tokens are generated and applies fractional decay based on progress toward expected_output_tokens. By default, output blocks are not tracked.",
)
parser.add_argument(
"--router-event-threads",
type=int,
default=int(os.environ.get("DYN_ROUTER_EVENT_THREADS", "1")),
help="KV Router: Number of event processing threads. When > 1, uses a concurrent radix tree with a thread pool for higher throughput. Can be set via DYN_ROUTER_EVENT_THREADS env var (default: 1).",
)
parser.add_argument(
"--enforce-disagg",
action="store_true",
......@@ -436,6 +442,7 @@ async def async_main():
router_ttl_secs=flags.router_ttl,
router_max_tree_size=flags.router_max_tree_size,
router_prune_target_ratio=flags.router_prune_target_ratio,
router_event_threads=flags.router_event_threads,
)
elif flags.router_mode == "random":
router_mode = RouterMode.Random
......
......@@ -15,6 +15,7 @@ routing decisions.
import argparse
import asyncio
import logging
import os
from typing import Optional
import uvloop
......@@ -256,6 +257,13 @@ def parse_args():
help="KV Router: Target size ratio after pruning (0.0-1.0). Only used when --no-kv-events is set. Determines how aggressively to prune the tree (default: 0.8)",
)
parser.add_argument(
"--router-event-threads",
type=int,
default=int(os.environ.get("DYN_ROUTER_EVENT_THREADS", "1")),
help="KV Router: Number of event processing threads. When > 1, uses a concurrent radix tree with a thread pool for higher throughput. Can be set via DYN_ROUTER_EVENT_THREADS env var (default: 1).",
)
return parser.parse_args()
......@@ -302,6 +310,7 @@ async def worker(runtime: DistributedRuntime):
router_ttl_secs=args.router_ttl_secs,
router_max_tree_size=args.router_max_tree_size,
router_prune_target_ratio=args.router_prune_target_ratio,
router_event_threads=args.router_event_threads,
)
# Create service component - use "router" as component name
......
......@@ -174,6 +174,8 @@ The main KV-aware routing arguments:
- `--router-prune-target-ratio`: Target size ratio to prune down to when `--router-max-tree-size` is exceeded. For example, with a value of 0.8 (default) and max tree size of 1048576, the router will prune down to approximately 838860 blocks when the threshold is exceeded. Defaults to 0.8 when `--no-kv-events` is used. This creates headroom before the next pruning cycle.
- `--router-event-threads`: Number of event processing threads for the KV indexer. When set to 1 (default), the router uses a single-threaded radix tree with channel-based event processing, which supports TTL-based expiration and pruning. When set to a value greater than 1, the router uses a concurrent radix tree with a thread pool of the specified size for higher event throughput. Note: the concurrent indexer does not support TTL/pruning (`--router-ttl`, `--router-max-tree-size`, `--router-prune-target-ratio` are ignored when `--router-event-threads > 1`). Can be set via `DYN_ROUTER_EVENT_THREADS` env var.
>[!Note]
> **State persistence** depends on the event transport mode:
> - **NATS Core / Event Plane mode** (default): State persists on workers—router rebuilds state by querying workers on startup. This is the default when workers have `local_indexer` enabled (which is the default). Works with both NATS Core and ZMQ event planes.
......
......@@ -130,6 +130,12 @@ The KVIndexer builds and maintains a global view of cached blocks in a prefix tr
The KVIndexer has a method `find_matches_for_request`, which takes in tokens and returns a dictionary with keys of worker id and values of the number of matched KV Blocks.
The KVIndexer supports two backend implementations, selected via `--router-event-threads`:
- **Single-threaded RadixTree** (default, `--router-event-threads 1`): Events are processed in a dedicated single-threaded tokio runtime via channel-based dispatch. Supports TTL-based expiration and size-based pruning (for `--no-kv-events` approximate mode).
- **ConcurrentRadixTree** (`--router-event-threads N` where N > 1): A thread-safe radix tree with a pool of N worker threads for event processing. Uses sticky worker routing (events for the same worker always go to the same thread) to ensure per-worker event serialization. Read operations (`find_matches`) execute concurrently with writes. Does not support TTL/pruning.
### Inter-Router Communication
In distributed deployments with multiple routers, each router maintains visibility over only a portion of the total requests. To ensure consistent routing decisions, routers synchronize their states through three event types:
......
......@@ -1613,8 +1613,11 @@ version = "0.9.0"
dependencies = [
"anyhow",
"async-trait",
"dashmap 6.1.0",
"dynamo-runtime",
"dynamo-tokens",
"flume",
"parking_lot",
"prometheus",
"rand 0.9.2",
"serde",
......@@ -2139,6 +2142,9 @@ name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
dependencies = [
"getrandom 0.2.16",
]
[[package]]
name = "fax"
......@@ -2252,6 +2258,18 @@ dependencies = [
"miniz_oxide",
]
[[package]]
name = "flume"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e139bc46ca777eb5efaf62df0ab8cc5fd400866427e56c68b22e414e53bd3be"
dependencies = [
"fastrand",
"futures-core",
"futures-sink",
"spin",
]
[[package]]
name = "fnv"
version = "1.0.7"
......@@ -6757,6 +6775,15 @@ dependencies = [
"winapi 0.3.9",
]
[[package]]
name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
dependencies = [
"lock_api",
]
[[package]]
name = "spki"
version = "0.7.3"
......
......@@ -52,7 +52,7 @@ impl KvRouterConfig {
#[pymethods]
impl KvRouterConfig {
#[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8))]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_event_threads=1))]
#[allow(clippy::too_many_arguments)]
fn new(
overlap_score_weight: f64,
......@@ -68,6 +68,7 @@ impl KvRouterConfig {
router_ttl_secs: f64,
router_max_tree_size: usize,
router_prune_target_ratio: f64,
router_event_threads: u32,
) -> Self {
KvRouterConfig {
inner: RsKvRouterConfig {
......@@ -84,6 +85,7 @@ impl KvRouterConfig {
router_ttl_secs,
router_max_tree_size,
router_prune_target_ratio,
router_event_threads,
},
}
}
......
......@@ -995,6 +995,7 @@ class KvRouterConfig:
router_ttl_secs: float = 120.0,
router_max_tree_size: int = 1048576,
router_prune_target_ratio: float = 0.8,
router_event_threads: int = 1,
) -> None:
"""
Create a KV router configuration.
......@@ -1018,6 +1019,8 @@ class KvRouterConfig:
router_ttl_secs: TTL for blocks in seconds when not using KV events (default: 120.0)
router_max_tree_size: Maximum tree size before pruning (default: 1048576, which is 2^20)
router_prune_target_ratio: Target size ratio after pruning (default: 0.8)
router_event_threads: Number of event processing threads (default: 1).
When > 1, uses a concurrent radix tree with a thread pool.
"""
...
......
......@@ -13,7 +13,7 @@ repository.workspace = true
[features]
default = []
metrics = ["dep:dynamo-runtime"]
bench = ["dep:clap", "dep:indicatif"]
bench = ["dep:clap", "dep:indicatif", "dep:serde_json", "dynamo-runtime/integration", "dep:uuid"]
[dependencies]
# repo
......@@ -23,24 +23,35 @@ dynamo-tokens = { workspace = true }
# workspace
anyhow = { workspace = true }
async-trait = { workspace = true }
dashmap = { workspace = true }
prometheus = { workspace = true }
rand = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true, optional = true }
thiserror = { workspace = true }
tokio = { workspace = true }
tokio-util = { workspace = true }
tracing = { workspace = true }
xxhash-rust = { workspace = true }
# dependencies
flume = "0.12.0"
parking_lot = { workspace = true }
# bench (optional)
clap = { version = "4.5", features = ["derive"], optional = true }
indicatif = { version = "0.18.0", optional = true }
uuid = { workspace = true, optional = true }
[dev-dependencies]
rstest = "0.18.2"
rstest_reuse = "0.7.0"
serde_json = { workspace = true }
tokio = { workspace = true, features = ["rt", "macros", "time"] }
dynamo-mocker = { workspace = true }
minstant = "0.1.7"
futures = "0.3"
[[bench]]
name = "radix_tree_microbench"
......@@ -51,3 +62,8 @@ required-features = ["bench"]
name = "kv_indexer_bench"
harness = false
required-features = ["bench"]
[[bench]]
name = "mooncake_bench"
harness = false
required-features = ["bench"]
# KV Router Index Data Structures
This document explains the KV cache index implementations: `RadixTree`, `ConcurrentRadixTree`, and `PositionalIndexer` (NestedMap).
## Motivation: The Four Block Identifiers
Every cached KV block in a distributed LLM system needs four pieces of information:
### 1. Local Block Hash (`LocalBlockHash`, u64)
**What**: Hash of the tokens *within* a single block (e.g., 64 tokens).
**Why**: Identifies the content of this specific block, independent of context. Two blocks with the same tokens have the same local hash.
```
Block at position 5: tokens [101, 102, 103, ...]
LocalBlockHash = hash(tokens) = 0xABCD1234
```
### 2. External Sequence Block Hash (`ExternalSequenceBlockHash`, u64)
**What**: Cumulative hash of the entire sequence up to and including this block.
**Why**: Uniquely identifies a block's position in a *specific* sequence history. Two blocks with the same local content but different prefixes have different sequence hashes.
```
Sequence A: [block0, block1, block2]
Sequence B: [block0', block1', block2] // block2 has same content but different prefix
block2 in A: seq_hash = hash(hash(hash(block0) + block1) + block2) = 0x1111
block2 in B: seq_hash = hash(hash(hash(block0') + block1') + block2) = 0x2222
```
**Computation**: `seq_hash[i] = hash(seq_hash[i-1] || local_hash[i])` where `seq_hash[0] = local_hash[0]`
> **Important: Engine-Provided Hashes**
>
> In practice, the `ExternalSequenceBlockHash` may come directly from the inference engine (e.g., vLLM, TensorRT-LLM) using a rolling hash algorithm that we don't know or control. The engine computes these hashes internally and reports them via KV cache events.
>
> **Implications for index implementations:**
>
> - **RadixTree**: Can handle engine-provided hashes because it traverses the tree structure using `LocalBlockHash` for navigation and only uses `ExternalSequenceBlockHash` as an opaque identifier for lookups. It doesn't need to recompute hashes.
>
> - **NestedMap**: Requires the ability to compute `ExternalSequenceBlockHash` incrementally for its lazy hash optimization in `find_matches`. To use NestedMap, one of the following is required:
> 1. **Force a known hasher**: Configure the engine to use a specific hashing algorithm that the router can replicate, OR
> 2. **Recompute on the relay**: Have the publisher/relay layer recompute the rolling hash using a known algorithm before forwarding events to the router.
>
> Without this, NestedMap's `find_matches` will fail when encountering `SeqEntry::Multi` cases (multiple seq_hashes at the same position+local_hash) because it cannot disambiguate which entry to use.
### 3. Worker ID (`WorkerWithDpRank`)
**What**: Identifies which worker (inference server) has this block cached.
**Why**: The router needs to know which workers can serve a request based on their cached blocks.
### 4. Position (`u64`)
**What**: The block's index in the sequence (0, 1, 2, ...).
**Why**: Enables efficient prefix matching. Position 0 is the first block, position N-1 is the last.
---
## The Core Operations
Both data structures support three operations:
| Operation | Description | Hot Path? |
|-----------|-------------|-----------|
| `store_blocks` | Add blocks for a worker | No (background) |
| `remove_blocks` | Remove blocks for a worker | No (background) |
| `find_matches` | Find workers with matching prefix | **Yes** (per-request) |
The key insight: **reads (find_matches) are far more frequent than writes (store/remove)**. This motivates different structural tradeoffs.
---
## RadixTree: Tree-Based Index
### Structure
```
RadixTree
├── root: SharedRadixBlock (Rc<RefCell<RadixBlock>>)
└── lookup: HashMap<Worker, HashMap<SeqHash, SharedRadixBlock>>
RadixBlock
├── children: HashMap<LocalBlockHash, SharedRadixBlock>
├── workers: HashSet<Worker>
├── block_hash: Option<SeqHash>
└── recent_uses: VecDeque<Instant>
```
### Visual Representation
```
[root]
/ \
local=0xA local=0xB
↓ ↓
[block0] [block0']
workers: workers:
{W0,W1} {W2}
|
local=0xC
[block1]
workers:
{W0,W1}
|
local=0xD
[block2]
workers:
{W0} ← W1 diverged here
```
### How Operations Work
**store_blocks(worker, parent_hash, blocks)**:
1. Find parent via `lookup[worker][parent_hash]`
2. For each block, traverse/create child nodes using `local_hash`
3. Add worker to each node's `workers` set
4. Update `lookup[worker][seq_hash] = node`
**remove_blocks(worker, block_hashes)**:
1. For each hash, find node via `lookup[worker][hash]`
2. Remove worker from node's `workers` set
3. If `workers` empty, clear children (cascading cleanup)
4. Remove from `lookup[worker]`
**find_matches(local_hashes, early_exit)**:
1. Start at root with all workers as candidates
2. For each position, traverse to child matching `local_hash`
3. Intersect candidates with node's `workers`
4. Track depth where each worker drops out
5. Return `{worker -> depth}` scores
### Complexity
| Operation | Time | Space |
|-----------|------|-------|
| store_blocks (N blocks) | O(N) | O(N) nodes |
| remove_blocks (N blocks) | O(N) | - |
| find_matches (depth D) | O(D × W) | O(W) |
Where W = number of workers.
---
## PositionalIndexer (NestedMap): Position-First HashMap Index
### Structure
```
PositionalIndexer
├── index: DashMap<(Position, LocalHash), SeqEntry>
├── worker_blocks: DashMap<Worker, RwLock<HashMap<SeqHash, (Position, LocalHash)>>>
└── jump_size: usize
SeqEntry (enum for memory optimization)
├── Single(SeqHash, HashSet<Worker>) // Common case: one seq_hash
└── Multi(HashMap<SeqHash, HashSet<Worker>>) // Rare: multiple prefixes
```
`PositionalIndexer` implements `SyncIndexer` and is thread-safe via `DashMap` (sharded
concurrent map) and `RwLock`. It is designed to be wrapped in a `ThreadPoolIndexer` which
routes write events to dedicated OS threads and executes reads inline.
The `index` uses a flat compound key `(position, local_hash)` in a `DashMap`, which
distributes lock contention across shards while enabling O(1) random-position access for
the jump optimization. The `worker_blocks` reverse lookup uses `DashMap` for the outer
per-worker map and `RwLock<HashMap>` for each worker's block set, since writes to a
given worker are serialized by sticky routing in `ThreadPoolIndexer`.
### Visual Representation
```
index (DashMap with compound keys):
┌──────────────────────┬──────────────────────────────────────┐
│ (pos=0, local=0xA) │ Single(seq=0x1111, {W0,W1}) │
│ (pos=0, local=0xB) │ Single(seq=0x2222, {W2}) │
│ (pos=1, local=0xC) │ Single(seq=0x3333, {W0,W1}) │
│ (pos=2, local=0xD) │ Multi{ │
│ │ seq=0x4444 → {W0}, │
│ │ seq=0x5555 → {W1} ← diverged │
│ │ } │
└──────────────────────┴──────────────────────────────────────┘
worker_blocks (DashMap<Worker, RwLock<HashMap>>):
┌─────────┬─────────────────────────────────────────────────┐
│ W0 │ seq=0x1111 → (pos=0, local=0xA) │
│ │ seq=0x3333 → (pos=1, local=0xC) │
│ │ seq=0x4444 → (pos=2, local=0xD) │
├─────────┼─────────────────────────────────────────────────┤
│ W1 │ seq=0x1111 → (pos=0, local=0xA) │
│ │ seq=0x3333 → (pos=1, local=0xC) │
│ │ seq=0x5555 → (pos=2, local=0xD) │
└─────────┴─────────────────────────────────────────────────┘
```
### How Operations Work
**store_blocks(worker, parent_hash, blocks)**:
1. Find starting position: `pos = worker_blocks[worker][parent_hash].position + 1`
2. For each block at position `i`:
- Insert into `index[(pos+i, local_hash)]` → add worker to SeqEntry
- Insert into `worker_blocks[worker][seq_hash] = (pos+i, local_hash)`
**remove_blocks(worker, block_hashes)**:
1. For each hash, lookup `(pos, local_hash) = worker_blocks[worker][hash]`
2. Remove worker from `index[(pos, local_hash)]`
3. Remove from `worker_blocks[worker]`
4. Cleanup empty SeqEntry entries from the DashMap
**find_matches(local_hashes, early_exit)** with Jump Optimization:
1. Start at position 0, initialize candidates from first block
2. **Jump**: Skip ahead by `jump_size` positions (e.g., 32)
3. At each jump point, check if candidates still match (count-only, no clone)
4. If workers dropped, **scan back** to find exact drain points
5. Continue until sequence exhausted or one worker remains
```
Query: [b0, b1, b2, ..., b63, b64, ..., b127, ...]
↑ ↑ ↑
pos=0 pos=64 pos=128
│ │ │
└── jump ──────────→└── jump ─────────→│
all match? some dropped?
↓ ↓
continue scan [64,128]
```
**Lazy Hash Optimization**:
- Most (position, local_hash) pairs have only ONE seq_hash (SeqEntry::Single)
- Skip seq_hash computation entirely in this case
- Only compute when disambiguation needed (SeqEntry::Multi)
**dump_events()**:
1. Iterate `worker_blocks`, collecting all blocks per worker
2. Sort each worker's blocks by position (parents before children)
3. Emit one single-block `RouterEvent::Stored` per block, synthesizing
`parent_hash` from any seq_hash at the prior position
4. Events can be replayed into a fresh `PositionalIndexer` to reconstruct
the same index state
### Complexity
| Operation | Time | Space |
|-----------|------|-------|
| store_blocks (N blocks) | O(N) | O(N) entries |
| remove_blocks (N blocks) | O(N) | - |
| find_matches (depth D) | O(D/J + J×W) | O(W) |
Where J = jump_size, W = number of workers. The jump optimization reduces D iterations to D/J jumps plus occasional scans.
---
## Comparison
| Aspect | RadixTree | PositionalIndexer |
|--------|-----------|-------------------|
| **Structure** | Tree with Rc<RefCell<>> nodes | DashMap with compound keys |
| **Concurrent variant** | ConcurrentRadixTree (Arc<RwLock<>> + DashMap) | Thread-safe by default (DashMap + RwLock) |
| **find_matches** | O(D×W) tree traversal | O(D/J) with jump optimization |
| **store_blocks** | O(N) node creation | O(N) DashMap inserts |
| **remove_blocks** | O(N) with cascading cleanup | O(N) with entry cleanup |
| **dump_events** | BFS traversal of tree | Sort by position per worker |
| **Memory** | Higher (Rc/Arc overhead per node) | Lower (flat entries) |
| **Cache locality** | Poor (pointer chasing) | Better (position-first) |
### Benchmark Results (1M blocks, depth 1024, 128 workers)
| Operation | RadixTree | NestedMap | Winner |
|-----------|-----------|-----------|--------|
| STORE_BLOCK | 90µs | 98µs | RadixTree (1.1x) |
| REMOVE_BLOCK | 91µs | 233µs | RadixTree (2.5x) |
| FIND_MATCHES (HIT) | 227µs | **44µs** | **NestedMap (5.2x)** |
| FIND_MATCHES (PARTIAL) | 216µs | **44µs** | **NestedMap (4.9x)** |
**Recommendation**: Use NestedMap for read-heavy workloads (typical router usage).
---
## Why Position Matters for PositionalIndexer
The compound key `(position, local_hash)` in the DashMap enables the jump optimization:
```rust
// Without position-first: must traverse entire tree
for pos in 0..depth {
node = node.children[local_hashes[pos]]; // O(depth) traversals
}
// With position-first: can jump directly to any position
let workers_at_64 = index.get(&(64, local_hashes[64])); // O(1) lookup
let workers_at_128 = index.get(&(128, local_hashes[128])); // O(1) lookup
// Skip positions 1-63, 65-127 entirely!
```
---
## SeqEntry Optimization
The innermost level uses an enum to avoid HashMap allocation in the common case:
```rust
enum SeqEntry {
// Common: one prefix leads to this (position, local_hash)
Single(SeqHash, HashSet<Worker>),
// Rare: different prefixes converge on same (position, local_hash)
Multi(HashMap<SeqHash, HashSet<Worker>>),
}
```
**When does Multi occur?**
Only when two different sequences have:
1. Same local block content at position P
2. Different prefix histories (different seq_hash)
Example:
```
Sequence A: [tok1, tok2, tok3] → positions 0,1,2
Sequence B: [tok4, tok5, tok3] → positions 0,1,2
^^^^
Same local content at pos=2
but different seq_hash!
```
This is rare in practice, so `Single` saves ~48 bytes per entry.
This diff is collapsed.
This diff is collapsed.
......@@ -15,30 +15,31 @@
use clap::{Parser, ValueEnum};
use dynamo_kv_router::{
RadixTree, RouterEvent,
ConcurrentRadixTree, OverlapScores, PositionalIndexer, RadixTree, RouterEvent, SyncIndexer,
bench_utils::{LatencyStats, SequenceData, generate_sequences},
compute_block_hash_for_seq,
flat_hashmap::FlatHashMap,
protocols::LocalBlockHash,
};
use std::time::{Duration, Instant};
/// Unified interface for RadixTree and FlatHashMap benchmarking.
/// Unified interface for RadixTree, ConcurrentRadixTree, and PositionalIndexer benchmarking.
///
/// Both structures have feature parity for store, remove, find_matches, and current_size.
/// All structures have feature parity for store, remove, find_matches, and current_size.
/// The key difference is find_matches input:
/// - RadixTree: uses LocalBlockHash (tokens_hash)
/// - FlatHashMap: uses ExternalSequenceBlockHash (cumulative sequence hash)
/// - RadixTree/ConcurrentRadixTree: uses LocalBlockHash (tokens_hash)
/// - PositionalIndexer: uses LocalBlockHash (same as tree; internal mapping uses sequence hashes)
enum KvIndex {
Tree(RadixTree),
Flat(FlatHashMap),
Concurrent(ConcurrentRadixTree),
Nested(PositionalIndexer),
}
impl KvIndex {
fn name(&self) -> &'static str {
match self {
KvIndex::Tree(_) => "RadixTree",
KvIndex::Flat(_) => "FlatHashMap",
KvIndex::Concurrent(_) => "ConcurrentRadixTree",
KvIndex::Nested(_) => "PositionalIndexer",
}
}
......@@ -47,8 +48,11 @@ impl KvIndex {
KvIndex::Tree(tree) => {
let _ = tree.apply_event(event);
}
KvIndex::Flat(map) => {
map.apply_event(event);
KvIndex::Concurrent(tree) => {
let _ = tree.apply_event(event);
}
KvIndex::Nested(map) => {
let _ = map.apply_event(event).ok();
}
}
}
......@@ -58,7 +62,8 @@ impl KvIndex {
let start = Instant::now();
let _ = match self {
KvIndex::Tree(tree) => tree.find_matches(local_hashes, early_exit),
KvIndex::Flat(map) => map.find_matches(local_hashes, early_exit),
KvIndex::Concurrent(tree) => tree.find_matches_impl(&local_hashes, early_exit),
KvIndex::Nested(map) => map.find_matches(&local_hashes, early_exit),
};
start.elapsed()
}
......@@ -70,7 +75,8 @@ impl KvIndex {
let start = Instant::now();
let _ = match self {
KvIndex::Tree(tree) => tree.find_matches(miss_hashes, early_exit),
KvIndex::Flat(map) => map.find_matches(miss_hashes, early_exit),
KvIndex::Concurrent(tree) => tree.find_matches_impl(&miss_hashes, early_exit),
KvIndex::Nested(map) => map.find_matches(&miss_hashes, early_exit),
};
start.elapsed()
}
......@@ -89,7 +95,8 @@ impl KvIndex {
let start = Instant::now();
let _ = match self {
KvIndex::Tree(tree) => tree.find_matches(partial, early_exit),
KvIndex::Flat(map) => map.find_matches(partial, early_exit),
KvIndex::Concurrent(tree) => tree.find_matches_impl(&partial, early_exit),
KvIndex::Nested(map) => map.find_matches(&partial, early_exit),
};
start.elapsed()
}
......@@ -97,7 +104,27 @@ impl KvIndex {
fn current_size(&self) -> usize {
match self {
KvIndex::Tree(tree) => tree.current_size(),
KvIndex::Flat(map) => map.current_size(),
KvIndex::Concurrent(tree) => tree.current_size(),
KvIndex::Nested(map) => map.current_size(),
}
}
fn find_matches(&self, local_hashes: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores {
match self {
KvIndex::Tree(tree) => tree.find_matches(local_hashes, early_exit),
KvIndex::Concurrent(tree) => tree.find_matches_impl(&local_hashes, early_exit),
KvIndex::Nested(map) => map.find_matches(&local_hashes, early_exit),
}
}
fn dump_tree_as_events(&self) -> Vec<RouterEvent> {
match self {
KvIndex::Tree(tree) => tree.dump_tree_as_events(),
KvIndex::Concurrent(tree) => tree.dump_tree_as_events(),
KvIndex::Nested(_) => {
// NestedMap does not support dump_tree_as_events
vec![]
}
}
}
}
......@@ -197,44 +224,22 @@ struct Args {
#[arg(long, default_value = "42")]
seed: u64,
/// Use flat HashMap baseline instead of radix tree (for comparison)
/// Use nested map instead of radix tree (for comparison)
#[arg(long)]
flat_hashmap: bool,
}
/// Build a pre-populated RadixTree (for sweep/dump benchmarks that specifically need RadixTree)
fn build_tree(sequences: &[SequenceData]) -> RadixTree {
let num_blocks: usize = sequences.iter().map(|s| s.local_hashes.len()).sum();
print!(
" Building tree with {} sequences ({} blocks)... ",
sequences.len(),
num_blocks
);
std::io::Write::flush(&mut std::io::stdout()).unwrap();
nested_map: bool,
let start = Instant::now();
let mut tree = RadixTree::new();
for (event_id, seq) in sequences.iter().enumerate() {
let event = seq.to_store_event(event_id as u64);
let _ = tree.apply_event(event);
}
let elapsed = start.elapsed();
println!(
"done in {:.2?} ({:.2} sequences/sec, {:.2} blocks/sec)",
elapsed,
sequences.len() as f64 / elapsed.as_secs_f64(),
num_blocks as f64 / elapsed.as_secs_f64()
);
tree
/// Use concurrent radix tree instead of single-threaded radix tree
#[arg(long)]
concurrent: bool,
}
/// Build a pre-populated KvIndex (prints timing info)
fn build_index(sequences: &[SequenceData], use_flat_hashmap: bool) -> KvIndex {
fn build_index(sequences: &[SequenceData], use_nested_map: bool, use_concurrent: bool) -> KvIndex {
let num_blocks: usize = sequences.iter().map(|s| s.local_hashes.len()).sum();
let name = if use_flat_hashmap {
"FlatHashMap"
let name = if use_nested_map {
"NestedMap"
} else if use_concurrent {
"ConcurrentRadixTree"
} else {
"RadixTree"
};
......@@ -247,8 +252,10 @@ fn build_index(sequences: &[SequenceData], use_flat_hashmap: bool) -> KvIndex {
std::io::Write::flush(&mut std::io::stdout()).unwrap();
let start = Instant::now();
let mut index = if use_flat_hashmap {
KvIndex::Flat(FlatHashMap::new())
let mut index = if use_nested_map {
KvIndex::Nested(PositionalIndexer::new(32))
} else if use_concurrent {
KvIndex::Concurrent(ConcurrentRadixTree::new())
} else {
KvIndex::Tree(RadixTree::new())
};
......@@ -329,10 +336,10 @@ fn bench_store_remove_cycle(args: &Args, time_store: bool) {
args.prefix_prompt_ratio,
args.num_prefix_prompts,
args.seed,
true, // use_cumulative_hash
true,
);
let mut index = build_index(&sequences, args.flat_hashmap);
let mut index = build_index(&sequences, args.nested_map, args.concurrent);
println!("\n=== Benchmarking {} ({}) ===", op_name, index.name());
println!(" Size: {} blocks", index.current_size());
......@@ -391,10 +398,10 @@ fn bench_find_matches(args: &Args) {
args.prefix_prompt_ratio,
args.num_prefix_prompts,
args.seed,
true, // use_cumulative_hash
true,
);
let index = build_index(&sequences, args.flat_hashmap);
let index = build_index(&sequences, args.nested_map, args.concurrent);
println!("\n=== Benchmarking FIND_MATCHES ({}) ===", index.name());
println!(
" Built with {} sequences, {} total blocks",
......@@ -696,12 +703,12 @@ fn bench_sweep(args: &Args) {
args.prefix_prompt_ratio,
num_prefix_prompts,
args.seed,
true, // use_cumulative_hash
true,
);
let tree_sequences = &all_sequences[..num_sequences];
let extra_sequences = &all_sequences[num_sequences..];
let mut tree = build_tree(tree_sequences);
let mut index = build_index(tree_sequences, args.nested_map, args.concurrent);
// --- STORE benchmark ---
let mut store_durations = Vec::with_capacity(args.sweep_iterations);
......@@ -718,12 +725,12 @@ fn bench_sweep(args: &Args) {
let store_event = truncated.to_store_event(i as u64);
let start = Instant::now();
let _ = tree.apply_event(store_event);
index.apply_event(store_event);
store_durations.push(start.elapsed());
// Remove to restore tree state (untimed)
// Remove to restore index state (untimed)
let remove_event = truncated.to_remove_event(i as u64);
let _ = tree.apply_event(remove_event);
index.apply_event(remove_event);
}
// --- REMOVE benchmark ---
......@@ -738,12 +745,12 @@ fn bench_sweep(args: &Args) {
let remove_event = truncated.to_remove_event(i as u64);
let start = Instant::now();
let _ = tree.apply_event(remove_event);
index.apply_event(remove_event);
remove_durations.push(start.elapsed());
// Re-add to restore state (untimed)
let store_event = truncated.to_store_event(i as u64 + 1000000);
let _ = tree.apply_event(store_event);
index.apply_event(store_event);
}
// --- FIND_MATCHES benchmark ---
......@@ -768,7 +775,7 @@ fn bench_sweep(args: &Args) {
};
let start = Instant::now();
let _ = tree.find_matches(query, false);
let _ = index.find_matches(query, false);
find_matches_durations.push(start.elapsed());
}
......@@ -808,19 +815,20 @@ fn bench_dump(args: &Args) {
args.prefix_prompt_ratio,
args.num_prefix_prompts,
args.seed,
true, // use_cumulative_hash
true,
);
let tree = build_tree(&sequences);
let index = build_index(&sequences, args.nested_map, args.concurrent);
println!(
" Tree built with {} sequences, {} total blocks",
" {} built with {} sequences, {} total blocks",
index.name(),
sequences.len(),
tree.current_size()
index.current_size()
);
// Single iteration timing
let start = Instant::now();
let events = tree.dump_tree_as_events();
let events = index.dump_tree_as_events();
let elapsed = start.elapsed();
println!("\nDUMP_TREE_AS_EVENTS Results:");
......
......@@ -391,7 +391,7 @@ mod tests {
max_tree_size: usize::MAX,
prune_target_ratio: 0.5,
};
let mut indexer = KvIndexer::new_with_frequency(
let indexer = KvIndexer::new_with_frequency(
cancel.clone(),
None,
KV_BLOCK_SIZE,
......@@ -443,7 +443,7 @@ mod tests {
max_tree_size: usize::MAX,
prune_target_ratio: 0.5,
};
let mut indexer = KvIndexer::new_with_frequency(
let indexer = KvIndexer::new_with_frequency(
cancel.clone(),
None,
KV_BLOCK_SIZE,
......
......@@ -188,11 +188,11 @@ pub fn generate_sequences(
let local_hashes: Vec<LocalBlockHash> = (0..depth)
.map(|block_idx| {
let block_idx_u64 = block_idx as u64;
if let Some(gid) = group_id {
if block_idx < prefix_length {
// Shared prefix based on group_id
return LocalBlockHash(0xDEAD_BEEF_0000_0000 | (gid << 32) | block_idx_u64);
}
if let Some(gid) = group_id
&& block_idx < prefix_length
{
// Shared prefix based on group_id
return LocalBlockHash(0xDEAD_BEEF_0000_0000 | (gid << 32) | block_idx_u64);
}
// Unique suffix (or no shared prefix)
LocalBlockHash((seq_id_u64 << 32) | block_idx_u64)
......@@ -205,12 +205,12 @@ pub fn generate_sequences(
let external_hashes: Vec<ExternalSequenceBlockHash> = (0..depth)
.map(|block_idx| {
let block_idx_u64 = block_idx as u64;
if let Some(gid) = group_id {
if block_idx < prefix_length {
return ExternalSequenceBlockHash(
0xDEAD_BEEF_0000_0000 | (gid << 32) | block_idx_u64,
);
}
if let Some(gid) = group_id
&& block_idx < prefix_length
{
return ExternalSequenceBlockHash(
0xDEAD_BEEF_0000_0000 | (gid << 32) | block_idx_u64,
);
}
ExternalSequenceBlockHash((seq_id_u64 << 32) | block_idx_u64)
})
......
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Flat HashMap baseline for benchmarking comparison with RadixTree.
//!
//! This module provides a `FlatHashMap` structure that has full feature parity with `RadixTree`
//! but uses flat HashMaps instead of a tree structure. This isolates the overhead of
//! tree traversal (pointer chasing) from pure HashMap operations.
//!
//! The `find_matches` API matches RadixTree exactly: it takes `LocalBlockHash` values
//! and internally computes the cumulative sequence hashes for lookup.
use std::collections::{HashMap, HashSet};
use crate::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash, OverlapScores, RouterEvent, WorkerId, WorkerWithDpRank,
compute_seq_hash_for_block,
};
/// A flat HashMap-based structure for KV cache indexing.
///
/// Unlike RadixTree which uses a tree of nodes connected by pointers,
/// FlatHashMap uses bidirectional HashMaps. This provides the same
/// find_matches semantics but with better cache locality.
///
/// # Structure
///
/// - `block_to_workers`: Maps ExternalSequenceBlockHash -> Set of workers that have this block.
/// Used for efficient find_matches lookups.
/// - `worker_to_blocks`: Maps Worker -> Set of ExternalSequenceBlockHash they have.
/// Used for remove operations and current_size.
pub struct FlatHashMap {
/// Primary index: block -> workers (for find_matches)
block_to_workers: HashMap<ExternalSequenceBlockHash, HashSet<WorkerWithDpRank>>,
/// Secondary index: worker -> blocks (for remove and current_size)
worker_to_blocks: HashMap<WorkerWithDpRank, HashSet<ExternalSequenceBlockHash>>,
}
impl FlatHashMap {
/// Create a new empty FlatHashMap.
pub fn new() -> Self {
Self {
block_to_workers: HashMap::new(),
worker_to_blocks: HashMap::new(),
}
}
/// Store blocks for a worker.
///
/// Updates both indexes for each block.
pub fn store(&mut self, worker: WorkerWithDpRank, block_hashes: &[ExternalSequenceBlockHash]) {
let worker_blocks = self.worker_to_blocks.entry(worker).or_default();
for &block_hash in block_hashes {
// Add to block -> workers index
self.block_to_workers
.entry(block_hash)
.or_default()
.insert(worker);
// Add to worker -> blocks index
worker_blocks.insert(block_hash);
}
}
/// Remove blocks for a worker.
///
/// Updates both indexes for each block.
pub fn remove(&mut self, worker: WorkerWithDpRank, block_hashes: &[ExternalSequenceBlockHash]) {
let Some(worker_blocks) = self.worker_to_blocks.get_mut(&worker) else {
return;
};
for &block_hash in block_hashes {
// Remove from worker -> blocks index
worker_blocks.remove(&block_hash);
// Remove from block -> workers index
if let Some(workers) = self.block_to_workers.get_mut(&block_hash) {
workers.remove(&worker);
if workers.is_empty() {
self.block_to_workers.remove(&block_hash);
}
}
}
// Clean up empty worker entry
if worker_blocks.is_empty() {
self.worker_to_blocks.remove(&worker);
}
}
/// Find matches for a sequence of local block hashes.
///
/// This has the same signature as `RadixTree::find_matches`: it takes `LocalBlockHash`
/// values and internally computes the cumulative sequence hashes for lookup.
///
/// Returns OverlapScores showing which workers have matching blocks.
/// Stops at first non-match (same semantics as RadixTree).
///
/// # Algorithm
///
/// 1. Compute cumulative sequence hashes from local block hashes
/// 2. For each sequence hash:
/// - Look up which workers have this block
/// - Intersect with previously matching workers (in place)
/// - Track depth for scoring
/// - Stop if no workers remain
///
/// This is O(depth) HashMap lookups + O(num_workers) set operations per level.
pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores {
let mut scores = OverlapScores::new();
if sequence.is_empty() {
return scores;
}
// Compute cumulative sequence hashes from local block hashes
let seq_hashes = compute_seq_hash_for_block(&sequence);
// Track active workers and their match depth
// Workers drop out when they miss a block; their final score is the depth they reached
let mut active_workers: Option<HashSet<WorkerWithDpRank>> = None;
let mut depth = 0u32;
for seq_hash in seq_hashes {
let block_hash = ExternalSequenceBlockHash(seq_hash);
// Look up workers that have this block
let Some(workers) = self.block_to_workers.get(&block_hash) else {
break; // No workers have this block, stop
};
// Intersect with previously active workers (or initialize on first block)
match &mut active_workers {
None => {
// First block: initialize with workers that have it
active_workers = Some(workers.clone());
}
Some(active) => {
// Record score for workers about to drop out (they matched up to current depth)
for &worker in active.iter() {
if !workers.contains(&worker) {
scores.scores.insert(worker, depth);
}
}
// Keep only workers that have this block (in-place, no allocation)
active.retain(|w| workers.contains(w));
}
}
depth += 1;
let active = active_workers.as_ref().unwrap();
if active.is_empty() {
break;
}
// Early exit if only one worker matches
if early_exit && active.len() == 1 {
break;
}
}
// Record final scores for workers that matched all blocks (or until early exit)
if let Some(active) = active_workers {
for worker in active {
scores.scores.insert(worker, depth);
}
}
// Populate tree sizes for workers with scores
for &worker in scores.scores.keys() {
if let Some(blocks) = self.worker_to_blocks.get(&worker) {
scores.tree_sizes.insert(worker, blocks.len());
}
}
scores
}
/// Apply a RouterEvent (for API compatibility with RadixTree).
pub fn apply_event(&mut self, event: RouterEvent) {
let worker = WorkerWithDpRank::new(event.worker_id, event.event.dp_rank);
match event.event.data {
KvCacheEventData::Stored(store_data) => {
let hashes: Vec<_> = store_data.blocks.iter().map(|b| b.block_hash).collect();
self.store(worker, &hashes);
}
KvCacheEventData::Removed(remove_data) => {
self.remove(worker, &remove_data.block_hashes);
}
KvCacheEventData::Cleared => {
self.clear_all_blocks(worker.worker_id);
}
}
}
/// Helper function to remove or clear blocks for a worker.
/// If `keep_worker` is true, the worker remains in lookup with empty blocks.
/// If `keep_worker` is false, the worker is completely removed from lookup.
fn remove_or_clear_worker_blocks(&mut self, worker_id: WorkerId, keep_worker: bool) {
// Collect all WorkerWithDpRank keys that match this worker_id
let workers: Vec<WorkerWithDpRank> = self
.worker_to_blocks
.keys()
.filter(|w| w.worker_id == worker_id)
.copied()
.collect();
for worker in workers {
if let Some(blocks) = self.worker_to_blocks.remove(&worker) {
for block_hash in blocks {
if let Some(workers_set) = self.block_to_workers.get_mut(&block_hash) {
workers_set.remove(&worker);
if workers_set.is_empty() {
self.block_to_workers.remove(&block_hash);
}
}
}
if keep_worker {
// Re-insert worker with empty blocks set to keep it tracked
self.worker_to_blocks.insert(worker, HashSet::new());
}
}
}
}
/// Remove a worker and all their blocks from the index.
pub fn remove_worker(&mut self, worker_id: WorkerId) {
self.remove_or_clear_worker_blocks(worker_id, false);
}
/// Clear all blocks for a worker but keep the worker tracked.
pub fn clear_all_blocks(&mut self, worker_id: WorkerId) {
self.remove_or_clear_worker_blocks(worker_id, true);
}
/// Get all worker IDs currently tracked in the index.
/// Returns unique worker_ids sorted (ignoring dp_rank differences).
pub fn get_workers(&self) -> Vec<WorkerId> {
let mut worker_ids: Vec<WorkerId> = self
.worker_to_blocks
.keys()
.map(|w| w.worker_id)
.collect::<HashSet<_>>()
.into_iter()
.collect();
worker_ids.sort_unstable();
worker_ids
}
/// Dump the index as a series of RouterEvents that can reconstruct the state.
/// For API compatibility with RadixTree.
pub fn dump_tree_as_events(&self) -> Vec<RouterEvent> {
let mut events = Vec::new();
let mut event_id = 0u64;
for (&worker, blocks) in &self.worker_to_blocks {
for &block_hash in blocks {
let event = RouterEvent {
worker_id: worker.worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None, // FlatHashMap doesn't track parent relationships
blocks: vec![KvCacheStoredBlockData {
block_hash,
mm_extra_info: None,
// We don't have the original tokens_hash, use a placeholder
tokens_hash: LocalBlockHash(0),
}],
}),
dp_rank: worker.dp_rank,
},
};
events.push(event);
event_id += 1;
}
}
events
}
/// Returns the total number of (worker, block) pairs stored.
pub fn current_size(&self) -> usize {
self.worker_to_blocks.values().map(|s| s.len()).sum()
}
}
impl Default for FlatHashMap {
fn default() -> Self {
Self::new()
}
}
This diff is collapsed.
......@@ -9,14 +9,19 @@
pub mod approx;
#[cfg(feature = "bench")]
pub mod bench_utils;
pub mod flat_hashmap;
pub mod concurrent_radix_tree;
pub mod indexer;
pub mod nested_map;
pub mod protocols;
pub mod radix_tree;
#[cfg(test)]
pub(crate) mod test_utils;
// Re-export key types for convenience
pub use flat_hashmap::FlatHashMap;
pub use indexer::MaybeError;
pub use concurrent_radix_tree::ConcurrentRadixTree;
pub use indexer::{MaybeError, SyncIndexer, ThreadPoolIndexer};
pub use nested_map::PositionalIndexer;
pub use protocols::{
KvCacheEventError, LocalBlockHash, OverlapScores, RouterEvent, WorkerId,
compute_block_hash_for_seq,
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment